hnswlib 0.6.2 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc8a40172623a3439b17c8b1854017c845bf4dc23f4e1714e9427e4b6b701c9b
4
- data.tar.gz: 95fa48d791f000bf48809e7067efaba9a654ab0dd517c626f689c66af3668d0e
3
+ metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
+ data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
5
5
  SHA512:
6
- metadata.gz: 07c449031cc7afbf80803ae15c8094589a146b59b1852ea72e88fcd39067def572bbdd58277e3a85c47c537e8dc6eef645f9f404a2b874e4e217b559adc38054
7
- data.tar.gz: 4b71a7f4fe0ffb3c15124e32c9c8fbba9eb6f82e6f0b95f5d91dd35c6b10be8aeba25d7313eabeaa5078b051b3410195415843ecd600b58ec1b0b36532b7ae1d
6
+ metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
+ data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [0.7.0] - 2023-03-04
2
+
3
+ - Update bundled hnswlib version to 0.7.0.
4
+ - Add support for replacing an element marked for deletion with a new element.
5
+ - Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
6
+
1
7
  ## [0.6.2] - 2022-06-25
2
8
 
3
9
  - Refactor codes and configs with RuboCop and clang-format.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,9 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <new>
27
+ #include <vector>
28
+
26
29
  VALUE rb_cHnswlibL2Space;
27
30
  VALUE rb_cHnswlibInnerProductSpace;
28
31
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -64,12 +67,12 @@ private:
64
67
  static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
68
  rb_iv_set(self, "@dim", dim);
66
69
  hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
70
+ new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
68
71
  return Qnil;
69
72
  };
70
73
 
71
74
  static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
73
76
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
74
77
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
75
78
  return Qnil;
@@ -79,9 +82,9 @@ private:
79
82
  return Qnil;
80
83
  }
81
84
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
82
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
85
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
83
86
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
87
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
85
88
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
86
89
  const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
87
90
  ruby_xfree(vec_a);
@@ -140,12 +143,12 @@ private:
140
143
  static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
141
144
  rb_iv_set(self, "@dim", dim);
142
145
  hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
143
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
146
+ new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
144
147
  return Qnil;
145
148
  };
146
149
 
147
150
  static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
148
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
149
152
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
150
153
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
151
154
  return Qnil;
@@ -155,9 +158,9 @@ private:
155
158
  return Qnil;
156
159
  }
157
160
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
158
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
161
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
159
162
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
163
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
161
164
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
162
165
  const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
163
166
  ruby_xfree(vec_a);
@@ -180,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
180
183
  };
181
184
  // clang-format on
182
185
 
186
+ class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
187
+ public:
188
+ CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
189
+
190
+ bool operator()(hnswlib::labeltype id) {
191
+ VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
192
+ return result == Qtrue ? true : false;
193
+ }
194
+
195
+ private:
196
+ VALUE callback_;
197
+ };
198
+
183
199
  class RbHnswlibHierarchicalNSW {
184
200
  public:
185
201
  static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
@@ -206,17 +222,21 @@ public:
206
222
  rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
207
223
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
208
224
  rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
225
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
211
227
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
228
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
213
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
214
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
215
231
  rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
232
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
216
233
  rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
217
234
  rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
235
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
218
236
  rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
219
237
  rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
238
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
239
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
220
240
  rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
221
241
  return rb_cHnswlibHierarchicalNSW;
222
242
  };
@@ -226,14 +246,15 @@ private:
226
246
 
227
247
  static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
228
248
  VALUE kw_args = Qnil;
229
- ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
230
- rb_intern("random_seed")};
231
- VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
249
+ ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
+ rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
+ VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
232
252
  rb_scan_args(argc, argv, ":", &kw_args);
233
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
234
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
235
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
236
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
253
+ rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
+ if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
+ if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
237
258
 
238
259
  if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
239
260
  rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
@@ -256,6 +277,10 @@ private:
256
277
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
257
278
  return Qnil;
258
279
  }
280
+ if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
281
+ rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
+ return Qnil;
283
+ }
259
284
 
260
285
  rb_iv_set(self, "@space", kw_values[0]);
261
286
  hnswlib::SpaceInterface<float>* space;
@@ -264,14 +289,15 @@ private:
264
289
  } else {
265
290
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
266
291
  }
267
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
268
- const size_t m = (size_t)NUM2INT(kw_values[2]);
269
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
270
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
292
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
293
+ const size_t m = NUM2SIZET(kw_values[2]);
294
+ const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
+ const size_t random_seed = NUM2SIZET(kw_values[4]);
296
+ const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
271
297
 
272
298
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
273
299
  try {
274
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
300
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
275
301
  } catch (const std::runtime_error& e) {
276
302
  rb_raise(rb_eRuntimeError, "%s", e.what());
277
303
  return Qnil;
@@ -280,33 +306,63 @@ private:
280
306
  return Qnil;
281
307
  };
282
308
 
283
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
284
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
309
+ static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
310
+ VALUE _arr, _idx, _replace_deleted;
311
+ VALUE kw_args = Qnil;
312
+ ID kw_table[1] = {rb_intern("replace_deleted")};
313
+ VALUE kw_values[1] = {Qundef};
314
+
315
+ rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
316
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
317
+ _replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
285
318
 
