hnswlib 0.6.2 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fc8a40172623a3439b17c8b1854017c845bf4dc23f4e1714e9427e4b6b701c9b
4
- data.tar.gz: 95fa48d791f000bf48809e7067efaba9a654ab0dd517c626f689c66af3668d0e
3
+ metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
+ data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
5
5
  SHA512:
6
- metadata.gz: 07c449031cc7afbf80803ae15c8094589a146b59b1852ea72e88fcd39067def572bbdd58277e3a85c47c537e8dc6eef645f9f404a2b874e4e217b559adc38054
7
- data.tar.gz: 4b71a7f4fe0ffb3c15124e32c9c8fbba9eb6f82e6f0b95f5d91dd35c6b10be8aeba25d7313eabeaa5078b051b3410195415843ecd600b58ec1b0b36532b7ae1d
6
+ metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
+ data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [0.7.0] - 2023-03-04
2
+
3
+ - Update bundled hnswlib version to 0.7.0.
4
+ - Add support for replacing an element marked for deletion with a new element.
5
+ - Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
6
+
1
7
  ## [0.6.2] - 2022-06-25
2
8
 
3
9
  - Refactor codes and configs with RuboCop and clang-format.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,9 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <new>
27
+ #include <vector>
28
+
26
29
  VALUE rb_cHnswlibL2Space;
27
30
  VALUE rb_cHnswlibInnerProductSpace;
28
31
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -64,12 +67,12 @@ private:
64
67
  static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
68
  rb_iv_set(self, "@dim", dim);
66
69
  hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
70
+ new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
68
71
  return Qnil;
69
72
  };
70
73
 
71
74
  static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
73
76
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
74
77
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
75
78
  return Qnil;
@@ -79,9 +82,9 @@ private:
79
82
  return Qnil;
80
83
  }
81
84
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
82
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
85
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
83
86
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
87
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
85
88
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
86
89
  const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
87
90
  ruby_xfree(vec_a);
@@ -140,12 +143,12 @@ private:
140
143
  static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
141
144
  rb_iv_set(self, "@dim", dim);
142
145
  hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
143
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
146
+ new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
144
147
  return Qnil;
145
148
  };
146
149
 
147
150
  static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
148
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
149
152
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
150
153
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
151
154
  return Qnil;
@@ -155,9 +158,9 @@ private:
155
158
  return Qnil;
156
159
  }
157
160
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
158
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
161
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
159
162
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
163
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
161
164
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
162
165
  const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
163
166
  ruby_xfree(vec_a);
@@ -180,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
180
183
  };
181
184
  // clang-format on
182
185
 
186
+ class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
187
+ public:
188
+ CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
189
+
190
+ bool operator()(hnswlib::labeltype id) {
191
+ VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
192
+ return result == Qtrue ? true : false;
193
+ }
194
+
195
+ private:
196
+ VALUE callback_;
197
+ };
198
+
183
199
  class RbHnswlibHierarchicalNSW {
184
200
  public:
185
201
  static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
@@ -206,17 +222,21 @@ public:
206
222
  rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
207
223
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
208
224
  rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
225
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
211
227
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
228
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
213
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
214
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
215
231
  rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
232
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
216
233
  rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
217
234
  rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
235
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
218
236
  rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
219
237
  rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
238
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
239
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
220
240
  rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
221
241
  return rb_cHnswlibHierarchicalNSW;
222
242
  };
@@ -226,14 +246,15 @@ private:
226
246
 
227
247
  static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
228
248
  VALUE kw_args = Qnil;
229
- ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
230
- rb_intern("random_seed")};
231
- VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
249
+ ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
+ rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
+ VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
232
252
  rb_scan_args(argc, argv, ":", &kw_args);
233
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
234
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
235
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
236
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
253
+ rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
+ if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
+ if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
237
258
 
238
259
  if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
239
260
  rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
@@ -256,6 +277,10 @@ private:
256
277
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
257
278
  return Qnil;
258
279
  }
280
+ if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
281
+ rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
+ return Qnil;
283
+ }
259
284
 
260
285
  rb_iv_set(self, "@space", kw_values[0]);
261
286
  hnswlib::SpaceInterface<float>* space;
@@ -264,14 +289,15 @@ private:
264
289
  } else {
265
290
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
266
291
  }
267
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
268
- const size_t m = (size_t)NUM2INT(kw_values[2]);
269
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
270
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
292
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
293
+ const size_t m = NUM2SIZET(kw_values[2]);
294
+ const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
+ const size_t random_seed = NUM2SIZET(kw_values[4]);
296
+ const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
271
297
 
272
298
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
273
299
  try {
274
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
300
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
275
301
  } catch (const std::runtime_error& e) {
276
302
  rb_raise(rb_eRuntimeError, "%s", e.what());
277
303
  return Qnil;
@@ -280,33 +306,63 @@ private:
280
306
  return Qnil;
281
307
  };
282
308
 
283
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
284
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
309
+ static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
310
+ VALUE _arr, _idx, _replace_deleted;
311
+ VALUE kw_args = Qnil;
312
+ ID kw_table[1] = {rb_intern("replace_deleted")};
313
+ VALUE kw_values[1] = {Qundef};
314
+
315
+ rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
316
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
317
+ _replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
285
318
 
