hnswlib 0.6.2 → 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +14 -7
- data/ext/hnswlib/hnswlibext.cpp +1 -3
- data/ext/hnswlib/hnswlibext.hpp +326 -93
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +21 -18
- data/sig/hnswlib.rbs +9 -7
- metadata +3 -3
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,11 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <cmath>
|
27
|
+
#include <new>
|
28
|
+
#include <vector>
|
29
|
+
|
30
|
+
VALUE rb_mHnswlib;
|
26
31
|
VALUE rb_cHnswlibL2Space;
|
27
32
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
33
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -49,8 +54,8 @@ public:
|
|
49
54
|
return ptr;
|
50
55
|
};
|
51
56
|
|
52
|
-
static VALUE define_class(VALUE
|
53
|
-
rb_cHnswlibL2Space = rb_define_class_under(
|
57
|
+
static VALUE define_class(VALUE outer) {
|
58
|
+
rb_cHnswlibL2Space = rb_define_class_under(outer, "L2Space", rb_cObject);
|
54
59
|
rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
|
55
60
|
rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
|
56
61
|
rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
|
@@ -64,12 +69,12 @@ private:
|
|
64
69
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
70
|
rb_iv_set(self, "@dim", dim);
|
66
71
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
72
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
73
|
return Qnil;
|
69
74
|
};
|
70
75
|
|
71
76
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
77
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
73
78
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
74
79
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
75
80
|
return Qnil;
|
@@ -79,9 +84,9 @@ private:
|
|
79
84
|
return Qnil;
|
80
85
|
}
|
81
86
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
82
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
83
88
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (
|
89
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
85
90
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
86
91
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
87
92
|
ruby_xfree(vec_a);
|
@@ -125,8 +130,8 @@ public:
|
|
125
130
|
return ptr;
|
126
131
|
};
|
127
132
|
|
128
|
-
static VALUE define_class(VALUE
|
129
|
-
rb_cHnswlibInnerProductSpace = rb_define_class_under(
|
133
|
+
static VALUE define_class(VALUE outer) {
|
134
|
+
rb_cHnswlibInnerProductSpace = rb_define_class_under(outer, "InnerProductSpace", rb_cObject);
|
130
135
|
rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
|
131
136
|
rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
|
132
137
|
rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
|
@@ -140,12 +145,12 @@ private:
|
|
140
145
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
141
146
|
rb_iv_set(self, "@dim", dim);
|
142
147
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
143
|
-
new (ptr) hnswlib::InnerProductSpace(
|
148
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
144
149
|
return Qnil;
|
145
150
|
};
|
146
151
|
|
147
152
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
148
|
-
const
|
153
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
149
154
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
150
155
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
151
156
|
return Qnil;
|
@@ -155,9 +160,9 @@ private:
|
|
155
160
|
return Qnil;
|
156
161
|
}
|
157
162
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
158
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
159
164
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (
|
165
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
161
166
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
162
167
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
163
168
|
ruby_xfree(vec_a);
|
@@ -180,6 +185,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
180
185
|
};
|
181
186
|
// clang-format on
|
182
187
|
|
188
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
189
|
+
public:
|
190
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
191
|
+
|
192
|
+
bool operator()(hnswlib::labeltype id) {
|
193
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
194
|
+
return result == Qtrue ? true : false;
|
195
|
+
}
|
196
|
+
|
197
|
+
private:
|
198
|
+
VALUE callback_;
|
199
|
+
};
|
200
|
+
|
183
201
|
class RbHnswlibHierarchicalNSW {
|
184
202
|
public:
|
185
203
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -202,21 +220,26 @@ public:
|
|
202
220
|
return ptr;
|
203
221
|
};
|
204
222
|
|
205
|
-
static VALUE define_class(VALUE
|
206
|
-
rb_cHnswlibHierarchicalNSW = rb_define_class_under(
|
223
|
+
static VALUE define_class(VALUE outer) {
|
224
|
+
rb_cHnswlibHierarchicalNSW = rb_define_class_under(outer, "HierarchicalNSW", rb_cObject);
|
207
225
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
208
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(
|
209
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "
|
210
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_initialize), -1);
|
227
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "init_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init_index), -1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
229
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
211
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
212
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
231
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
213
232
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
214
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
215
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
216
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
217
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
218
239
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
219
240
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
241
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
242
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
220
243
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
221
244
|
return rb_cHnswlibHierarchicalNSW;
|
222
245
|
};
|
@@ -224,54 +247,91 @@ public:
|
|
224
247
|
private:
|
225
248
|
static const rb_data_type_t hnsw_hierarchicalnsw_type;
|
226
249
|
|
227
|
-
static VALUE
|
250
|
+
static VALUE _hnsw_hierarchicalnsw_initialize(int argc, VALUE* argv, VALUE self) {
|
228
251
|
VALUE kw_args = Qnil;
|
229
|
-
ID kw_table[
|
230
|
-
|
231
|
-
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
252
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
|
253
|
+
VALUE kw_values[2] = {Qundef, Qundef};
|
232
254
|
rb_scan_args(argc, argv, ":", &kw_args);
|
233
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
234
|
-
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
235
|
-
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
236
|
-
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
255
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
237
256
|
|
238
|
-
if (!(
|
239
|
-
|
240
|
-
|
257
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
258
|
+
rb_raise(rb_eTypeError, "expected space, String");
|
259
|
+
return Qnil;
|
260
|
+
}
|
261
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
|
262
|
+
strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
|
263
|
+
rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
|
241
264
|
return Qnil;
|
242
265
|
}
|
243
266
|
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
267
|
+
rb_raise(rb_eTypeError, "expected dim, Integer");
|
268
|
+
return Qnil;
|
269
|
+
}
|
270
|
+
|
271
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
|
272
|
+
rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
|
273
|
+
} else {
|
274
|
+
rb_iv_set(self, "@space",
|
275
|
+
rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
|
276
|
+
}
|
277
|
+
|
278
|
+
rb_iv_set(self, "@normalize", Qfalse);
|
279
|
+
if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
|
280
|
+
|
281
|
+
return Qnil;
|
282
|
+
};
|
283
|
+
|
284
|
+
static VALUE _hnsw_hierarchicalnsw_init_index(int argc, VALUE* argv, VALUE self) {
|
285
|
+
VALUE kw_args = Qnil;
|
286
|
+
ID kw_table[5] = {rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"), rb_intern("random_seed"),
|
287
|
+
rb_intern("allow_replace_deleted")};
|
288
|
+
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
289
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
290
|
+
rb_get_kwargs(kw_args, kw_table, 1, 4, kw_values);
|
291
|
+
if (kw_values[1] == Qundef) kw_values[1] = SIZET2NUM(16);
|
292
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(200);
|
293
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(100);
|
294
|
+
if (kw_values[4] == Qundef) kw_values[4] = Qfalse;
|
295
|
+
|
296
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
244
297
|
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
245
298
|
return Qnil;
|
246
299
|
}
|
247
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
300
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
248
301
|
rb_raise(rb_eTypeError, "expected m, Integer");
|
249
302
|
return Qnil;
|
250
303
|
}
|
251
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
304
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
252
305
|
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
253
306
|
return Qnil;
|
254
307
|
}
|
255
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
308
|
+
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
256
309
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
257
310
|
return Qnil;
|
258
311
|
}
|
312
|
+
if (!RB_TYPE_P(kw_values[4], T_TRUE) && !RB_TYPE_P(kw_values[4], T_FALSE)) {
|
313
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
314
|
+
return Qnil;
|
315
|
+
}
|
259
316
|
|
260
|
-
|
261
|
-
|
262
|
-
if (rb_obj_is_instance_of(
|
263
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(
|
317
|
+
hnswlib::SpaceInterface<float>* space = nullptr;
|
318
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
319
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
320
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
264
321
|
} else {
|
265
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(
|
322
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
266
323
|
}
|
267
|
-
|
268
|
-
const size_t
|
269
|
-
const size_t
|
270
|
-
const size_t
|
324
|
+
|
325
|
+
const size_t max_elements = NUM2SIZET(kw_values[0]);
|
326
|
+
const size_t m = NUM2SIZET(kw_values[1]);
|
327
|
+
const size_t ef_construction = NUM2SIZET(kw_values[2]);
|
328
|
+
const size_t random_seed = NUM2SIZET(kw_values[3]);
|
329
|
+
const bool allow_replace_deleted = kw_values[4] == Qtrue ? true : false;
|
271
330
|
|
272
331
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
273
332
|
try {
|
274
|
-
|
333
|
+
ptr->~HierarchicalNSW();
|
334
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
275
335
|
} catch (const std::runtime_error& e) {
|
276
336
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
277
337
|
return Qnil;
|
@@ -280,33 +340,72 @@ private:
|
|
280
340
|
return Qnil;
|
281
341
|
};
|
282
342
|
|
283
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
284
|
-
|
343
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
344
|
+
VALUE _arr, _idx, _replace_deleted;
|
345
|
+
VALUE kw_args = Qnil;
|
346
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
347
|
+
VALUE kw_values[1] = {Qundef};
|
285
348
|
|
286
|
-
|
349
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
350
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
351
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
352
|
+
|
353
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
354
|
+
|
355
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
287
356
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
288
357
|
return Qfalse;
|
289
358
|
}
|
290
|
-
if (!RB_INTEGER_TYPE_P(
|
359
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
291
360
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
292
361
|
return Qfalse;
|
293
362
|
}
|
294
|
-
if (dim != RARRAY_LEN(
|
363
|
+
if (dim != RARRAY_LEN(_arr)) {
|
295
364
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
296
365
|
return Qfalse;
|
297
366
|
}
|
367
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
368
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
369
|
+
return Qfalse;
|
370
|
+
}
|
298
371
|
|
299
372
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
300
|
-
for (
|
373
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
374
|
+
const size_t idx = NUM2SIZET(_idx);
|
375
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
376
|
+
|
377
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
378
|
+
float norm = 0.0;
|
379
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
380
|
+
norm = std::sqrt(std::fabs(norm));
|
381
|
+
if (norm >= 0.0) {
|
382
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
383
|
+
}
|
384
|
+
}
|
301
385
|
|
302
|
-
|
386
|
+
try {
|
387
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, idx, replace_deleted);
|
388
|
+
} catch (const std::runtime_error& e) {
|
389
|
+
ruby_xfree(vec);
|
390
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
391
|
+
return Qfalse;
|
392
|
+
}
|
303
393
|
|
304
394
|
ruby_xfree(vec);
|
305
395
|
return Qtrue;
|
306
396
|
};
|
307
397
|
|
308
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
309
|
-
|
398
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
399
|
+
VALUE arr, k, filter;
|
400
|
+
VALUE kw_args = Qnil;
|
401
|
+
ID kw_table[1] = {rb_intern("filter")};
|
402
|
+
VALUE kw_values[1] = {Qundef};
|
403
|
+
|
404
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
405
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
406
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
407
|
+
|
408
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
310
409
|
|
311
410
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
312
411
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -321,14 +420,33 @@ private:
|
|
321
420
|
return Qnil;
|
322
421
|
}
|
323
422
|
|
423
|
+
CustomFilterFunctor* filter_func = nullptr;
|
424
|
+
if (!NIL_P(filter)) {
|
425
|
+
try {
|
426
|
+
filter_func = new CustomFilterFunctor(filter);
|
427
|
+
} catch (const std::bad_alloc& e) {
|
428
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
|
+
return Qnil;
|
430
|
+
}
|
431
|
+
}
|
432
|
+
|
324
433
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
325
|
-
for (
|
434
|
+
for (size_t i = 0; i < dim; i++) {
|
326
435
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
327
436
|
}
|
328
437
|
|
438
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
439
|
+
float norm = 0.0;
|
440
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
441
|
+
norm = std::sqrt(std::fabs(norm));
|
442
|
+
if (norm >= 0.0) {
|
443
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
444
|
+
}
|
445
|
+
}
|
446
|
+
|
329
447
|
std::priority_queue<std::pair<float, size_t>> result;
|
330
448
|
try {
|
331
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
449
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
332
450
|
} catch (const std::runtime_error& e) {
|
333
451
|
ruby_xfree(vec);
|
334
452
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -336,8 +454,9 @@ private:
|
|
336
454
|
}
|
337
455
|
|
338
456
|
ruby_xfree(vec);
|
457
|
+
if (filter_func) delete filter_func;
|
339
458
|
|
340
|
-
if (result.size() != (
|
459
|
+
if (result.size() != NUM2SIZET(k)) {
|
341
460
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
342
461
|
}
|
343
462
|
|
@@ -347,7 +466,7 @@ private:
|
|
347
466
|
while (!result.empty()) {
|
348
467
|
const std::pair<float, size_t>& result_tuple = result.top();
|
349
468
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
350
|
-
rb_ary_unshift(neighbors_arr,
|
469
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
351
470
|
result.pop();
|
352
471
|
}
|
353
472
|
|
@@ -364,8 +483,28 @@ private:
|
|
364
483
|
return Qnil;
|
365
484
|
};
|
366
485
|
|
367
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
486
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
487
|
+
VALUE _filename, _allow_replace_deleted;
|
488
|
+
VALUE kw_args = Qnil;
|
489
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
490
|
+
VALUE kw_values[1] = {Qundef};
|
491
|
+
|
492
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
493
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
494
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
495
|
+
|
496
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
497
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
498
|
+
return Qnil;
|
499
|
+
}
|
500
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
501
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
502
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
503
|
+
return Qnil;
|
504
|
+
}
|
505
|
+
|
368
506
|
std::string filename(StringValuePtr(_filename));
|
507
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
369
508
|
VALUE ivspace = rb_iv_get(self, "@space");
|
370
509
|
hnswlib::SpaceInterface<float>* space;
|
371
510
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -373,6 +512,7 @@ private:
|
|
373
512
|
} else {
|
374
513
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
375
514
|
}
|
515
|
+
|
376
516
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
377
517
|
if (index->data_level0_memory_) {
|
378
518
|
free(index->data_level0_memory_);
|
@@ -392,12 +532,15 @@ private:
|
|
392
532
|
delete index->visited_list_pool_;
|
393
533
|
index->visited_list_pool_ = nullptr;
|
394
534
|
}
|
535
|
+
|
395
536
|
try {
|
396
537
|
index->loadIndex(filename, space);
|
538
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
397
539
|
} catch (const std::runtime_error& e) {
|
398
540
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
399
541
|
return Qnil;
|
400
542
|
}
|
543
|
+
|
401
544
|
RB_GC_GUARD(_filename);
|
402
545
|
return Qnil;
|
403
546
|
};
|
@@ -405,7 +548,7 @@ private:
|
|
405
548
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
406
549
|
VALUE ret = Qnil;
|
407
550
|
try {
|
408
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
551
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
409
552
|
ret = rb_ary_new2(vec.size());
|
410
553
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
411
554
|
} catch (const std::runtime_error& e) {
|
@@ -417,13 +560,23 @@ private:
|
|
417
560
|
|
418
561
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
419
562
|
VALUE ret = rb_ary_new();
|
420
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
563
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
421
564
|
return ret;
|
422
565
|
};
|
423
566
|
|
424
567
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
425
568
|
try {
|
426
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
569
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
570
|
+
} catch (const std::runtime_error& e) {
|
571
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
572
|
+
return Qnil;
|
573
|
+
}
|
574
|
+
return Qnil;
|
575
|
+
};
|
576
|
+
|
577
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
578
|
+
try {
|
579
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
427
580
|
} catch (const std::runtime_error& e) {
|
428
581
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
582
|
return Qnil;
|
@@ -432,31 +585,42 @@ private:
|
|
432
585
|
};
|
433
586
|
|
434
587
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
435
|
-
if ((
|
588
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
436
589
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
437
590
|
return Qnil;
|
438
591
|
}
|
439
592
|
try {
|
440
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
593
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
441
594
|
} catch (const std::runtime_error& e) {
|
442
595
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
443
596
|
return Qnil;
|
597
|
+
} catch (const std::bad_alloc& e) {
|
598
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
599
|
+
return Qnil;
|
444
600
|
}
|
445
601
|
return Qnil;
|
446
602
|
};
|
447
603
|
|
448
604
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
449
|
-
get_hnsw_hierarchicalnsw(self)->
|
605
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
450
606
|
return Qnil;
|
451
607
|
};
|
452
608
|
|
609
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
610
|
+
|
453
611
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
454
|
-
return
|
612
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
455
613
|
};
|
456
614
|
|
457
615
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
458
|
-
return
|
616
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
617
|
+
};
|
618
|
+
|
619
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
620
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
459
621
|
};
|
622
|
+
|
623
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
460
624
|
};
|
461
625
|
|
462
626
|
// clang-format off
|
@@ -495,12 +659,13 @@ public:
|
|
495
659
|
return ptr;
|
496
660
|
};
|
497
661
|
|
498
|
-
static VALUE define_class(VALUE
|
499
|
-
rb_cHnswlibBruteforceSearch = rb_define_class_under(
|
662
|
+
static VALUE define_class(VALUE outer) {
|
663
|
+
rb_cHnswlibBruteforceSearch = rb_define_class_under(outer, "BruteforceSearch", rb_cObject);
|
500
664
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
501
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(
|
665
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_initialize), -1);
|
666
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "init_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init_index), -1);
|
502
667
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
503
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
668
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
504
669
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
505
670
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
506
671
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -513,34 +678,66 @@ public:
|
|
513
678
|
private:
|
514
679
|
static const rb_data_type_t hnsw_bruteforcesearch_type;
|
515
680
|
|
516
|
-
static VALUE
|
681
|
+
static VALUE _hnsw_bruteforcesearch_initialize(int argc, VALUE* argv, VALUE self) {
|
517
682
|
VALUE kw_args = Qnil;
|
518
|
-
ID kw_table[2] = {rb_intern("space"), rb_intern("
|
683
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
|
519
684
|
VALUE kw_values[2] = {Qundef, Qundef};
|
520
685
|
rb_scan_args(argc, argv, ":", &kw_args);
|
521
686
|
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
522
687
|
|
523
|
-
if (!(
|
524
|
-
|
525
|
-
|
688
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
689
|
+
rb_raise(rb_eTypeError, "expected space, String");
|
690
|
+
return Qnil;
|
691
|
+
}
|
692
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
|
693
|
+
strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
|
694
|
+
rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
|
526
695
|
return Qnil;
|
527
696
|
}
|
528
697
|
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
529
|
-
rb_raise(rb_eTypeError, "expected
|
698
|
+
rb_raise(rb_eTypeError, "expected dim, Integer");
|
530
699
|
return Qnil;
|
531
700
|
}
|
532
701
|
|
533
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
534
702
|
hnswlib::SpaceInterface<float>* space;
|
535
|
-
if (
|
536
|
-
space
|
703
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
|
704
|
+
rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
|
705
|
+
} else {
|
706
|
+
rb_iv_set(self, "@space",
|
707
|
+
rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
|
708
|
+
}
|
709
|
+
|
710
|
+
rb_iv_set(self, "@normalize", Qfalse);
|
711
|
+
if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
|
712
|
+
|
713
|
+
return Qnil;
|
714
|
+
};
|
715
|
+
|
716
|
+
static VALUE _hnsw_bruteforcesearch_init_index(int argc, VALUE* argv, VALUE self) {
|
717
|
+
VALUE kw_args = Qnil;
|
718
|
+
ID kw_table[1] = {rb_intern("max_elements")};
|
719
|
+
VALUE kw_values[1] = {Qundef};
|
720
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
721
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
722
|
+
|
723
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
724
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
725
|
+
return Qnil;
|
726
|
+
}
|
727
|
+
|
728
|
+
hnswlib::SpaceInterface<float>* space = nullptr;
|
729
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
730
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
731
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
537
732
|
} else {
|
538
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(
|
733
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
539
734
|
}
|
540
|
-
|
735
|
+
|
736
|
+
const size_t max_elements = NUM2SIZET(kw_values[0]);
|
541
737
|
|
542
738
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
543
739
|
try {
|
740
|
+
ptr->~BruteforceSearch();
|
544
741
|
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
545
742
|
} catch (const std::runtime_error& e) {
|
546
743
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -551,7 +748,7 @@ private:
|
|
551
748
|
};
|
552
749
|
|
553
750
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
554
|
-
const
|
751
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
555
752
|
|
556
753
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
557
754
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -567,10 +764,19 @@ private:
|
|
567
764
|
}
|
568
765
|
|
569
766
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
570
|
-
for (
|
767
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
768
|
+
|
769
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
770
|
+
float norm = 0.0;
|
771
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
772
|
+
norm = std::sqrt(std::fabs(norm));
|
773
|
+
if (norm >= 0.0) {
|
774
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
775
|
+
}
|
776
|
+
}
|
571
777
|
|
572
778
|
try {
|
573
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
779
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
574
780
|
} catch (const std::runtime_error& e) {
|
575
781
|
ruby_xfree(vec);
|
576
782
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -581,8 +787,17 @@ private:
|
|
581
787
|
return Qtrue;
|
582
788
|
};
|
583
789
|
|
584
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
585
|
-
|
790
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
791
|
+
VALUE arr, k, filter;
|
792
|
+
VALUE kw_args = Qnil;
|
793
|
+
ID kw_table[1] = {rb_intern("filter")};
|
794
|
+
VALUE kw_values[1] = {Qundef};
|
795
|
+
|
796
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
797
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
798
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
799
|
+
|
800
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
586
801
|
|
587
802
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
588
803
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -597,17 +812,35 @@ private:
|
|
597
812
|
return Qnil;
|
598
813
|
}
|
599
814
|
|
815
|
+
CustomFilterFunctor* filter_func = nullptr;
|
816
|
+
if (!NIL_P(filter)) {
|
817
|
+
try {
|
818
|
+
filter_func = new CustomFilterFunctor(filter);
|
819
|
+
} catch (const std::bad_alloc& e) {
|
820
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
821
|
+
return Qnil;
|
822
|
+
}
|
823
|
+
}
|
824
|
+
|
600
825
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
601
|
-
for (
|
602
|
-
|
826
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
827
|
+
|
828
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
829
|
+
float norm = 0.0;
|
830
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
831
|
+
norm = std::sqrt(std::fabs(norm));
|
832
|
+
if (norm >= 0.0) {
|
833
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
834
|
+
}
|
603
835
|
}
|
604
836
|
|
605
837
|
std::priority_queue<std::pair<float, size_t>> result =
|
606
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
838
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
607
839
|
|
608
840
|
ruby_xfree(vec);
|
841
|
+
if (filter_func) delete filter_func;
|
609
842
|
|
610
|
-
if (result.size() != (
|
843
|
+
if (result.size() != NUM2SIZET(k)) {
|
611
844
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
612
845
|
}
|
613
846
|
|
@@ -617,7 +850,7 @@ private:
|
|
617
850
|
while (!result.empty()) {
|
618
851
|
const std::pair<float, size_t>& result_tuple = result.top();
|
619
852
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
620
|
-
rb_ary_unshift(neighbors_arr,
|
853
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
621
854
|
result.pop();
|
622
855
|
}
|
623
856
|
|
@@ -659,16 +892,16 @@ private:
|
|
659
892
|
};
|
660
893
|
|
661
894
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
662
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
895
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
663
896
|
return Qnil;
|
664
897
|
};
|
665
898
|
|
666
899
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
667
|
-
return
|
900
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
668
901
|
};
|
669
902
|
|
670
903
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
671
|
-
return
|
904
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
672
905
|
};
|
673
906
|
};
|
674
907
|
|