hnswlib 0.6.1 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 293c9038e57db28357f77c753988b593ef921a2d8caf7234f4a547547580f2cb
4
- data.tar.gz: 668eba08220e29d970f886b91382a335834ac746d9a208d9add986f2ff21fbfc
3
+ metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
4
+ data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
5
5
  SHA512:
6
- metadata.gz: 1a748a2d3f8291453221b60b55f184b152f56aa35ec5a0830d7b6c6f82adb90e09d757737f0b0527e3404f30c61d7da7f8567c1a7d087045fada991e6152a333
7
- data.tar.gz: 4f5fbf6a8b14e4179862f3e9030bf8e08fa971e874e518141261e2f83942fdb96b68217adcb0dbd2a26b12b5684ef7b645fd7b7fa8025b03fc6e08694a14d9c9
6
+ metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
7
+ data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
data/CHANGELOG.md CHANGED
@@ -1,3 +1,14 @@
1
+ ## [0.7.0] - 2023-03-04
2
+
3
+ - Update bundled hnswlib version to 0.7.0.
4
+ - Add support for replacing an element marked for deletion with a new element.
5
+ - Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
6
+
7
+ ## [0.6.2] - 2022-06-25
8
+
9
+ - Refactor codes and configs with RuboCop and clang-format.
10
+ - Change to raise ArgumentError when non-array object is given to distance method.
11
+
1
12
  ## [0.6.1] - 2022-04-30
2
13
 
3
14
  - Change the `search_knn` method of `BruteforceSearch` to output warning message instead of rasing RuntimeError
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -20,8 +20,7 @@
20
20
 
21
21
  VALUE rb_mHnswlib;
22
22
 
23
- extern "C"
24
- void Init_hnswlibext(void) {
23
+ extern "C" void Init_hnswlibext(void) {
25
24
  rb_mHnswlib = rb_define_module("Hnswlib");
26
25
  RbHnswlibL2Space::define_class(rb_mHnswlib);
27
26
  RbHnswlibInnerProductSpace::define_class(rb_mHnswlib);
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,9 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <new>
27
+ #include <vector>
28
+
26
29
  VALUE rb_cHnswlibL2Space;
27
30
  VALUE rb_cHnswlibInnerProductSpace;
28
31
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -64,20 +67,24 @@ private:
64
67
  static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
68
  rb_iv_set(self, "@dim", dim);
66
69
  hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
70
+ new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
68
71
  return Qnil;
69
72
  };
70
73
 
71
74
  static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
76
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
77
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
78
+ return Qnil;
79
+ }
73
80
  if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
74
81
  rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
75
82
  return Qnil;
76
83
  }
77
84
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
78
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
85
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
79
86
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
80
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
87
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
81
88
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
82
89
  const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
83
90
  ruby_xfree(vec_a);
@@ -136,20 +143,24 @@ private:
136
143
  static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
137
144
  rb_iv_set(self, "@dim", dim);
138
145
  hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
139
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
146
+ new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
140
147
  return Qnil;
141
148
  };
142
149
 
143
150
  static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
144
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
152
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
153
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
154
+ return Qnil;
155
+ }
145
156
  if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
146
157
  rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
147
158
  return Qnil;
148
159
  }
149
160
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
150
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
161
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
151
162
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
152
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
163
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
153
164
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
154
165
  const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
155
166
  ruby_xfree(vec_a);
@@ -172,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
172
183
  };
173
184
  // clang-format on
174
185
 
186
+ class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
187
+ public:
188
+ CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
189
+
190
+ bool operator()(hnswlib::labeltype id) {
191
+ VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
192
+ return result == Qtrue ? true : false;
193
+ }
194
+
195
+ private:
196
+ VALUE callback_;
197
+ };
198
+
175
199
  class RbHnswlibHierarchicalNSW {
176
200
  public:
177
201
  static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
@@ -198,17 +222,21 @@ public:
198
222
  rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
199
223
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
200
224
  rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
201
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
202
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
225
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
203
227
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
204
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
228
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
205
229
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
206
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
207
231
  rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
232
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
208
233
  rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
209
234
  rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
235
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
210
236
  rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
211
237
  rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
238
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
239
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
212
240
  rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
213
241
  return rb_cHnswlibHierarchicalNSW;
214
242
  };
