hnswlib 0.6.2 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +14 -7
- data/ext/hnswlib/hnswlibext.cpp +1 -3
- data/ext/hnswlib/hnswlibext.hpp +326 -93
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +21 -18
- data/sig/hnswlib.rbs +9 -7
- metadata +3 -3
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,11 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <cmath>
|
27
|
+
#include <new>
|
28
|
+
#include <vector>
|
29
|
+
|
30
|
+
VALUE rb_mHnswlib;
|
26
31
|
VALUE rb_cHnswlibL2Space;
|
27
32
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
33
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -49,8 +54,8 @@ public:
|
|
49
54
|
return ptr;
|
50
55
|
};
|
51
56
|
|
52
|
-
static VALUE define_class(VALUE
|
53
|
-
rb_cHnswlibL2Space = rb_define_class_under(
|
57
|
+
static VALUE define_class(VALUE outer) {
|
58
|
+
rb_cHnswlibL2Space = rb_define_class_under(outer, "L2Space", rb_cObject);
|
54
59
|
rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
|
55
60
|
rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
|
56
61
|
rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
|
@@ -64,12 +69,12 @@ private:
|
|
64
69
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
70
|
rb_iv_set(self, "@dim", dim);
|
66
71
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
72
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
73
|
return Qnil;
|
69
74
|
};
|
70
75
|
|
71
76
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
77
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
73
78
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
74
79
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
75
80
|
return Qnil;
|
@@ -79,9 +84,9 @@ private:
|
|
79
84
|
return Qnil;
|
80
85
|
}
|
81
86
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
82
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
83
88
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (
|
89
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
85
90
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
86
91
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
87
92
|
ruby_xfree(vec_a);
|
@@ -125,8 +130,8 @@ public:
|
|
125
130
|
return ptr;
|
126
131
|
};
|
127
132
|
|
128
|
-
static VALUE define_class(VALUE
|
129
|
-
rb_cHnswlibInnerProductSpace = rb_define_class_under(
|
133
|
+
static VALUE define_class(VALUE outer) {
|
134
|
+
rb_cHnswlibInnerProductSpace = rb_define_class_under(outer, "InnerProductSpace", rb_cObject);
|
130
135
|
rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
|
131
136
|
rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
|
132
137
|
rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
|
@@ -140,12 +145,12 @@ private:
|
|
140
145
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
141
146
|
rb_iv_set(self, "@dim", dim);
|
142
147
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
143
|
-
new (ptr) hnswlib::InnerProductSpace(
|
148
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
144
149
|
return Qnil;
|
145
150
|
};
|
146
151
|
|
147
152
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
148
|
-
const
|
153
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
149
154
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
150
155
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
151
156
|
return Qnil;
|
@@ -155,9 +160,9 @@ private:
|
|
155
160
|
return Qnil;
|
156
161
|
}
|
157
162
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
158
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
159
164
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (
|
165
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
161
166
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
162
167
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
163
168
|
ruby_xfree(vec_a);
|
@@ -180,6 +185,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
180
185
|
};
|
181
186
|
// clang-format on
|
182
187
|
|
188
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
189
|
+
public:
|
190
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
191
|
+
|
192
|
+
bool operator()(hnswlib::labeltype id) {
|
193
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
194
|
+
return result == Qtrue ? true : false;
|
195
|
+
}
|
196
|
+
|
197
|
+
private:
|
198
|
+
VALUE callback_;
|
199
|
+
};
|
200
|
+
|
183
201
|
class RbHnswlibHierarchicalNSW {
|
184
202
|
public:
|
185
203
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -202,21 +220,26 @@ public:
|
|
202
220
|
return ptr;
|
203
221
|
};
|
204
222
|
|
205
|
-
static VALUE define_class(VALUE
|
206
|
-
rb_cHnswlibHierarchicalNSW = rb_define_class_under(
|
223
|
+
static VALUE define_class(VALUE outer) {
|
224
|
+
rb_cHnswlibHierarchicalNSW = rb_define_class_under(outer, "HierarchicalNSW", rb_cObject);
|
207
225
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
208
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(
|
209
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "
|
210
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_initialize), -1);
|
227
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "init_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init_index), -1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
229
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
211
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
212
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
231
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
213
232
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
214
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
215
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
216
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
217
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
218
239
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
219
240
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
241
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
242
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
220
243
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
221
244
|
return rb_cHnswlibHierarchicalNSW;
|
222
245
|
};
|
@@ -224,54 +247,91 @@ public:
|
|
224
247
|
private:
|
225
248
|
static const rb_data_type_t hnsw_hierarchicalnsw_type;
|
226
249
|
|
227
|
-
static VALUE
|
250
|
+
static VALUE _hnsw_hierarchicalnsw_initialize(int argc, VALUE* argv, VALUE self) {
|
228
251
|
VALUE kw_args = Qnil;
|
229
|
-
ID kw_table[
|
230
|
-
|
231
|
-
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
252
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
|
253
|
+
VALUE kw_values[2] = {Qundef, Qundef};
|
232
254
|
rb_scan_args(argc, argv, ":", &kw_args);
|
233
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
234
|
-
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
235
|
-
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
236
|
-
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
255
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
237
256
|
|
238
|
-
if (!(
|
239
|
-
|
240
|
-
|
257
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
258
|
+
rb_raise(rb_eTypeError, "expected space, String");
|
259
|
+
return Qnil;
|
260
|
+
}
|
261
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
|
262
|
+
strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
|
263
|
+
rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
|
241
264
|
return Qnil;
|
242
265
|
}
|
243
266
|
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
267
|
+
rb_raise(rb_eTypeError, "expected dim, Integer");
|
268
|
+
return Qnil;
|
269
|
+
}
|
270
|
+
|
271
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
|
272
|
+
rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
|
273
|
+
} else {
|
274
|
+
rb_iv_set(self, "@space",
|
275
|
+
rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
|
276
|
+
}
|
277
|
+
|
278
|
+
rb_iv_set(self, "@normalize", Qfalse);
|
279
|
+
if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
|
280
|
+
|
281
|
+
return Qnil;
|
282
|
+
};
|
283
|
+
|
284
|
+
static VALUE _hnsw_hierarchicalnsw_init_index(int argc, VALUE* argv, VALUE self) {
|
285
|
+
VALUE kw_args = Qnil;
|
286
|
+
ID kw_table[5] = {rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"), rb_intern("random_seed"),
|
287
|
+
rb_intern("allow_replace_deleted")};
|
288
|
+
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
289
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
290
|
+
rb_get_kwargs(kw_args, kw_table, 1, 4, kw_values);
|
291
|
+
if (kw_values[1] == Qundef) kw_values[1] = SIZET2NUM(16);
|
292
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(200);
|
293
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(100);
|
294
|
+
if (kw_values[4] == Qundef) kw_values[4] = Qfalse;
|
295
|
+
|
296
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
244
297
|
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
245
298
|
return Qnil;
|
246
299
|
}
|
247
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
300
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
248
301
|
rb_raise(rb_eTypeError, "expected m, Integer");
|
249
302
|
return Qnil;
|
250
303
|
}
|
251
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
304
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
252
305
|
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
253
306
|
return Qnil;
|
254
307
|
}
|
255
|
-
if (!RB_INTEGER_TYPE_P(kw_values[
|
308
|
+
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
256
309
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
257
310
|
return Qnil;
|
258
311
|
}
|
312
|
+
if (!RB_TYPE_P(kw_values[4], T_TRUE) && !RB_TYPE_P(kw_values[4], T_FALSE)) {
|
313
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
314
|
+
return Qnil;
|
315
|
+
}
|
259
316
|
|
260
|
-
|
261
|
-
|
262
|
-
if (rb_obj_is_instance_of(
|
263
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(
|
317
|
+
hnswlib::SpaceInterface<float>* space = nullptr;
|
318
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
319
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
320
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
264
321
|
} else {
|
265
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(
|
322
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
266
323
|
}
|
267
|
-
|
268
|
-
const size_t
|
269
|
-
const size_t
|
270
|
-
const size_t
|
324
|
+
|
325
|
+
const size_t max_elements = NUM2SIZET(kw_values[0]);
|
326
|
+
const size_t m = NUM2SIZET(kw_values[1]);
|
327
|
+
const size_t ef_construction = NUM2SIZET(kw_values[2]);
|
328
|
+
const size_t random_seed = NUM2SIZET(kw_values[3]);
|
329
|
+
const bool allow_replace_deleted = kw_values[4] == Qtrue ? true : false;
|
271
330
|
|
272
331
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
273
332
|
try {
|
274
|
-
|
333
|
+
ptr->~HierarchicalNSW();
|
334
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
275
335
|
} catch (const std::runtime_error& e) {
|
276
336
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
277
337
|
return Qnil;
|
@@ -280,33 +340,72 @@ private:
|
|
280
340
|
return Qnil;
|
281
341
|
};
|
282
342
|
|
283
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
284
|
-
|
343
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
344
|
+
VALUE _arr, _idx, _replace_deleted;
|
345
|
+
VALUE kw_args = Qnil;
|
346
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
347
|
+
VALUE kw_values[1] = {Qundef};
|
285
348
|
|
286
|
-
|
349
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
350
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
351
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
352
|
+
|
353
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
354
|
+
|
355
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
287
356
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
288
357
|
return Qfalse;
|
289
358
|
}
|
290
|
-
if (!RB_INTEGER_TYPE_P(
|
359
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
291
360
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
292
361
|
return Qfalse;
|
293
362
|
}
|
294
|
-
if (dim != RARRAY_LEN(
|
363
|
+
if (dim != RARRAY_LEN(_arr)) {
|
295
364
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
296
365
|
return Qfalse;
|
297
366
|
}
|
367
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
368
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
369
|
+
return Qfalse;
|
370
|
+
}
|
298
371
|
|
299
372
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
300
|
-
for (
|
373
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
374
|
+
const size_t idx = NUM2SIZET(_idx);
|
375
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
376
|
+
|
377
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
378
|
+
float norm = 0.0;
|
379
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
380
|
+
norm = std::sqrt(std::fabs(norm));
|
381
|
+
if (norm >= 0.0) {
|
382
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
383
|
+
}
|
384
|
+
}
|
301
385
|
|
302
|
-
|
386
|
+
try {
|
387
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, idx, replace_deleted);
|
388
|
+
} catch (const std::runtime_error& e) {
|
389
|
+
ruby_xfree(vec);
|
390
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
391
|
+
return Qfalse;
|
392
|
+
}
|
303
393
|
|
304
394
|
ruby_xfree(vec);
|
305
395
|
return Qtrue;
|
306
396
|
};
|
307
397
|
|
308
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
309
|
-
|
398
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
399
|
+
VALUE arr, k, filter;
|
400
|
+
VALUE kw_args = Qnil;
|
401
|
+
ID kw_table[1] = {rb_intern("filter")};
|
402
|
+
VALUE kw_values[1] = {Qundef};
|
403
|
+
|
404
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
405
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
406
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
407
|
+
|
408
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
310
409
|
|
311
410
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
312
411
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -321,14 +420,33 @@ private:
|
|
321
420
|
return Qnil;
|
322
421
|
}
|
323
422
|
|
423
|
+
CustomFilterFunctor* filter_func = nullptr;
|
424
|
+
if (!NIL_P(filter)) {
|
425
|
+
try {
|
426
|
+
filter_func = new CustomFilterFunctor(filter);
|
427
|
+
} catch (const std::bad_alloc& e) {
|
428
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
|
+
return Qnil;
|
430
|
+
}
|
431
|
+
}
|
432
|
+
|
324
433
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
325
|
-
for (
|
434
|
+
for (size_t i = 0; i < dim; i++) {
|
326
435
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
327
436
|
}
|
328
437
|
|
438
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
439
|
+
float norm = 0.0;
|
440
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
441
|
+
norm = std::sqrt(std::fabs(norm));
|
442
|
+
if (norm >= 0.0) {
|
443
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
444
|
+
}
|
445
|
+
}
|
446
|
+
|
329
447
|
std::priority_queue<std::pair<float, size_t>> result;
|
330
448
|
try {
|
331
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
449
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
332
450
|
} catch (const std::runtime_error& e) {
|
333
451
|
ruby_xfree(vec);
|
334
452
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -336,8 +454,9 @@ private:
|
|
336
454
|
}
|
337
455
|
|
338
456
|
ruby_xfree(vec);
|
457
|
+
if (filter_func) delete filter_func;
|
339
458
|
|
340
|
-
if (result.size() != (
|
459
|
+
if (result.size() != NUM2SIZET(k)) {
|
341
460
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
342
461
|
}
|
343
462
|
|
@@ -347,7 +466,7 @@ private:
|
|
347
466
|
while (!result.empty()) {
|
348
467
|
const std::pair<float, size_t>& result_tuple = result.top();
|
349
468
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
350
|
-
rb_ary_unshift(neighbors_arr,
|
469
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
351
470
|
result.pop();
|
352
471
|
}
|
353
472
|
|
@@ -364,8 +483,28 @@ private:
|
|
364
483
|
return Qnil;
|
365
484
|
};
|
366
485
|
|
367
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
486
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
487
|
+
VALUE _filename, _allow_replace_deleted;
|
488
|
+
VALUE kw_args = Qnil;
|
489
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
490
|
+
VALUE kw_values[1] = {Qundef};
|
491
|
+
|
492
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
493
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
494
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
495
|
+
|
496
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
497
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
498
|
+
return Qnil;
|
499
|
+
}
|
500
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
501
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
502
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
503
|
+
return Qnil;
|
504
|
+
}
|
505
|
+
|
368
506
|
std::string filename(StringValuePtr(_filename));
|
507
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
369
508
|
VALUE ivspace = rb_iv_get(self, "@space");
|
370
509
|
hnswlib::SpaceInterface<float>* space;
|
371
510
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -373,6 +512,7 @@ private:
|
|
373
512
|
} else {
|
374
513
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
375
514
|
}
|
515
|
+
|
376
516
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
377
517
|
if (index->data_level0_memory_) {
|
378
518
|
free(index->data_level0_memory_);
|
@@ -392,12 +532,15 @@ private:
|
|
392
532
|
delete index->visited_list_pool_;
|
393
533
|
index->visited_list_pool_ = nullptr;
|
394
534
|
}
|
535
|
+
|
395
536
|
try {
|
396
537
|
index->loadIndex(filename, space);
|
538
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
397
539
|
} catch (const std::runtime_error& e) {
|
398
540
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
399
541
|
return Qnil;
|
400
542
|
}
|
543
|
+
|
401
544
|
RB_GC_GUARD(_filename);
|
402
545
|
return Qnil;
|
403
546
|
};
|
@@ -405,7 +548,7 @@ private:
|
|
405
548
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
406
549
|
VALUE ret = Qnil;
|
407
550
|
try {
|
408
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
551
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
409
552
|
ret = rb_ary_new2(vec.size());
|
410
553
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
411
554
|
} catch (const std::runtime_error& e) {
|
@@ -417,13 +560,23 @@ private:
|
|
417
560
|
|
418
561
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
419
562
|
VALUE ret = rb_ary_new();
|
420
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
563
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
421
564
|
return ret;
|
422
565
|
};
|
423
566
|
|
424
567
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
425
568
|
try {
|
426
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
569
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
570
|
+
} catch (const std::runtime_error& e) {
|
571
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
572
|
+
return Qnil;
|
573
|
+
}
|
574
|
+
return Qnil;
|
575
|
+
};
|
576
|
+
|
577
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
578
|
+
try {
|
579
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
427
580
|
} catch (const std::runtime_error& e) {
|
428
581
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
582
|
return Qnil;
|
@@ -432,31 +585,42 @@ private:
|
|
432
585
|
};
|
433
586
|
|
434
587
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
435
|
-
if ((
|
588
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
436
589
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
437
590
|
return Qnil;
|
438
591
|
}
|
439
592
|
try {
|
440
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
593
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
441
594
|
} catch (const std::runtime_error& e) {
|
442
595
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
443
596
|
return Qnil;
|
597
|
+
} catch (const std::bad_alloc& e) {
|
598
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
599
|
+
return Qnil;
|
444
600
|
}
|
445
601
|
return Qnil;
|
446
602
|
};
|
447
603
|
|
448
604
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
449
|
-
get_hnsw_hierarchicalnsw(self)->
|
605
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
450
606
|
return Qnil;
|
451
607
|
};
|
452
608
|
|
609
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
610
|
+
|
453
611
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
454
|
-
return
|
612
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
455
613
|
};
|
456
614
|
|
457
615
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
458
|
-
return
|
616
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
617
|
+
};
|
618
|
+
|
619
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
620
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
459
621
|
};
|
622
|
+
|
623
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
460
624
|
};
|
461
625
|
|
462
626
|
// clang-format off
|
@@ -495,12 +659,13 @@ public:
|
|
495
659
|
return ptr;
|
496
660
|
};
|
497
661
|
|
498
|
-
static VALUE define_class(VALUE
|
499
|
-
rb_cHnswlibBruteforceSearch = rb_define_class_under(
|
662
|
+
static VALUE define_class(VALUE outer) {
|
663
|
+
rb_cHnswlibBruteforceSearch = rb_define_class_under(outer, "BruteforceSearch", rb_cObject);
|
500
664
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
501
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(
|
665
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_initialize), -1);
|
666
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "init_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init_index), -1);
|
502
667
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
503
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
668
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
504
669
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
505
670
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
506
671
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -513,34 +678,66 @@ public:
|
|
513
678
|
private:
|
514
679
|
static const rb_data_type_t hnsw_bruteforcesearch_type;
|
515
680
|
|
516
|
-
static VALUE
|
681
|
+
static VALUE _hnsw_bruteforcesearch_initialize(int argc, VALUE* argv, VALUE self) {
|
517
682
|
VALUE kw_args = Qnil;
|
518
|
-
ID kw_table[2] = {rb_intern("space"), rb_intern("
|
683
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("dim")};
|
519
684
|
VALUE kw_values[2] = {Qundef, Qundef};
|
520
685
|
rb_scan_args(argc, argv, ":", &kw_args);
|
521
686
|
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
522
687
|
|
523
|
-
if (!(
|
524
|
-
|
525
|
-
|
688
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
689
|
+
rb_raise(rb_eTypeError, "expected space, String");
|
690
|
+
return Qnil;
|
691
|
+
}
|
692
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") != 0 && strcmp(StringValueCStr(kw_values[0]), "ip") != 0 &&
|
693
|
+
strcmp(StringValueCStr(kw_values[0]), "cosine") != 0) {
|
694
|
+
rb_raise(rb_eArgError, "expected space, 'l2', 'ip', or 'cosine' only");
|
526
695
|
return Qnil;
|
527
696
|
}
|
528
697
|
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
529
|
-
rb_raise(rb_eTypeError, "expected
|
698
|
+
rb_raise(rb_eTypeError, "expected dim, Integer");
|
530
699
|
return Qnil;
|
531
700
|
}
|
532
701
|
|
533
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
534
702
|
hnswlib::SpaceInterface<float>* space;
|
535
|
-
if (
|
536
|
-
space
|
703
|
+
if (strcmp(StringValueCStr(kw_values[0]), "l2") == 0) {
|
704
|
+
rb_iv_set(self, "@space", rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("L2Space")), rb_intern("new"), 1, kw_values[1]));
|
705
|
+
} else {
|
706
|
+
rb_iv_set(self, "@space",
|
707
|
+
rb_funcall(rb_const_get(rb_mHnswlib, rb_intern("InnerProductSpace")), rb_intern("new"), 1, kw_values[1]));
|
708
|
+
}
|
709
|
+
|
710
|
+
rb_iv_set(self, "@normalize", Qfalse);
|
711
|
+
if (strcmp(StringValueCStr(kw_values[0]), "cosine") == 0) rb_iv_set(self, "@normalize", Qtrue);
|
712
|
+
|
713
|
+
return Qnil;
|
714
|
+
};
|
715
|
+
|
716
|
+
static VALUE _hnsw_bruteforcesearch_init_index(int argc, VALUE* argv, VALUE self) {
|
717
|
+
VALUE kw_args = Qnil;
|
718
|
+
ID kw_table[1] = {rb_intern("max_elements")};
|
719
|
+
VALUE kw_values[1] = {Qundef};
|
720
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
721
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
722
|
+
|
723
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
724
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
725
|
+
return Qnil;
|
726
|
+
}
|
727
|
+
|
728
|
+
hnswlib::SpaceInterface<float>* space = nullptr;
|
729
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
730
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
731
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
537
732
|
} else {
|
538
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(
|
733
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
539
734
|
}
|
540
|
-
|
735
|
+
|
736
|
+
const size_t max_elements = NUM2SIZET(kw_values[0]);
|
541
737
|
|
542
738
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
543
739
|
try {
|
740
|
+
ptr->~BruteforceSearch();
|
544
741
|
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
545
742
|
} catch (const std::runtime_error& e) {
|
546
743
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -551,7 +748,7 @@ private:
|
|
551
748
|
};
|
552
749
|
|
553
750
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
554
|
-
const
|
751
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
555
752
|
|
556
753
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
557
754
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -567,10 +764,19 @@ private:
|
|
567
764
|
}
|
568
765
|
|
569
766
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
570
|
-
for (
|
767
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
768
|
+
|
769
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
770
|
+
float norm = 0.0;
|
771
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
772
|
+
norm = std::sqrt(std::fabs(norm));
|
773
|
+
if (norm >= 0.0) {
|
774
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
775
|
+
}
|
776
|
+
}
|
571
777
|
|
572
778
|
try {
|
573
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
779
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
574
780
|
} catch (const std::runtime_error& e) {
|
575
781
|
ruby_xfree(vec);
|
576
782
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -581,8 +787,17 @@ private:
|
|
581
787
|
return Qtrue;
|
582
788
|
};
|
583
789
|
|
584
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
585
|
-
|
790
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
791
|
+
VALUE arr, k, filter;
|
792
|
+
VALUE kw_args = Qnil;
|
793
|
+
ID kw_table[1] = {rb_intern("filter")};
|
794
|
+
VALUE kw_values[1] = {Qundef};
|
795
|
+
|
796
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
797
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
798
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
799
|
+
|
800
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
586
801
|
|
587
802
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
588
803
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -597,17 +812,35 @@ private:
|
|
597
812
|
return Qnil;
|
598
813
|
}
|
599
814
|
|
815
|
+
CustomFilterFunctor* filter_func = nullptr;
|
816
|
+
if (!NIL_P(filter)) {
|
817
|
+
try {
|
818
|
+
filter_func = new CustomFilterFunctor(filter);
|
819
|
+
} catch (const std::bad_alloc& e) {
|
820
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
821
|
+
return Qnil;
|
822
|
+
}
|
823
|
+
}
|
824
|
+
|
600
825
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
601
|
-
for (
|
602
|
-
|
826
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
827
|
+
|
828
|
+
if (rb_iv_get(self, "@normalize") == Qtrue) {
|
829
|
+
float norm = 0.0;
|
830
|
+
for (size_t i = 0; i < dim; i++) norm += vec[i] * vec[i];
|
831
|
+
norm = std::sqrt(std::fabs(norm));
|
832
|
+
if (norm >= 0.0) {
|
833
|
+
for (size_t i = 0; i < dim; i++) vec[i] /= norm;
|
834
|
+
}
|
603
835
|
}
|
604
836
|
|
605
837
|
std::priority_queue<std::pair<float, size_t>> result =
|
606
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
838
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
607
839
|
|
608
840
|
ruby_xfree(vec);
|
841
|
+
if (filter_func) delete filter_func;
|
609
842
|
|
610
|
-
if (result.size() != (
|
843
|
+
if (result.size() != NUM2SIZET(k)) {
|
611
844
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
612
845
|
}
|
613
846
|
|
@@ -617,7 +850,7 @@ private:
|
|
617
850
|
while (!result.empty()) {
|
618
851
|
const std::pair<float, size_t>& result_tuple = result.top();
|
619
852
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
620
|
-
rb_ary_unshift(neighbors_arr,
|
853
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
621
854
|
result.pop();
|
622
855
|
}
|
623
856
|
|
@@ -659,16 +892,16 @@ private:
|
|
659
892
|
};
|
660
893
|
|
661
894
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
662
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
895
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
663
896
|
return Qnil;
|
664
897
|
};
|
665
898
|
|
666
899
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
667
|
-
return
|
900
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
668
901
|
};
|
669
902
|
|
670
903
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
671
|
-
return
|
904
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
672
905
|
};
|
673
906
|
};
|
674
907
|
|