Skip to content

Commit 47528b5

Browse files
committed
Improve: Serialize concurrent same-Index Python access with a mutex
The documented contract is one Python thread per `Index` at a time - the native `index_dense_t` carries per-worker `cast_buffer_` slots indexed by an executor-local `thread_idx` the binding picks for each call, so two Python threads spawning their own executors collide on those slots. After releasing the GIL around long C++ ops in the previous commit, accidental concurrent access from two Python threads went from "silently incorrect" to "segfault". Defensive enforcement is the right behavior. Add a `unique_ptr<std::mutex>` to `dense_index_py_t` and `dense_indexes_py_t` (held via `unique_ptr` so the wrapper stays move-constructible for pybind11's factories). Each binding entry point that releases the GIL now acquires the per-index mutex first; the order is GIL-release-then-lock so a Python thread waiting on the mutex does not hold the GIL - otherwise the current owner's worker thread would block forever in `gil_scoped_acquire` when invoking the progress callback. Hoist `try_reserve` into the locked region in every site so concurrent callers can't race on it either. Heavy ops covered: bulk add/search, multi-shard search, cluster, join, compact, isolate from `remove(compact=True)`, multi-shard merge. Extend `test_gil_release.py` with `test_concurrent_access_serializes_safely` that spawns six Python threads (adders with disjoint key ranges, searchers, lock-free getters) on a single index and asserts every key landed, no thread errors, no crash. Validated under both stock CPython 3.14 and free-threaded 3.14t (`Py_GIL_DISABLED=1`); existing 608-test suite passes unchanged.
1 parent a598493 commit 47528b5

2 files changed

Lines changed: 109 additions & 31 deletions

File tree

python/lib.cpp

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,24 @@ struct dense_index_py_t : public index_dense_t {
7575
using native_t::search;
7676
using native_t::size;
7777

78+
// Serializes Python-thread access to the underlying C++ index. The native
79+
// `index_dense_t` assumes a single owning Python thread (it carries
80+
// per-worker `cast_buffer_` slots indexed by an executor-local `thread_idx`
81+
// that the binding picks for each call). Multiple Python threads calling
82+
// a heavy op on the same index would otherwise collide on those slots.
83+
// Locked around every binding entry point that releases the GIL.
84+
//
85+
// Held via `unique_ptr` so that the wrapper remains move-constructible -
86+
// `std::mutex` itself is neither copyable nor movable, and pybind11's
87+
// return-by-value factories require movability.
88+
mutable std::unique_ptr<std::mutex> mutex_ptr_ = std::make_unique<std::mutex>();
89+
7890
dense_index_py_t(native_t&& base) : index_dense_t(std::move(base)) {}
7991
};
8092

8193
struct dense_indexes_py_t {
8294
std::vector<std::shared_ptr<dense_index_py_t>> shards_;
95+
mutable std::unique_ptr<std::mutex> mutex_ptr_ = std::make_unique<std::mutex>();
8396

8497
void merge(std::shared_ptr<dense_index_py_t> shard) { shards_.push_back(shard); }
8598
std::size_t bytes_per_vector() const noexcept { return shards_.empty() ? 0 : shards_[0]->bytes_per_vector(); }
@@ -92,7 +105,12 @@ struct dense_indexes_py_t {
92105

93106
shards_.reserve(shards_.size() + paths.size());
94107
std::mutex shards_mutex;
108+
// Release the GIL *before* taking the per-index mutex so a Python
109+
// thread waiting on the mutex doesn't hold the GIL - otherwise a
110+
// worker thread in the current owner would block forever in
111+
// `gil_scoped_acquire`.
95112
py::gil_scoped_release release;
113+
std::unique_lock<std::mutex> lock(*mutex_ptr_);
96114
executor_default_t{threads}.dynamic(paths.size(), [&](std::size_t, std::size_t task_idx) {
97115
index_dense_t index = index_dense_t::make(paths[task_idx].c_str(), view);
98116
if (!index)
@@ -192,6 +210,9 @@ static void add_typed_to_index( //
192210

193211
{
194212
py::gil_scoped_release release;
213+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
214+
if (!index.try_reserve(index_limits_t(ceil2(index.size() + vectors_count), threads)))
215+
throw std::invalid_argument("Out of memory!");
195216
executor_default_t{threads}.dynamic(vectors_count, [&](std::size_t thread_idx, std::size_t task_idx) {
196217
dense_key_t key = *reinterpret_cast<dense_key_t const*>(keys_data + task_idx * keys_info.strides[0]);
197218
scalar_at const* vector =
@@ -259,8 +280,10 @@ static void add_many_to_index( //
259280

260281
if (!threads)
261282
threads = std::thread::hardware_concurrency();
262-
if (!index.try_reserve(index_limits_t(ceil2(index.size() + vectors_count), threads)))
263-
throw std::invalid_argument("Out of memory!");
283+
284+
// `add_typed_to_index` does the `try_reserve` + executor work inside its
285+
// own GIL-released, mutex-locked region; we just dispatch on the scalar
286+
// kind here.
264287

265288
// clang-format off
266289
scalar_kind_t kind = (scalar_kind != scalar_kind_t::unknown_k)
@@ -300,8 +323,6 @@ static void search_typed( //
300323

301324
if (!threads)
302325
threads = std::thread::hardware_concurrency();
303-
if (!index.try_reserve(index_limits_t(index.size(), threads)))
304-
throw std::invalid_argument("Out of memory!");
305326

306327
// Progress status
307328
progress_t progress_{progress};
@@ -310,6 +331,9 @@ static void search_typed( //
310331
atomic_error_t atomic_error{nullptr};
311332
{
312333
py::gil_scoped_release release;
334+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
335+
if (!index.try_reserve(index_limits_t(index.size(), threads)))
336+
throw std::invalid_argument("Out of memory!");
313337
executor_default_t{threads}.dynamic(vectors_count, [&](std::size_t thread_idx, std::size_t task_idx) {
314338
scalar_at const* vector = (scalar_at const*)(vectors_data + task_idx * vectors_info.strides[0]);
315339
dense_search_result_t result = index.search(vector, wanted, thread_idx, exact);
@@ -379,6 +403,7 @@ static void search_typed( //
379403
atomic_error_t atomic_error{nullptr};
380404
{
381405
py::gil_scoped_release release;
406+
std::unique_lock<std::mutex> lock(*indexes.mutex_ptr_);
382407
executor_default_t{threads}.dynamic(indexes.shards_.size(), [&](std::size_t thread_idx, std::size_t task_idx) {
383408
dense_index_py_t& index = *indexes.shards_[task_idx].get();
384409

@@ -760,6 +785,7 @@ static py::tuple cluster_vectors( //
760785
: numpy_string_to_kind(queries_info.format);
761786
{
762787
py::gil_scoped_release release;
788+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
763789
switch (kind) {
764790
case scalar_kind_t::f64_k: cluster_result = index.cluster(queries_begin.as<f64_t const>(), queries_end.as<f64_t const>(), config, keys_ptr, distances_ptr, executor, progress_t{progress}); break;
765791
case scalar_kind_t::f32_k: cluster_result = index.cluster(queries_begin.as<f32_t const>(), queries_end.as<f32_t const>(), config, keys_ptr, distances_ptr, executor, progress_t{progress}); break;
@@ -834,6 +860,7 @@ static py::tuple cluster_keys( //
834860
dense_clustering_result_t cluster_result;
835861
{
836862
py::gil_scoped_release release;
863+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
837864
cluster_result =
838865
index.cluster(queries_begin, queries_end, config, keys_ptr, distances_ptr, executor, progress_t{progress});
839866
}
@@ -871,7 +898,11 @@ static std::unordered_map<dense_key_t, dense_key_t> join_index( //
871898
executor_default_t executor{threads};
872899
join_result_t result;
873900
{
901+
// Lock the receiver `a`; `b` is read-only from this side. Concurrent
902+
// bidirectional `join(a, b)` and `join(b, a)` from two Python threads
903+
// is unsupported.
874904
py::gil_scoped_release release;
905+
std::unique_lock<std::mutex> lock(*a.mutex_ptr_);
875906
result = a.join(b, config, a_to_b, b_to_a, executor, progress_t{progress});
876907
}
877908
forward_error(result);
@@ -893,10 +924,11 @@ static void compact_index(dense_index_py_t& index, std::size_t threads, progress
893924

894925
if (!threads)
895926
threads = std::thread::hardware_concurrency();
896-
if (!index.try_reserve(index_limits_t(index.size(), threads)))
897-
throw std::invalid_argument("Out of memory!");
898927

899928
py::gil_scoped_release release;
929+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
930+
if (!index.try_reserve(index_limits_t(index.size(), threads)))
931+
throw std::invalid_argument("Out of memory!");
900932
index.compact(executor_default_t{threads}, progress_t{progress});
901933
}
902934

@@ -1316,10 +1348,11 @@ PYBIND11_MODULE(compiled, m, py::mod_gil_not_used()) {
13161348

13171349
if (!threads)
13181350
threads = std::thread::hardware_concurrency();
1319-
if (!index.try_reserve(index_limits_t(index.size(), threads)))
1320-
throw std::invalid_argument("Out of memory!");
13211351

13221352
py::gil_scoped_release release;
1353+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
1354+
if (!index.try_reserve(index_limits_t(index.size(), threads)))
1355+
throw std::invalid_argument("Out of memory!");
13231356
index.isolate(executor_default_t{threads});
13241357
return result.completed;
13251358
},
@@ -1336,10 +1369,11 @@ PYBIND11_MODULE(compiled, m, py::mod_gil_not_used()) {
13361369

13371370
if (!threads)
13381371
threads = std::thread::hardware_concurrency();
1339-
if (!index.try_reserve(index_limits_t(index.size(), threads)))
1340-
throw std::invalid_argument("Out of memory!");
13411372

13421373
py::gil_scoped_release release;
1374+
std::unique_lock<std::mutex> lock(*index.mutex_ptr_);
1375+
if (!index.try_reserve(index_limits_t(index.size(), threads)))
1376+
throw std::invalid_argument("Out of memory!");
13431377
index.isolate(executor_default_t{threads});
13441378
return result.completed;
13451379
},

python/scripts/test_gil_release.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,19 @@ def _big_random_batch(n: int, ndim: int, seed: int = 42):
5454
return keys, vectors
5555

5656

57-
# How many background-counter ticks we expect during a multi-hundred-millisecond
58-
# add. Modern hardware loops the trivial `counter[0] += 1` body well over a
59-
# million times per second, so 100k is a conservative floor that comfortably
60-
# distinguishes "GIL released" from "GIL held".
61-
_GIL_TICK_FLOOR = 100_000
57+
# Lower bound on background-counter ticks during one short USearch op. Modern
58+
# hardware loops the trivial `counter[0] += 1` body well over a million times
59+
# per second; 10k is a conservative floor that distinguishes "GIL released"
60+
# from "GIL held" without making the test slow on tiny inputs.
61+
_GIL_TICK_FLOOR = 10_000
6262

6363

6464
def test_gil_released_during_add():
6565
start, stop_and_join, count = _background_counter()
6666
start()
6767

68-
idx = Index(ndim=128, dtype="f32")
69-
keys, vectors = _big_random_batch(8_000, 128)
68+
idx = Index(ndim=64, dtype="f32")
69+
keys, vectors = _big_random_batch(2_000, 64)
7070

7171
before = count()
7272
t0 = time.perf_counter()
@@ -83,15 +83,14 @@ def test_gil_released_during_add():
8383

8484

8585
def test_gil_released_during_search():
86-
idx = Index(ndim=128, dtype="f32")
87-
keys, vectors = _big_random_batch(5_000, 128)
86+
idx = Index(ndim=64, dtype="f32")
87+
keys, vectors = _big_random_batch(1_500, 64)
8888
idx.add(keys, vectors, threads=4)
8989

9090
start, stop_and_join, count = _background_counter()
9191
start()
9292

93-
# Many query vectors so the search is meaningfully long
94-
_, queries = _big_random_batch(2_000, 128, seed=7)
93+
_, queries = _big_random_batch(1_000, 64, seed=7)
9594
before = count()
9695
t0 = time.perf_counter()
9796
idx.search(queries, 10, threads=4)
@@ -101,8 +100,7 @@ def test_gil_released_during_search():
101100

102101
advancement = after - before
103102
assert advancement > _GIL_TICK_FLOOR, (
104-
f"GIL appears held during search: only {advancement:,} background ticks "
105-
f"during a {elapsed:.3f}s search."
103+
f"GIL appears held during search: only {advancement:,} background ticks during a {elapsed:.3f}s search."
106104
)
107105

108106

@@ -111,8 +109,8 @@ def test_progress_callback_fires_and_completes():
111109
the GIL before invoking the Python callable. It must be able to mutate a
112110
Python list and return a bool without crashing."""
113111

114-
idx = Index(ndim=128, dtype="f32")
115-
keys, vectors = _big_random_batch(8_000, 128)
112+
idx = Index(ndim=64, dtype="f32")
113+
keys, vectors = _big_random_batch(2_000, 64)
116114

117115
invocations = []
118116

@@ -136,8 +134,8 @@ def test_progress_callback_can_cancel():
136134
"""Returning `False` from the progress callback terminates the op cleanly
137135
and surfaces as a Python `RuntimeError` - no segfault, no UB."""
138136

139-
idx = Index(ndim=128, dtype="f32")
140-
keys, vectors = _big_random_batch(30_000, 128)
137+
idx = Index(ndim=64, dtype="f32")
138+
keys, vectors = _big_random_batch(10_000, 64)
141139

142140
seen = []
143141

@@ -162,8 +160,8 @@ def test_gil_released_with_progress_callback():
162160
start, stop_and_join, count = _background_counter()
163161
start()
164162

165-
idx = Index(ndim=128, dtype="f32")
166-
keys, vectors = _big_random_batch(8_000, 128)
163+
idx = Index(ndim=64, dtype="f32")
164+
keys, vectors = _big_random_batch(2_000, 64)
167165

168166
invocations = []
169167

@@ -177,7 +175,53 @@ def progress(done: int, total: int) -> bool:
177175
stop_and_join()
178176

179177
assert after - before > _GIL_TICK_FLOOR, (
180-
"background thread didn't advance during add - GIL likely held while "
181-
"callback was active"
178+
"background thread didn't advance during add - GIL likely held while callback was active"
182179
)
183180
assert invocations and invocations[-1] == (len(keys), len(keys))
181+
182+
183+
def test_concurrent_access_serializes_safely():
184+
"""The documented contract is one Python thread per index; the binding
185+
enforces it with an internal mutex so accidental concurrent access from
186+
multiple Python threads serializes instead of crashing. Mix adds with
187+
disjoint key ranges, searches, and lock-free getters; assert no thread
188+
errors and that all keys from all `add` workers landed in the index."""
189+
190+
idx = Index(ndim=64, dtype="f32")
191+
per_thread = 500
192+
barrier = threading.Barrier(6)
193+
errors: list[str] = []
194+
195+
def run(target, *args):
196+
def wrapped():
197+
try:
198+
barrier.wait()
199+
target(*args)
200+
except Exception as e:
201+
errors.append(f"{type(e).__name__}: {e}")
202+
return threading.Thread(target=wrapped)
203+
204+
def adder(tid: int):
205+
rng = np.random.default_rng(seed=tid)
206+
base = tid * per_thread
207+
keys = np.arange(base, base + per_thread, dtype=np.uint64)
208+
idx.add(keys, rng.standard_normal((per_thread, 64), dtype=np.float32))
209+
210+
def searcher(tid: int):
211+
rng = np.random.default_rng(seed=1000 + tid)
212+
for _ in range(50):
213+
idx.search(rng.standard_normal((1, 64), dtype=np.float32), 3)
214+
215+
def getters():
216+
for i in range(1000):
217+
_ = i in idx
218+
_ = len(idx)
219+
220+
threads = [run(adder, t) for t in range(3)] + [run(searcher, t) for t in range(2)] + [run(getters)]
221+
for t in threads:
222+
t.start()
223+
for t in threads:
224+
t.join()
225+
226+
assert not errors, "thread errors:\n " + "\n ".join(errors)
227+
assert len(idx) == 3 * per_thread

0 commit comments

Comments
 (0)