hnswlib 0.6.1 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/hnswlib/hnswlibext.cpp +2 -3
- data/ext/hnswlib/hnswlibext.hpp +202 -62
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +17 -10
- data/sig/hnswlib.rbs +6 -6
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
|
4
|
+
data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
|
7
|
+
data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,14 @@
|
|
1
|
+
## [0.7.0] - 2023-03-04
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.7.0.
|
4
|
+
- Add support for replacing an element marked for deletion with a new element.
|
5
|
+
- Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
|
6
|
+
|
7
|
+
## [0.6.2] - 2022-06-25
|
8
|
+
|
9
|
+
- Refactor codes and configs with RuboCop and clang-format.
|
10
|
+
- Change to raise ArgumentError when non-array object is given to distance method.
|
11
|
+
|
1
12
|
## [0.6.1] - 2022-04-30
|
2
13
|
|
3
14
|
- Change the `search_knn` method of `BruteforceSearch` to output warning message instead of rasing RuntimeError
|
data/ext/hnswlib/hnswlibext.cpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -20,8 +20,7 @@
|
|
20
20
|
|
21
21
|
VALUE rb_mHnswlib;
|
22
22
|
|
23
|
-
extern "C"
|
24
|
-
void Init_hnswlibext(void) {
|
23
|
+
extern "C" void Init_hnswlibext(void) {
|
25
24
|
rb_mHnswlib = rb_define_module("Hnswlib");
|
26
25
|
RbHnswlibL2Space::define_class(rb_mHnswlib);
|
27
26
|
RbHnswlibInnerProductSpace::define_class(rb_mHnswlib);
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,9 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <new>
|
27
|
+
#include <vector>
|
28
|
+
|
26
29
|
VALUE rb_cHnswlibL2Space;
|
27
30
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
31
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -64,20 +67,24 @@ private:
|
|
64
67
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
68
|
rb_iv_set(self, "@dim", dim);
|
66
69
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
70
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
71
|
return Qnil;
|
69
72
|
};
|
70
73
|
|
71
74
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
75
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
76
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
77
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
78
|
+
return Qnil;
|
79
|
+
}
|
73
80
|
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
74
81
|
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
75
82
|
return Qnil;
|
76
83
|
}
|
77
84
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
78
|
-
for (
|
85
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
79
86
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
80
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
81
88
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
82
89
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
83
90
|
ruby_xfree(vec_a);
|
@@ -136,20 +143,24 @@ private:
|
|
136
143
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
137
144
|
rb_iv_set(self, "@dim", dim);
|
138
145
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
139
|
-
new (ptr) hnswlib::InnerProductSpace(
|
146
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
140
147
|
return Qnil;
|
141
148
|
};
|
142
149
|
|
143
150
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
144
|
-
const
|
151
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
152
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
153
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
154
|
+
return Qnil;
|
155
|
+
}
|
145
156
|
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
146
157
|
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
147
158
|
return Qnil;
|
148
159
|
}
|
149
160
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
150
|
-
for (
|
161
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
151
162
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
152
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
153
164
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
154
165
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
155
166
|
ruby_xfree(vec_a);
|
@@ -172,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
172
183
|
};
|
173
184
|
// clang-format on
|
174
185
|
|
186
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
187
|
+
public:
|
188
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
189
|
+
|
190
|
+
bool operator()(hnswlib::labeltype id) {
|
191
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
192
|
+
return result == Qtrue ? true : false;
|
193
|
+
}
|
194
|
+
|
195
|
+
private:
|
196
|
+
VALUE callback_;
|
197
|
+
};
|
198
|
+
|
175
199
|
class RbHnswlibHierarchicalNSW {
|
176
200
|
public:
|
177
201
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -198,17 +222,21 @@ public:
|
|
198
222
|
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
199
223
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
200
224
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
201
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point),
|
202
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn),
|
225
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
203
227
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
204
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
205
229
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
206
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
207
231
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
232
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
208
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
209
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
210
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
211
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
239
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
212
240
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
213
241
|
return rb_cHnswlibHierarchicalNSW;
|
214
242
|
};
|
@@ -218,14 +246,15 @@ private:
|
|
218
246
|
|
219
247
|
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
220
248
|
VALUE kw_args = Qnil;
|
221
|
-
ID kw_table[
|
222
|
-
rb_intern("random_seed")};
|
223
|
-
VALUE kw_values[
|
249
|
+
ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
|
250
|
+
rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
|
251
|
+
VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
224
252
|
rb_scan_args(argc, argv, ":", &kw_args);
|
225
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
226
|
-
if (kw_values[2] == Qundef) kw_values[2] =
|
227
|
-
if (kw_values[3] == Qundef) kw_values[3] =
|
228
|
-
if (kw_values[4] == Qundef) kw_values[4] =
|
253
|
+
rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
|
254
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
|
255
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
|
256
|
+
if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
|
257
|
+
if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
|
229
258
|
|
230
259
|
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
231
260
|
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
@@ -248,6 +277,10 @@ private:
|
|
248
277
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
249
278
|
return Qnil;
|
250
279
|
}
|
280
|
+
if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
|
281
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
282
|
+
return Qnil;
|
283
|
+
}
|
251
284
|
|
252
285
|
rb_iv_set(self, "@space", kw_values[0]);
|
253
286
|
hnswlib::SpaceInterface<float>* space;
|
@@ -256,14 +289,15 @@ private:
|
|
256
289
|
} else {
|
257
290
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
258
291
|
}
|
259
|
-
const size_t max_elements = (
|
260
|
-
const size_t m = (
|
261
|
-
const size_t ef_construction = (
|
262
|
-
const size_t random_seed = (
|
292
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
293
|
+
const size_t m = NUM2SIZET(kw_values[2]);
|
294
|
+
const size_t ef_construction = NUM2SIZET(kw_values[3]);
|
295
|
+
const size_t random_seed = NUM2SIZET(kw_values[4]);
|
296
|
+
const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
|
263
297
|
|
264
298
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
265
299
|
try {
|
266
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
300
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
267
301
|
} catch (const std::runtime_error& e) {
|
268
302
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
269
303
|
return Qnil;
|
@@ -272,33 +306,63 @@ private:
|
|
272
306
|
return Qnil;
|
273
307
|
};
|
274
308
|
|
275
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
276
|
-
|
309
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
310
|
+
VALUE _arr, _idx, _replace_deleted;
|
311
|
+
VALUE kw_args = Qnil;
|
312
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
313
|
+
VALUE kw_values[1] = {Qundef};
|
277
314
|
|
278
|
-
|
315
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
316
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
317
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
318
|
+
|
319
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
320
|
+
|
321
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
279
322
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
280
323
|
return Qfalse;
|
281
324
|
}
|
282
|
-
if (!RB_INTEGER_TYPE_P(
|
325
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
283
326
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
284
327
|
return Qfalse;
|
285
328
|
}
|
286
|
-
if (dim != RARRAY_LEN(
|
329
|
+
if (dim != RARRAY_LEN(_arr)) {
|
287
330
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
288
331
|
return Qfalse;
|
289
332
|
}
|
333
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
334
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
335
|
+
return Qfalse;
|
336
|
+
}
|
290
337
|
|
291
|
-
float*
|
292
|
-
for (
|
338
|
+
float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
|
339
|
+
for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
340
|
+
const size_t idx = NUM2SIZET(_idx);
|
341
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
293
342
|
|
294
|
-
|
343
|
+
try {
|
344
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
|
345
|
+
} catch (const std::runtime_error& e) {
|
346
|
+
ruby_xfree(arr);
|
347
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
+
return Qfalse;
|
349
|
+
}
|
295
350
|
|
296
|
-
ruby_xfree(
|
351
|
+
ruby_xfree(arr);
|
297
352
|
return Qtrue;
|
298
353
|
};
|
299
354
|
|
300
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
301
|
-
|
355
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
356
|
+
VALUE arr, k, filter;
|
357
|
+
VALUE kw_args = Qnil;
|
358
|
+
ID kw_table[1] = {rb_intern("filter")};
|
359
|
+
VALUE kw_values[1] = {Qundef};
|
360
|
+
|
361
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
362
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
363
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
364
|
+
|
365
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
302
366
|
|
303
367
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
304
368
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -313,14 +377,24 @@ private:
|
|
313
377
|
return Qnil;
|
314
378
|
}
|
315
379
|
|
380
|
+
CustomFilterFunctor* filter_func = nullptr;
|
381
|
+
if (!NIL_P(filter)) {
|
382
|
+
try {
|
383
|
+
filter_func = new CustomFilterFunctor(filter);
|
384
|
+
} catch (const std::bad_alloc& e) {
|
385
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
386
|
+
return Qnil;
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
316
390
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
317
|
-
for (
|
391
|
+
for (size_t i = 0; i < dim; i++) {
|
318
392
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
319
393
|
}
|
320
394
|
|
321
395
|
std::priority_queue<std::pair<float, size_t>> result;
|
322
396
|
try {
|
323
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
397
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
324
398
|
} catch (const std::runtime_error& e) {
|
325
399
|
ruby_xfree(vec);
|
326
400
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -328,8 +402,9 @@ private:
|
|
328
402
|
}
|
329
403
|
|
330
404
|
ruby_xfree(vec);
|
405
|
+
if (filter_func) delete filter_func;
|
331
406
|
|
332
|
-
if (result.size() != (
|
407
|
+
if (result.size() != NUM2SIZET(k)) {
|
333
408
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
334
409
|
}
|
335
410
|
|
@@ -339,7 +414,7 @@ private:
|
|
339
414
|
while (!result.empty()) {
|
340
415
|
const std::pair<float, size_t>& result_tuple = result.top();
|
341
416
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
342
|
-
rb_ary_unshift(neighbors_arr,
|
417
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
343
418
|
result.pop();
|
344
419
|
}
|
345
420
|
|
@@ -356,8 +431,28 @@ private:
|
|
356
431
|
return Qnil;
|
357
432
|
};
|
358
433
|
|
359
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
434
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
435
|
+
VALUE _filename, _allow_replace_deleted;
|
436
|
+
VALUE kw_args = Qnil;
|
437
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
438
|
+
VALUE kw_values[1] = {Qundef};
|
439
|
+
|
440
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
441
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
442
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
443
|
+
|
444
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
445
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
446
|
+
return Qnil;
|
447
|
+
}
|
448
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
449
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
450
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
451
|
+
return Qnil;
|
452
|
+
}
|
453
|
+
|
360
454
|
std::string filename(StringValuePtr(_filename));
|
455
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
361
456
|
VALUE ivspace = rb_iv_get(self, "@space");
|
362
457
|
hnswlib::SpaceInterface<float>* space;
|
363
458
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -365,6 +460,7 @@ private:
|
|
365
460
|
} else {
|
366
461
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
367
462
|
}
|
463
|
+
|
368
464
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
369
465
|
if (index->data_level0_memory_) {
|
370
466
|
free(index->data_level0_memory_);
|
@@ -384,12 +480,15 @@ private:
|
|
384
480
|
delete index->visited_list_pool_;
|
385
481
|
index->visited_list_pool_ = nullptr;
|
386
482
|
}
|
483
|
+
|
387
484
|
try {
|
388
485
|
index->loadIndex(filename, space);
|
486
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
389
487
|
} catch (const std::runtime_error& e) {
|
390
488
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
391
489
|
return Qnil;
|
392
490
|
}
|
491
|
+
|
393
492
|
RB_GC_GUARD(_filename);
|
394
493
|
return Qnil;
|
395
494
|
};
|
@@ -397,7 +496,7 @@ private:
|
|
397
496
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
398
497
|
VALUE ret = Qnil;
|
399
498
|
try {
|
400
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
499
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
401
500
|
ret = rb_ary_new2(vec.size());
|
402
501
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
403
502
|
} catch (const std::runtime_error& e) {
|
@@ -409,13 +508,23 @@ private:
|
|
409
508
|
|
410
509
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
411
510
|
VALUE ret = rb_ary_new();
|
412
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
511
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
413
512
|
return ret;
|
414
513
|
};
|
415
514
|
|
416
515
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
417
516
|
try {
|
418
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
517
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
518
|
+
} catch (const std::runtime_error& e) {
|
519
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
520
|
+
return Qnil;
|
521
|
+
}
|
522
|
+
return Qnil;
|
523
|
+
};
|
524
|
+
|
525
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
526
|
+
try {
|
527
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
419
528
|
} catch (const std::runtime_error& e) {
|
420
529
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
421
530
|
return Qnil;
|
@@ -424,31 +533,42 @@ private:
|
|
424
533
|
};
|
425
534
|
|
426
535
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
427
|
-
if ((
|
536
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
428
537
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
429
538
|
return Qnil;
|
430
539
|
}
|
431
540
|
try {
|
432
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
541
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
433
542
|
} catch (const std::runtime_error& e) {
|
434
543
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
435
544
|
return Qnil;
|
545
|
+
} catch (const std::bad_alloc& e) {
|
546
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
547
|
+
return Qnil;
|
436
548
|
}
|
437
549
|
return Qnil;
|
438
550
|
};
|
439
551
|
|
440
552
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
441
|
-
get_hnsw_hierarchicalnsw(self)->
|
553
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
442
554
|
return Qnil;
|
443
555
|
};
|
444
556
|
|
557
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
558
|
+
|
445
559
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
446
|
-
return
|
560
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
447
561
|
};
|
448
562
|
|
449
563
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
450
|
-
return
|
564
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
565
|
+
};
|
566
|
+
|
567
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
568
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
451
569
|
};
|
570
|
+
|
571
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
452
572
|
};
|
453
573
|
|
454
574
|
// clang-format off
|
@@ -492,7 +612,7 @@ public:
|
|
492
612
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
493
613
|
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
494
614
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
495
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
615
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
496
616
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
497
617
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
498
618
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -529,7 +649,7 @@ private:
|
|
529
649
|
} else {
|
530
650
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
531
651
|
}
|
532
|
-
const size_t max_elements = (
|
652
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
533
653
|
|
534
654
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
535
655
|
try {
|
@@ -543,7 +663,7 @@ private:
|
|
543
663
|
};
|
544
664
|
|
545
665
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
546
|
-
const
|
666
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
547
667
|
|
548
668
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
549
669
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -559,10 +679,10 @@ private:
|
|
559
679
|
}
|
560
680
|
|
561
681
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
562
|
-
for (
|
682
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
563
683
|
|
564
684
|
try {
|
565
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
685
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
566
686
|
} catch (const std::runtime_error& e) {
|
567
687
|
ruby_xfree(vec);
|
568
688
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -573,8 +693,17 @@ private:
|
|
573
693
|
return Qtrue;
|
574
694
|
};
|
575
695
|
|
576
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
577
|
-
|
696
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
697
|
+
VALUE arr, k, filter;
|
698
|
+
VALUE kw_args = Qnil;
|
699
|
+
ID kw_table[1] = {rb_intern("filter")};
|
700
|
+
VALUE kw_values[1] = {Qundef};
|
701
|
+
|
702
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
703
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
704
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
705
|
+
|
706
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
578
707
|
|
579
708
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
580
709
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -589,17 +718,28 @@ private:
|
|
589
718
|
return Qnil;
|
590
719
|
}
|
591
720
|
|
721
|
+
CustomFilterFunctor* filter_func = nullptr;
|
722
|
+
if (!NIL_P(filter)) {
|
723
|
+
try {
|
724
|
+
filter_func = new CustomFilterFunctor(filter);
|
725
|
+
} catch (const std::bad_alloc& e) {
|
726
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
727
|
+
return Qnil;
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
592
731
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
593
|
-
for (
|
732
|
+
for (size_t i = 0; i < dim; i++) {
|
594
733
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
595
734
|
}
|
596
735
|
|
597
736
|
std::priority_queue<std::pair<float, size_t>> result =
|
598
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
737
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
599
738
|
|
600
739
|
ruby_xfree(vec);
|
740
|
+
if (filter_func) delete filter_func;
|
601
741
|
|
602
|
-
if (result.size() != (
|
742
|
+
if (result.size() != NUM2SIZET(k)) {
|
603
743
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
604
744
|
}
|
605
745
|
|
@@ -609,7 +749,7 @@ private:
|
|
609
749
|
while (!result.empty()) {
|
610
750
|
const std::pair<float, size_t>& result_tuple = result.top();
|
611
751
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
612
|
-
rb_ary_unshift(neighbors_arr,
|
752
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
613
753
|
result.pop();
|
614
754
|
}
|
615
755
|
|
@@ -651,16 +791,16 @@ private:
|
|
651
791
|
};
|
652
792
|
|
653
793
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
654
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
794
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
655
795
|
return Qnil;
|
656
796
|
};
|
657
797
|
|
658
798
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
659
|
-
return
|
799
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
660
800
|
};
|
661
801
|
|
662
802
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
663
|
-
return
|
803
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
664
804
|
};
|
665
805
|
};
|
666
806
|
|