286
- if (!RB_TYPE_P(arr, T_ARRAY)) {
319
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
320
+
321
+ if (!RB_TYPE_P(_arr, T_ARRAY)) {
287
322
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
288
323
  return Qfalse;
289
324
  }
290
- if (!RB_INTEGER_TYPE_P(idx)) {
325
+ if (!RB_INTEGER_TYPE_P(_idx)) {
291
326
  rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
292
327
  return Qfalse;
293
328
  }
294
- if (dim != RARRAY_LEN(arr)) {
329
+ if (dim != RARRAY_LEN(_arr)) {
295
330
  rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
296
331
  return Qfalse;
297
332
  }
333
+ if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
334
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
335
+ return Qfalse;
336
+ }
298
337
 
299
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
300
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
338
+ float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
+ for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
+ const size_t idx = NUM2SIZET(_idx);
341
+ const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
301
342
 
302
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
343
+ try {
344
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
345
+ } catch (const std::runtime_error& e) {
346
+ ruby_xfree(arr);
347
+ rb_raise(rb_eRuntimeError, "%s", e.what());
348
+ return Qfalse;
349
+ }
303
350
 
304
- ruby_xfree(vec);
351
+ ruby_xfree(arr);
305
352
  return Qtrue;
306
353
  };
307
354
 
308
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
309
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
355
+ static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
356
+ VALUE arr, k, filter;
357
+ VALUE kw_args = Qnil;
358
+ ID kw_table[1] = {rb_intern("filter")};
359
+ VALUE kw_values[1] = {Qundef};
360
+
361
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
362
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
363
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
364
+
365
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
310
366
 
311
367
  if (!RB_TYPE_P(arr, T_ARRAY)) {
312
368
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -321,14 +377,24 @@ private:
321
377
  return Qnil;
322
378
  }
323
379
 
380
+ CustomFilterFunctor* filter_func = nullptr;
381
+ if (!NIL_P(filter)) {
382
+ try {
383
+ filter_func = new CustomFilterFunctor(filter);
384
+ } catch (const std::bad_alloc& e) {
385
+ rb_raise(rb_eRuntimeError, "%s", e.what());
386
+ return Qnil;
387
+ }
388
+ }
389
+
324
390
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
325
- for (int i = 0; i < dim; i++) {
391
+ for (size_t i = 0; i < dim; i++) {
326
392
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
327
393
  }
328
394
 
329
395
  std::priority_queue<std::pair<float, size_t>> result;
330
396
  try {
331
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
397
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
332
398
  } catch (const std::runtime_error& e) {
333
399
  ruby_xfree(vec);
334
400
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -336,8 +402,9 @@ private:
336
402
  }
337
403
 
338
404
  ruby_xfree(vec);
405
+ if (filter_func) delete filter_func;
339
406
 
340
- if (result.size() != (size_t)NUM2INT(k)) {
407
+ if (result.size() != NUM2SIZET(k)) {
341
408
  rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
342
409
  }
343
410
 
@@ -347,7 +414,7 @@ private:
347
414
  while (!result.empty()) {
348
415
  const std::pair<float, size_t>& result_tuple = result.top();
349
416
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
350
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
417
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
351
418
  result.pop();
352
419
  }
353
420
 
@@ -364,8 +431,28 @@ private:
364
431
  return Qnil;
365
432
  };
366
433
 
367
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
434
+ static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
435
+ VALUE _filename, _allow_replace_deleted;
436
+ VALUE kw_args = Qnil;
437
+ ID kw_table[1] = {rb_intern("allow_replace_deleted")};
438
+ VALUE kw_values[1] = {Qundef};
439
+
440
+ rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
441
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
442
+ _allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
443
+
444
+ if (!RB_TYPE_P(_filename, T_STRING)) {
445
+ rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
446
+ return Qnil;
447
+ }
448
+ if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
449
+ !RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
450
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
451
+ return Qnil;
452
+ }
453
+
368
454
  std::string filename(StringValuePtr(_filename));
455
+ const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
369
456
  VALUE ivspace = rb_iv_get(self, "@space");
370
457
  hnswlib::SpaceInterface<float>* space;
371
458
  if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
@@ -373,6 +460,7 @@ private:
373
460
  } else {
374
461
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
375
462
  }
463
+
376
464
  hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
377
465
  if (index->data_level0_memory_) {
378
466
  free(index->data_level0_memory_);
@@ -392,12 +480,15 @@ private:
392
480
  delete index->visited_list_pool_;
393
481
  index->visited_list_pool_ = nullptr;
394
482
  }
483
+
395
484
  try {
396
485
  index->loadIndex(filename, space);
486
+ index->allow_replace_deleted_ = allow_replace_deleted;
397
487
  } catch (const std::runtime_error& e) {
398
488
  rb_raise(rb_eRuntimeError, "%s", e.what());
399
489
  return Qnil;
400
490
  }
491
+
401
492
  RB_GC_GUARD(_filename);
402
493
  return Qnil;
