hnswlib 0.6.2 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * hnswlib.rb is a Ruby binding for the Hnswlib.
3
3
  *
4
- * Copyright (c) 2021-2022 Atsushi Tatsuma
4
+ * Copyright (c) 2021-2023 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,11 @@
23
23
 
24
24
  #include <hnswlib.h>
25
25
 
26
+ #include <cmath>
27
+ #include <new>
28
+ #include <vector>
29
+
30
+ VALUE rb_mHnswlib;
26
31
  VALUE rb_cHnswlibL2Space;
27
32
  VALUE rb_cHnswlibInnerProductSpace;
28
33
  VALUE rb_cHnswlibHierarchicalNSW;
@@ -49,8 +54,8 @@ public:
49
54
  return ptr;
50
55
  };
51
56
 
52
- static VALUE define_class(VALUE rb_mHnswlib) {
53
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
57
+ static VALUE define_class(VALUE outer) {
58
+ rb_cHnswlibL2Space = rb_define_class_under(outer, "L2Space", rb_cObject);
54
59
  rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
55
60
  rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
56
61
  rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
@@ -64,12 +69,12 @@ private:
64
69
  static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
70
  rb_iv_set(self, "@dim", dim);
66
71
  hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
72
+ new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
68
73
  return Qnil;
69
74
  };
70
75
 
71
76
  static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
77
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
73
78
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
74
79
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
75
80
  return Qnil;
@@ -79,9 +84,9 @@ private:
79
84
  return Qnil;
80
85
  }
81
86
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
82
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
87
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
83
88
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
89
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
85
90
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
86
91
  const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
87
92
  ruby_xfree(vec_a);
@@ -125,8 +130,8 @@ public:
125
130
  return ptr;
126
131
  };
127
132
 
128
- static VALUE define_class(VALUE rb_mHnswlib) {
129
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
133
+ static VALUE define_class(VALUE outer) {
134
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(outer, "InnerProductSpace", rb_cObject);
130
135
  rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
131
136
  rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
132
137
  rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
@@ -140,12 +145,12 @@ private:
140
145
  static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
141
146
  rb_iv_set(self, "@dim", dim);
142
147
  hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
143
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
148
+ new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
144
149
  return Qnil;
145
150
  };
146
151
 
147
152
  static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
148
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
153
+ const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
149
154
  if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
150
155
  rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
151
156
  return Qnil;
@@ -155,9 +160,9 @@ private:
155
160
  return Qnil;
156
161
  }
157
162
  float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
158
- for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
163
+ for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
159
164
  float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
165
+ for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
161
166
  hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
162
167
  const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
163
168
  ruby_xfree(vec_a);
@@ -180,6 +185,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
180
185
  };
181
186
  // clang-format on
182
187
 
188
+ class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
189
+ public:
190
+ CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
191
+
192
+ bool operator()(hnswlib::labeltype id) {
193
+ VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
194
+ return result == Qtrue ? true : false;
195
+ }
196
+
197
+ private:
198
+ VALUE callback_;
199
+ };
200
+
183
201
  class RbHnswlibHierarchicalNSW {
184
202
  public:
185
203
  static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
@@ -202,21 +220,26 @@ public:
202
220
  return ptr;
203
221
  };
204
222
 
205
- static VALUE define_class(VALUE rb_mHnswlib) {
206
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
223
+ static VALUE define_class(VALUE outer) {
224
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(outer, "HierarchicalNSW", rb_cObject);
207
225
  rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
208
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
226
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_initialize), -1);
227
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "init_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init_index), -1);
228
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
229
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
211
230
  rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
231
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
213
232
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
214
233
  rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
215
234
  rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
235
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
216
236
  rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
217
237
  rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
238
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
218
239
  rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
219
240
  rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
241
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
242
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
220
243
  rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
221
244
  return rb_cHnswlibHierarchicalNSW;
222
245
  };
@@ -224,54 +247,91 @@ public:
224
247
  private:
225
248
  static const rb_data_type_t hnsw_hierarchicalnsw_type;
226
249
 
227
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
250
+ static VALUE _hnsw_hierarchicalnsw_initialize(int argc, VALUE* argv, VALUE self) {
228
251
  VALUE kw_args = Qnil;
229
- ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
230
- rb_intern("random_seed")};
231
- VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
252
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
253
+ VALUE kw_values[2] = {Qundef, Qundef};
232
254
  rb_scan_args(argc, argv, ":", &kw_args);
233
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
234
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
235
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
236
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
255
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
237
256
 
