hnswlib 0.6.1 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 293c9038e57db28357f77c753988b593ef921a2d8caf7234f4a547547580f2cb
4
- data.tar.gz: 668eba08220e29d970f886b91382a335834ac746d9a208d9add986f2ff21fbfc
3
+ metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
+ data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
5
5
  SHA512:
6
- metadata.gz: 1a748a2d3f8291453221b60b55f184b152f56aa35ec5a0830d7b6c6f82adb90e09d757737f0b0527e3404f30c61d7da7f8567c1a7d087045fada991e6152a333
7
- data.tar.gz: 4f5fbf6a8b14e4179862f3e9030bf8e08fa971e874e518141261e2f83942fdb96b68217adcb0dbd2a26b12b5684ef7b645fd7b7fa8025b03fc6e08694a14d9c9
6
+ metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
+ data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## [0.7.0] - 2023-03-04
2
+
3
+ - Update bundled hnswlib version to 0.7.0.
4
+ - Add support for replacing an element marked for deletion with a new element.
5
+ - Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
6
+
7
+ ## [0.6.2] - 2022-06-25
8
+
9
+ - Refactor codes and configs with RuboCop and clang-format.
10
+ - Change to raise ArgumentError when non-array object is given to distance method.
11
+
1
12
  ## [0.6.1] - 2022-04-30
2
13
 
3
14
  - Change the `search_knn` method of `BruteforceSearch` to output warning message instead of rasing RuntimeError
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -20,8 +20,7 @@
20
20
 
21
21
  VALUE rb_mHnswlib;
22
22
 
23
- extern "C"
24
- void Init_hnswlibext(void) {
23
+ extern "C" void Init_hnswlibext(void) {
25
24
  rb_mHnswlib = rb_define_module("Hnswlib");
26
25
  RbHnswlibL2Space::define_class(rb_mHnswlib);
27
26
  RbHnswlibInnerProductSpace::define_class(rb_mHnswlib);
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,9 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <new>
27
+ #include <vector>
28
+
26
29
  VALUE rb_cHnswlibL2Space;
27
30
  VALUE rb_cHnswlibInnerProductSpace;
28
31
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -64,20 +67,24 @@ private:
64
67
  static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
68
  rb_iv_set(self, "@dim", dim);
66
69
  hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
70
+ new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
68
71
  return Qnil;
69
72
  };
70
73
 
71
74
  static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
76
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
77
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
78
+ return Qnil;
79
+ }
73
80
  if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
74
81
  rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
75
82
  return Qnil;
76
83
  }
77
84
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
78
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
85
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
79
86
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
80
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
87
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
81
88
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
82
89
  const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
83
90
  ruby_xfree(vec_a);
@@ -136,20 +143,24 @@ private:
136
143
  static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
137
144
  rb_iv_set(self, "@dim", dim);
138
145
  hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
139
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
146
+ new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
140
147
  return Qnil;
141
148
  };
142
149
 
143
150
  static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
144
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
152
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
153
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
154
+ return Qnil;
155
+ }
145
156
  if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
146
157
  rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
147
158
  return Qnil;
148
159
  }
149
160
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
150
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
161
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
151
162
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
152
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
163
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
153
164
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
154
165
  const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
155
166
  ruby_xfree(vec_a);
@@ -172,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
172
183
  };
173
184
  // clang-format on
174
185
 
186
+ class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
187
+ public:
188
+ CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
189
+
190
+ bool operator()(hnswlib::labeltype id) {
191
+ VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
192
+ return result == Qtrue ? true : false;
193
+ }
194
+
195
+ private:
196
+ VALUE callback_;
197
+ };
198
+
175
199
  class RbHnswlibHierarchicalNSW {
176
200
  public:
177
201
  static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
@@ -198,17 +222,21 @@ public:
198
222
  rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
199
223
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
200
224
  rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
201
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
202
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
225
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
203
227
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
204
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
228
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
205
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
206
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
207
231
  rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
232
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
208
233
  rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
209
234
  rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
235
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
210
236
  rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
211
237
  rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
238
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
239
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
212
240
  rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
213
241
  return rb_cHnswlibHierarchicalNSW;
214
242
  };
