diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs index b70264e363b..5c8a7f28b0e 100644 --- a/libs/host/Configuration/Options.cs +++ b/libs/host/Configuration/Options.cs @@ -680,6 +680,10 @@ public IEnumerable LuaAllowedFunctions [Option("vector-set-replay-task-count", Required = false, HelpText = "Configure how many replay tasks are used to replay VectorSet operations at the replica (default: 0 uses the machine CPU count)")] public int VectorSetReplayTaskCount { get; set; } + [IntRangeValidation(0, int.MaxValue, isRequired: false)] + [Option("vector-set-quantization-task-count", Required = false, HelpText = "Configure how many quantization tasks are used to optimize Vector Set operations (default: 0 uses the machine CPU count)")] + public int VectorSetQuantizationTaskCount { get; set; } + /// /// This property contains all arguments that were not parsed by the command line argument parser /// @@ -980,7 +984,8 @@ endpoint is IPEndPoint listenEp && clusterAnnounceEndpoint[0] is IPEndPoint anno ClusterReplicationReestablishmentTimeout = ClusterReplicationReestablishmentTimeout, ClusterReplicaResumeWithData = ClusterReplicaResumeWithData, EnableVectorSetPreview = EnableVectorSetPreview, - VectorSetReplayTaskCount = VectorSetReplayTaskCount + VectorSetReplayTaskCount = VectorSetReplayTaskCount, + VectorSetQuantizationTaskCount = VectorSetQuantizationTaskCount, }; } diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf index b331c3a3b4d..0c6ec8a2cd8 100644 --- a/libs/host/defaults.conf +++ b/libs/host/defaults.conf @@ -459,5 +459,8 @@ "EnableVectorSetPreview": false, /* Configure how many replay tasks are used to replay VectorSet operations at the replica (default: 0 uses the machine CPU count) */ - "VectorSetReplayTaskCount": 0 + "VectorSetReplayTaskCount": 0, + + /* Configure how many quantization tasks are used to optimize Vector Set operations (default: 0 uses the machine CPU count) */ + "VectorSetQuantizationTaskCount": 0 } \ No newline at end of file diff --git a/libs/server/API/GarnetApi.cs b/libs/server/API/GarnetApi.cs index 435212cc982..7b682bf95d0 100644 --- a/libs/server/API/GarnetApi.cs +++ b/libs/server/API/GarnetApi.cs @@ -520,8 +520,8 @@ public unsafe GarnetStatus VectorSetRemove(ArgSlice key, ArgSlice element) => storageSession.VectorSetRemove(SpanByte.FromPinnedPointer(key.ptr, key.length), SpanByte.FromPinnedPointer(element.ptr, element.length)); /// - public unsafe GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) - => storageSession.VectorSetValueSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), valueType, values, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result, ref filterBitmap); + public unsafe GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, out ReadOnlySpan errorMessage, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) + => storageSession.VectorSetValueSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), valueType, values, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, out errorMessage, ref outputDistances, ref outputAttributes, out result, ref filterBitmap); /// public unsafe GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) diff --git a/libs/server/API/GarnetWatchApi.cs b/libs/server/API/GarnetWatchApi.cs index d60d8546f65..6f279dae391 100644 --- a/libs/server/API/GarnetWatchApi.cs +++ b/libs/server/API/GarnetWatchApi.cs @@ -650,10 +650,10 @@ public bool ResetScratchBuffer(int offset) #region Vector Sets /// - public GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) + public GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, out ReadOnlySpan errorMessage, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) { garnetApi.WATCH(key, StoreType.Main); - return garnetApi.VectorSetValueSimilarity(key, valueType, value, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result, ref filterBitmap); + return garnetApi.VectorSetValueSimilarity(key, valueType, value, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, out errorMessage, ref outputDistances, ref outputAttributes, out result, ref filterBitmap); } /// diff --git a/libs/server/API/IGarnetApi.cs b/libs/server/API/IGarnetApi.cs index ac4e370a762..10c76928f4e 100644 --- a/libs/server/API/IGarnetApi.cs +++ b/libs/server/API/IGarnetApi.cs @@ -2041,7 +2041,7 @@ public bool IterateObjectStore(ref TScanFunctions scanFunctions, /// Ids are encoded in as length prefixed blobs of bytes. /// Attributes are encoded in as length prefixed blobs of bytes. /// - GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap); + GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, out ReadOnlySpan errorMessage, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap); /// /// Perform a similarity search given an element already in the vector set and these parameters. diff --git a/libs/server/Databases/MultiDatabaseManager.cs b/libs/server/Databases/MultiDatabaseManager.cs index 20ad35a523c..5efee6da1ef 100644 --- a/libs/server/Databases/MultiDatabaseManager.cs +++ b/libs/server/Databases/MultiDatabaseManager.cs @@ -1088,6 +1088,7 @@ public override void RecoverVectorSets() for (var i = 0; i < activeDbIdsMapSize; i++) { var dbId = activeDbIdsMapSnapshot[i]; + databasesMapSnapshot[dbId].VectorManager.Initialize(); databasesMapSnapshot[dbId].VectorManager.ResumePostRecovery(); } } diff --git a/libs/server/Databases/SingleDatabaseManager.cs b/libs/server/Databases/SingleDatabaseManager.cs index 755e35f7c83..5e455e1dc39 100644 --- a/libs/server/Databases/SingleDatabaseManager.cs +++ b/libs/server/Databases/SingleDatabaseManager.cs @@ -432,6 +432,9 @@ private void SafeTruncateAOF(AofEntryType entryType, bool unsafeTruncateLog) /// public override void RecoverVectorSets() { + // Guarantee initialize has happened before we attempt to recover + defaultDatabase.VectorManager.Initialize(); + defaultDatabase.VectorManager.ResumePostRecovery(); } diff --git a/libs/server/Resp/Vector/DiskANNService.cs b/libs/server/Resp/Vector/DiskANNService.cs index dd3f81d5413..b2ff8df7693 100644 --- a/libs/server/Resp/Vector/DiskANNService.cs +++ b/libs/server/Resp/Vector/DiskANNService.cs @@ -31,12 +31,20 @@ public nint CreateIndex( delegate* unmanaged[Cdecl] readCallback, delegate* unmanaged[Cdecl] writeCallback, delegate* unmanaged[Cdecl] deleteCallback, - delegate* unmanaged[Cdecl] readModifyWriteCallback + delegate* unmanaged[Cdecl] readModifyWriteCallback, + out bool quantizationRequested ) { + // TODO: This needs to be set appropriately - requires DiskANN changes + quantizationRequested = false; + unsafe { - return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, distanceMetric, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + var ret = NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, distanceMetric, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + + Debug.Assert(ret != 0, "create_index failed, returning a null pointer - this shouldn't be possible"); + + return ret; } } @@ -51,40 +59,45 @@ public nint RecreateIndex( delegate* unmanaged[Cdecl] readCallback, delegate* unmanaged[Cdecl] writeCallback, delegate* unmanaged[Cdecl] deleteCallback, - delegate* unmanaged[Cdecl] readModifyWriteCallback + delegate* unmanaged[Cdecl] readModifyWriteCallback, + out bool quantizationRequested ) - => CreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetricType, readCallback, writeCallback, deleteCallback, readModifyWriteCallback); + => CreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetricType, readCallback, writeCallback, deleteCallback, readModifyWriteCallback, out quantizationRequested); public void DropIndex(ulong context, nint index) { NativeDiskANNMethods.drop_index(context, index); } - public bool Insert(ulong context, nint index, ReadOnlySpan id, VectorValueType vectorType, ReadOnlySpan vector, ReadOnlySpan attributes) + public bool Insert(ulong context, nint index, ReadOnlySpan id, ReadOnlySpan vector, int vectorElementCount, ReadOnlySpan attributes, out bool needsQuantization) { var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); var id_len = id.Length; var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); - int vector_len; - if (vectorType == VectorValueType.FP32) - { - vector_len = vector.Length / sizeof(float); - } - else if (vectorType == VectorValueType.XB8) - { - vector_len = vector.Length; - } - else + var attributes_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(attributes)); + var attributes_len = attributes.Length; + + var res = NativeDiskANNMethods.insert(context, index, (nint)id_data, (nuint)id_len, (nint)vector_data, (nuint)vectorElementCount, (nint)attributes_data, (nuint)attributes_len); + if (res == NativeDiskANNMethods.DiskANNInsertResult.False) { - throw new NotImplementedException($"{vectorType}"); + needsQuantization = false; + return false; } - var attributes_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(attributes)); - var attributes_len = attributes.Length; + needsQuantization = res == NativeDiskANNMethods.DiskANNInsertResult.QuantizationRequested; + return true; + } - return NativeDiskANNMethods.insert(context, index, (nint)id_data, (nuint)id_len, vectorType, (nint)vector_data, (nuint)vector_len, (nint)attributes_data, (nuint)attributes_len) == 1; + public bool BuildQuantizationTable(ulong context, nint index) + { + return NativeDiskANNMethods.build_quant_table(context, index) == 1; + } + + public void BackfillQuantizedVectors(ulong context, nint index, int taskIndex, int taskCount) + { + NativeDiskANNMethods.backfill_quant_vectors(context, index, (nuint)taskIndex, (nuint)taskCount); } public bool Remove(ulong context, nint index, ReadOnlySpan id) @@ -98,8 +111,8 @@ public bool Remove(ulong context, nint index, ReadOnlySpan id) public int SearchVector( ulong context, nint index, - VectorValueType vectorType, ReadOnlySpan vector, + int vectorElementCount, float delta, int searchExplorationFactor, ReadOnlySpan filter, @@ -110,20 +123,6 @@ out nint continuation ) { var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); - int vector_len; - - if (vectorType == VectorValueType.FP32) - { - vector_len = vector.Length / sizeof(float); - } - else if (vectorType == VectorValueType.XB8) - { - vector_len = vector.Length; - } - else - { - throw new NotImplementedException($"{vectorType}"); - } var filter_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); var filter_len = filter.Length; @@ -174,9 +173,8 @@ out nint continuation return NativeDiskANNMethods.search_vector( context, index, - vectorType, (nint)vector_data, - (nuint)vector_len, + (nuint)vectorElementCount, delta, searchExplorationFactor, (nint)filter_data, @@ -306,6 +304,13 @@ public bool CheckExternalIdValid(ulong context, nint index, ReadOnlySpan e public static partial class NativeDiskANNMethods { + public enum DiskANNInsertResult : byte + { + False = 0, + True = 1, + QuantizationRequested = 2, + } + const string DISKANN_GARNET = "diskann_garnet"; [LibraryImport(DISKANN_GARNET)] @@ -330,12 +335,11 @@ nint index ); [LibraryImport(DISKANN_GARNET)] - public static partial byte insert( + public static partial DiskANNInsertResult insert( ulong context, nint index, nint id_data, nuint id_len, - VectorValueType vector_value_type, nint vector_data, nuint vector_len, nint attribute_data, @@ -364,7 +368,6 @@ nuint attribute_len public static partial int search_vector( ulong context, nint index, - VectorValueType vector_value_type, nint vector_data, nuint vector_len, float delta, @@ -430,5 +433,19 @@ public static partial byte check_external_id_valid( nint external_id, nuint external_id_len ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte build_quant_table( + ulong context, + nint index + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial void backfill_quant_vectors( + ulong context, + nint index, + nuint task_index, + nuint task_count + ); } } \ No newline at end of file diff --git a/libs/server/Resp/Vector/RespServerSessionVectors.cs b/libs/server/Resp/Vector/RespServerSessionVectors.cs index a81c86a1d82..d0e9df227bc 100644 --- a/libs/server/Resp/Vector/RespServerSessionVectors.cs +++ b/libs/server/Resp/Vector/RespServerSessionVectors.cs @@ -125,7 +125,7 @@ private bool NetworkVADD(ref TGarnetApi storageApi) curIx++; } } - else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XU8"u8) || parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) // XB8 preserved for backwards compatability, prefer XU8 { curIx++; if (curIx >= parseState.Count) @@ -135,14 +135,21 @@ private bool NetworkVADD(ref TGarnetApi storageApi) var asBytes = parseState.GetArgSliceByRef(curIx).Span; - vectorDims = asBytes.Length; - if (vectorDims > VectorManager.MaxVectorDimensions) + valueType = VectorValueType.XU8; + values = asBytes; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XI8"u8)) + { + curIx++; + if (curIx >= parseState.Count) { - return AbortWithErrorMessage($"ERR vector exceeds maximum of {VectorManager.MaxVectorDimensions} dimensions"); + return AbortWithWrongNumberOfArguments("VADD"); } + var asBytes = parseState.GetArgSliceByRef(curIx).Span; curIx++; - valueType = VectorValueType.XB8; + + valueType = VectorValueType.XI8; values = asBytes; } else @@ -150,6 +157,11 @@ private bool NetworkVADD(ref TGarnetApi storageApi) return AbortWithErrorMessage("ERR invalid vector specification"); } + if (vectorDims > VectorManager.MaxVectorDimensions) + { + return AbortWithErrorMessage($"ERR vector exceeds maximum of {VectorManager.MaxVectorDimensions} dimensions"); + } + if (reduceDim > vectorDims) { return AbortWithErrorMessage("ERR REDUCE dimension must be <= vector dimensions"); @@ -231,14 +243,50 @@ private bool NetworkVADD(ref TGarnetApi storageApi) continue; } - else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XPREQ8"u8)) + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XNOQUANT_U8"u8) || parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XPREQ8"u8)) // XPREQ8 kept for backwards compatability, prefer XNOQUANT_U8 + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.XNoQuant_U8; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XNOQUANT_I8"u8)) { if (quantType != null) { return AbortWithErrorMessage("Quantization specified multiple times"); } - quantType = VectorQuantType.XPreQ8; + quantType = VectorQuantType.XNoQuant_I8; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XBIN_I8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.XBin_I8; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XBIN_U8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.XBin_U8; curIx++; continue; @@ -389,17 +437,11 @@ private bool NetworkVADD(ref TGarnetApi storageApi) return true; } - if (quantType != VectorQuantType.XPreQ8 && quantType != VectorQuantType.NoQuant) - { - WriteError("ERR Unsupported quantization type"u8); - return true; - } - // We need to reject these HERE because validation during create_index is very awkward GarnetStatus res; VectorManagerResult result; ReadOnlySpan customErrMsg; - if (quantType == VectorQuantType.XPreQ8 && reduceDim != 0) + if (quantType is VectorQuantType.XBin_U8 or VectorQuantType.XBin_I8 or VectorQuantType.XNoQuant_U8 or VectorQuantType.XNoQuant_I8 && reduceDim != 0) { result = VectorManagerResult.BadParams; res = GarnetStatus.OK; @@ -536,7 +578,7 @@ private bool NetworkVSIM(ref TGarnetApi storageApi) values = asBytes; curIx++; } - else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("XU8"u8) || kind.Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) // XB8 preserved for backwards compatability, prefer XU8 { if (curIx >= parseState.Count) { @@ -544,15 +586,33 @@ private bool NetworkVSIM(ref TGarnetApi storageApi) } var asBytes = parseState.GetArgSliceByRef(curIx).Span; - if (asBytes.Length > VectorManager.MaxVectorDimensions) { return AbortWithErrorMessage($"ERR vector exceeds maximum of {VectorManager.MaxVectorDimensions} dimensions"); } - valueType = VectorValueType.XB8; + curIx++; + + valueType = VectorValueType.XU8; values = asBytes; + } + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("XI8"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + if (asBytes.Length > VectorManager.MaxVectorDimensions) + { + return AbortWithErrorMessage($"ERR vector exceeds maximum of {VectorManager.MaxVectorDimensions} dimensions"); + } + curIx++; + + valueType = VectorValueType.XI8; + values = asBytes; } else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("VALUES"u8)) { @@ -815,17 +875,18 @@ private bool NetworkVSIM(ref TGarnetApi storageApi) var filterBitmapResult = SpanByteAndMemory.FromPinnedSpan(bitmapSpace); try { - GarnetStatus res; VectorManagerResult vectorRes; VectorIdFormat idFormat; + scoped ReadOnlySpan customErrMsg; if (!element.HasValue) { - res = storageApi.VectorSetValueSimilarity(key, valueType, ArgSlice.FromPinnedSpan(values), count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes, ref filterBitmapResult); + res = storageApi.VectorSetValueSimilarity(key, valueType, ArgSlice.FromPinnedSpan(values), count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, out customErrMsg, ref distanceResult, ref attributeResult, out vectorRes, ref filterBitmapResult); } else { res = storageApi.VectorSetElementSimilarity(key, element.Value, count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes, ref filterBitmapResult); + customErrMsg = default; } if (res == GarnetStatus.NOTFOUND) @@ -967,6 +1028,15 @@ private bool NetworkVSIM(ref TGarnetApi storageApi) } } } + else if (vectorRes == VectorManagerResult.BadParams) + { + if (customErrMsg.IsEmpty) + { + return AbortWithErrorMessage("ERR asked quantization mismatch with existing vector set"u8); + } + + return AbortWithErrorMessage(customErrMsg); + } else { throw new GarnetException($"Unexpected {nameof(VectorManagerResult)}: {vectorRes}"); @@ -1232,7 +1302,10 @@ private bool NetworkVINFO(ref TGarnetApi storageApi) VectorQuantType.NoQuant => "f32"u8, VectorQuantType.Bin => "bin"u8, VectorQuantType.Q8 => "q8"u8, - VectorQuantType.XPreQ8 => "xpreq8"u8, + VectorQuantType.XNoQuant_U8 => "xnoquant_u8"u8, + VectorQuantType.XNoQuant_I8 => "xnoquant_i8"u8, + VectorQuantType.XBin_I8 => "xbin_i8"u8, + VectorQuantType.XBin_U8 => "xbin_u8"u8, _ => throw new GarnetException($"Invalid VectorQuantType: {quantType}"), }; diff --git a/libs/server/Resp/Vector/VectorManager.Cleanup.cs b/libs/server/Resp/Vector/VectorManager.Cleanup.cs index ae2bdd3c6f2..cfd06f8a8c8 100644 --- a/libs/server/Resp/Vector/VectorManager.Cleanup.cs +++ b/libs/server/Resp/Vector/VectorManager.Cleanup.cs @@ -79,7 +79,7 @@ public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata re private readonly Channel cleanupTaskChannel; private readonly Task cleanupTask; - private readonly Func getCleanupSession; + private readonly Func getTempSession; private async Task RunCleanupTaskAsync() { @@ -102,7 +102,7 @@ private async Task RunCleanupTaskAsync() } // TODO: this doesn't work with non-RESP impls... which maybe we don't care about? - using var cleanupSession = (RespServerSession)getCleanupSession(); + using var cleanupSession = (RespServerSession)getTempSession(); if (cleanupSession.activeDbId != dbId && !cleanupSession.TrySwitchActiveDatabaseSession(dbId)) { throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); diff --git a/libs/server/Resp/Vector/VectorManager.ElementData.cs b/libs/server/Resp/Vector/VectorManager.ElementData.cs new file mode 100644 index 00000000000..62b6b952622 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.ElementData.cs @@ -0,0 +1,339 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Methods for converting data as received by Garnet into formats (and alignment) that DiskANN expects. + /// + public sealed partial class VectorManager + { + /// + /// A vector's element data that has been prepared to be passed to DiskANN. + /// + /// This means the data is pinned, correctly aligned, and converted to the expected format for the target index. + /// + internal readonly ref struct PreparedVectorData : IDisposable + { + private readonly GCHandle pin; + private readonly byte[] rentedArray; + + /// + /// Vector data to pass to DiskANN. + /// + public readonly ReadOnlySpan ReadOnlySpan { get; } + + /// + /// Count of elements in . + /// + /// This is not the same as as the data might represent + /// something other than bytes. + /// + public readonly int ElementCount { get; } + + /// + /// Create a with an already pinned span. + /// + internal PreparedVectorData(ReadOnlySpan data, int count) + { + pin = default; + rentedArray = default; + ReadOnlySpan = data; + ElementCount = count; + } + + /// + /// Create a with an array and a gc pin handle. + /// + internal PreparedVectorData(GCHandle pin, byte[] data, int count) + { + this.pin = pin; + rentedArray = data; + ReadOnlySpan = data; + ElementCount = count; + } + + /// + public void Dispose() + { + if (rentedArray != null) + { + pin.Free(); + ArrayPool.Shared.Return(rentedArray); + } + } + } + + /// + /// Ensure the provided vector data is converted and aligned for passing to DiskANN. + /// + /// Quantizers have an internal "native" format they expect all vectors to passed in. + /// + /// -> + /// -> + /// -> + /// -> + /// -> + /// -> + /// -> + /// + /// + /// Even if the formats match, the data must also be aligned to the element's native alignment (i.e. for that's 4 bytes, for it's 1 byte). + /// + private static PreparedVectorData PrepareVectorData(VectorQuantType quantType, VectorValueType valueType, ReadOnlySpan providedData, out ReadOnlySpan error) + { + switch (quantType) + { + // All Redis compatible quantizers expect F32 vectors + case VectorQuantType.NoQuant: + case VectorQuantType.Q8: + case VectorQuantType.Bin: + error = default; + switch (valueType) + { + case VectorValueType.FP32: return ConvertF32ForAlignment(providedData); + case VectorValueType.XI8: return ConvertI8ToF32(providedData); + case VectorValueType.XU8: return ConvertU8ToF32(providedData); + + case VectorValueType.Invalid: + default: throw new InvalidOperationException($"Unexpected VectorValueType: {valueType}"); + } + // XNoQuant_U8 and XBin_U8 expected U8 vectors + case VectorQuantType.XNoQuant_U8: + case VectorQuantType.XBin_U8: + switch (valueType) + { + case VectorValueType.FP32: return ConvertF32ToU8(providedData, out error); + case VectorValueType.XI8: return ConvertI8ToU8(providedData, out error); + case VectorValueType.XU8: + error = default; + return new(providedData, providedData.Length); + + case VectorValueType.Invalid: + default: throw new InvalidOperationException($"Unexpected VectorValueType: {valueType}"); + } + // XNoQuant_I8 and XBin_I8 expects I8 vectors + case VectorQuantType.XNoQuant_I8: + case VectorQuantType.XBin_I8: + switch (valueType) + { + case VectorValueType.FP32: return ConvertF32ToI8(providedData, out error); + case VectorValueType.XI8: + error = default; + return new(providedData, providedData.Length); + case VectorValueType.XU8: return ConvertU8ToI8(providedData, out error); + + case VectorValueType.Invalid: + default: throw new InvalidOperationException($"Unexpected VectorValueType: {valueType}"); + } + + case VectorQuantType.Invalid: + default: throw new InvalidOperationException($"Unexpected VectorQuantType: {quantType}"); + } + + // Copy provided data (which is assumed to have floats in it) to a pinned and aligned buffer if needed + static PreparedVectorData ConvertF32ForAlignment(ReadOnlySpan providedData) + { + var numElements = providedData.Length / sizeof(float); + + unsafe + { + // Already aligned, pass it on down + var isAligned = ((nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(providedData)) % sizeof(float)) == 0; + if (isAligned) + { + return new(providedData, numElements); + } + } + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(providedData.Length); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + unsafe + { + Debug.Assert(((nint)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(toCopyTo)) % sizeof(float)) == 0, "byte[] should be word aligned, or we're in trouble"); + } + + providedData.CopyTo(toCopyTo); + + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have signed byte in it) to a pinned buffer full of floats + static PreparedVectorData ConvertI8ToF32(ReadOnlySpan providedData) + { + var asI8 = MemoryMarshal.Cast(providedData); + + var numElements = providedData.Length; + var numBytes = numElements * sizeof(float); + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numBytes); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + unsafe + { + Debug.Assert(((nint)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(toCopyTo)) % sizeof(float)) == 0, "byte[] should be word aligned, or we're in trouble"); + } + + var asFloats = MemoryMarshal.Cast(toCopyTo.AsSpan()); + + // Do the actual copy, which expands sbyte -> float + for (var i = 0; i < numElements; i++) + { + asFloats[i] = asI8[i]; + } + + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have unsigned bytes in it) to a pinned buffer full of floats + static PreparedVectorData ConvertU8ToF32(ReadOnlySpan providedData) + { + var numElements = providedData.Length; + var numBytes = numElements * sizeof(float); + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numBytes); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + unsafe + { + Debug.Assert(((nint)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(toCopyTo)) % sizeof(float)) == 0, "byte[] should be word aligned, or we're in trouble"); + } + + var asFloats = MemoryMarshal.Cast(toCopyTo.AsSpan()); + + // Do the actual copy, which expands byte -> float + for (var i = 0; i < numElements; i++) + { + asFloats[i] = providedData[i]; + } + + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have floats in it) to a pinned buffer full of unsigned bytes + static PreparedVectorData ConvertF32ToU8(ReadOnlySpan providedData, out ReadOnlySpan error) + { + var asFloats = MemoryMarshal.Cast(providedData); + var numElements = asFloats.Length; + + // Validate vector can convert from f32 -> u8 without issue + if (asFloats.ContainsAnyExceptInRange(byte.MinValue, byte.MaxValue)) + { + error = "Vector contains element that is < 0 or > 255, operation will lose precision"u8; + return default; + } + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numElements); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + // Do the actual copy, which truncates f32 -> u8 + for (var i = 0; i < numElements; i++) + { + var f32 = asFloats[i]; + toCopyTo[i] = (byte)f32; + } + + error = default; + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have signed bytes in it) to a pinned buffer full of unsigned bytes + static PreparedVectorData ConvertI8ToU8(ReadOnlySpan providedData, out ReadOnlySpan error) + { + var asI8 = MemoryMarshal.Cast(providedData); + var numElements = asI8.Length; + + // Validate vector can convert from f32 -> u8 without issue + if (asI8.ContainsAnyInRange(sbyte.MinValue, (sbyte)-1)) + { + error = "Vector contains element that is < 0, operation will lose precision"u8; + return default; + } + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numElements); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + // Do the actual copy, which truncates i8 -> u8 + for (var i = 0; i < numElements; i++) + { + var i8 = asI8[i]; + toCopyTo[i] = (byte)i8; + } + + error = default; + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have floats in it) to a pinned buffer full of signed bytes + static PreparedVectorData ConvertF32ToI8(ReadOnlySpan providedData, out ReadOnlySpan error) + { + var asFloats = MemoryMarshal.Cast(providedData); + var numElements = asFloats.Length; + + // Validate vector can convert from f32 -> u8 without issue + if (asFloats.ContainsAnyExceptInRange(sbyte.MinValue, sbyte.MaxValue)) + { + error = "Vector contains element that is < -128 or > 127, operation will lose precision"u8; + return default; + } + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numElements); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + var asSBytes = MemoryMarshal.Cast(toCopyTo.AsSpan()); + + // Do the actual copy, which truncates f32 -> i8 + for (var i = 0; i < numElements; i++) + { + var f32 = asFloats[i]; + asSBytes[i] = (sbyte)f32; + } + + error = default; + return new(pin, toCopyTo, numElements); + } + + // Convert provided data (which is assumed to have unsigned bytes in it) to a pinned buffer full of signed bytes + static PreparedVectorData ConvertU8ToI8(ReadOnlySpan providedData, out ReadOnlySpan error) + { + var numElements = providedData.Length; + + // Validate vector can convert from u8 -> i8 without issue + if (providedData.ContainsAnyInRange((byte)128, byte.MaxValue)) + { + error = "Vector contains element that is > 127, operation will lose precision"u8; + return default; + } + + // Need to copy to an aligned buffer + var toCopyTo = ArrayPool.Shared.Rent(numElements); + var pin = GCHandle.Alloc(toCopyTo, GCHandleType.Pinned); + + var asI8 = MemoryMarshal.Cast(toCopyTo.AsSpan()); + + // Do the actual copy, which truncates i8 -> u8 + for (var i = 0; i < numElements; i++) + { + var u8 = providedData[i]; + asI8[i] = (sbyte)u8; + } + + error = default; + return new(pin, toCopyTo, numElements); + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Locking.cs b/libs/server/Resp/Vector/VectorManager.Locking.cs index 718232a9e5e..10b823b0a12 100644 --- a/libs/server/Resp/Vector/VectorManager.Locking.cs +++ b/libs/server/Resp/Vector/VectorManager.Locking.cs @@ -167,9 +167,10 @@ internal ReadVectorLock ReadVectorIndex(StorageSession storageSession, ref SpanB input.arg1 = RecreateIndexArg; nint newlyAllocatedIndex; + bool requestQuantization; unsafe { - newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr, out requestQuantization); } input.header.cmd = RespCommand.VADD; @@ -210,6 +211,12 @@ internal ReadVectorLock ReadVectorIndex(StorageSession storageSession, ref SpanB if (writeRes == GarnetStatus.OK) { + // Post recreate the index might already need quantization - if so, queue it up + if (requestQuantization) + { + _ = quantizationChannel.Writer.TryWrite(new(key.ToByteArray(), QuantizationStep.BuildQuantizationTable, 0)); + } + // Try again so we don't hold an exclusive lock while performing a search vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); continue; @@ -317,6 +324,7 @@ out GarnetStatus status ulong indexContext; nint newlyAllocatedIndex; + bool requestQuantization; if (needsRecreate) { ReadIndex(indexSpan, out indexContext, out var dims, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var distanceMetric, out _, out _); @@ -325,7 +333,7 @@ out GarnetStatus status unsafe { - newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr, out requestQuantization); } input.parseState.EnsureCapacity(12); @@ -357,7 +365,7 @@ out GarnetStatus status unsafe { - newlyAllocatedIndex = Service.CreateIndex(indexContext, dims, reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + newlyAllocatedIndex = Service.CreateIndex(indexContext, dims, reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr, out requestQuantization); } input.parseState.EnsureCapacity(12); @@ -417,6 +425,12 @@ out GarnetStatus status if (writeRes == GarnetStatus.OK) { + // Post (re)create the index might already need quantization - if so, queue it up + if (requestQuantization) + { + _ = quantizationChannel.Writer.TryWrite(new(key.ToByteArray(), QuantizationStep.BuildQuantizationTable, 0)); + } + // Try again so we don't hold an exclusive lock while adding a vector (which might be time consuming) vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); continue; diff --git a/libs/server/Resp/Vector/VectorManager.Migration.cs b/libs/server/Resp/Vector/VectorManager.Migration.cs index 69235d18aca..209749ac083 100644 --- a/libs/server/Resp/Vector/VectorManager.Migration.cs +++ b/libs/server/Resp/Vector/VectorManager.Migration.cs @@ -173,9 +173,10 @@ public void HandleMigratedIndexKey( var distanceMetricArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref distanceMetric, 1))); nint newlyAllocatedIndex; + bool requestQuantization; unsafe { - newlyAllocatedIndex = Service.RecreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + newlyAllocatedIndex = Service.RecreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, distanceMetric, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr, out requestQuantization); } var ctxArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref context, 1))); @@ -207,6 +208,12 @@ public void HandleMigratedIndexKey( UpdateContextMetadata(ref ActiveThreadSession.vectorContext); + // Post recreation the index might already need quantization - if so, queue it up + if (requestQuantization) + { + _ = quantizationChannel.Writer.TryWrite(new(key.ToByteArray(), QuantizationStep.BuildQuantizationTable, 0)); + } + // For REPLICAs which are following, we need to fake up a write ReplicateMigratedIndexKey(ref ActiveThreadSession.basicContext, ref key, ref value, context, logger); } diff --git a/libs/server/Resp/Vector/VectorManager.Quantization.cs b/libs/server/Resp/Vector/VectorManager.Quantization.cs new file mode 100644 index 00000000000..33afc4c200c --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Quantization.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + public partial class VectorManager + { + /// + /// Different steps of quantization process. + /// + private enum QuantizationStep + { + Invalid = 0, + + /// + /// Build the quantization table - only one task can do this per Vector Set Index. + /// + BuildQuantizationTable, + + /// + /// Backfill quantized vectors - many tasks can do this concurrently for a Vector Set Index. + /// + BackfillQuantizedVectors, + } + + private readonly record struct QuantizationState(ReadOnlyMemory Key, QuantizationStep Step, int StepIndex); + + private readonly Channel quantizationChannel; + private readonly Task[] quantizationTasks; + + /// + /// Populate with running tasks for handling any quantization requests. + /// + public void StartQuantizationTasks() + { + for (var i = 0; i < quantizationTasks.Length; i++) + { + quantizationTasks[i] = QuantizationTaskAsync(this, quantizationChannel.Reader, quantizationChannel.Writer); + } + + static async Task QuantizationTaskAsync(VectorManager self, ChannelReader reader, ChannelWriter writer) + { + // Force async + await Task.Yield(); + + while (await reader.WaitToReadAsync().ConfigureAwait(false)) + { + using var session = (RespServerSession)self.getTempSession(); + + Span indexSpan = new byte[IndexSizeBytes]; + + while (reader.TryRead(out var state)) + { + try + { + unsafe + { + fixed (byte* keyPtr = state.Key.Span) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, state.Key.Length); + + // Dummy command, we just need something Vector Set-y + RawStringInput input = default; + input.header.cmd = RespCommand.VSIM; + + using (self.ReadVectorIndex(session.storageSession, ref keySpan, ref input, indexSpan, out var res)) + { + if (res != GarnetStatus.OK) + { + // Index was dropped before quantization request could be processed, ignore request + continue; + } + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out var indexPtr, out _); + + switch (state.Step) + { + case QuantizationStep.BuildQuantizationTable: + if (self.Service.BuildQuantizationTable(context, indexPtr)) + { + // Schedule backfill after quantization table is available + for (var i = 0; i < self.quantizationTasks.Length; i++) + { + _ = writer.TryWrite(new(state.Key, QuantizationStep.BackfillQuantizedVectors, i)); + } + } + + break; + + case QuantizationStep.BackfillQuantizedVectors: + self.Service.BackfillQuantizedVectors(context, indexPtr, state.StepIndex, self.quantizationTasks.Length); + break; + default: + self.logger?.LogError("Unexpected step: {step}", state.Step); + break; + } + } + } + } + } + catch (Exception ex) + { + self.logger?.LogError(ex, "During Vector Set quantization"); + } + } + } + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Replication.cs b/libs/server/Resp/Vector/VectorManager.Replication.cs index 5f5249a66ba..4b7a08b9f58 100644 --- a/libs/server/Resp/Vector/VectorManager.Replication.cs +++ b/libs/server/Resp/Vector/VectorManager.Replication.cs @@ -446,7 +446,7 @@ static unsafe void ApplyVectorSetAdd(VectorManager self, StorageSession storageS { Debug.Assert(status == GarnetStatus.OK, "Replication should only occur when an add is successful, so index must exist"); - var addRes = self.TryAdd(indexSpan, element.AsReadOnlySpan(), valueType, values.AsReadOnlySpan(), attributes.AsReadOnlySpan(), reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, out _); + var addRes = self.TryAdd(ref key, indexSpan, element.AsReadOnlySpan(), valueType, values.AsReadOnlySpan(), attributes.AsReadOnlySpan(), reduceDims, quantizer, buildExplorationFactor, numLinks, distanceMetric, out _); if (addRes != VectorManagerResult.OK) { diff --git a/libs/server/Resp/Vector/VectorManager.cs b/libs/server/Resp/Vector/VectorManager.cs index 0bf96c6e67c..f6f3f649f6d 100644 --- a/libs/server/Resp/Vector/VectorManager.cs +++ b/libs/server/Resp/Vector/VectorManager.cs @@ -127,6 +127,8 @@ private static void EnsureIdBufferSize(ref SpanByteAndMemory buffer, int retriev /// public bool IsEnabled { get; } + private bool initialized; + /// /// Unique id for this . /// @@ -138,7 +140,7 @@ private static void EnsureIdBufferSize(ref SpanByteAndMemory buffer, int retriev private readonly int dbId; - public VectorManager(int dbId, GarnetServerOptions serverOptions, Func getCleanupSession, ILoggerFactory loggerFactory) + public VectorManager(int dbId, GarnetServerOptions serverOptions, Func getTempSession, ILoggerFactory loggerFactory) { this.dbId = dbId; @@ -161,10 +163,18 @@ public VectorManager(int dbId, GarnetServerOptions serverOptions, Func(new() { SingleWriter = false, SingleReader = true, AllowSynchronousContinuations = false }); cleanupTask = RunCleanupTaskAsync(); + quantizationChannel = Channel.CreateUnbounded(new() { SingleWriter = false, SingleReader = false, AllowSynchronousContinuations = false }); + + if (serverOptions.VectorSetQuantizationTaskCount < 0 || serverOptions.VectorSetQuantizationTaskCount > Environment.ProcessorCount) + throw new GarnetException($"VectorSetQuantizationTaskCount should be in range [0,{Environment.ProcessorCount}]!"); + var vectorSetQuantizationTaskCount = serverOptions.VectorSetQuantizationTaskCount == 0 ? Environment.ProcessorCount : serverOptions.VectorSetQuantizationTaskCount; + quantizationTasks = new Task[vectorSetQuantizationTaskCount]; + Array.Fill(quantizationTasks, Task.CompletedTask); + logger?.LogInformation("Created VectorManager"); } @@ -173,9 +183,11 @@ public VectorManager(int dbId, GarnetServerOptions serverOptions, Func public void Initialize() { - if (!IsEnabled) return; + if (!IsEnabled || initialized) return; + + initialized = true; - using var session = (RespServerSession)getCleanupSession(); + using var session = (RespServerSession)getTempSession(); if (session.activeDbId != dbId && !session.TrySwitchActiveDatabaseSession(dbId)) { throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); @@ -210,6 +222,7 @@ public void Initialize() } } + StartQuantizationTasks(); } /// @@ -219,7 +232,7 @@ public void ResumePostRecovery() { if (!IsEnabled) return; - using var session = (RespServerSession)getCleanupSession(); + using var session = (RespServerSession)getTempSession(); ref var ctx = ref session.storageSession.vectorContext; @@ -358,6 +371,11 @@ public void Dispose() cleanupTaskChannel.Writer.Complete(); AsyncUtils.BlockingWait(cleanupTaskChannel.Reader.Completion); AsyncUtils.BlockingWait(cleanupTask); + + // drain quantization task + _ = quantizationChannel.Writer.TryComplete(); + while (quantizationChannel.Reader.TryRead(out _)) { } + AsyncUtils.BlockingWait(Task.WhenAll(quantizationTasks)); } private static void CompletePending(ref Status status, ref SpanByte output, ref TContext ctx) @@ -379,6 +397,7 @@ private static void CompletePending(ref Status status, ref SpanByte ou /// /// Result of the operation. internal VectorManagerResult TryAdd( + scoped ref SpanByte key, scoped ReadOnlySpan indexValue, ReadOnlySpan element, VectorValueType valueType, @@ -398,22 +417,7 @@ out ReadOnlySpan errorMsg ReadIndex(indexValue, out var context, out var dimensions, out var reduceDims, out var quantType, out _, out var numLinks, out var distanceMetric, out var indexPtr, out _); - var valueDims = CalculateValueDimensions(valueType, values); - - if (dimensions != valueDims) - { - // Matching Redis behavior - errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {dimensions}"); - return VectorManagerResult.BadParams; - } - - if (providedReduceDims == 0 && reduceDims != 0) - { - // Matching Redis behavior, which is definitely a bit weird here - errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {reduceDims}"); - return VectorManagerResult.BadParams; - } - else if (providedReduceDims != 0 && providedReduceDims != reduceDims) + if (providedReduceDims != 0 && providedReduceDims != reduceDims) { return VectorManagerResult.BadParams; } @@ -436,18 +440,47 @@ out ReadOnlySpan errorMsg return VectorManagerResult.BadParams; } - var insert = - Service.Insert( - context, - indexPtr, - element, - valueType, - values, - attributes - ); + bool insert; + bool needsQuantization; + using (var vectorData = PrepareVectorData(quantType, valueType, values, out errorMsg)) + { + if (!errorMsg.IsEmpty) + { + return VectorManagerResult.BadParams; + } + + if (vectorData.ElementCount != dimensions) + { + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {vectorData.ElementCount} but set has {dimensions}"); + return VectorManagerResult.BadParams; + } + + if (providedReduceDims == 0 && reduceDims != 0) + { + // Matching Redis behavior, which is definitely a bit weird here + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {vectorData.ElementCount} but set has {reduceDims}"); + return VectorManagerResult.BadParams; + } + + insert = + Service.Insert( + context, + indexPtr, + element, + vectorData.ReadOnlySpan, + vectorData.ElementCount, + attributes, + out needsQuantization + ); + } if (insert) { + if (needsQuantization) + { + _ = this.quantizationChannel.Writer.TryWrite(new(key.ToByteArray(), QuantizationStep.BuildQuantizationTable, 0)); + } + return VectorManagerResult.OK; } @@ -544,17 +577,18 @@ internal Status TryDeleteVectorSet(StorageSession storageSession, ref SpanByte k /// Perform a similarity search given a vector to compare against. /// internal VectorManagerResult ValueSimilarity( - ReadOnlySpan indexValue, + scoped ReadOnlySpan indexValue, VectorValueType valueType, - ReadOnlySpan values, + scoped ReadOnlySpan values, int count, float delta, int searchExplorationFactor, - ReadOnlySpan filter, + scoped ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, + scoped out ReadOnlySpan errorMsg, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, ref SpanByteAndMemory filterBitmap @@ -564,13 +598,6 @@ ref SpanByteAndMemory filterBitmap ReadIndex(indexValue, out var context, out var dimensions, out _, out var quantType, out _, out _, out _, out var indexPtr, out _); - var valueDims = CalculateValueDimensions(valueType, values); - if (dimensions != valueDims) - { - outputIdFormat = VectorIdFormat.Invalid; - return VectorManagerResult.BadParams; - } - // When a filter is present, over-retrieve candidates from DiskANN so that // post-filtering has enough results to fill the requested count. // @@ -591,25 +618,46 @@ ref SpanByteAndMemory filterBitmap EnsureDistanceBufferSize(ref outputDistances, retrieveCount); EnsureIdBufferSize(ref outputIds, retrieveCount); - var found = - Service.SearchVector( - context, - indexPtr, - valueType, - values, - delta, - effectiveEF, - filter, - maxFilteringEffort, - outputIds, - outputDistances, - out var continuation - ); + int found; + nint continuation; + using (var vectorData = PrepareVectorData(quantType, valueType, values, out var tempErrorMsg)) + { + if (!tempErrorMsg.IsEmpty) + { + // Have to copy for scoping reasons - it's an error path, so we'll just eat the perf hit for now + errorMsg = tempErrorMsg.ToArray(); + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + if (dimensions != vectorData.ElementCount) + { + outputIdFormat = VectorIdFormat.Invalid; + errorMsg = default; + return VectorManagerResult.BadParams; + } + + found = + Service.SearchVector( + context, + indexPtr, + vectorData.ReadOnlySpan, + vectorData.ElementCount, + delta, + effectiveEF, + filter, + maxFilteringEffort, + outputIds, + outputDistances, + out continuation + ); + } if (found < 0) { logger?.LogWarning("Error indicating response from vector service {found}", found); outputIdFormat = VectorIdFormat.Invalid; + errorMsg = default; return VectorManagerResult.BadParams; } @@ -633,7 +681,7 @@ out var continuation filterBitmap = new SpanByteAndMemory(MemoryPool.Shared.Rent(requiredBitmapBytes), requiredBitmapBytes); } - ApplyPostFilter(filter, found, outputAttributes.AsReadOnlySpan(), filterBitmap.AsSpan(), ActiveThreadSession.scratchBufferBuilder); + _ = ApplyPostFilter(filter, found, outputAttributes.AsReadOnlySpan(), filterBitmap.AsSpan(), ActiveThreadSession.scratchBufferBuilder); } if (continuation != 0) @@ -647,13 +695,7 @@ out var continuation // Default assumption is length prefixed outputIdFormat = VectorIdFormat.I32LengthPrefixed; - if (quantType == VectorQuantType.XPreQ8) - { - // But in this special case, we force them to be 4-byte ids - //outputIdFormat = VectorIdFormat.FixedI32; - outputIdFormat = VectorIdFormat.I32LengthPrefixed; - } - + errorMsg = default; return VectorManagerResult.OK; } @@ -736,7 +778,7 @@ out var continuation filterBitmap = new SpanByteAndMemory(MemoryPool.Shared.Rent(requiredBitmapBytes), requiredBitmapBytes); } - ApplyPostFilter(filter, found, outputAttributes.AsReadOnlySpan(), filterBitmap.AsSpan(), ActiveThreadSession.scratchBufferBuilder); + _ = ApplyPostFilter(filter, found, outputAttributes.AsReadOnlySpan(), filterBitmap.AsSpan(), ActiveThreadSession.scratchBufferBuilder); } if (continuation != 0) @@ -750,13 +792,6 @@ out var continuation // Default assumption is length prefixed outputIdFormat = VectorIdFormat.I32LengthPrefixed; - if (quantType == VectorQuantType.XPreQ8) - { - // But in this special case, we force them to be 4-byte ids - //outputIdFormat = VectorIdFormat.FixedI32; - outputIdFormat = VectorIdFormat.I32LengthPrefixed; - } - return VectorManagerResult.OK; } @@ -923,22 +958,38 @@ internal bool TryGetEmbedding(ReadOnlySpan indexValue, ReadOnlySpan var into = MemoryMarshal.Cast(outputDistances.AsSpan()); var from = asBytes.AsReadOnlySpan(); - if (quantType == VectorQuantType.NoQuant) - { - var fromFloat = MemoryMarshal.Cast(from); - fromFloat.CopyTo(into); - } - else if (quantType == VectorQuantType.XPreQ8) - { - for (var i = 0; i < asBytes.Length; i++) - { - into[i] = from[i]; - } - } - else + + // Internal vector format differs depend on the selected quantizer, so do that mapping as needed + switch (quantType) { - // TODO: Handle Q8 and BIN as they are implemented - throw new NotImplementedException($"Unexpected quantization: {quantType}"); + // All Redis quantizers store F32s + case VectorQuantType.Bin: + case VectorQuantType.Q8: + case VectorQuantType.NoQuant: + MemoryMarshal.Cast(from).CopyTo(into); + break; + + // XNoQuant_I8 & XBin_I8 stores _signed_ bytes + case VectorQuantType.XNoQuant_I8: + case VectorQuantType.XBin_I8: + for (var i = 0; i < from.Length; i++) + { + into[i] = (sbyte)from[i]; + } + break; + + // XNoQuant_I8 & NoQuant_U8 stores unsigned bytes + case VectorQuantType.XNoQuant_U8: + case VectorQuantType.XBin_U8: + for (var i = 0; i < from.Length; i++) + { + into[i] = from[i]; + } + break; + + case VectorQuantType.Invalid: + default: + throw new InvalidOperationException($"Unexpected VectorQuantType: {quantType}"); } // Vector might have been deleted, so check that after getting data @@ -950,25 +1001,6 @@ internal bool TryGetEmbedding(ReadOnlySpan indexValue, ReadOnlySpan } } - /// - /// Determine the dimensions of a vector given its and its raw data. - /// - internal static uint CalculateValueDimensions(VectorValueType valueType, ReadOnlySpan values) - { - if (valueType == VectorValueType.FP32) - { - return (uint)(values.Length / sizeof(float)); - } - else if (valueType == VectorValueType.XB8) - { - return (uint)(values.Length); - } - else - { - throw new NotImplementedException($"{valueType}"); - } - } - [Conditional("DEBUG")] private static void AssertHaveStorageSession() { diff --git a/libs/server/Servers/GarnetServerOptions.cs b/libs/server/Servers/GarnetServerOptions.cs index ad3c0db7846..0e6142938fc 100644 --- a/libs/server/Servers/GarnetServerOptions.cs +++ b/libs/server/Servers/GarnetServerOptions.cs @@ -550,6 +550,11 @@ public class GarnetServerOptions : ServerOptions /// public int VectorSetReplayTaskCount = 0; + /// + /// Configure how many quantization tasks are used to optimize Vector Set operations (default: 0 uses the machine CPU count). + /// + public int VectorSetQuantizationTaskCount = 0; + /// /// Get the directory name for database checkpoints /// diff --git a/libs/server/Storage/Session/MainStore/VectorStoreOps.cs b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs index 9073d73fa83..7091b56a172 100644 --- a/libs/server/Storage/Session/MainStore/VectorStoreOps.cs +++ b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs @@ -13,7 +13,7 @@ namespace Garnet.server /// /// This controls the mapping of vector elements to how they're actually stored. /// - public enum VectorQuantType + public enum VectorQuantType : int { Invalid = 0, @@ -22,23 +22,45 @@ public enum VectorQuantType /// /// Vectors stored as is with no quantization. /// - NoQuant, + NoQuant = 1, /// /// Vectors stored as binary (1 bit). /// - Bin, + Bin = 2, /// /// Vectors stored as bytes (8 bits). /// - Q8, + Q8 = 3, // Extended quantizations /// - /// Vectors stored as bytes (8 bits). XPREQ8 is a non-Redis extension, stands for: - /// eXtension PREcalculated Quantization 8-bit - requests no quantization on pre-calculated [0, 255] values + /// Vectors stored as bytes (8 bits unsigned). XNoQuant_U8 is a non-Redis extension, stands for: + /// eXtension No Quantization Unsigned integer 8 bits + /// + /// XPREQ8 aliases to this. /// - XPreQ8, + XNoQuant_U8 = 4, + + /// + /// Vectors stored as bytes (8 bits signed). XNoQuant_I8 is a non-Redis extension, stands for: + /// eXtension No Quantization Integer 8 bits + /// + /// XPREQ8 aliases to this. + /// + XNoQuant_I8 = 5, + + /// + /// Vectors stored as bytes (8 bits signed). XBin_I8 is a non-Redis extension, stands for: + /// eXtension Binary quantized Integer 8 bits + /// + XBin_I8 = 6, + + /// + /// Vectors stored as bytes (8 bits unsigned). XBin_U8 is a non-Redis extension, stands for: + /// eXtension Binary quantized Unsigned integer 8 bits + /// + XBin_U8 = 7, } /// @@ -53,14 +75,22 @@ public enum VectorValueType : int /// /// Floats (FP32). /// - FP32, + FP32 = 1, // Extended formats /// - /// Bytes (8 bit). + /// Bytes (8 bit), unsigned. XU8 is a non-Redis extensions, stands for: + /// eXtension Unsigned-integer 8 bits + /// + /// XB8 aliases to this. /// - XB8, + XU8 = 2, + + /// + /// Bytes (8 bit), signed. + /// + XI8 = 3, } /// @@ -97,17 +127,17 @@ public enum VectorDistanceMetricType : int /// /// Inner product /// - InnerProduct, + InnerProduct = 1, /// /// Squared Euclidean (L2-Squared) /// - L2, + L2 = 2, /// - /// Normalized Cosine Similarity + /// Normalized Cosine Similarity. XCosine_Normalized /// - XCosine_Normalized, + XCosine_Normalized = 3, } /// @@ -121,7 +151,13 @@ sealed partial class StorageSession : IDisposable [SkipLocalsInit] public unsafe GarnetStatus VectorSetAdd(SpanByte key, int reduceDims, VectorValueType valueType, ArgSlice values, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, VectorDistanceMetricType distanceMetric, out VectorManagerResult result, out ReadOnlySpan errorMsg) { - var dims = VectorManager.CalculateValueDimensions(valueType, values.ReadOnlySpan); + var dims = + valueType switch + { + VectorValueType.FP32 => (uint)(values.ReadOnlySpan.Length / sizeof(float)), + VectorValueType.XI8 or VectorValueType.XU8 => (uint)values.ReadOnlySpan.Length, + _ => throw new InvalidOperationException($"Unexpected VectorValueType: {valueType}"), + }; ; var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dims, 1))); var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); @@ -151,7 +187,7 @@ public unsafe GarnetStatus VectorSetAdd(SpanByte key, int reduceDims, VectorValu // After a successful read we add the vector while holding a shared lock // That lock prevents deletion, but everything else can proceed in parallel - result = vectorManager.TryAdd(indexSpan, element.ReadOnlySpan, valueType, values.ReadOnlySpan, attributes.ReadOnlySpan, (uint)reduceDims, quantizer, (uint)buildExplorationFactor, (uint)numLinks, distanceMetric, out errorMsg); + result = vectorManager.TryAdd(ref key, indexSpan, element.ReadOnlySpan, valueType, values.ReadOnlySpan, attributes.ReadOnlySpan, (uint)reduceDims, quantizer, (uint)buildExplorationFactor, (uint)numLinks, distanceMetric, out errorMsg); if (result == VectorManagerResult.OK) { @@ -200,7 +236,7 @@ public unsafe GarnetStatus VectorSetRemove(SpanByte key, SpanByte element) /// Perform a similarity search on an existing Vector Set given a vector as a bunch of floats. /// [SkipLocalsInit] - public unsafe GarnetStatus VectorSetValueSimilarity(SpanByte key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) + public unsafe GarnetStatus VectorSetValueSimilarity(SpanByte key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, out ReadOnlySpan errorMsg, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result, ref SpanByteAndMemory filterBitmap) { parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); @@ -215,10 +251,11 @@ public unsafe GarnetStatus VectorSetValueSimilarity(SpanByte key, VectorValueTyp { result = VectorManagerResult.Invalid; outputIdFormat = VectorIdFormat.Invalid; + errorMsg = default; return status; } - result = vectorManager.ValueSimilarity(indexSpan, valueType, values.ReadOnlySpan, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, ref filterBitmap); + result = vectorManager.ValueSimilarity(indexSpan, valueType, values.ReadOnlySpan, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, out errorMsg, ref outputDistances, ref outputAttributes, ref filterBitmap); return GarnetStatus.OK; } diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs index 5eb42adbd66..c1c4e56e6a5 100644 --- a/test/Garnet.test.cluster/ClusterTestContext.cs +++ b/test/Garnet.test.cluster/ClusterTestContext.cs @@ -425,6 +425,7 @@ public GarnetServer CreateInstance( bool useAcl = false, bool asyncReplay = false, int vectorSetReplayTaskCount = 0, + int vectorSetQuantizationTaskCount = 0, EndPoint clusterAnnounceEndpoint = null, X509CertificateCollection certificates = null, ServerCredential clusterCreds = new ServerCredential()) @@ -461,7 +462,8 @@ public GarnetServer CreateInstance( authPassword: clusterCreds.password, certificates: certificates, clusterAnnounceEndpoint: clusterAnnounceEndpoint, - vectorSetReplayTaskCount: vectorSetReplayTaskCount); + vectorSetReplayTaskCount: vectorSetReplayTaskCount, + vectorSetQuantizationTaskCount: vectorSetQuantizationTaskCount); return new GarnetServer(opts, loggerFactory); } diff --git a/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs index 769b6fe792b..534b17c87f8 100644 --- a/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs +++ b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs @@ -102,11 +102,13 @@ public virtual void TearDown() context?.TearDown(); } - // TODO: restore BIN and Q8 when implemented + // TODO: restore BIN, Q8, XBIN_I8 when implemented [Test] - [TestCase("XB8", "XPREQ8")] - [TestCase("XB8", "NOQUANT")] - [TestCase("FP32", "XPREQ8")] + [TestCase("XU8", "XNOQUANT_U8")] + [TestCase("XU8", "NOQUANT")] + [TestCase("XI8", "XNOQUANT_U8")] + [TestCase("XI8", "NOQUANT")] + [TestCase("FP32", "XNOQUANT_U8")] [TestCase("FP32", "NOQUANT")] public async Task BasicVADDReplicatesAsync(string vectorFormat, string quantizer) { @@ -127,7 +129,7 @@ public async Task BasicVADDReplicatesAsync(string vectorFormat, string quantizer ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); byte[] vectorAddData; - if (vectorFormatParsed == VectorValueType.XB8) + if (vectorFormatParsed == VectorValueType.XU8) { vectorAddData = new byte[75]; vectorAddData[0] = 1; @@ -136,6 +138,17 @@ public async Task BasicVADDReplicatesAsync(string vectorFormat, string quantizer vectorAddData[i] = (byte)(vectorAddData[i - 1] + 1); } } + if (vectorFormatParsed == VectorValueType.XI8) + { + var sbytes = new sbyte[75]; + sbytes[0] = 1; + for (var i = 1; i < sbytes.Length; i++) + { + sbytes[i] = (sbyte)(sbytes[i - 1] + 1); + } + + vectorAddData = MemoryMarshal.Cast(sbytes).ToArray(); + } else if (vectorFormatParsed == VectorValueType.FP32) { var floats = new float[75]; @@ -157,7 +170,7 @@ public async Task BasicVADDReplicatesAsync(string vectorFormat, string quantizer ClassicAssert.AreEqual(1, addRes); byte[] vectorSimData; - if (vectorFormatParsed == VectorValueType.XB8) + if (vectorFormatParsed == VectorValueType.XU8) { vectorSimData = new byte[75]; vectorSimData[0] = 2; @@ -166,6 +179,17 @@ public async Task BasicVADDReplicatesAsync(string vectorFormat, string quantizer vectorSimData[i] = (byte)(vectorSimData[i - 1] + 1); } } + else if (vectorFormatParsed == VectorValueType.XI8) + { + var sbytes = new sbyte[75]; + sbytes[0] = 2; + for (var i = 1; i < sbytes.Length; i++) + { + sbytes[i] = (sbyte)(sbytes[i - 1] + 1); + } + + vectorSimData = MemoryMarshal.Cast(sbytes).ToArray(); + } else if (vectorFormatParsed == VectorValueType.FP32) { var floats = new float[75]; diff --git a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs index 383f798f3b1..31f5fecbb39 100644 --- a/test/Garnet.test/DiskANN/DiskANNServiceTests.cs +++ b/test/Garnet.test/DiskANN/DiskANNServiceTests.cs @@ -170,7 +170,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); - var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XNoQuant_U8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); Span id = [0, 1, 2, 3]; Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); @@ -179,8 +179,8 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength // Insert unsafe { - var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); - ClassicAssert.AreEqual(1, insertRes); + var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); + ClassicAssert.AreEqual(NativeDiskANNMethods.DiskANNInsertResult.True, insertRes); } // Check valid initially @@ -365,7 +365,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); - var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XNoQuant_U8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); Span id = [0, 1, 2, 3]; Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); @@ -374,8 +374,8 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength // Insert unsafe { - var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); - ClassicAssert.AreEqual(1, insertRes); + var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); + ClassicAssert.AreEqual(NativeDiskANNMethods.DiskANNInsertResult.True, insertRes); } Span filter = []; @@ -391,7 +391,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var numRes = NativeDiskANNMethods.search_vector( Context, rawIndex, - VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, 1f, outputDistances.Length, // SearchExplorationFactor must >= Count (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, 0, @@ -410,7 +410,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength { NativeDiskANNMethods.drop_index(Context, rawIndex); - rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XNoQuant_U8, VectorDistanceMetricType.L2, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); } // Search value @@ -424,7 +424,7 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var numRes = NativeDiskANNMethods.search_vector( Context, rawIndex, - VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, 1f, outputDistances.Length, // SearchExplorationFactor must >= Count (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, 0, @@ -486,10 +486,10 @@ unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength var insertRes = NativeDiskANNMethods.insert( Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id2)), (nuint)id2.Length, - VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem2)), (nuint)elem2.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem2)), (nuint)elem2.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr2)), (nuint)attr2.Length ); - ClassicAssert.AreEqual(1, insertRes); + ClassicAssert.AreEqual(NativeDiskANNMethods.DiskANNInsertResult.True, insertRes); } GC.KeepAlive(deleteDel); diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs index 524991d6336..71e7c33b403 100644 --- a/test/Garnet.test/GarnetServerConfigTests.cs +++ b/test/Garnet.test/GarnetServerConfigTests.cs @@ -1057,6 +1057,106 @@ public void MinimumPageSizeWithVectorSetPreview() } } + [Test] + public void VectorSetQuantizationTaskCount() + { + // Command line args + { + // Default accepted + { + var args = Array.Empty(); + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(0, options.VectorSetQuantizationTaskCount); + } + + // Switch is accepted + { + var args = new[] { "--vector-set-quantization-task-count", "1" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(1, options.VectorSetQuantizationTaskCount); + } + + // Zero is accepted + { + var args = new[] { "--vector-set-quantization-task-count", "0" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(0, options.VectorSetQuantizationTaskCount); + } + + // Invalid rejected + { + var args = new[] { "--vector-set-quantization-task-count", "foo" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Too low rejected + { + var args = new[] { "--vector-set-quantization-task-count", "-1" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Too high rejected + { + var args = new[] { "--vector-set-quantization-task-count", "2147483648" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsFalse(parseSuccessful); + } + } + + // JSON args + { + // Default accepted + { + const string JSON = @"{ }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(0, options.VectorSetQuantizationTaskCount); + } + + // Field is accepted + { + const string JSON = @"{ ""VectorSetQuantizationTaskCount"": 1 }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(1, options.VectorSetQuantizationTaskCount); + } + + // Zero is accepted + { + const string JSON = @"{ ""VectorSetQuantizationTaskCount"": 0 }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.AreEqual(0, options.VectorSetQuantizationTaskCount); + } + + // Invalid rejected + { + const string JSON = @"{ ""VectorSetQuantizationTaskCount"": ""foo"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Too low rejected + { + const string JSON = @"{ ""VectorSetQuantizationTaskCount"": -1 }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + + // Too high rejected + { + const string JSON = @"{ ""VectorSetQuantizationTaskCount"": 2147483648 }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } + } + /// /// Import a garnet.conf file with the given contents /// diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs index 77a37b5e889..7a5a0db9d42 100644 --- a/test/Garnet.test/TestUtils.cs +++ b/test/Garnet.test/TestUtils.cs @@ -682,6 +682,7 @@ public static GarnetServerOptions GetGarnetServerOptions( string clusterAnnounceHostname = null, bool enableVectorSetPreview = true, int vectorSetReplayTaskCount = 0, + int vectorSetQuantizationTaskCount = 0, int threadPoolMinIOCompletionThreads = 0) { if (useAzureStorage) @@ -808,6 +809,7 @@ public static GarnetServerOptions GetGarnetServerOptions( ReplicaSyncTimeout = replicaSyncTimeout <= 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromSeconds(replicaSyncTimeout), EnableVectorSetPreview = enableVectorSetPreview, VectorSetReplayTaskCount = vectorSetReplayTaskCount, + VectorSetQuantizationTaskCount = vectorSetQuantizationTaskCount, ExpiredObjectCollectionFrequencySecs = expiredObjectCollectionFrequencySecs, ThreadPoolMinIOCompletionThreads = threadPoolMinIOCompletionThreads, }; diff --git a/website/docs/dev/vector-sets.md b/website/docs/dev/vector-sets.md index 8e0893cd02a..332485ddf6e 100644 --- a/website/docs/dev/vector-sets.md +++ b/website/docs/dev/vector-sets.md @@ -397,10 +397,10 @@ Garnet calls into the following DiskANN functions: - [x] `nint create_index(ulong context, uint dimensions, uint reduceDims, VectorQuantType quantType, uint buildExplorationFactor, uint numLinks, nint readCallback, nint writeCallback, nint deleteCallback, nint readModifyWriteCallback)` - [x] `void drop_index(ulong context, nint index)` - - [x] `byte insert(ulong context, nint index, nint id_data, nuint id_len, VectorValueType vector_value_type, nint vector_data, nuint vector_len, nint attribute_data, nuint attribute_len)` + - [x] `byte insert(ulong context, nint index, nint id_data, nuint id_len, nint vector_data, nuint vector_len, nint attribute_data, nuint attribute_len)` - [x] `byte remove(ulong context, nint index, nint id_data, nuint id_len)` - [ ] `byte set_attribute(ulong context, nint index, nint id_data, nuint id_len, nint attribute_data, nuint attribute_len)` - - [x] `int search_vector(ulong context, nint index, VectorValueType vector_value_type, nint vector_data, nuint vector_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` + - [x] `int search_vector(ulong context, nint index, nint vector_data, nuint vector_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` - [x] `int search_element(ulong context, nint index, nint id_data, nuint id_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` - [ ] `int continue_search(ulong context, nint index, nint continuation, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint new_continuation)` - [ ] `ulong card(ulong context, nint index)` @@ -409,7 +409,7 @@ Garnet calls into the following DiskANN functions: Some non-obvious subtleties: - The number of results _requested_ from `search_vector` and `search_element` is indicated by `output_distances_len` - `output_distances_len` is the number of _floats_ in `output_distances`, not bytes - - When inserting, if `vector_value_type == FP32` then `vector_len` is the number of _floats_ in `vector_data`, otherwise it is the number of bytes + - When inserting and searching, the `VectorQuantType` used to create the index defines the expected format of `vector_data` and whether `vector_len` is counting bytes, floats, etc. - `byte` returning functions are effectively returning booleans, `0 == false` and `1 == true` - `index` is always a pointer created by DiskANN and returned from `create_index` - `context` is always the `Context` value created by Garnet and stored in [`Index`](#indexes) for a Vector Set, this implies it is always a non-0 multiple of 8