286
- if (!RB_TYPE_P(arr, T_ARRAY)) {
319
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
320
+
321
+ if (!RB_TYPE_P(_arr, T_ARRAY)) {
287
322
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
288
323
  return Qfalse;
289
324
  }
290
- if (!RB_INTEGER_TYPE_P(idx)) {
325
+ if (!RB_INTEGER_TYPE_P(_idx)) {
291
326
  rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
292
327
  return Qfalse;
293
328
  }
294
- if (dim != RARRAY_LEN(arr)) {
329
+ if (dim != RARRAY_LEN(_arr)) {
295
330
  rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
296
331
  return Qfalse;
297
332
  }
333
+ if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
334
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
335
+ return Qfalse;
336
+ }
298
337
 
299
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
300
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
338
+ float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
+ for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
+ const size_t idx = NUM2SIZET(_idx);
341
+ const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
301
342
 
302
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
343
+ try {
344
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
345
+ } catch (const std::runtime_error& e) {
346
+ ruby_xfree(arr);
347
+ rb_raise(rb_eRuntimeError, "%s", e.what());
348
+ return Qfalse;
349
+ }
303
350
 
304
- ruby_xfree(vec);
351
+ ruby_xfree(arr);
305
352
  return Qtrue;
306
353
  };
307
354
 
308
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
309
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
355
+ static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
356
+ VALUE arr, k, filter;
357
+ VALUE kw_args = Qnil;
358
+ ID kw_table[1] = {rb_intern("filter")};
359
+ VALUE kw_values[1] = {Qundef};
360
+
361
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
362
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
363
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
364
+
365
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
310
366
 
311
367
  if (!RB_TYPE_P(arr, T_ARRAY)) {
312
368
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -321,14 +377,24 @@ private:
321
377
  return Qnil;
322
378
  }
323
379
 
380
+ CustomFilterFunctor* filter_func = nullptr;
381
+ if (!NIL_P(filter)) {
382
+ try {
383
+ filter_func = new CustomFilterFunctor(filter);
384
+ } catch (const std::bad_alloc& e) {
385
+ rb_raise(rb_eRuntimeError, "%s", e.what());
386
+ return Qnil;
387
+ }
388
+ }
389
+
324
390
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
325
- for (int i = 0; i < dim; i++) {
391
+ for (size_t i = 0; i < dim; i++) {
326
392
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
327
393
  }
328
394
 
329
395
  std::priority_queue<std::pair<float, size_t>> result;
330
396
  try {
331
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
397
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
332
398
  } catch (const std::runtime_error& e) {
333
399
  ruby_xfree(vec);
334
400
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -336,8 +402,9 @@ private:
336
402
  }
337
403
 
338
404
  ruby_xfree(vec);
405
+ if (filter_func) delete filter_func;
339
406
 
340
- if (result.size() != (size_t)NUM2INT(k)) {
407
+ if (result.size() != NUM2SIZET(k)) {
341
408
  rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
342
409
  }
343
410
 
@@ -347,7 +414,7 @@ private:
347
414
  while (!result.empty()) {
348
415
  const std::pair<float, size_t>& result_tuple = result.top();
349
416
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
350
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
417
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
351
418
  result.pop();
352
419
  }
353
420
 
@@ -364,8 +431,28 @@ private:
364
431
  return Qnil;
365
432
  };
366
433
 
367
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
434
+ static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
435
+ VALUE _filename, _allow_replace_deleted;
436
+ VALUE kw_args = Qnil;
437
+ ID kw_table[1] = {rb_intern("allow_replace_deleted")};
438
+ VALUE kw_values[1] = {Qundef};
439
+
440
+ rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
441
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
442
+ _allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
443
+
444
+ if (!RB_TYPE_P(_filename, T_STRING)) {
445
+ rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
446
+ return Qnil;
447
+ }
448
+ if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
449
+ !RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
450
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
451
+ return Qnil;
452
+ }
453
+
368
454
  std::string filename(StringValuePtr(_filename));
455
+ const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
369
456
  VALUE ivspace = rb_iv_get(self, "@space");
370
457
  hnswlib::SpaceInterface<float>* space;
371
458
  if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
@@ -373,6 +460,7 @@ private:
373
460
  } else {
374
461
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
375
462
  }
463
+
376
464
  hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
377
465
  if (index->data_level0_memory_) {
378
466
  free(index->data_level0_memory_);
@@ -392,12 +480,15 @@ private:
392
480
  delete index->visited_list_pool_;
393
481
  index->visited_list_pool_ = nullptr;
394
482
  }
483
+
395
484
  try {
396
485
  index->loadIndex(filename, space);
486
+ index->allow_replace_deleted_ = allow_replace_deleted;
397
487
  } catch (const std::runtime_error& e) {
398
488
  rb_raise(rb_eRuntimeError, "%s", e.what());
399
489
  return Qnil;
400
490
  }
491
+
401
492
  RB_GC_GUARD(_filename);
402
493
  return Qnil;