@@ -218,14 +246,15 @@ private:
218
246
 
219
247
  static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
220
248
  VALUE kw_args = Qnil;
221
- ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
222
- rb_intern("random_seed")};
223
- VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
249
+ ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
250
+ rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
251
+ VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
224
252
  rb_scan_args(argc, argv, ":", &kw_args);
225
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
226
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
227
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
228
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
253
+ rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
254
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
255
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
256
+ if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
257
+ if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
229
258
 
230
259
  if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
231
260
  rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
@@ -248,6 +277,10 @@ private:
248
277
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
249
278
  return Qnil;
250
279
  }
280
+ if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
281
+ rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
282
+ return Qnil;
283
+ }
251
284
 
252
285
  rb_iv_set(self, "@space", kw_values[0]);
253
286
  hnswlib::SpaceInterface<float>* space;
@@ -256,14 +289,15 @@ private:
256
289
  } else {
257
290
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
258
291
  }
259
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
260
- const size_t m = (size_t)NUM2INT(kw_values[2]);
261
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
262
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
292
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
293
+ const size_t m = NUM2SIZET(kw_values[2]);
294
+ const size_t ef_construction = NUM2SIZET(kw_values[3]);
295
+ const size_t random_seed = NUM2SIZET(kw_values[4]);
296
+ const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
263
297
 
264
298
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
265
299
  try {
266
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
300
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
267
301
  } catch (const std::runtime_error& e) {
268
302
  rb_raise(rb_eRuntimeError, "%s", e.what());
269
303
  return Qnil;
@@ -272,33 +306,63 @@ private:
272
306
  return Qnil;
273
307
  };
274
308
 
275
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
276
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
309
+ static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
310
+ VALUE _arr, _idx, _replace_deleted;
311
+ VALUE kw_args = Qnil;
312
+ ID kw_table[1] = {rb_intern("replace_deleted")};
313
+ VALUE kw_values[1] = {Qundef};
277
314
 