@@ -218,14 +246,15 @@ private:
218
246
 
219
247
  static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
220
248
  VALUE kw_args = Qnil;
221
- ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
222
- rb_intern("random_seed")};
223
- VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
249
+ ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
+ rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
+ VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
224
252
  rb_scan_args(argc, argv, ":", &kw_args);
225
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
226
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
227
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
228
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
253
+ rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
+ if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
+ if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
229
258
 
230
259
  if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
231
260
  rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
@@ -248,6 +277,10 @@ private:
248
277
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
249
278
  return Qnil;
250
279
  }
280
+ if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
281
+ rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
+ return Qnil;
283
+ }
251
284
 
252
285
  rb_iv_set(self, "@space", kw_values[0]);
253
286
  hnswlib::SpaceInterface<float>* space;
@@ -256,14 +289,15 @@ private:
256
289
  } else {
257
290
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
258
291
  }
259
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
260
- const size_t m = (size_t)NUM2INT(kw_values[2]);
261
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
262
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
292
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
293
+ const size_t m = NUM2SIZET(kw_values[2]);
294
+ const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
+ const size_t random_seed = NUM2SIZET(kw_values[4]);
296
+ const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
263
297
 
264
298
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
265
299
  try {
266
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
300
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
267
301
  } catch (const std::runtime_error& e) {
268
302
  rb_raise(rb_eRuntimeError, "%s", e.what());
269
303
  return Qnil;
@@ -272,33 +306,63 @@ private:
272
306
  return Qnil;
273
307
  };
274
308
 
275
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
276
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
309
+ static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
310
+ VALUE _arr, _idx, _replace_deleted;
311
+ VALUE kw_args = Qnil;
312
+ ID kw_table[1] = {rb_intern("replace_deleted")};
313
+ VALUE kw_values[1] = {Qundef};
277
314
 
