hnswlib 0.6.2 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/hnswlib/hnswlibext.cpp +1 -1
- data/ext/hnswlib/hnswlibext.hpp +194 -62
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +17 -10
- data/sig/hnswlib.rbs +6 -6
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
|
4
|
+
data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
|
7
|
+
data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [0.7.0] - 2023-03-04
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.7.0.
|
4
|
+
- Add support for replacing an element marked for deletion with a new element.
|
5
|
+
- Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
|
6
|
+
|
1
7
|
## [0.6.2] - 2022-06-25
|
2
8
|
|
3
9
|
- Refactor codes and configs with RuboCop and clang-format.
|
data/ext/hnswlib/hnswlibext.cpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,9 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <new>
|
27
|
+
#include <vector>
|
28
|
+
|
26
29
|
VALUE rb_cHnswlibL2Space;
|
27
30
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
31
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -64,12 +67,12 @@ private:
|
|
64
67
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
68
|
rb_iv_set(self, "@dim", dim);
|
66
69
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
70
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
71
|
return Qnil;
|
69
72
|
};
|
70
73
|
|
71
74
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
75
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
73
76
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
74
77
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
75
78
|
return Qnil;
|
@@ -79,9 +82,9 @@ private:
|
|
79
82
|
return Qnil;
|
80
83
|
}
|
81
84
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
82
|
-
for (
|
85
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
83
86
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
85
88
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
86
89
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
87
90
|
ruby_xfree(vec_a);
|
@@ -140,12 +143,12 @@ private:
|
|
140
143
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
141
144
|
rb_iv_set(self, "@dim", dim);
|
142
145
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
143
|
-
new (ptr) hnswlib::InnerProductSpace(
|
146
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
144
147
|
return Qnil;
|
145
148
|
};
|
146
149
|
|
147
150
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
148
|
-
const
|
151
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
149
152
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
150
153
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
151
154
|
return Qnil;
|
@@ -155,9 +158,9 @@ private:
|
|
155
158
|
return Qnil;
|
156
159
|
}
|
157
160
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
158
|
-
for (
|
161
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
159
162
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
161
164
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
162
165
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
163
166
|
ruby_xfree(vec_a);
|
@@ -180,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
180
183
|
};
|
181
184
|
// clang-format on
|
182
185
|
|
186
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
187
|
+
public:
|
188
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
189
|
+
|
190
|
+
bool operator()(hnswlib::labeltype id) {
|
191
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
192
|
+
return result == Qtrue ? true : false;
|
193
|
+
}
|
194
|
+
|
195
|
+
private:
|
196
|
+
VALUE callback_;
|
197
|
+
};
|
198
|
+
|
183
199
|
class RbHnswlibHierarchicalNSW {
|
184
200
|
public:
|
185
201
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -206,17 +222,21 @@ public:
|
|
206
222
|
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
207
223
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
208
224
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
209
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point),
|
210
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn),
|
225
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
211
227
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
212
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
213
229
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
214
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
215
231
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
232
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
216
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
217
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
218
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
219
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
239
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
220
240
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
221
241
|
return rb_cHnswlibHierarchicalNSW;
|
222
242
|
};
|
@@ -226,14 +246,15 @@ private:
|
|
226
246
|
|
227
247
|
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
228
248
|
VALUE kw_args = Qnil;
|
229
|
-
ID kw_table[
|
230
|
-
rb_intern("random_seed")};
|
231
|
-
VALUE kw_values[
|
249
|
+
ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
|
250
|
+
rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
|
251
|
+
VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
232
252
|
rb_scan_args(argc, argv, ":", &kw_args);
|
233
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
234
|
-
if (kw_values[2] == Qundef) kw_values[2] =
|
235
|
-
if (kw_values[3] == Qundef) kw_values[3] =
|
236
|
-
if (kw_values[4] == Qundef) kw_values[4] =
|
253
|
+
rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
|
254
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
|
255
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
|
256
|
+
if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
|
257
|
+
if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
|
237
258
|
|
238
259
|
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
239
260
|
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
@@ -256,6 +277,10 @@ private:
|
|
256
277
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
257
278
|
return Qnil;
|
258
279
|
}
|
280
|
+
if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
|
281
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
282
|
+
return Qnil;
|
283
|
+
}
|
259
284
|
|
260
285
|
rb_iv_set(self, "@space", kw_values[0]);
|
261
286
|
hnswlib::SpaceInterface<float>* space;
|
@@ -264,14 +289,15 @@ private:
|
|
264
289
|
} else {
|
265
290
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
266
291
|
}
|
267
|
-
const size_t max_elements = (
|
268
|
-
const size_t m = (
|
269
|
-
const size_t ef_construction = (
|
270
|
-
const size_t random_seed = (
|
292
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
293
|
+
const size_t m = NUM2SIZET(kw_values[2]);
|
294
|
+
const size_t ef_construction = NUM2SIZET(kw_values[3]);
|
295
|
+
const size_t random_seed = NUM2SIZET(kw_values[4]);
|
296
|
+
const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
|
271
297
|
|
272
298
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
273
299
|
try {
|
274
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
300
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
275
301
|
} catch (const std::runtime_error& e) {
|
276
302
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
277
303
|
return Qnil;
|
@@ -280,33 +306,63 @@ private:
|
|
280
306
|
return Qnil;
|
281
307
|
};
|
282
308
|
|
283
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
284
|
-
|
309
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
310
|
+
VALUE _arr, _idx, _replace_deleted;
|
311
|
+
VALUE kw_args = Qnil;
|
312
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
313
|
+
VALUE kw_values[1] = {Qundef};
|
314
|
+
|
315
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
316
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
317
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
285
318
|
|
286
|
-
|
319
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
320
|
+
|
321
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
287
322
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
288
323
|
return Qfalse;
|
289
324
|
}
|
290
|
-
if (!RB_INTEGER_TYPE_P(
|
325
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
291
326
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
292
327
|
return Qfalse;
|
293
328
|
}
|
294
|
-
if (dim != RARRAY_LEN(
|
329
|
+
if (dim != RARRAY_LEN(_arr)) {
|
295
330
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
296
331
|
return Qfalse;
|
297
332
|
}
|
333
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
334
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
335
|
+
return Qfalse;
|
336
|
+
}
|
298
337
|
|
299
|
-
float*
|
300
|
-
for (
|
338
|
+
float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
|
339
|
+
for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
340
|
+
const size_t idx = NUM2SIZET(_idx);
|
341
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
301
342
|
|
302
|
-
|
343
|
+
try {
|
344
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
|
345
|
+
} catch (const std::runtime_error& e) {
|
346
|
+
ruby_xfree(arr);
|
347
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
+
return Qfalse;
|
349
|
+
}
|
303
350
|
|
304
|
-
ruby_xfree(
|
351
|
+
ruby_xfree(arr);
|
305
352
|
return Qtrue;
|
306
353
|
};
|
307
354
|
|
308
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
309
|
-
|
355
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
356
|
+
VALUE arr, k, filter;
|
357
|
+
VALUE kw_args = Qnil;
|
358
|
+
ID kw_table[1] = {rb_intern("filter")};
|
359
|
+
VALUE kw_values[1] = {Qundef};
|
360
|
+
|
361
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
362
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
363
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
364
|
+
|
365
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
310
366
|
|
311
367
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
312
368
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -321,14 +377,24 @@ private:
|
|
321
377
|
return Qnil;
|
322
378
|
}
|
323
379
|
|
380
|
+
CustomFilterFunctor* filter_func = nullptr;
|
381
|
+
if (!NIL_P(filter)) {
|
382
|
+
try {
|
383
|
+
filter_func = new CustomFilterFunctor(filter);
|
384
|
+
} catch (const std::bad_alloc& e) {
|
385
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
386
|
+
return Qnil;
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
324
390
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
325
|
-
for (
|
391
|
+
for (size_t i = 0; i < dim; i++) {
|
326
392
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
327
393
|
}
|
328
394
|
|
329
395
|
std::priority_queue<std::pair<float, size_t>> result;
|
330
396
|
try {
|
331
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
397
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
332
398
|
} catch (const std::runtime_error& e) {
|
333
399
|
ruby_xfree(vec);
|
334
400
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -336,8 +402,9 @@ private:
|
|
336
402
|
}
|
337
403
|
|
338
404
|
ruby_xfree(vec);
|
405
|
+
if (filter_func) delete filter_func;
|
339
406
|
|
340
|
-
if (result.size() != (
|
407
|
+
if (result.size() != NUM2SIZET(k)) {
|
341
408
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
342
409
|
}
|
343
410
|
|
@@ -347,7 +414,7 @@ private:
|
|
347
414
|
while (!result.empty()) {
|
348
415
|
const std::pair<float, size_t>& result_tuple = result.top();
|
349
416
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
350
|
-
rb_ary_unshift(neighbors_arr,
|
417
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
351
418
|
result.pop();
|
352
419
|
}
|
353
420
|
|
@@ -364,8 +431,28 @@ private:
|
|
364
431
|
return Qnil;
|
365
432
|
};
|
366
433
|
|
367
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
434
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
435
|
+
VALUE _filename, _allow_replace_deleted;
|
436
|
+
VALUE kw_args = Qnil;
|
437
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
438
|
+
VALUE kw_values[1] = {Qundef};
|
439
|
+
|
440
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
441
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
442
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
443
|
+
|
444
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
445
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
446
|
+
return Qnil;
|
447
|
+
}
|
448
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
449
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
450
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
451
|
+
return Qnil;
|
452
|
+
}
|
453
|
+
|
368
454
|
std::string filename(StringValuePtr(_filename));
|
455
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
369
456
|
VALUE ivspace = rb_iv_get(self, "@space");
|
370
457
|
hnswlib::SpaceInterface<float>* space;
|
371
458
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -373,6 +460,7 @@ private:
|
|
373
460
|
} else {
|
374
461
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
375
462
|
}
|
463
|
+
|
376
464
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
377
465
|
if (index->data_level0_memory_) {
|
378
466
|
free(index->data_level0_memory_);
|
@@ -392,12 +480,15 @@ private:
|
|
392
480
|
delete index->visited_list_pool_;
|
393
481
|
index->visited_list_pool_ = nullptr;
|
394
482
|
}
|
483
|
+
|
395
484
|
try {
|
396
485
|
index->loadIndex(filename, space);
|
486
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
397
487
|
} catch (const std::runtime_error& e) {
|
398
488
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
399
489
|
return Qnil;
|
400
490
|
}
|
491
|
+
|
401
492
|
RB_GC_GUARD(_filename);
|
402
493
|
return Qnil;
|
403
494
|
};
|
@@ -405,7 +496,7 @@ private:
|
|
405
496
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
406
497
|
VALUE ret = Qnil;
|
407
498
|
try {
|
408
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
499
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
409
500
|
ret = rb_ary_new2(vec.size());
|
410
501
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
411
502
|
} catch (const std::runtime_error& e) {
|
@@ -417,13 +508,23 @@ private:
|
|
417
508
|
|
418
509
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
419
510
|
VALUE ret = rb_ary_new();
|
420
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
511
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
421
512
|
return ret;
|
422
513
|
};
|
423
514
|
|
424
515
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
425
516
|
try {
|
426
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
517
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
518
|
+
} catch (const std::runtime_error& e) {
|
519
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
520
|
+
return Qnil;
|
521
|
+
}
|
522
|
+
return Qnil;
|
523
|
+
};
|
524
|
+
|
525
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
526
|
+
try {
|
527
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
427
528
|
} catch (const std::runtime_error& e) {
|
428
529
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
530
|
return Qnil;
|
@@ -432,31 +533,42 @@ private:
|
|
432
533
|
};
|
433
534
|
|
434
535
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
435
|
-
if ((
|
536
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
436
537
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
437
538
|
return Qnil;
|
438
539
|
}
|
439
540
|
try {
|
440
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
541
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
441
542
|
} catch (const std::runtime_error& e) {
|
442
543
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
443
544
|
return Qnil;
|
545
|
+
} catch (const std::bad_alloc& e) {
|
546
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
547
|
+
return Qnil;
|
444
548
|
}
|
445
549
|
return Qnil;
|
446
550
|
};
|
447
551
|
|
448
552
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
449
|
-
get_hnsw_hierarchicalnsw(self)->
|
553
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
450
554
|
return Qnil;
|
451
555
|
};
|
452
556
|
|
557
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
558
|
+
|
453
559
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
454
|
-
return
|
560
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
455
561
|
};
|
456
562
|
|
457
563
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
458
|
-
return
|
564
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
565
|
+
};
|
566
|
+
|
567
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
568
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
459
569
|
};
|
570
|
+
|
571
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
460
572
|
};
|
461
573
|
|
462
574
|
// clang-format off
|
@@ -500,7 +612,7 @@ public:
|
|
500
612
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
501
613
|
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
502
614
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
503
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
615
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
504
616
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
505
617
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
506
618
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -537,7 +649,7 @@ private:
|
|
537
649
|
} else {
|
538
650
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
539
651
|
}
|
540
|
-
const size_t max_elements = (
|
652
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
541
653
|
|
542
654
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
543
655
|
try {
|
@@ -551,7 +663,7 @@ private:
|
|
551
663
|
};
|
552
664
|
|
553
665
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
554
|
-
const
|
666
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
555
667
|
|
556
668
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
557
669
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -567,10 +679,10 @@ private:
|
|
567
679
|
}
|
568
680
|
|
569
681
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
570
|
-
for (
|
682
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
571
683
|
|
572
684
|
try {
|
573
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
685
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
574
686
|
} catch (const std::runtime_error& e) {
|
575
687
|
ruby_xfree(vec);
|
576
688
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -581,8 +693,17 @@ private:
|
|
581
693
|
return Qtrue;
|
582
694
|
};
|
583
695
|
|
584
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
585
|
-
|
696
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
697
|
+
VALUE arr, k, filter;
|
698
|
+
VALUE kw_args = Qnil;
|
699
|
+
ID kw_table[1] = {rb_intern("filter")};
|
700
|
+
VALUE kw_values[1] = {Qundef};
|
701
|
+
|
702
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
703
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
704
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
705
|
+
|
706
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
586
707
|
|
587
708
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
588
709
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -597,17 +718,28 @@ private:
|
|
597
718
|
return Qnil;
|
598
719
|
}
|
599
720
|
|
721
|
+
CustomFilterFunctor* filter_func = nullptr;
|
722
|
+
if (!NIL_P(filter)) {
|
723
|
+
try {
|
724
|
+
filter_func = new CustomFilterFunctor(filter);
|
725
|
+
} catch (const std::bad_alloc& e) {
|
726
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
727
|
+
return Qnil;
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
600
731
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
601
|
-
for (
|
732
|
+
for (size_t i = 0; i < dim; i++) {
|
602
733
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
603
734
|
}
|
604
735
|
|
605
736
|
std::priority_queue<std::pair<float, size_t>> result =
|
606
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
737
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
607
738
|
|
608
739
|
ruby_xfree(vec);
|
740
|
+
if (filter_func) delete filter_func;
|
609
741
|
|
610
|
-
if (result.size() != (
|
742
|
+
if (result.size() != NUM2SIZET(k)) {
|
611
743
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
612
744
|
}
|
613
745
|
|
@@ -617,7 +749,7 @@ private:
|
|
617
749
|
while (!result.empty()) {
|
618
750
|
const std::pair<float, size_t>& result_tuple = result.top();
|
619
751
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
620
|
-
rb_ary_unshift(neighbors_arr,
|
752
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
621
753
|
result.pop();
|
622
754
|
}
|
623
755
|
|
@@ -659,16 +791,16 @@ private:
|
|
659
791
|
};
|
660
792
|
|
661
793
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
662
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
794
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
663
795
|
return Qnil;
|
664
796
|
};
|
665
797
|
|
666
798
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
667
|
-
return
|
799
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
668
800
|
};
|
669
801
|
|
670
802
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
671
|
-
return
|
803
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
672
804
|
};
|
673
805
|
};
|
674
806
|
|