238
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
239
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
240
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
257
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
258
+ rb_raise(rb_eTypeError, "expected space, String");
259
+ return Qnil;
260
+ }
261
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
262
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
263
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
241
264
  return Qnil;
242
265
  }
243
266
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
267
+ rb_raise(rb_eTypeError, "expected dim, Integer");
268
+ return Qnil;
269
+ }
270
+
271
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
272
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
273
+ } else {
274
+ rb_iv_set(self, "@space",
275
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
276
+ }
277
+
278
+ rb_iv_set(self, "@normalize", Qfalse);
279
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
280
+
281
+ return Qnil;
282
+ };
283
+
284
+ static VALUE _hnsw_hierarchicalnsw_init_index(int argc, VALUE* argv, VALUE self) {
285
+ VALUE kw_args = Qnil;
286
+ ID kw_table[5] = {rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"), rb_intern("random_seed"),
287
+ rb_intern("allow_replace_deleted")};
288
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
289
+ rb_scan_args(argc, argv, ":", &kw_args);
290
+ rb_get_kwargs(kw_args, kw_table, 1, 4, kw_values);
291
+ if (kw_values[1] == Qundef) kw_values[1] = SIZET2NUM(16);
292
+ if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(200);
293
+ if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(100);
294
+ if (kw_values[4] == Qundef) kw_values[4] = Qfalse;
295
+
296
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
244
297
  rb_raise(rb_eTypeError, "expected max_elements, Integer");
245
298
  return Qnil;
246
299
  }
247
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
300
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
248
301
  rb_raise(rb_eTypeError, "expected m, Integer");
249
302
  return Qnil;
250
303
  }
251
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
304
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
252
305
  rb_raise(rb_eTypeError, "expected ef_construction, Integer");
253
306
  return Qnil;
254
307
  }
255
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
308
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
256
309
  rb_raise(rb_eTypeError, "expected random_seed, Integer");
257
310
  return Qnil;
258
311
  }
312
+ if (!RB_TYPE_P(kw_values[4], T_TRUE) && !RB_TYPE_P(kw_values[4], T_FALSE)) {
313
+ rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
314
+ return Qnil;
315
+ }
259
316
 
260
- rb_iv_set(self, "@space", kw_values[0]);
261
- hnswlib::SpaceInterface<float>* space;
262
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
263
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
317
+ hnswlib::SpaceInterface<float>* space = nullptr;
318
+ VALUE ivspace = rb_iv_get(self, "@space");
319
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
320
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
264
321
  } else {
265
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
322
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
266
323
  }
267
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
268
- const size_t m = (size_t)NUM2INT(kw_values[2]);
269
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
270
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
324
+
325
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
326
+ const size_t m = NUM2SIZET(kw_values[1]);
327
+ const size_t ef_construction = NUM2SIZET(kw_values[2]);
328
+ const size_t random_seed = NUM2SIZET(kw_values[3]);
329
+ const bool allow_replace_deleted = kw_values[4] == Qtrue ? true : false;
271
330
 
272
331
  hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
273
332
  try {
274
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
333
+ ptr->~HierarchicalNSW();
334
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
275
335
  } catch (const std::runtime_error& e) {
276
336
  rb_raise(rb_eRuntimeError, "%s", e.what());
277
337
  return Qnil;
@@ -280,33 +340,72 @@ private:
280
340
  return Qnil;
281
341
  };
282
342
 
283
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
284
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
343
+ static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
344
+ VALUE _arr, _idx, _replace_deleted;
345
+ VALUE kw_args = Qnil;
346
+ ID kw_table[1] = {rb_intern("replace_deleted")};
347
+ VALUE kw_values[1] = {Qundef};
285
348
 
286
- if (!RB_TYPE_P(arr, T_ARRAY)) {
349
+ rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
350
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
351
+ _replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
352
+
353
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
354
+
355
+ if (!RB_TYPE_P(_arr, T_ARRAY)) {
287
356
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
288
357
  return Qfalse;
289
358
  }
290
- if (!RB_INTEGER_TYPE_P(idx)) {
359
+ if (!RB_INTEGER_TYPE_P(_idx)) {
291
360
  rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
292
361
  return Qfalse;
293
362
  }
294
- if (dim != RARRAY_LEN(arr)) {
363
+ if (dim != RARRAY_LEN(_arr)) {
295
364
  rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
296
365
  return Qfalse;
297
366
  }
367
+ if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
368
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
369
+ return Qfalse;
370
+ }
298
371
 
299
372
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
300
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
373
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
374
+ const size_t idx = NUM2SIZET(_idx);
375
+ const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
376
+
377
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
378
+ float norm = 0.0;
379
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
380
+ norm = std::sqrt(std::fabs(norm));
381
+ if (norm >= 0.0) {
382
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
383
+ }
384
+ }
301
385
 
302
- get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
386
+ try {
387
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, idx, replace_deleted);
388
+ } catch (const std::runtime_error& e) {
389
+ ruby_xfree(vec);
390
+ rb_raise(rb_eRuntimeError, "%s", e.what());
391
+ return Qfalse;
392
+ }
303
393
 
