Skip to content

Commit 967eb54

Browse files
committed
Improve: Harden mini-float testing across SDKs
1 parent 91c08cc commit 967eb54

20 files changed

Lines changed: 428 additions & 155 deletions

File tree

README.md

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ Linux • macOS • Windows • iOS • Android • WebAssembly •
5151
- ✅ Simple and extensible [single C++11 header][usearch-header] __library__.
5252
-[Trusted](#integrations) by giants like Google and DBs like [ClickHouse][clickhouse-docs] & [DuckDB][duckdb-docs].
5353
-[SIMD][simd]-optimized and [user-defined metrics](#user-defined-functions) with JIT compilation.
54-
- ✅ Hardware-agnostic `f16` & `i8` - [half-precision & quarter-precision support](#memory-efficiency-downcasting-and-quantization).
54+
- ✅ Hardware-agnostic `bf16`, `e5m2`, & `i8` - [half-precision & quarter-precision support](#memory-efficiency-downcasting-and-quantization).
5555
-[View large indexes from disk](#serialization--serving-index-from-disk) without loading into RAM.
5656
- ✅ Heterogeneous lookups, renaming/relabeling, and on-the-fly deletions.
5757
- ✅ Binary Tanimoto and Sorensen coefficients for [Genomics and Chemistry applications](#usearch--rdkit--molecular-search).
@@ -138,7 +138,7 @@ The default storage/quantization level is hardware-dependant for efficiency, but
138138
index = Index(
139139
ndim=3, # Define the number of dimensions in input vectors
140140
metric='cos', # Choose 'l2sq', 'ip', 'haversine' or other metric, default = 'cos'
141-
dtype='bf16', # Store as 'f64', 'f32', 'f16', 'i8', 'b1'..., default = None
141+
dtype='bf16', # Store as 'f64', 'f32', 'bf16', 'f16', 'e5m2', 'e4m3', 'e3m2', 'e2m3', 'u8', 'i8', 'b1'..., default = None
142142
connectivity=16, # Optional: Limit number of neighbors per graph node
143143
expansion_add=128, # Optional: Control the recall of indexing
144144
expansion_search=64, # Optional: Control the quality of the search
@@ -251,7 +251,7 @@ assert!(
251251
Training a quantization model and dimension-reduction is a common approach to accelerate vector search.
252252
Those, however, are only sometimes reliable, can significantly affect the statistical properties of your data, and require regular adjustments if your distribution shifts.
253253
Instead, we have focused on high-precision arithmetic over low-precision downcasted vectors.
254-
The same index, and `add` and `search` operations will automatically down-cast or up-cast between `f64_t`, `f32_t`, `f16_t`, `i8_t`, and single-bit `b1x8_t` representations.
254+
The same index, and `add` and `search` operations will automatically down-cast or up-cast between `f64_t`, `f32_t`, `bf16_t`, `f16_t`, `e5m2_t`, `e4m3_t`, `e3m2_t`, `e2m3_t`, `u8_t`, `i8_t`, and single-bit `b1x8_t` representations.
255255
You can use the following command to check, if hardware acceleration is enabled:
256256

257257
```sh
@@ -261,7 +261,9 @@ $ python -c 'from usearch.index import Index; print(Index(ndim=166, metric="tani
261261
> ice
262262
```
263263

264-
In most cases, it's recommended to use half-precision floating-point numbers on modern hardware.
264+
In most cases, `bf16` is recommended for modern CPUs.
265+
For even smaller footprints, USearch supports IEEE & MX-compatible Float8 (`e5m2` and `e4m3`) and Float6 (`e3m2` and `e2m3`) formats.
266+
You can pass pre-quantized buffers from [NumKong](https://github.com/ashvardanian/numkong) with the explicit `dtype=` parameter on `add` and `search`, or let USearch handle the quantization internally from higher-precision inputs.
265267
When quantization is enabled, the "get"-like functions won't be able to recover the original data, so you may want to replicate the original vectors elsewhere.
266268
When quantizing to `i8_t` integers, note that it's only valid for cosine-like metrics.
267269
As part of the quantization process, the vectors are normalized to unit length and later scaled to [-127, 127] range to occupy the full 8-bit range.
@@ -479,46 +481,41 @@ The Haversine distance is available out of the box, but you can also define more
479481
from numba import cfunc, types, carray
480482
import math
481483

482-
# Define the dimension as 2 for latitude and longitude
483484
ndim = 2
485+
semi_major, flattening = 6378137.0, 1 / 298.257223563
486+
semi_minor = (1 - flattening) * semi_major
487+
488+
def vincenty_distance(first_ptr, second_ptr):
489+
first, second = carray(first_ptr, ndim), carray(second_ptr, ndim)
490+
lat1, lon1, lat2, lon2 = first[0], first[1], second[0], second[1]
491+
diff_lon = lon2 - lon1
492+
rlat1, rlat2 = math.atan((1 - flattening) * math.tan(lat1)), math.atan((1 - flattening) * math.tan(lat2))
493+
sin_rlat1, cos_rlat1 = math.sin(rlat1), math.cos(rlat1)
494+
sin_rlat2, cos_rlat2 = math.sin(rlat2), math.cos(rlat2)
495+
lon_on_sphere = diff_lon
496+
for _ in range(100):
497+
sin_lon, cos_lon = math.sin(lon_on_sphere), math.cos(lon_on_sphere)
498+
sin_ang = math.sqrt((cos_rlat2 * sin_lon) ** 2 + (cos_rlat1 * sin_rlat2 - sin_rlat1 * cos_rlat2 * cos_lon) ** 2)
499+
if sin_ang == 0: return 0.0
500+
cos_ang = sin_rlat1 * sin_rlat2 + cos_rlat1 * cos_rlat2 * cos_lon
501+
ang = math.atan2(sin_ang, cos_ang)
502+
sin_az = cos_rlat1 * cos_rlat2 * sin_lon / sin_ang
503+
cos2_az = 1 - sin_az ** 2
504+
cos2_mid = cos_ang - 2 * sin_rlat1 * sin_rlat2 / cos2_az if cos2_az != 0 else 0.0
505+
corr = flattening / 16 * cos2_az * (4 + flattening * (4 - 3 * cos2_az))
506+
prev = lon_on_sphere
507+
lon_on_sphere = diff_lon + (1 - corr) * flattening * (
508+
sin_az * (ang + corr * sin_ang * (cos2_mid + corr * cos_ang * (-1 + 2 * cos2_mid ** 2))))
509+
if abs(lon_on_sphere - prev) <= 1e-12: break
510+
else:
511+
return float('nan')
512+
u_sq = cos2_az * (semi_major ** 2 - semi_minor ** 2) / (semi_minor ** 2)
513+
ca = 1 + u_sq / 16384 * (4096 + u_sq * (-768 + u_sq * (320 - 175 * u_sq)))
514+
cb = u_sq / 1024 * (256 + u_sq * (-128 + u_sq * (74 - 47 * u_sq)))
515+
delta = cb * sin_ang * (cos2_mid + cb / 4 * (cos_ang * (-1 + 2 * cos2_mid ** 2)
516+
- cb / 6 * cos2_mid * (-3 + 4 * sin_ang ** 2) * (-3 + 4 * cos2_mid ** 2)))
517+
return semi_minor * ca * (ang - delta) / 1000.0
484518

485-
# Signature for the custom metric
486-
signature = types.float32(
487-
types.CPointer(types.float32),
488-
types.CPointer(types.float32))
489-
490-
# WGS-84 ellipsoid parameters
491-
a = 6378137.0 # major axis in meters
492-
f = 1 / 298.257223563 # flattening
493-
b = (1 - f) * a # minor axis
494-
495-
def vincenty_distance(a_ptr, b_ptr):
496-
a_array = carray(a_ptr, ndim)
497-
b_array = carray(b_ptr, ndim)
498-
lat1, lon1, lat2, lon2 = a_array[0], a_array[1], b_array[0], b_array[1]
499-
L, U1, U2 = lon2 - lon1, math.atan((1 - f) * math.tan(lat1)), math.atan((1 - f) * math.tan(lat2))
500-
sinU1, cosU1, sinU2, cosU2 = math.sin(U1), math.cos(U1), math.sin(U2), math.cos(U2)
501-
lambda_, iterLimit = L, 100
502-
while iterLimit > 0:
503-
iterLimit -= 1
504-
sinLambda, cosLambda = math.sin(lambda_), math.cos(lambda_)
505-
sinSigma = math.sqrt((cosU2 * sinLambda) ** 2 + (cosU1 * sinU2 - sinU1 * cosU2 * cosLambda) ** 2)
506-
if sinSigma == 0: return 0.0 # Co-incident points
507-
cosSigma, sigma = sinU1 * sinU2 + cosU1 * cosU2 * cosLambda, math.atan2(sinSigma, cosSigma)
508-
sinAlpha, cos2Alpha = cosU1 * cosU2 * sinLambda / sinSigma, 1 - (cosU1 * cosU2 * sinLambda / sinSigma) ** 2
509-
cos2SigmaM = cosSigma - 2 * sinU1 * sinU2 / cos2Alpha if not math.isnan(cosSigma - 2 * sinU1 * sinU2 / cos2Alpha) else 0 # Equatorial line
510-
C = f / 16 * cos2Alpha * (4 + f * (4 - 3 * cos2Alpha))
511-
lambda_, lambdaP = L + (1 - C) * f * (sinAlpha * (sigma + C * sinSigma * (cos2SigmaM + C * cosSigma * (-1 + 2 * cos2SigmaM ** 2)))), lambda_
512-
if abs(lambda_ - lambdaP) <= 1e-12: break
513-
if iterLimit == 0: return float('nan') # formula failed to converge
514-
u2 = cos2Alpha * (a ** 2 - b ** 2) / (b ** 2)
515-
A = 1 + u2 / 16384 * (4096 + u2 * (-768 + u2 * (320 - 175 * u2)))
516-
B = u2 / 1024 * (256 + u2 * (-128 + u2 * (74 - 47 * u2)))
517-
deltaSigma = B * sinSigma * (cos2SigmaM + B / 4 * (cosSigma * (-1 + 2 * cos2SigmaM ** 2) - B / 6 * cos2SigmaM * (-3 + 4 * sinSigma ** 2) * (-3 + 4 * cos2SigmaM ** 2)))
518-
s = b * A * (sigma - deltaSigma)
519-
return s / 1000.0 # Distance in kilometers
520-
521-
# Example usage:
522519
index = Index(ndim=ndim, metric=CompiledMetric(
523520
pointer=vincenty_distance.address,
524521
kind=MetricKind.Haversine,

c/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ int main() {
1919
usearch_error_t error = NULL;
2020
usearch_init_options_t opts = {
2121
.metric_kind = usearch_metric_cos_k,
22-
.scalar_kind = usearch_scalar_f16_k,
22+
.scalar_kind = usearch_scalar_f16_k, // or f32_k, bf16_k, e5m2_k, e4m3_k, e3m2_k, e2m3_k, i8_k, u8_k
2323
.dimensions = dimensions,
2424
.expansion_add = 0, // for defaults
2525
.expansion_search = 0 // for defaults

c/test.c

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,46 @@ void test_view(size_t const collection_size, size_t const dimensions) {
448448
printf("Test: View - PASSED\n");
449449
}
450450

451+
void test_mini_float_quantizations(size_t const collection_size, size_t const dimensions) {
452+
printf("Test: Mini-float quantizations... %zu vectors, %zu dimensions\n", collection_size, dimensions);
453+
usearch_scalar_kind_t kinds[] = {
454+
usearch_scalar_e5m2_k,
455+
usearch_scalar_e4m3_k,
456+
usearch_scalar_e3m2_k,
457+
usearch_scalar_e2m3_k,
458+
};
459+
float* data = create_vectors(collection_size, dimensions);
460+
usearch_key_t* keys = (usearch_key_t*)malloc(collection_size * sizeof(usearch_key_t));
461+
float* distances = (float*)malloc(collection_size * sizeof(float));
462+
expect(keys && distances, "Failed to allocate memory");
463+
464+
for (size_t k = 0; k < sizeof(kinds) / sizeof(kinds[0]); ++k) {
465+
usearch_error_t error = NULL;
466+
usearch_init_options_t opts = create_options(dimensions);
467+
opts.quantization = kinds[k];
468+
usearch_index_t index = usearch_init(&opts, &error);
469+
expect(!error, error);
470+
usearch_reserve(index, collection_size, &error);
471+
expect(!error, error);
472+
for (size_t i = 0; i < collection_size; ++i) {
473+
usearch_add(index, (usearch_key_t)i, data + i * dimensions, usearch_scalar_f32_k, &error);
474+
expect(!error, error);
475+
}
476+
expect_eq(usearch_size(index, &error), collection_size, error);
477+
for (size_t i = 0; i < collection_size; ++i) {
478+
size_t found =
479+
usearch_search(index, data + i * dimensions, usearch_scalar_f32_k, 1, keys, distances, &error);
480+
expect(!error, error);
481+
expect(found >= 1, "Vector not found");
482+
}
483+
usearch_free(index, &error);
484+
}
485+
free(data);
486+
free(keys);
487+
free(distances);
488+
printf("Test: Mini-float quantizations - PASSED\n");
489+
}
490+
451491
int main(int argc, char const* argv[]) {
452492
install_crash_handlers();
453493
printf("Running tests...\n");
@@ -464,6 +504,7 @@ int main(int argc, char const* argv[]) {
464504
test_remove_vector(collection_sizes[index], dimensions[jdx]);
465505
test_save_load(collection_sizes[index], dimensions[jdx]);
466506
test_view(collection_sizes[index], dimensions[jdx]);
507+
test_mini_float_quantizations(collection_sizes[index], dimensions[jdx]);
467508
}
468509
}
469510

csharp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using Cloud.Unum.USearch;
1414

1515
using var index = new USearchIndex(
1616
metricKind: MetricKind.Cos, // Choose cosine metric
17-
quantization: ScalarKind.Float32, // Only quantization to Float32, Float64, Int8 is currently supported
17+
quantization: ScalarKind.Float32, // or Float64, BFloat16, Float16, E5M2, E4M3, E3M2, E2M3, Int8, UInt8
1818
dimensions: 3, // Define the number of dimensions in input vectors
1919
connectivity: 16, // How frequent should the connections in the graph be, optional
2020
expansionAdd: 128, // Control the recall of indexing, optional

csharp/src/Cloud.Unum.USearch.Tests/USearchIndexTests.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,31 @@ public void Add_FloatVector_UpdatesIndexOptions()
165165
}
166166
}
167167

168+
[Fact]
169+
public void Add_FloatVector_MiniFloatQuantizations()
170+
{
171+
ScalarKind[] kinds = { ScalarKind.E5M2, ScalarKind.E4M3, ScalarKind.E3M2, ScalarKind.E2M3 };
172+
foreach (var kind in kinds)
173+
{
174+
var indexOptions = new IndexOptions(
175+
metricKind: MetricKind.Cos,
176+
quantization: kind,
177+
dimensions: 64
178+
);
179+
var vector = GenerateFloatVector(64);
180+
using (var index = new USearchIndex(indexOptions))
181+
{
182+
index.Add(1, vector);
183+
Assert.True(index.Contains(1));
184+
Assert.Equal(1u, index.Size());
185+
186+
int found = index.Search(vector, 1, out ulong[] keys, out float[] distances);
187+
Assert.Equal(1, found);
188+
Assert.Equal(1UL, keys[0]);
189+
}
190+
}
191+
}
192+
168193
[Fact]
169194
public void Add_ByteVector_UpdatesIndexOptions()
170195
{

golang/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ func main() {
5858
vectorSize := 3
5959
vectorsCount := 100
6060
conf := usearch.DefaultConfig(uint(vectorSize))
61+
conf.Quantization = usearch.F32 // or BF16, F16, E5M2, E4M3, E3M2, E2M3, I8, U8
6162
index, err := usearch.NewIndex(conf)
6263
if err != nil {
6364
panic("Failed to create Index")

golang/lib.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ func (index *Index) FilteredSearch(query []float32, limit uint, handler *Filtere
769769
distances = make([]float32, limit)
770770
var errorMessage *C.char
771771
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_f32_k, (C.size_t)(limit),
772-
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
772+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler), //nolint:govet // handler is kept alive by the caller
773773
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
774774
runtime.KeepAlive(query)
775775
runtime.KeepAlive(keys)
@@ -856,7 +856,7 @@ func (index *Index) FilteredSearchUnsafe(query unsafe.Pointer, limit uint, handl
856856
distances = make([]float32, limit)
857857
var errorMessage *C.char
858858
resultCount := uint(C.usearch_filtered_search(index.handle, query, index.config.Quantization.CValue(), (C.size_t)(limit),
859-
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
859+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler), //nolint:govet // handler is kept alive by the caller
860860
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
861861
runtime.KeepAlive(query)
862862
runtime.KeepAlive(keys)
@@ -1041,7 +1041,7 @@ func (index *Index) FilteredSearchI8(query []int8, limit uint, handler *Filtered
10411041
distances = make([]float32, limit)
10421042
var errorMessage *C.char
10431043
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_i8_k, (C.size_t)(limit),
1044-
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
1044+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler), //nolint:govet // handler is kept alive by the caller
10451045
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
10461046
runtime.KeepAlive(query)
10471047
runtime.KeepAlive(keys)
@@ -1134,7 +1134,7 @@ func (index *Index) FilteredSearchU8(query []uint8, limit uint, handler *Filtere
11341134
distances = make([]float32, limit)
11351135
var errorMessage *C.char
11361136
resultCount := uint(C.usearch_filtered_search(index.handle, unsafe.Pointer(&query[0]), C.usearch_scalar_u8_k, (C.size_t)(limit),
1137-
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler),
1137+
(C.usearch_filtered_search_callback_t)(C.goFilteredSearchCallback), unsafe.Pointer(handler), //nolint:govet // handler is kept alive by the caller
11381138
(*C.usearch_key_t)(&keys[0]), (*C.usearch_distance_t)(&distances[0]), (*C.usearch_error_t)(&errorMessage)))
11391139
runtime.KeepAlive(query)
11401140
runtime.KeepAlive(keys)

golang/lib_test.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package usearch
22

33
import (
44
"errors"
5+
"fmt"
56
"io"
67
"math"
78
"runtime"
@@ -724,6 +725,32 @@ func TestQuantizationTypes(t *testing.T) {
724725
t.Fatalf("U8 Get returned wrong dimensions: got %d, expected 32", len(retrieved))
725726
}
726727
})
728+
729+
for _, qt := range []Quantization{E5M2, E4M3, E3M2, E2M3} {
730+
qt := qt
731+
t.Run(fmt.Sprintf("%v mini-float operations", qt), func(t *testing.T) {
732+
index := createTestIndex(t, 32, qt)
733+
defer func() {
734+
if err := index.Destroy(); err != nil {
735+
t.Errorf("Failed to destroy index: %v", err)
736+
}
737+
}()
738+
if err := index.Reserve(1); err != nil {
739+
t.Fatalf("Failed to reserve: %v", err)
740+
}
741+
vector := generateTestVector(32)
742+
if err := index.Add(1, vector); err != nil {
743+
t.Fatalf("Add failed: %v", err)
744+
}
745+
keys, _, err := index.Search(vector, 1)
746+
if err != nil {
747+
t.Fatalf("Search failed: %v", err)
748+
}
749+
if len(keys) == 0 || keys[0] != 1 {
750+
t.Fatalf("search results incorrect")
751+
}
752+
})
753+
}
727754
}
728755

729756
func TestUnsafeOperations(t *testing.T) {
@@ -1114,7 +1141,7 @@ func TestVersion(t *testing.T) {
11141141

11151142
func TestClear(t *testing.T) {
11161143
index := createTestIndex(t, 32, F32)
1117-
defer index.Destroy()
1144+
defer func() { _ = index.Destroy() }()
11181145

11191146
if err := index.Reserve(10); err != nil {
11201147
t.Fatalf("Failed to reserve capacity: %v", err)
@@ -1155,7 +1182,7 @@ func TestClear(t *testing.T) {
11551182

11561183
func TestCount(t *testing.T) {
11571184
index := createTestIndex(t, 32, F32)
1158-
defer index.Destroy()
1185+
defer func() { _ = index.Destroy() }()
11591186

11601187
if err := index.Reserve(10); err != nil {
11611188
t.Fatalf("Failed to reserve capacity: %v", err)
@@ -1188,7 +1215,7 @@ func TestCount(t *testing.T) {
11881215

11891216
func TestRename(t *testing.T) {
11901217
index := createTestIndex(t, 32, F32)
1191-
defer index.Destroy()
1218+
defer func() { _ = index.Destroy() }()
11921219

11931220
if err := index.Reserve(10); err != nil {
11941221
t.Fatalf("Failed to reserve capacity: %v", err)

java/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ long connectivity = index.connectivity();
104104

105105
## Multiple Data Types and Quantization
106106

107-
USearch supports hardware-agnostic `f64`, `f32`, and `i8` quantization for memory efficiency and performance optimization.
107+
USearch supports hardware-agnostic `f64`, `f32`, `bf16`, `f16`, `e5m2`, `e4m3`, `e3m2`, `e2m3`, `i8`, and `b1` quantization for memory efficiency and performance optimization.
108108

109109
```java
110110
// Double precision (f64) for highest accuracy

java/test/IndexTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,22 @@ public void testByteVectorWithInt8() {
229229
}
230230
}
231231

232+
@Test
233+
public void testMiniFloatQuantizations() {
234+
for (String quantization : new String[]{"e5m2", "e4m3", "e3m2", "e2m3"}) {
235+
try (Index index = new Index.Config()
236+
.metric("cos").dimensions(64).quantization(quantization).build()) {
237+
float[] vec = new float[64];
238+
for (int i = 0; i < 64; i++) vec[i] = (float) i * 0.1f;
239+
index.reserve(10);
240+
index.add(42, vec);
241+
242+
long[] keys = index.search(vec, 1);
243+
assertEquals("Self-match failed for " + quantization, 42L, keys[0]);
244+
}
245+
}
246+
}
247+
232248
@Test
233249
public void testGetIntoBufferMethods() {
234250
try (Index index = new Index.Config().metric("cos").dimensions(3).build()) {

0 commit comments

Comments
 (0)