403
494
  };
@@ -405,7 +496,7 @@ private:
405
496
  static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
406
497
  VALUE ret = Qnil;
407
498
  try {
408
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
499
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
409
500
  ret = rb_ary_new2(vec.size());
410
501
  for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
411
502
  } catch (const std::runtime_error& e) {
@@ -417,13 +508,23 @@ private:
417
508
 
418
509
  static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
419
510
  VALUE ret = rb_ary_new();
420
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
511
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
421
512
  return ret;
422
513
  };
423
514
 
424
515
  static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
425
516
  try {
426
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
517
+ get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
518
+ } catch (const std::runtime_error& e) {
519
+ rb_raise(rb_eRuntimeError, "%s", e.what());
520
+ return Qnil;
521
+ }
522
+ return Qnil;
523
+ };
524
+
525
+ static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
526
+ try {
527
+ get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
427
528
  } catch (const std::runtime_error& e) {
428
529
  rb_raise(rb_eRuntimeError, "%s", e.what());
429
530
  return Qnil;
@@ -432,31 +533,42 @@ private:
432
533
  };
433
534
 
434
535
  static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
435
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
536
+ if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
436
537
  rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
437
538
  return Qnil;
438
539
  }
439
540
  try {
440
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
541
+ get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
441
542
  } catch (const std::runtime_error& e) {
442
543
  rb_raise(rb_eRuntimeError, "%s", e.what());
443
544
  return Qnil;
545
+ } catch (const std::bad_alloc& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
444
548
  }
445
549
  return Qnil;
446
550
  };
447
551
 
448
552
  static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
449
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
553
+ get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
450
554
  return Qnil;
451
555
  };
452
556
 
557
+ static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
558
+
453
559
  static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
454
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
560
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
455
561
  };
456
562
 
457
563
  static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
458
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
564
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
565
+ };
566
+
567
+ static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
568
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
459
569
  };
570
+
571
+ static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
460
572
  };
461
573
 
462
574
  // clang-format off
@@ -500,7 +612,7 @@ public:
500
612
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
501
613
  rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
502
614
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
503
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
615
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
504
616
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
505
617
  rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
506
618
  rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
@@ -537,7 +649,7 @@ private:
537
649
  } else {
538
650
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
539
651
  }
540
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
652
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
541
653
 
542
654
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
543
655
  try {
@@ -551,7 +663,7 @@ private:
551
663
  };
552
664
 
553
665
  static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
554
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
666
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
555
667
 
556
668
  if (!RB_TYPE_P(arr, T_ARRAY)) {
557
669
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
@@ -567,10 +679,10 @@ private:
567
679
  }
568
680
 
569
681
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
570
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
682
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
571
683
 
572
684
  try {
573
- get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
685
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
574
686
  } catch (const std::runtime_error& e) {
575
687
  ruby_xfree(vec);
576
688
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -581,8 +693,17 @@ private:
581
693
  return Qtrue;
582
694
  };
583
695
 
584
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
585
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
696
+ static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
697
+ VALUE arr, k, filter;
698
+ VALUE kw_args = Qnil;
699
+ ID kw_table[1] = {rb_intern("filter")};
700
+ VALUE kw_values[1] = {Qundef};
701
+
702
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
703
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
704
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
705
+
706
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
586
707
 
587
708
  if (!RB_TYPE_P(arr, T_ARRAY)) {
588
709
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -597,17 +718,28 @@ private:
597
718
  return Qnil;
598
719
  }
599
720
 
721
+ CustomFilterFunctor* filter_func = nullptr;
722
+ if (!NIL_P(filter)) {
723
+ try {
724
+ filter_func = new CustomFilterFunctor(filter);
725
+ } catch (const std::bad_alloc& e) {
726
+ rb_raise(rb_eRuntimeError, "%s", e.what());
727
+ return Qnil;
728
+ }
729
+ }
730
+
600
731
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
601
- for (int i = 0; i < dim; i++) {
732
+ for (size_t i = 0; i < dim; i++) {
602
733
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
603
734
  }
604
735
 
605
736
  std::priority_queue<std::pair<float, size_t>> result =
606
- get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
737
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
607
738
 
608
739
  ruby_xfree(vec);
740
+ if (filter_func) delete filter_func;
609
741
 
610
- if (result.size() != (size_t)NUM2INT(k)) {
742
+ if (result.size() != NUM2SIZET(k)) {
611
743
  rb_warning("Cannot return as many search results as the requested number of neighbors.");
612
744
  }
613
745
 
@@ -617,7 +749,7 @@ private:
617
749
  while (!result.empty()) {
618
750
  const std::pair<float, size_t>& result_tuple = result.top();
619
751
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
620
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
752
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
621
753
  result.pop();
622
754
  }
623
755
 
@@ -659,16 +791,16 @@ private:
659
791
  };
660
792
 
661
793
  static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
662
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
794
+ get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
663
795
  return Qnil;
664
796
  };
665
797
 
666
798
  static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
667
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
799
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
668
800
  };
669
801
 
670
802
  static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
671
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
803
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
672
804
  };
673
805
  };
674
806