304
394
  ruby_xfree(vec);
305
395
  return Qtrue;
306
396
  };
307
397
 
308
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
309
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
398
+ static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
399
+ VALUE arr, k, filter;
400
+ VALUE kw_args = Qnil;
401
+ ID kw_table[1] = {rb_intern("filter")};
402
+ VALUE kw_values[1] = {Qundef};
403
+
404
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
405
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
406
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
407
+
408
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
310
409
 
311
410
  if (!RB_TYPE_P(arr, T_ARRAY)) {
312
411
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -321,14 +420,33 @@ private:
321
420
  return Qnil;
322
421
  }
323
422
 
423
+ CustomFilterFunctor* filter_func = nullptr;
424
+ if (!NIL_P(filter)) {
425
+ try {
426
+ filter_func = new CustomFilterFunctor(filter);
427
+ } catch (const std::bad_alloc& e) {
428
+ rb_raise(rb_eRuntimeError, "%s", e.what());
429
+ return Qnil;
430
+ }
431
+ }
432
+
324
433
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
325
- for (int i = 0; i < dim; i++) {
434
+ for (size_t i = 0; i < dim; i++) {
326
435
  vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
327
436
  }
328
437
 
438
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
439
+ float norm = 0.0;
440
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
441
+ norm = std::sqrt(std::fabs(norm));
442
+ if (norm >= 0.0) {
443
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
444
+ }
445
+ }
446
+
329
447
  std::priority_queue<std::pair<float, size_t>> result;
330
448
  try {
331
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
449
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
332
450
  } catch (const std::runtime_error& e) {
333
451
  ruby_xfree(vec);
334
452
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -336,8 +454,9 @@ private:
336
454
  }
337
455
 
338
456
  ruby_xfree(vec);
457
+ if (filter_func) delete filter_func;
339
458
 
340
- if (result.size() != (size_t)NUM2INT(k)) {
459
+ if (result.size() != NUM2SIZET(k)) {
341
460
  rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
342
461
  }
343
462
 
@@ -347,7 +466,7 @@ private:
347
466
  while (!result.empty()) {
348
467
  const std::pair<float, size_t>& result_tuple = result.top();
349
468
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
350
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
469
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
351
470
  result.pop();
352
471
  }
353
472
 
@@ -364,8 +483,28 @@ private:
364
483
  return Qnil;
365
484
  };
366
485
 
367
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
486
+ static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
487
+ VALUE _filename, _allow_replace_deleted;
488
+ VALUE kw_args = Qnil;
489
+ ID kw_table[1] = {rb_intern("allow_replace_deleted")};
490
+ VALUE kw_values[1] = {Qundef};
491
+
492
+ rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
493
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
494
+ _allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
495
+
496
+ if (!RB_TYPE_P(_filename, T_STRING)) {
497
+ rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
498
+ return Qnil;
499
+ }
500
+ if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
501
+ !RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
502
+ rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
503
+ return Qnil;
504
+ }
505
+
368
506
  std::string filename(StringValuePtr(_filename));
507
+ const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
369
508
  VALUE ivspace = rb_iv_get(self, "@space");
370
509
  hnswlib::SpaceInterface<float>* space;
371
510
  if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
@@ -373,6 +512,7 @@ private:
373
512
  } else {
374
513
  space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
375
514
  }
515
+
376
516
  hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
377
517
  if (index->data_level0_memory_) {
378
518
  free(index->data_level0_memory_);
@@ -392,12 +532,15 @@ private:
392
532
  delete index->visited_list_pool_;
393
533
  index->visited_list_pool_ = nullptr;
394
534
  }
535
+
395
536
  try {
396
537
  index->loadIndex(filename, space);
538
+ index->allow_replace_deleted_ = allow_replace_deleted;
397
539
  } catch (const std::runtime_error& e) {
398
540
  rb_raise(rb_eRuntimeError, "%s", e.what());
399
541
  return Qnil;
400
542
  }