278
- if (!RB_TYPE_P(arr, T_ARRAY)) {
315
+ rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
316
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
317
+ _replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
318
+
319
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
320
+
321
+ if (!RB_TYPE_P(_arr, T_ARRAY)) {
279
322
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
280
323
  return Qfalse;
281
324
  }
282
- if (!RB_INTEGER_TYPE_P(idx)) {
325
+ if (!RB_INTEGER_TYPE_P(_idx)) {
283
326
  rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
284
327
  return Qfalse;
285
328
  }
286
- if (dim != RARRAY_LEN(arr)) {
329
+ if (dim != RARRAY_LEN(_arr)) {
287
330
  rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
288
331
  return Qfalse;
289
332
  }
333
+ if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
334
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
335
+ return Qfalse;
336
+ }
290
337
 
291
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
292
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
338
+ float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
339
+ for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
340
+ const size_t idx = NUM2SIZET(_idx);
341
+ const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
293
342
 
294
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
343
+ try {
344
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
345
+ } catch (const std::runtime_error& e) {
346
+ ruby_xfree(arr);
347
+ rb_raise(rb_eRuntimeError, "%s", e.what());
348
+ return Qfalse;
349
+ }
295
350
 
296
- ruby_xfree(vec);
351
+ ruby_xfree(arr);
297
352
  return Qtrue;
298
353
  };
299
354
 
300
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
301
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
355
+ static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
356
+ VALUE arr, k, filter;
357
+ VALUE kw_args = Qnil;
358
+ ID kw_table[1] = {rb_intern("filter")};
359
+ VALUE kw_values[1] = {Qundef};
360
+
361
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
362
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
363
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
364
+
365
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
302
366
 
303
367
  if (!RB_TYPE_P(arr, T_ARRAY)) {
304
368
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -313,14 +377,24 @@ private:
313
377
  return Qnil;
314
378
  }
315
379
 
380
+ CustomFilterFunctor* filter_func = nullptr;
381
+ if (!NIL_P(filter)) {
382
+ try {
383
+ filter_func = new CustomFilterFunctor(filter);
384
+ } catch (const std::bad_alloc& e) {
385
+ rb_raise(rb_eRuntimeError, "%s", e.what());
386
+ return Qnil;
387
+ }
388
+ }
389
+
316
390
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
317
- for (int i = 0; i < dim; i++) {
391
+ for (size_t i = 0; i < dim; i++) {
318
392
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
319
393
  }
320
394
 
321
395
  std::priority_queue<std::pair<float, size_t>> result;
322
396
  try {
323
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
397
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
324
398
  } catch (const std::runtime_error& e) {
325
399
  ruby_xfree(vec);
326
400
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -328,8 +402,9 @@ private:
328
402
  }
329
403
 
330
404
  ruby_xfree(vec);
405
+ if (filter_func) delete filter_func;
331
406
 
332
- if (result.size() != (size_t)NUM2INT(k)) {
407
+ if (result.size() != NUM2SIZET(k)) {
333
408
  rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
334
409
  }
335
410
 
@@ -339,7 +414,7 @@ private:
339
414
  while (!result.empty()) {
340
415
  const std::pair<float, size_t>& result_tuple = result.top();
341
416
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
342
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
417
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
343
418
  result.pop();
344
419
  }
345
420
 
@@ -356,8 +431,28 @@ private:
356
431
  return Qnil;
357
432
  };
358
433
 
359
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
434
+ static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
435
+ VALUE _filename, _allow_replace_deleted;
436
+ VALUE kw_args = Qnil;
437
+ ID kw_table[1] = {rb_intern("allow_replace_deleted")};
438
+ VALUE kw_values[1] = {Qundef};
439
+
440
+ rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
441
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
442
+ _allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
443
+
444
+ if (!RB_TYPE_P(_filename, T_STRING)) {
445
+ rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
446
+ return Qnil;
447
+ }
448
+ if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
449
+ !RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
450
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
451
+ return Qnil;
452
+ }
453
+
360
454
  std::string filename(StringValuePtr(_filename));
455
+ const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
361
456
  VALUE ivspace = rb_iv_get(self, "@space");
362
457
  hnswlib::SpaceInterface<float>* space;
363
458
  if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
@@ -365,6 +460,7 @@ private:
365
460
  } else {
366
461
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
367
462
  }
463
+
368
464
  hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
369
465
  if (index->data_level0_memory_) {
370
466
  free(index->data_level0_memory_);
@@ -384,12 +480,15 @@ private:
384
480
  delete index->visited_list_pool_;
385
481
  index->visited_list_pool_ = nullptr;
386
482
  }
483
+
387
484
  try {
388
485
  index->loadIndex(filename, space);
486
+ index->allow_replace_deleted_ = allow_replace_deleted;
389
487
  } catch (const std::runtime_error& e) {
390
488
  rb_raise(rb_eRuntimeError, "%s", e.what());
391
489
  return Qnil;
392
490
  }
491
+
393
492
  RB_GC_GUARD(_filename);
394
493
  return Qnil;
395
494
  };
@@ -397,7 +496,7 @@ private:
397
496
  static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
398
497
  VALUE ret = Qnil;
399
498
  try {
400
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
499
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
401
500
  ret = rb_ary_new2(vec.size());
402
501
  for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
403
502
  } catch (const std::runtime_error& e) {
@@ -409,13 +508,23 @@ private:
409
508
 
410
509
  static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
411
510
  VALUE ret = rb_ary_new();
412
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
511
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
413
512
  return ret;
414
513
  };
415
514
 
416
515
  static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
417
516
  try {
418
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
517
+ get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
518
+ } catch (const std::runtime_error& e) {
519
+ rb_raise(rb_eRuntimeError, "%s", e.what());
520
+ return Qnil;
521
+ }
522
+ return Qnil;
523
+ };
524
+
525
+ static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
526
+ try {
527
+ get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
419
528
  } catch (const std::runtime_error& e) {
420
529
  rb_raise(rb_eRuntimeError, "%s", e.what());
421
530
  return Qnil;
@@ -424,31 +533,42 @@ private:
424
533
  };