278
- if (!RB_TYPE_P(arr, T_ARRAY)) {
315
+ rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
316
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
317
+ _replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
318
+
319
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
320
+
321
+ if (!RB_TYPE_P(_arr, T_ARRAY)) {
279
322
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
280
323
  return Qfalse;
281
324
  }
282
- if (!RB_INTEGER_TYPE_P(idx)) {
325
+ if (!RB_INTEGER_TYPE_P(_idx)) {
283
326
  rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
284
327
  return Qfalse;
285
328
  }
286
- if (dim != RARRAY_LEN(arr)) {
329
+ if (dim != RARRAY_LEN(_arr)) {
287
330
  rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
288
331
  return Qfalse;
289
332
  }
333
+ if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
334
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
335
+ return Qfalse;
336
+ }
290
337
 
291
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
292
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
338
+ float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
+ for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
+ const size_t idx = NUM2SIZET(_idx);
341
+ const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
293
342
 
294
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
343
+ try {
344
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
345
+ } catch (const std::runtime_error& e) {
346
+ ruby_xfree(arr);
347
+ rb_raise(rb_eRuntimeError, "%s", e.what());
348
+ return Qfalse;
349
+ }
295
350
 
296
- ruby_xfree(vec);
351
+ ruby_xfree(arr);
297
352
  return Qtrue;
298
353
  };
299
354
 
300
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
301
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
355
+ static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
356
+ VALUE arr, k, filter;
357
+ VALUE kw_args = Qnil;
358
+ ID kw_table[1] = {rb_intern("filter")};
359
+ VALUE kw_values[1] = {Qundef};
360
+
361
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
362
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
363
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
364
+
365
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
302
366
 
303
367
  if (!RB_TYPE_P(arr, T_ARRAY)) {
304
368
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -313,14 +377,24 @@ private:
313
377
  return Qnil;
314
378
  }
315
379
 
380
+ CustomFilterFunctor* filter_func = nullptr;
381
+ if (!NIL_P(filter)) {
382
+ try {
383
+ filter_func = new CustomFilterFunctor(filter);
384
+ } catch (const std::bad_alloc& e) {
385
+ rb_raise(rb_eRuntimeError, "%s", e.what());
386
+ return Qnil;
387
+ }
388
+ }
389
+
316
390
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
317
- for (int i = 0; i < dim; i++) {
391
+ for (size_t i = 0; i < dim; i++) {
318
392
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
319
393
  }
320
394
 
321
395
  std::priority_queue<std::pair<float, size_t>> result;
322
396
  try {
323
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
397
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
324
398
  } catch (const std::runtime_error& e) {
325
399
  ruby_xfree(vec);
326
400
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -328,8 +402,9 @@ private:
328
402
  }
329
403
 
330
404
  ruby_xfree(vec);
405
+ if (filter_func) delete filter_func;
331
406
 
332
- if (result.size() != (size_t)NUM2INT(k)) {
407
+ if (result.size() != NUM2SIZET(k)) {
333
408
  rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
334
409
  }
335
410
 
@@ -339,7 +414,7 @@ private:
339
414
  while (!result.empty()) {
340
415
  const std::pair<float, size_t>& result_tuple = result.top();
341
416
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
342
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
417
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
343
418
  result.pop();
344
419
  }
345
420
 
@@ -356,8 +431,28 @@ private:
356
431
  return Qnil;
357
432
  };
358
433
 
359
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
434
+ static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
435
+ VALUE _filename, _allow_replace_deleted;
436
+ VALUE kw_args = Qnil;
437
+ ID kw_table[1] = {rb_intern("allow_replace_deleted")};
438
+ VALUE kw_values[1] = {Qundef};
439
+
440
+ rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
441
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
442
+ _allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
443
+
444
+ if (!RB_TYPE_P(_filename, T_STRING)) {
445
+ rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
446
+ return Qnil;
447
+ }
448
+ if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
449
+ !RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
450
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
451
+ return Qnil;
452
+ }
453
+
360
454
  std::string filename(StringValuePtr(_filename));
455
+ const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
361
456
  VALUE ivspace = rb_iv_get(self, "@space");
362
457
  hnswlib::SpaceInterface<float>* space;
363
458
  if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
@@ -365,6 +460,7 @@ private:
365
460
  } else {
366
461
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
367
462
  }
463
+
368
464
  hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
369
465
  if (index->data_level0_memory_) {
370
466
  free(index->data_level0_memory_);
@@ -384,12 +480,15 @@ private:
384
480
  delete index->visited_list_pool_;
385
481
  index->visited_list_pool_ = nullptr;
386
482
  }
483
+
387
484
  try {
388
485
  index->loadIndex(filename, space);
486
+ index->allow_replace_deleted_ = allow_replace_deleted;
389
487
  } catch (const std::runtime_error& e) {
390
488
  rb_raise(rb_eRuntimeError, "%s", e.what());
391
489
  return Qnil;
392
490
  }
491
+
393
492
  RB_GC_GUARD(_filename);
394
493
  return Qnil;
395
494
  };
@@ -397,7 +496,7 @@ private:
397
496
  static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
398
497
  VALUE ret = Qnil;
399
498
  try {
400
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
499
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
401
500
  ret = rb_ary_new2(vec.size());
402
501
  for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
403
502
  } catch (const std::runtime_error& e) {
@@ -409,13 +508,23 @@ private:
409
508
 
410
509
  static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
411
510
  VALUE ret = rb_ary_new();
412
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
511
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
413
512
  return ret;
414
513
  };
415
514
 
416
515
  static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
417
516
  try {
418
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
517
+ get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
518
+ } catch (const std::runtime_error& e) {
519
+ rb_raise(rb_eRuntimeError, "%s", e.what());
520
+ return Qnil;
521
+ }
522
+ return Qnil;
523
+ };
524
+
525
+ static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
526
+ try {
527
+ get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
419
528
  } catch (const std::runtime_error& e) {
420
529
  rb_raise(rb_eRuntimeError, "%s", e.what());
421
530
  return Qnil;
@@ -424,31 +533,42 @@ private:
424
533
  };