543
+
401
544
  RB_GC_GUARD(_filename);
402
545
  return Qnil;
403
546
  };
@@ -405,7 +548,7 @@ private:
405
548
  static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
406
549
  VALUE ret = Qnil;
407
550
  try {
408
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
551
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
409
552
  ret = rb_ary_new2(vec.size());
410
553
  for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
411
554
  } catch (const std::runtime_error& e) {
@@ -417,13 +560,23 @@ private:
417
560
 
418
561
  static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
419
562
  VALUE ret = rb_ary_new();
420
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
563
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
421
564
  return ret;
422
565
  };
423
566
 
424
567
  static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
425
568
  try {
426
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
569
+ get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
570
+ } catch (const std::runtime_error& e) {
571
+ rb_raise(rb_eRuntimeError, "%s", e.what());
572
+ return Qnil;
573
+ }
574
+ return Qnil;
575
+ };
576
+
577
+ static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
578
+ try {
579
+ get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
427
580
  } catch (const std::runtime_error& e) {
428
581
  rb_raise(rb_eRuntimeError, "%s", e.what());
429
582
  return Qnil;
@@ -432,31 +585,42 @@ private:
432
585
  };
433
586
 
434
587
  static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
435
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
588
+ if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
436
589
  rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
437
590
  return Qnil;
438
591
  }
439
592
  try {
440
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
593
+ get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
441
594
  } catch (const std::runtime_error& e) {
442
595
  rb_raise(rb_eRuntimeError, "%s", e.what());
443
596
  return Qnil;
597
+ } catch (const std::bad_alloc& e) {
598
+ rb_raise(rb_eRuntimeError, "%s", e.what());
599
+ return Qnil;
444
600
  }
445
601
  return Qnil;
446
602
  };
447
603
 
448
604
  static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
449
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
605
+ get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
450
606
  return Qnil;
451
607
  };
452
608
 
609
+ static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
610
+
453
611
  static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
454
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
612
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
455
613
  };
456
614
 
457
615
  static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
458
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
616
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
617
+ };
618
+
619
+ static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
620
+ return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
459
621
  };
622
+
623
+ static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
460
624
  };
461
625
 
462
626
  // clang-format off
@@ -495,12 +659,13 @@ public:
495
659
  return ptr;
496
660
  };
497
661
 
498
- static VALUE define_class(VALUE rb_mHnswlib) {
499
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
662
+ static VALUE define_class(VALUE outer) {
663
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(outer, "BruteforceSearch", rb_cObject);
500
664
  rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
501
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
665
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_initialize), -1);
666
+ rb_define_method(rb_cHnswlibBruteforceSearch, "init_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init_index), -1);
502
667
  rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
503
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
668
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
504
669
  rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
505
670
  rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
506
671
  rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
@@ -513,34 +678,66 @@ public:
513
678
  private:
514
679
  static const rb_data_type_t hnsw_bruteforcesearch_type;
515
680
 
516
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
681
+ static VALUE _hnsw_bruteforcesearch_initialize(int argc, VALUE* argv, VALUE self) {
517
682
  VALUE kw_args = Qnil;
518
- ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
683
+ ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
519
684
  VALUE kw_values[2] = {Qundef, Qundef};
520
685
  rb_scan_args(argc, argv, ":", &kw_args);
521
686
  rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
522
687
 
523
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
524
- rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
525
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
688
+ if (!RB_TYPE_P(kw_values[0], T_STRING)) {
689
+ rb_raise(rb_eTypeError, "expected space, String");
690
+ return Qnil;
691
+ }
692
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
693
+ strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
694
+ rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
526
695
  return Qnil;
527
696
  }
528
697
  if (!RB_INTEGER_TYPE_P(kw_values[1])) {
529
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
698
+ rb_raise(rb_eTypeError, "expected dim, Integer");
530
699
  return Qnil;
531
700
  }
532
701
 
533
- rb_iv_set(self, "@space", kw_values[0]);
534
702
  hnswlib::SpaceInterface<float>* space;
535
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
536
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
703
+ if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
704
+ rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
705
+ } else {
706
+ rb_iv_set(self, "@space",
707
+ rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
708
+ }
709
+
710
+ rb_iv_set(self, "@normalize", Qfalse);
711
+ if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
712
+
713
+ return Qnil;
714
+ };
715
+
716
+ static VALUE _hnsw_bruteforcesearch_init_index(int argc, VALUE* argv, VALUE self) {
717
+ VALUE kw_args = Qnil;
718
+ ID kw_table[1] = {rb_intern("max_elements")};
719
+ VALUE kw_values[1] = {Qundef};
720
+ rb_scan_args(argc, argv, ":", &kw_args);
721
+ rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
722
+
723
+ if (!RB_INTEGER_TYPE_P(kw_values[0])) {
724
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
725
+ return Qnil;
726
+ }
727
+
728
+ hnswlib::SpaceInterface<float>* space = nullptr;
729
+ VALUE ivspace = rb_iv_get(self, "@space");
730
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
731
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
537
732
  } else {
538
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
733
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
539
734
  }
