hnswlib 0.6.2 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/hnswlib/hnswlibext.cpp +1 -1
- data/ext/hnswlib/hnswlibext.hpp +194 -62
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +17 -10
- data/sig/hnswlib.rbs +6 -6
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
|
4
|
+
data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
|
7
|
+
data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [0.7.0] - 2023-03-04
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.7.0.
|
4
|
+
- Add support for replacing an element marked for deletion with a new element.
|
5
|
+
- Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
|
6
|
+
|
1
7
|
## [0.6.2] - 2022-06-25
|
2
8
|
|
3
9
|
- Refactor codes and configs with RuboCop and clang-format.
|
data/ext/hnswlib/hnswlibext.cpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,9 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <new>
|
27
|
+
#include <vector>
|
28
|
+
|
26
29
|
VALUE rb_cHnswlibL2Space;
|
27
30
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
31
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -64,12 +67,12 @@ private:
|
|
64
67
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
68
|
rb_iv_set(self, "@dim", dim);
|
66
69
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
70
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
71
|
return Qnil;
|
69
72
|
};
|
70
73
|
|
71
74
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
75
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
73
76
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
74
77
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
75
78
|
return Qnil;
|
@@ -79,9 +82,9 @@ private:
|
|
79
82
|
return Qnil;
|
80
83
|
}
|
81
84
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
82
|
-
for (
|
85
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
83
86
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
85
88
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
86
89
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
87
90
|
ruby_xfree(vec_a);
|
@@ -140,12 +143,12 @@ private:
|
|
140
143
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
141
144
|
rb_iv_set(self, "@dim", dim);
|
142
145
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
143
|
-
new (ptr) hnswlib::InnerProductSpace(
|
146
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
144
147
|
return Qnil;
|
145
148
|
};
|
146
149
|
|
147
150
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
148
|
-
const
|
151
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
149
152
|
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
150
153
|
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
151
154
|
return Qnil;
|
@@ -155,9 +158,9 @@ private:
|
|
155
158
|
return Qnil;
|
156
159
|
}
|
157
160
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
158
|
-
for (
|
161
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
159
162
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
161
164
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
162
165
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
163
166
|
ruby_xfree(vec_a);
|
@@ -180,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
180
183
|
};
|
181
184
|
// clang-format on
|
182
185
|
|
186
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
187
|
+
public:
|
188
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
189
|
+
|
190
|
+
bool operator()(hnswlib::labeltype id) {
|
191
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
192
|
+
return result == Qtrue ? true : false;
|
193
|
+
}
|
194
|
+
|
195
|
+
private:
|
196
|
+
VALUE callback_;
|
197
|
+
};
|
198
|
+
|
183
199
|
class RbHnswlibHierarchicalNSW {
|
184
200
|
public:
|
185
201
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -206,17 +222,21 @@ public:
|
|
206
222
|
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
207
223
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
208
224
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
209
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point),
|
210
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn),
|
225
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
211
227
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
212
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
213
229
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
214
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
215
231
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
232
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
216
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
217
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
218
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
219
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
239
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
220
240
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
221
241
|
return rb_cHnswlibHierarchicalNSW;
|
222
242
|
};
|
@@ -226,14 +246,15 @@ private:
|
|
226
246
|
|
227
247
|
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
228
248
|
VALUE kw_args = Qnil;
|
229
|
-
ID kw_table[
|
230
|
-
rb_intern("random_seed")};
|
231
|
-
VALUE kw_values[
|
249
|
+
ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
|
250
|
+
rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
|
251
|
+
VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
232
252
|
rb_scan_args(argc, argv, ":", &kw_args);
|
233
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
234
|
-
if (kw_values[2] == Qundef) kw_values[2] =
|
235
|
-
if (kw_values[3] == Qundef) kw_values[3] =
|
236
|
-
if (kw_values[4] == Qundef) kw_values[4] =
|
253
|
+
rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
|
254
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
|
255
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
|
256
|
+
if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
|
257
|
+
if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
|
237
258
|
|
238
259
|
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
239
260
|
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
@@ -256,6 +277,10 @@ private:
|
|
256
277
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
257
278
|
return Qnil;
|
258
279
|
}
|
280
|
+
if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
|
281
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
282
|
+
return Qnil;
|
283
|
+
}
|
259
284
|
|
260
285
|
rb_iv_set(self, "@space", kw_values[0]);
|
261
286
|
hnswlib::SpaceInterface<float>* space;
|
@@ -264,14 +289,15 @@ private:
|
|
264
289
|
} else {
|
265
290
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
266
291
|
}
|
267
|
-
const size_t max_elements = (
|
268
|
-
const size_t m = (
|
269
|
-
const size_t ef_construction = (
|
270
|
-
const size_t random_seed = (
|
292
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
293
|
+
const size_t m = NUM2SIZET(kw_values[2]);
|
294
|
+
const size_t ef_construction = NUM2SIZET(kw_values[3]);
|
295
|
+
const size_t random_seed = NUM2SIZET(kw_values[4]);
|
296
|
+
const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
|
271
297
|
|
272
298
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
273
299
|
try {
|
274
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
300
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
275
301
|
} catch (const std::runtime_error& e) {
|
276
302
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
277
303
|
return Qnil;
|
@@ -280,33 +306,63 @@ private:
|
|
280
306
|
return Qnil;
|
281
307
|
};
|
282
308
|
|
283
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
284
|
-
|
309
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
310
|
+
VALUE _arr, _idx, _replace_deleted;
|
311
|
+
VALUE kw_args = Qnil;
|
312
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
313
|
+
VALUE kw_values[1] = {Qundef};
|
314
|
+
|
315
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
316
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
317
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
285
318
|
|
286
|
-
|
319
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
320
|
+
|
321
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
287
322
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
288
323
|
return Qfalse;
|
289
324
|
}
|
290
|
-
if (!RB_INTEGER_TYPE_P(
|
325
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
291
326
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
292
327
|
return Qfalse;
|
293
328
|
}
|
294
|
-
if (dim != RARRAY_LEN(
|
329
|
+
if (dim != RARRAY_LEN(_arr)) {
|
295
330
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
296
331
|
return Qfalse;
|
297
332
|
}
|
333
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
334
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
335
|
+
return Qfalse;
|
336
|
+
}
|
298
337
|
|
299
|
-
float*
|
300
|
-
for (
|
338
|
+
float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
|
339
|
+
for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
340
|
+
const size_t idx = NUM2SIZET(_idx);
|
341
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
301
342
|
|
302
|
-
|
343
|
+
try {
|
344
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
|
345
|
+
} catch (const std::runtime_error& e) {
|
346
|
+
ruby_xfree(arr);
|
347
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
+
return Qfalse;
|
349
|
+
}
|
303
350
|
|
304
|
-
ruby_xfree(
|
351
|
+
ruby_xfree(arr);
|
305
352
|
return Qtrue;
|
306
353
|
};
|
307
354
|
|
308
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
309
|
-
|
355
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
356
|
+
VALUE arr, k, filter;
|
357
|
+
VALUE kw_args = Qnil;
|
358
|
+
ID kw_table[1] = {rb_intern("filter")};
|
359
|
+
VALUE kw_values[1] = {Qundef};
|
360
|
+
|
361
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
362
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
363
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
364
|
+
|
365
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
310
366
|
|
311
367
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
312
368
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -321,14 +377,24 @@ private:
|
|
321
377
|
return Qnil;
|
322
378
|
}
|
323
379
|
|
380
|
+
CustomFilterFunctor* filter_func = nullptr;
|
381
|
+
if (!NIL_P(filter)) {
|
382
|
+
try {
|
383
|
+
filter_func = new CustomFilterFunctor(filter);
|
384
|
+
} catch (const std::bad_alloc& e) {
|
385
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
386
|
+
return Qnil;
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
324
390
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
325
|
-
for (
|
391
|
+
for (size_t i = 0; i < dim; i++) {
|
326
392
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
327
393
|
}
|
328
394
|
|
329
395
|
std::priority_queue<std::pair<float, size_t>> result;
|
330
396
|
try {
|
331
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
397
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
332
398
|
} catch (const std::runtime_error& e) {
|
333
399
|
ruby_xfree(vec);
|
334
400
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -336,8 +402,9 @@ private:
|
|
336
402
|
}
|
337
403
|
|
338
404
|
ruby_xfree(vec);
|
405
|
+
if (filter_func) delete filter_func;
|
339
406
|
|
340
|
-
if (result.size() != (
|
407
|
+
if (result.size() != NUM2SIZET(k)) {
|
341
408
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
342
409
|
}
|
343
410
|
|
@@ -347,7 +414,7 @@ private:
|
|
347
414
|
while (!result.empty()) {
|
348
415
|
const std::pair<float, size_t>& result_tuple = result.top();
|
349
416
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
350
|
-
rb_ary_unshift(neighbors_arr,
|
417
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
351
418
|
result.pop();
|
352
419
|
}
|
353
420
|
|
@@ -364,8 +431,28 @@ private:
|
|
364
431
|
return Qnil;
|
365
432
|
};
|
366
433
|
|
367
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
434
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
435
|
+
VALUE _filename, _allow_replace_deleted;
|
436
|
+
VALUE kw_args = Qnil;
|
437
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
438
|
+
VALUE kw_values[1] = {Qundef};
|
439
|
+
|
440
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
441
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
442
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
443
|
+
|
444
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
445
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
446
|
+
return Qnil;
|
447
|
+
}
|
448
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
449
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
450
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
451
|
+
return Qnil;
|
452
|
+
}
|
453
|
+
|
368
454
|
std::string filename(StringValuePtr(_filename));
|
455
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
369
456
|
VALUE ivspace = rb_iv_get(self, "@space");
|
370
457
|
hnswlib::SpaceInterface<float>* space;
|
371
458
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -373,6 +460,7 @@ private:
|
|
373
460
|
} else {
|
374
461
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
375
462
|
}
|
463
|
+
|
376
464
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
377
465
|
if (index->data_level0_memory_) {
|
378
466
|
free(index->data_level0_memory_);
|
@@ -392,12 +480,15 @@ private:
|
|
392
480
|
delete index->visited_list_pool_;
|
393
481
|
index->visited_list_pool_ = nullptr;
|
394
482
|
}
|
483
|
+
|
395
484
|
try {
|
396
485
|
index->loadIndex(filename, space);
|
486
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
397
487
|
} catch (const std::runtime_error& e) {
|
398
488
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
399
489
|
return Qnil;
|
400
490
|
}
|
491
|
+
|
401
492
|
RB_GC_GUARD(_filename);
|
402
493
|
return Qnil;
|
403
494
|
};
|
@@ -405,7 +496,7 @@ private:
|
|
405
496
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
406
497
|
VALUE ret = Qnil;
|
407
498
|
try {
|
408
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
499
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
409
500
|
ret = rb_ary_new2(vec.size());
|
410
501
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
411
502
|
} catch (const std::runtime_error& e) {
|
@@ -417,13 +508,23 @@ private:
|
|
417
508
|
|
418
509
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
419
510
|
VALUE ret = rb_ary_new();
|
420
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
511
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
421
512
|
return ret;
|
422
513
|
};
|
423
514
|
|
424
515
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
425
516
|
try {
|
426
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
517
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
518
|
+
} catch (const std::runtime_error& e) {
|
519
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
520
|
+
return Qnil;
|
521
|
+
}
|
522
|
+
return Qnil;
|
523
|
+
};
|
524
|
+
|
525
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
526
|
+
try {
|
527
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
427
528
|
} catch (const std::runtime_error& e) {
|
428
529
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
429
530
|
return Qnil;
|
@@ -432,31 +533,42 @@ private:
|
|
432
533
|
};
|
433
534
|
|
434
535
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
435
|
-
if ((
|
536
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
436
537
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
437
538
|
return Qnil;
|
438
539
|
}
|
439
540
|
try {
|
440
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
541
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
441
542
|
} catch (const std::runtime_error& e) {
|
442
543
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
443
544
|
return Qnil;
|
545
|
+
} catch (const std::bad_alloc& e) {
|
546
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
547
|
+
return Qnil;
|
444
548
|
}
|
445
549
|
return Qnil;
|
446
550
|
};
|
447
551
|
|
448
552
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
449
|
-
get_hnsw_hierarchicalnsw(self)->
|
553
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
450
554
|
return Qnil;
|
451
555
|
};
|
452
556
|
|
557
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
558
|
+
|
453
559
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
454
|
-
return
|
560
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
455
561
|
};
|
456
562
|
|
457
563
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
458
|
-
return
|
564
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
565
|
+
};
|
566
|
+
|
567
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
568
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
459
569
|
};
|
570
|
+
|
571
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
460
572
|
};
|
461
573
|
|
462
574
|
// clang-format off
|
@@ -500,7 +612,7 @@ public:
|
|
500
612
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
501
613
|
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
502
614
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
503
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
615
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
504
616
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
505
617
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
506
618
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -537,7 +649,7 @@ private:
|
|
537
649
|
} else {
|
538
650
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
539
651
|
}
|
540
|
-
const size_t max_elements = (
|
652
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
541
653
|
|
542
654
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
543
655
|
try {
|
@@ -551,7 +663,7 @@ private:
|
|
551
663
|
};
|
552
664
|
|
553
665
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
554
|
-
const
|
666
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
555
667
|
|
556
668
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
557
669
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -567,10 +679,10 @@ private:
|
|
567
679
|
}
|
568
680
|
|
569
681
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
570
|
-
for (
|
682
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
571
683
|
|
572
684
|
try {
|
573
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
685
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
574
686
|
} catch (const std::runtime_error& e) {
|
575
687
|
ruby_xfree(vec);
|
576
688
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -581,8 +693,17 @@ private:
|
|
581
693
|
return Qtrue;
|
582
694
|
};
|
583
695
|
|
584
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
585
|
-
|
696
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
697
|
+
VALUE arr, k, filter;
|
698
|
+
VALUE kw_args = Qnil;
|
699
|
+
ID kw_table[1] = {rb_intern("filter")};
|
700
|
+
VALUE kw_values[1] = {Qundef};
|
701
|
+
|
702
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
703
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
704
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
705
|
+
|
706
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
586
707
|
|
587
708
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
588
709
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -597,17 +718,28 @@ private:
|
|
597
718
|
return Qnil;
|
598
719
|
}
|
599
720
|
|
721
|
+
CustomFilterFunctor* filter_func = nullptr;
|
722
|
+
if (!NIL_P(filter)) {
|
723
|
+
try {
|
724
|
+
filter_func = new CustomFilterFunctor(filter);
|
725
|
+
} catch (const std::bad_alloc& e) {
|
726
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
727
|
+
return Qnil;
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
600
731
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
601
|
-
for (
|
732
|
+
for (size_t i = 0; i < dim; i++) {
|
602
733
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
603
734
|
}
|
604
735
|
|
605
736
|
std::priority_queue<std::pair<float, size_t>> result =
|
606
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
737
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
607
738
|
|
608
739
|
ruby_xfree(vec);
|
740
|
+
if (filter_func) delete filter_func;
|
609
741
|
|
610
|
-
if (result.size() != (
|
742
|
+
if (result.size() != NUM2SIZET(k)) {
|
611
743
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
612
744
|
}
|
613
745
|
|
@@ -617,7 +749,7 @@ private:
|
|
617
749
|
while (!result.empty()) {
|
618
750
|
const std::pair<float, size_t>& result_tuple = result.top();
|
619
751
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
620
|
-
rb_ary_unshift(neighbors_arr,
|
752
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
621
753
|
result.pop();
|
622
754
|
}
|
623
755
|
|
@@ -659,16 +791,16 @@ private:
|
|
659
791
|
};
|
660
792
|
|
661
793
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
662
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
794
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
663
795
|
return Qnil;
|
664
796
|
};
|
665
797
|
|
666
798
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
667
|
-
return
|
799
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
668
800
|
};
|
669
801
|
|
670
802
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
671
|
-
return
|
803
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
672
804
|
};
|
673
805
|
};
|
674
806
|
|