425
534
 
426
535
  static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
427
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
536
+ if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
428
537
  rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
429
538
  return Qnil;
430
539
  }
431
540
  try {
432
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
541
+ get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
433
542
  } catch (const std::runtime_error& e) {
434
543
  rb_raise(rb_eRuntimeError, "%s", e.what());
435
544
  return Qnil;
545
+ } catch (const std::bad_alloc& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
436
548
  }
437
549
  return Qnil;
438
550
  };
439
551
 
440
552
  static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
441
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
553
+ get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
442
554
  return Qnil;
443
555
  };
444
556
 
557
+ static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
558
+
445
559
  static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
446
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
560
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
447
561
  };
448
562
 
449
563
  static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
450
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
564
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
565
+ };
566
+
567
+ static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
568
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
451
569
  };
570
+
571
+ static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
452
572
  };
453
573
 
454
574
  // clang-format off
@@ -492,7 +612,7 @@ public:
492
612
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
493
613
  rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
494
614
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
495
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
615
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
496
616
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
497
617
  rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
498
618
  rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
@@ -529,7 +649,7 @@ private:
529
649
  } else {
530
650
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
531
651
  }
532
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
652
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
533
653
 
534
654
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
535
655
  try {
@@ -543,7 +663,7 @@ private:
543
663
  };
544
664
 
545
665
  static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
546
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
666
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
547
667
 
548
668
  if (!RB_TYPE_P(arr, T_ARRAY)) {
549
669
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
@@ -559,10 +679,10 @@ private:
559
679
  }
560
680
 
561
681
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
562
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
682
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
563
683
 
564
684
  try {
565
- get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
685
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
566
686
  } catch (const std::runtime_error& e) {
567
687
  ruby_xfree(vec);
568
688
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -573,8 +693,17 @@ private:
573
693
  return Qtrue;
574
694
  };
575
695
 
576
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
577
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
696
+ static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
697
+ VALUE arr, k, filter;
698
+ VALUE kw_args = Qnil;
699
+ ID kw_table[1] = {rb_intern("filter")};
700
+ VALUE kw_values[1] = {Qundef};
701
+
702
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
703
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
704
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
705
+
706
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
578
707
 
579
708
  if (!RB_TYPE_P(arr, T_ARRAY)) {
580
709
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -589,17 +718,28 @@ private:
589
718
  return Qnil;
590
719
  }
591
720
 
721
+ CustomFilterFunctor* filter_func = nullptr;
722
+ if (!NIL_P(filter)) {
723
+ try {
724
+ filter_func = new CustomFilterFunctor(filter);
725
+ } catch (const std::bad_alloc& e) {
726
+ rb_raise(rb_eRuntimeError, "%s", e.what());
727
+ return Qnil;
728
+ }
729
+ }
730
+
592
731
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
593
- for (int i = 0; i < dim; i++) {
732
+ for (size_t i = 0; i < dim; i++) {
594
733
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
595
734
  }
596
735
 
597
736
  std::priority_queue<std::pair<float, size_t>> result =
598
- get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
737
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
599
738
 
600
739
  ruby_xfree(vec);
740
+ if (filter_func) delete filter_func;
601
741
 
602
- if (result.size() != (size_t)NUM2INT(k)) {
742
+ if (result.size() != NUM2SIZET(k)) {
603
743
  rb_warning("Cannot return as many search results as the requested number of neighbors.");
604
744
  }
605
745
 
@@ -609,7 +749,7 @@ private:
609
749
  while (!result.empty()) {
610
750
  const std::pair<float, size_t>& result_tuple = result.top();
611
751
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
612
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
752
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
613
753
  result.pop();
614
754
  }
615
755
 
@@ -651,16 +791,16 @@ private:
651
791
  };
652
792
 
653
793
  static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
654
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
794
+ get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
655
795
  return Qnil;
656
796
  };
657
797
 
658
798
  static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
659
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
799
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
660
800
  };
661
801
 
662
802
  static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
663
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
803
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
664
804
  };
665
805
  };
666
806