425
534
 
426
535
  static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
427
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
536
+ if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
428
537
  rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
429
538
  return Qnil;
430
539
  }
431
540
  try {
432
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
541
+ get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
433
542
  } catch (const std::runtime_error& e) {
434
543
  rb_raise(rb_eRuntimeError, "%s", e.what());
435
544
  return Qnil;
545
+ } catch (const std::bad_alloc& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
436
548
  }
437
549
  return Qnil;
438
550
  };
439
551
 
440
552
  static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
441
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
553
+ get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
442
554
  return Qnil;
443
555
  };
444
556
 
557
+ static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
558
+
445
559
  static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
446
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
560
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
447
561
  };
448
562
 
449
563
  static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
450
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
564
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
565
+ };
566
+
567
+ static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
568
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
451
569
  };
570
+
571
+ static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
452
572
  };
453
573
 
454
574
  // clang-format off
@@ -492,7 +612,7 @@ public:
492
612
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
493
613
  rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
494
614
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
495
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
615
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
496
616
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
497
617
  rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
498
618
  rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
@@ -529,7 +649,7 @@ private:
529
649
  } else {
530
650
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
531
651
  }
532
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
652
+ const size_t max_elements = NUM2SIZET(kw_values[1]);
533
653
 
534
654
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
535
655
  try {
@@ -543,7 +663,7 @@ private:
543
663
  };
544
664
 
545
665
  static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
546
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
666
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
547
667
 
548
668
  if (!RB_TYPE_P(arr, T_ARRAY)) {
549
669
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
@@ -559,10 +679,10 @@ private:
559
679
  }
560
680
 
561
681
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
562
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
682
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
563
683
 
564
684
  try {
565
- get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
685
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
566
686
  } catch (const std::runtime_error& e) {
567
687
  ruby_xfree(vec);
568
688
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -573,8 +693,17 @@ private:
573
693
  return Qtrue;
574
694
  };
575
695
 
576
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
577
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
696
+ static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
697
+ VALUE arr, k, filter;
698
+ VALUE kw_args = Qnil;
699
+ ID kw_table[1] = {rb_intern("filter")};
700
+ VALUE kw_values[1] = {Qundef};
701
+
702
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
703
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
704
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
705
+
706
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
578
707
 
579
708
  if (!RB_TYPE_P(arr, T_ARRAY)) {
580
709
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -589,17 +718,28 @@ private:
589
718
  return Qnil;
590
719
  }
591
720
 
721
+ CustomFilterFunctor* filter_func = nullptr;
722
+ if (!NIL_P(filter)) {
723
+ try {
724
+ filter_func = new CustomFilterFunctor(filter);
725
+ } catch (const std::bad_alloc& e) {
726
+ rb_raise(rb_eRuntimeError, "%s", e.what());
727
+ return Qnil;
728
+ }
729
+ }
730
+
592
731
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
593
- for (int i = 0; i < dim; i++) {
732
+ for (size_t i = 0; i < dim; i++) {
594
733
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
595
734
  }
596
735
 
597
736
  std::priority_queue<std::pair<float, size_t>> result =
598
- get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
737
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
599
738
 
600
739
  ruby_xfree(vec);
740
+ if (filter_func) delete filter_func;
601
741
 
602
- if (result.size() != (size_t)NUM2INT(k)) {
742
+ if (result.size() != NUM2SIZET(k)) {
603
743
  rb_warning("Cannot return as many search results as the requested number of neighbors.");
604
744
  }
605
745
 
@@ -609,7 +749,7 @@ private:
609
749
  while (!result.empty()) {
610
750
  const std::pair<float, size_t>& result_tuple = result.top();
611
751
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
612
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
752
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
613
753
  result.pop();
614
754
  }
615
755
 
@@ -651,16 +791,16 @@ private:
651
791
  };
652
792
 
653
793
  static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
654
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
794
+ get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
655
795
  return Qnil;
656
796
  };
657
797
 
658
798
  static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
659
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
799
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
660
800
  };
661
801
 
662
802
  static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
663
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
803
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
664
804
  };
665
805
  };
666
806