403
494
  };
@@ -405,7 +496,7 @@ private:
405
496
  static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
406
497
  VALUE ret = Qnil;
407
498
  try {
408
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
499
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
409
500
  ret = rb_ary_new2(vec.size());
410
501
  for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
411
502
  } catch (const std::runtime_error& e) {
@@ -417,13 +508,23 @@ private:
417
508
 
418
509
  static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
419
510
  VALUE ret = rb_ary_new();
420
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
511
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
421
512
  return ret;
422
513
  };
423
514
 
424
515
  static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
425
516
  try {
426
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
517
+ get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
518
+ } catch (const std::runtime_error& e) {
519
+ rb_raise(rb_eRuntimeError, "%s", e.what());
520
+ return Qnil;
521
+ }
522
+ return Qnil;
523
+ };
524
+
525
+ static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
526
+ try {
527
+ get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
427
528
  } catch (const std::runtime_error& e) {
428
529
  rb_raise(rb_eRuntimeError, "%s", e.what());
429
530
  return Qnil;
@@ -432,31 +533,42 @@ private:
432
533
  };
433
534
 
434
535
  static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
435
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
536
+ if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
436
537
  rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
437
538
  return Qnil;
438
539
  }
439
540
  try {
440
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
541
+ get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
441
542
  } catch (const std::runtime_error& e) {
442
543
  rb_raise(rb_eRuntimeError, "%s", e.what());
443
544
  return Qnil;
545
+ } catch (const std::bad_alloc& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
444
548
  }
445
549
  return Qnil;
446
550
  };
447
551
 
448
552
  static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
449
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
553
+ get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
450
554
  return Qnil;
451
555
  };
452
556
 
557
+ static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
558
+
453
559
  static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
454
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
560
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
455
561
  };
456
562
 
457
563
  static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
458
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
564
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
565
+ };
566
+
567
+ static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
568
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
459
569
  };
570
+
571
+ static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
460
572
  };
461
573
 
462
574
  // clang-format off
@@ -500,7 +612,7 @@ public:
500
612
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
501
613
  rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
502
614
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
503
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
615
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
504
616
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
505
617
  rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
506
618
  rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
@@ -537,7 +649,7 @@ private:
537
649
  } else {
538
650
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
539
651
  }
540
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
652
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
541
653
 
542
654
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
543
655
  try {
@@ -551,7 +663,7 @@ private:
551
663
  };
552
664
 
553
665
  static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
554
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
666
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
555
667
 
556
668
  if (!RB_TYPE_P(arr, T_ARRAY)) {
557
669
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
@@ -567,10 +679,10 @@ private:
567
679
  }
568
680
 
569
681
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
570
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
682
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
571
683
 
572
684
  try {
573
- get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
685
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
574
686
  } catch (const std::runtime_error& e) {
575
687
  ruby_xfree(vec);
576
688
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -581,8 +693,17 @@ private:
581
693
  return Qtrue;
582
694
  };
583
695
 
584
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
585
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
696
+ static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
697
+ VALUE arr, k, filter;
698
+ VALUE kw_args = Qnil;
699
+ ID kw_table[1] = {rb_intern("filter")};
700
+ VALUE kw_values[1] = {Qundef};
701
+
702
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
703
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
704
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
705
+
706
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
586
707
 
587
708
  if (!RB_TYPE_P(arr, T_ARRAY)) {
588
709
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -597,17 +718,28 @@ private:
597
718
  return Qnil;
598
719
  }
599
720
 
721
+ CustomFilterFunctor* filter_func = nullptr;
722
+ if (!NIL_P(filter)) {
723
+ try {
724
+ filter_func = new CustomFilterFunctor(filter);
725
+ } catch (const std::bad_alloc& e) {
726
+ rb_raise(rb_eRuntimeError, "%s", e.what());
727
+ return Qnil;
728
+ }
729
+ }
730
+
600
731
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
601
- for (int i = 0; i < dim; i++) {
732
+ for (size_t i = 0; i < dim; i++) {
602
733
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
603
734
  }
604
735
 
605
736
  std::priority_queue<std::pair<float, size_t>> result =
606
- get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
737
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
607
738
 
608
739
  ruby_xfree(vec);
740
+ if (filter_func) delete filter_func;
609
741
 
610
- if (result.size() != (size_t)NUM2INT(k)) {
742
+ if (result.size() != NUM2SIZET(k)) {
611
743
  rb_warning("Cannot return as many search results as the requested number of neighbors.");
612
744
  }
613
745
 
@@ -617,7 +749,7 @@ private:
617
749
  while (!result.empty()) {
618
750
  const std::pair<float, size_t>& result_tuple = result.top();
619
751
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
620
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
752
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
621
753
  result.pop();
622
754
  }
623
755
 
@@ -659,16 +791,16 @@ private:
659
791
  };
660
792
 
661
793
  static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
662
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
794
+ get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
663
795
  return Qnil;
664
796
  };
665
797
 
666
798
  static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
667
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
799
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
668
800
  };
669
801
 
670
802
  static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
671
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
803
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
672
804
  };
673
805
  };
674
806