540
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
735
+
736
+ const size_t max_elements = NUM2SIZET(kw_values[0]);
541
737
 
542
738
  hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
543
739
  try {
740
+ ptr->~BruteforceSearch();
544
741
  new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
545
742
  } catch (const std::runtime_error& e) {
546
743
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -551,7 +748,7 @@ private:
551
748
  };
552
749
 
553
750
  static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
554
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
751
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
555
752
 
556
753
  if (!RB_TYPE_P(arr, T_ARRAY)) {
557
754
  rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
@@ -567,10 +764,19 @@ private:
567
764
  }
568
765
 
569
766
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
570
- for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
767
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
768
+
769
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
770
+ float norm = 0.0;
771
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
772
+ norm = std::sqrt(std::fabs(norm));
773
+ if (norm >= 0.0) {
774
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
775
+ }
776
+ }
571
777
 
572
778
  try {
573
- get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
779
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
574
780
  } catch (const std::runtime_error& e) {
575
781
  ruby_xfree(vec);
576
782
  rb_raise(rb_eRuntimeError, "%s", e.what());
@@ -581,8 +787,17 @@ private:
581
787
  return Qtrue;
582
788
  };
583
789
 
584
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
585
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
790
+ static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
791
+ VALUE arr, k, filter;
792
+ VALUE kw_args = Qnil;
793
+ ID kw_table[1] = {rb_intern("filter")};
794
+ VALUE kw_values[1] = {Qundef};
795
+
796
+ rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
797
+ rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
798
+ filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
799
+
800
+ const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
586
801
 
587
802
  if (!RB_TYPE_P(arr, T_ARRAY)) {
588
803
  rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
@@ -597,17 +812,35 @@ private:
597
812
  return Qnil;
598
813
  }
599
814
 
815
+ CustomFilterFunctor* filter_func = nullptr;
816
+ if (!NIL_P(filter)) {
817
+ try {
818
+ filter_func = new CustomFilterFunctor(filter);
819
+ } catch (const std::bad_alloc& e) {
820
+ rb_raise(rb_eRuntimeError, "%s", e.what());
821
+ return Qnil;
822
+ }
823
+ }
824
+
600
825
  float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
601
- for (int i = 0; i < dim; i++) {
602
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
826
+ for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
827
+
828
+ if (rb_iv_get(self, "@normalize") == Qtrue) {
829
+ float norm = 0.0;
830
+ for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
831
+ norm = std::sqrt(std::fabs(norm));
832
+ if (norm >= 0.0) {
833
+ for (size_t i = 0; i < dim; i++) vec[i] /= norm;
834
+ }
603
835
  }
604
836
 
605
837
  std::priority_queue<std::pair<float, size_t>> result =
606
- get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
838
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
607
839
 
608
840
  ruby_xfree(vec);
841
+ if (filter_func) delete filter_func;
609
842
 
610
- if (result.size() != (size_t)NUM2INT(k)) {
843
+ if (result.size() != NUM2SIZET(k)) {
611
844
  rb_warning("Cannot return as many search results as the requested number of neighbors.");
612
845
  }
613
846
 
@@ -617,7 +850,7 @@ private:
617
850
  while (!result.empty()) {
618
851
  const std::pair<float, size_t>& result_tuple = result.top();
619
852
  rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
620
- rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
853
+ rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
621
854
  result.pop();
622
855
  }
623
856
 
@@ -659,16 +892,16 @@ private:
659
892
  };
660
893
 
661
894
  static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
662
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
895
+ get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
663
896
  return Qnil;
664
897
  };
665
898
 
666
899
  static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
667
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
900
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
668
901
  };
669
902
 
670
903
  static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
671
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
904
+ return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
672
905
  };
673
906
  };
674
907