hnswlib 0.6.1 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/ext/hnswlib/hnswlibext.cpp +2 -3
- data/ext/hnswlib/hnswlibext.hpp +202 -62
- data/ext/hnswlib/src/bruteforce.h +142 -131
- data/ext/hnswlib/src/hnswalg.h +1028 -964
- data/ext/hnswlib/src/hnswlib.h +74 -66
- data/ext/hnswlib/src/space_ip.h +299 -299
- data/ext/hnswlib/src/space_l2.h +268 -273
- data/ext/hnswlib/src/visited_list_pool.h +54 -55
- data/lib/hnswlib/version.rb +2 -2
- data/lib/hnswlib.rb +17 -10
- data/sig/hnswlib.rbs +6 -6
- metadata +4 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a564835f983f6c07a04d62e198aac0b2a80eef8eaa784a14d3d7bdc5dadaa962
|
4
|
+
data.tar.gz: 98cc90158fbe92a012a6e0f945a1c82cd3bbfc1ae973cf1c4c1eccf95c249fed
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ada8080314124e768aba55385f92c98a8e1206932c25a3d93cff2eaf8659d44375622a22226bb96379c7377055d8840703aa8dac25b49440f1a862f5c9b45444
|
7
|
+
data.tar.gz: 6ce0ee55fd1d0174bd1d06377eb1a5ed3cf875ea1225da0320ba6e4b003bdd18d73d00f126fc17484dc8c2b45621a2cfe86c74134f6c227b4b1364b631977d69
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,14 @@
|
|
1
|
+
## [0.7.0] - 2023-03-04
|
2
|
+
|
3
|
+
- Update bundled hnswlib version to 0.7.0.
|
4
|
+
- Add support for replacing an element marked for deletion with a new element.
|
5
|
+
- Add support filtering function by label in search_knn method of BruteforeceSearch and HierarchicalNSW.
|
6
|
+
|
7
|
+
## [0.6.2] - 2022-06-25
|
8
|
+
|
9
|
+
- Refactor codes and configs with RuboCop and clang-format.
|
10
|
+
- Change to raise ArgumentError when non-array object is given to distance method.
|
11
|
+
|
1
12
|
## [0.6.1] - 2022-04-30
|
2
13
|
|
3
14
|
- Change the `search_knn` method of `BruteforceSearch` to output warning message instead of rasing RuntimeError
|
data/ext/hnswlib/hnswlibext.cpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -20,8 +20,7 @@
|
|
20
20
|
|
21
21
|
VALUE rb_mHnswlib;
|
22
22
|
|
23
|
-
extern "C"
|
24
|
-
void Init_hnswlibext(void) {
|
23
|
+
extern "C" void Init_hnswlibext(void) {
|
25
24
|
rb_mHnswlib = rb_define_module("Hnswlib");
|
26
25
|
RbHnswlibL2Space::define_class(rb_mHnswlib);
|
27
26
|
RbHnswlibInnerProductSpace::define_class(rb_mHnswlib);
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
/**
|
2
2
|
* hnswlib.rb is a Ruby binding for the Hnswlib.
|
3
3
|
*
|
4
|
-
* Copyright (c) 2021-
|
4
|
+
* Copyright (c) 2021-2023 Atsushi Tatsuma
|
5
5
|
*
|
6
6
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
* you may not use this file except in compliance with the License.
|
@@ -23,6 +23,9 @@
|
|
23
23
|
|
24
24
|
#include <hnswlib.h>
|
25
25
|
|
26
|
+
#include <new>
|
27
|
+
#include <vector>
|
28
|
+
|
26
29
|
VALUE rb_cHnswlibL2Space;
|
27
30
|
VALUE rb_cHnswlibInnerProductSpace;
|
28
31
|
VALUE rb_cHnswlibHierarchicalNSW;
|
@@ -64,20 +67,24 @@ private:
|
|
64
67
|
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
68
|
rb_iv_set(self, "@dim", dim);
|
66
69
|
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
-
new (ptr) hnswlib::L2Space(
|
70
|
+
new (ptr) hnswlib::L2Space(NUM2SIZET(rb_iv_get(self, "@dim")));
|
68
71
|
return Qnil;
|
69
72
|
};
|
70
73
|
|
71
74
|
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
-
const
|
75
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
76
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
77
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
78
|
+
return Qnil;
|
79
|
+
}
|
73
80
|
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
74
81
|
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
75
82
|
return Qnil;
|
76
83
|
}
|
77
84
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
78
|
-
for (
|
85
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
79
86
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
80
|
-
for (
|
87
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
81
88
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
82
89
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
83
90
|
ruby_xfree(vec_a);
|
@@ -136,20 +143,24 @@ private:
|
|
136
143
|
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
137
144
|
rb_iv_set(self, "@dim", dim);
|
138
145
|
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
139
|
-
new (ptr) hnswlib::InnerProductSpace(
|
146
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2SIZET(rb_iv_get(self, "@dim")));
|
140
147
|
return Qnil;
|
141
148
|
};
|
142
149
|
|
143
150
|
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
144
|
-
const
|
151
|
+
const size_t dim = NUM2SIZET(rb_iv_get(self, "@dim"));
|
152
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
153
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
154
|
+
return Qnil;
|
155
|
+
}
|
145
156
|
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
146
157
|
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
147
158
|
return Qnil;
|
148
159
|
}
|
149
160
|
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
150
|
-
for (
|
161
|
+
for (size_t i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
151
162
|
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
152
|
-
for (
|
163
|
+
for (size_t i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
153
164
|
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
154
165
|
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
155
166
|
ruby_xfree(vec_a);
|
@@ -172,6 +183,19 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
172
183
|
};
|
173
184
|
// clang-format on
|
174
185
|
|
186
|
+
class CustomFilterFunctor : public hnswlib::BaseFilterFunctor {
|
187
|
+
public:
|
188
|
+
CustomFilterFunctor(const VALUE& callback) : callback_(callback) {}
|
189
|
+
|
190
|
+
bool operator()(hnswlib::labeltype id) {
|
191
|
+
VALUE result = rb_funcall(callback_, rb_intern("call"), 1, SIZET2NUM(id));
|
192
|
+
return result == Qtrue ? true : false;
|
193
|
+
}
|
194
|
+
|
195
|
+
private:
|
196
|
+
VALUE callback_;
|
197
|
+
};
|
198
|
+
|
175
199
|
class RbHnswlibHierarchicalNSW {
|
176
200
|
public:
|
177
201
|
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
@@ -198,17 +222,21 @@ public:
|
|
198
222
|
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
199
223
|
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
200
224
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
201
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point),
|
202
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn),
|
225
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), -1);
|
226
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), -1);
|
203
227
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
204
|
-
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
228
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), -1);
|
205
229
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
206
230
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
207
231
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
232
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "unmark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_unmark_deleted), 1);
|
208
233
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
209
234
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
235
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ef), 0);
|
210
236
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
211
237
|
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
238
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "ef_construction", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_ef_construction), 0);
|
239
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "m", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_m), 0);
|
212
240
|
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
213
241
|
return rb_cHnswlibHierarchicalNSW;
|
214
242
|
};
|
@@ -218,14 +246,15 @@ private:
|
|
218
246
|
|
219
247
|
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
220
248
|
VALUE kw_args = Qnil;
|
221
|
-
ID kw_table[
|
222
|
-
rb_intern("random_seed")};
|
223
|
-
VALUE kw_values[
|
249
|
+
ID kw_table[6] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"),
|
250
|
+
rb_intern("ef_construction"), rb_intern("random_seed"), rb_intern("allow_replace_deleted")};
|
251
|
+
VALUE kw_values[6] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
|
224
252
|
rb_scan_args(argc, argv, ":", &kw_args);
|
225
|
-
rb_get_kwargs(kw_args, kw_table, 2,
|
226
|
-
if (kw_values[2] == Qundef) kw_values[2] =
|
227
|
-
if (kw_values[3] == Qundef) kw_values[3] =
|
228
|
-
if (kw_values[4] == Qundef) kw_values[4] =
|
253
|
+
rb_get_kwargs(kw_args, kw_table, 2, 4, kw_values);
|
254
|
+
if (kw_values[2] == Qundef) kw_values[2] = SIZET2NUM(16);
|
255
|
+
if (kw_values[3] == Qundef) kw_values[3] = SIZET2NUM(200);
|
256
|
+
if (kw_values[4] == Qundef) kw_values[4] = SIZET2NUM(100);
|
257
|
+
if (kw_values[5] == Qundef) kw_values[5] = Qfalse;
|
229
258
|
|
230
259
|
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
231
260
|
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
@@ -248,6 +277,10 @@ private:
|
|
248
277
|
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
249
278
|
return Qnil;
|
250
279
|
}
|
280
|
+
if (!RB_TYPE_P(kw_values[5], T_TRUE) && !RB_TYPE_P(kw_values[5], T_FALSE)) {
|
281
|
+
rb_raise(rb_eTypeError, "expected allow_replace_deleted, Boolean");
|
282
|
+
return Qnil;
|
283
|
+
}
|
251
284
|
|
252
285
|
rb_iv_set(self, "@space", kw_values[0]);
|
253
286
|
hnswlib::SpaceInterface<float>* space;
|
@@ -256,14 +289,15 @@ private:
|
|
256
289
|
} else {
|
257
290
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
258
291
|
}
|
259
|
-
const size_t max_elements = (
|
260
|
-
const size_t m = (
|
261
|
-
const size_t ef_construction = (
|
262
|
-
const size_t random_seed = (
|
292
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
293
|
+
const size_t m = NUM2SIZET(kw_values[2]);
|
294
|
+
const size_t ef_construction = NUM2SIZET(kw_values[3]);
|
295
|
+
const size_t random_seed = NUM2SIZET(kw_values[4]);
|
296
|
+
const bool allow_replace_deleted = kw_values[5] == Qtrue ? true : false;
|
263
297
|
|
264
298
|
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
265
299
|
try {
|
266
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
300
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed, allow_replace_deleted);
|
267
301
|
} catch (const std::runtime_error& e) {
|
268
302
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
269
303
|
return Qnil;
|
@@ -272,33 +306,63 @@ private:
|
|
272
306
|
return Qnil;
|
273
307
|
};
|
274
308
|
|
275
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(
|
276
|
-
|
309
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(int argc, VALUE* argv, VALUE self) {
|
310
|
+
VALUE _arr, _idx, _replace_deleted;
|
311
|
+
VALUE kw_args = Qnil;
|
312
|
+
ID kw_table[1] = {rb_intern("replace_deleted")};
|
313
|
+
VALUE kw_values[1] = {Qundef};
|
277
314
|
|
278
|
-
|
315
|
+
rb_scan_args(argc, argv, "2:", &_arr, &_idx, &kw_args);
|
316
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
317
|
+
_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
318
|
+
|
319
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
320
|
+
|
321
|
+
if (!RB_TYPE_P(_arr, T_ARRAY)) {
|
279
322
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
280
323
|
return Qfalse;
|
281
324
|
}
|
282
|
-
if (!RB_INTEGER_TYPE_P(
|
325
|
+
if (!RB_INTEGER_TYPE_P(_idx)) {
|
283
326
|
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
284
327
|
return Qfalse;
|
285
328
|
}
|
286
|
-
if (dim != RARRAY_LEN(
|
329
|
+
if (dim != RARRAY_LEN(_arr)) {
|
287
330
|
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
288
331
|
return Qfalse;
|
289
332
|
}
|
333
|
+
if (!RB_TYPE_P(_replace_deleted, T_TRUE) && !RB_TYPE_P(_replace_deleted, T_FALSE)) {
|
334
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
335
|
+
return Qfalse;
|
336
|
+
}
|
290
337
|
|
291
|
-
float*
|
292
|
-
for (
|
338
|
+
float* arr = (float*)ruby_xmalloc(dim * sizeof(float));
|
339
|
+
for (size_t i = 0; i < dim; i++) arr[i] = (float)NUM2DBL(rb_ary_entry(_arr, i));
|
340
|
+
const size_t idx = NUM2SIZET(_idx);
|
341
|
+
const bool replace_deleted = _replace_deleted == Qtrue ? true : false;
|
293
342
|
|
294
|
-
|
343
|
+
try {
|
344
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)arr, idx, replace_deleted);
|
345
|
+
} catch (const std::runtime_error& e) {
|
346
|
+
ruby_xfree(arr);
|
347
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
+
return Qfalse;
|
349
|
+
}
|
295
350
|
|
296
|
-
ruby_xfree(
|
351
|
+
ruby_xfree(arr);
|
297
352
|
return Qtrue;
|
298
353
|
};
|
299
354
|
|
300
|
-
static VALUE _hnsw_hierarchicalnsw_search_knn(
|
301
|
-
|
355
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(int argc, VALUE* argv, VALUE self) {
|
356
|
+
VALUE arr, k, filter;
|
357
|
+
VALUE kw_args = Qnil;
|
358
|
+
ID kw_table[1] = {rb_intern("filter")};
|
359
|
+
VALUE kw_values[1] = {Qundef};
|
360
|
+
|
361
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
362
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
363
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
364
|
+
|
365
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
302
366
|
|
303
367
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
304
368
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -313,14 +377,24 @@ private:
|
|
313
377
|
return Qnil;
|
314
378
|
}
|
315
379
|
|
380
|
+
CustomFilterFunctor* filter_func = nullptr;
|
381
|
+
if (!NIL_P(filter)) {
|
382
|
+
try {
|
383
|
+
filter_func = new CustomFilterFunctor(filter);
|
384
|
+
} catch (const std::bad_alloc& e) {
|
385
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
386
|
+
return Qnil;
|
387
|
+
}
|
388
|
+
}
|
389
|
+
|
316
390
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
317
|
-
for (
|
391
|
+
for (size_t i = 0; i < dim; i++) {
|
318
392
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
319
393
|
}
|
320
394
|
|
321
395
|
std::priority_queue<std::pair<float, size_t>> result;
|
322
396
|
try {
|
323
|
-
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (
|
397
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
324
398
|
} catch (const std::runtime_error& e) {
|
325
399
|
ruby_xfree(vec);
|
326
400
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -328,8 +402,9 @@ private:
|
|
328
402
|
}
|
329
403
|
|
330
404
|
ruby_xfree(vec);
|
405
|
+
if (filter_func) delete filter_func;
|
331
406
|
|
332
|
-
if (result.size() != (
|
407
|
+
if (result.size() != NUM2SIZET(k)) {
|
333
408
|
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
334
409
|
}
|
335
410
|
|
@@ -339,7 +414,7 @@ private:
|
|
339
414
|
while (!result.empty()) {
|
340
415
|
const std::pair<float, size_t>& result_tuple = result.top();
|
341
416
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
342
|
-
rb_ary_unshift(neighbors_arr,
|
417
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
343
418
|
result.pop();
|
344
419
|
}
|
345
420
|
|
@@ -356,8 +431,28 @@ private:
|
|
356
431
|
return Qnil;
|
357
432
|
};
|
358
433
|
|
359
|
-
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE
|
434
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(int argc, VALUE* argv, VALUE self) {
|
435
|
+
VALUE _filename, _allow_replace_deleted;
|
436
|
+
VALUE kw_args = Qnil;
|
437
|
+
ID kw_table[1] = {rb_intern("allow_replace_deleted")};
|
438
|
+
VALUE kw_values[1] = {Qundef};
|
439
|
+
|
440
|
+
rb_scan_args(argc, argv, "1:", &_filename, &kw_args);
|
441
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
442
|
+
_allow_replace_deleted = kw_values[0] != Qundef ? kw_values[0] : Qfalse;
|
443
|
+
|
444
|
+
if (!RB_TYPE_P(_filename, T_STRING)) {
|
445
|
+
rb_raise(rb_eArgError, "Expect filename to be Ruby Array.");
|
446
|
+
return Qnil;
|
447
|
+
}
|
448
|
+
if (!NIL_P(_allow_replace_deleted) && !RB_TYPE_P(_allow_replace_deleted, T_TRUE) &&
|
449
|
+
!RB_TYPE_P(_allow_replace_deleted, T_FALSE)) {
|
450
|
+
rb_raise(rb_eArgError, "Expect replace_deleted to be Boolean.");
|
451
|
+
return Qnil;
|
452
|
+
}
|
453
|
+
|
360
454
|
std::string filename(StringValuePtr(_filename));
|
455
|
+
const bool allow_replace_deleted = _allow_replace_deleted == Qtrue ? true : false;
|
361
456
|
VALUE ivspace = rb_iv_get(self, "@space");
|
362
457
|
hnswlib::SpaceInterface<float>* space;
|
363
458
|
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
@@ -365,6 +460,7 @@ private:
|
|
365
460
|
} else {
|
366
461
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
367
462
|
}
|
463
|
+
|
368
464
|
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
369
465
|
if (index->data_level0_memory_) {
|
370
466
|
free(index->data_level0_memory_);
|
@@ -384,12 +480,15 @@ private:
|
|
384
480
|
delete index->visited_list_pool_;
|
385
481
|
index->visited_list_pool_ = nullptr;
|
386
482
|
}
|
483
|
+
|
387
484
|
try {
|
388
485
|
index->loadIndex(filename, space);
|
486
|
+
index->allow_replace_deleted_ = allow_replace_deleted;
|
389
487
|
} catch (const std::runtime_error& e) {
|
390
488
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
391
489
|
return Qnil;
|
392
490
|
}
|
491
|
+
|
393
492
|
RB_GC_GUARD(_filename);
|
394
493
|
return Qnil;
|
395
494
|
};
|
@@ -397,7 +496,7 @@ private:
|
|
397
496
|
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
398
497
|
VALUE ret = Qnil;
|
399
498
|
try {
|
400
|
-
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((
|
499
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>(NUM2SIZET(idx));
|
401
500
|
ret = rb_ary_new2(vec.size());
|
402
501
|
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
403
502
|
} catch (const std::runtime_error& e) {
|
@@ -409,13 +508,23 @@ private:
|
|
409
508
|
|
410
509
|
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
411
510
|
VALUE ret = rb_ary_new();
|
412
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret,
|
511
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, SIZET2NUM(kv.first));
|
413
512
|
return ret;
|
414
513
|
};
|
415
514
|
|
416
515
|
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
417
516
|
try {
|
418
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((
|
517
|
+
get_hnsw_hierarchicalnsw(self)->markDelete(NUM2SIZET(idx));
|
518
|
+
} catch (const std::runtime_error& e) {
|
519
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
520
|
+
return Qnil;
|
521
|
+
}
|
522
|
+
return Qnil;
|
523
|
+
};
|
524
|
+
|
525
|
+
static VALUE _hnsw_hierarchicalnsw_unmark_deleted(VALUE self, VALUE idx) {
|
526
|
+
try {
|
527
|
+
get_hnsw_hierarchicalnsw(self)->unmarkDelete(NUM2SIZET(idx));
|
419
528
|
} catch (const std::runtime_error& e) {
|
420
529
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
421
530
|
return Qnil;
|
@@ -424,31 +533,42 @@ private:
|
|
424
533
|
};
|
425
534
|
|
426
535
|
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
427
|
-
if ((
|
536
|
+
if (NUM2SIZET(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
428
537
|
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
429
538
|
return Qnil;
|
430
539
|
}
|
431
540
|
try {
|
432
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((
|
541
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex(NUM2SIZET(new_max_elements));
|
433
542
|
} catch (const std::runtime_error& e) {
|
434
543
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
435
544
|
return Qnil;
|
545
|
+
} catch (const std::bad_alloc& e) {
|
546
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
547
|
+
return Qnil;
|
436
548
|
}
|
437
549
|
return Qnil;
|
438
550
|
};
|
439
551
|
|
440
552
|
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
441
|
-
get_hnsw_hierarchicalnsw(self)->
|
553
|
+
get_hnsw_hierarchicalnsw(self)->setEf(NUM2SIZET(ef));
|
442
554
|
return Qnil;
|
443
555
|
};
|
444
556
|
|
557
|
+
static VALUE _hnsw_hierarchicalnsw_get_ef(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_); };
|
558
|
+
|
445
559
|
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
446
|
-
return
|
560
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->max_elements_);
|
447
561
|
};
|
448
562
|
|
449
563
|
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
450
|
-
return
|
564
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->cur_element_count);
|
565
|
+
};
|
566
|
+
|
567
|
+
static VALUE _hnsw_hierarchicalnsw_ef_construction(VALUE self) {
|
568
|
+
return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->ef_construction_);
|
451
569
|
};
|
570
|
+
|
571
|
+
static VALUE _hnsw_hierarchicalnsw_m(VALUE self) { return SIZET2NUM(get_hnsw_hierarchicalnsw(self)->M_); };
|
452
572
|
};
|
453
573
|
|
454
574
|
// clang-format off
|
@@ -492,7 +612,7 @@ public:
|
|
492
612
|
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
493
613
|
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
494
614
|
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
495
|
-
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn),
|
615
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), -1);
|
496
616
|
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
497
617
|
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
498
618
|
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
@@ -529,7 +649,7 @@ private:
|
|
529
649
|
} else {
|
530
650
|
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
531
651
|
}
|
532
|
-
const size_t max_elements = (
|
652
|
+
const size_t max_elements = NUM2SIZET(kw_values[1]);
|
533
653
|
|
534
654
|
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
535
655
|
try {
|
@@ -543,7 +663,7 @@ private:
|
|
543
663
|
};
|
544
664
|
|
545
665
|
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
546
|
-
const
|
666
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
547
667
|
|
548
668
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
549
669
|
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
@@ -559,10 +679,10 @@ private:
|
|
559
679
|
}
|
560
680
|
|
561
681
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
562
|
-
for (
|
682
|
+
for (size_t i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
563
683
|
|
564
684
|
try {
|
565
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (
|
685
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, NUM2SIZET(idx));
|
566
686
|
} catch (const std::runtime_error& e) {
|
567
687
|
ruby_xfree(vec);
|
568
688
|
rb_raise(rb_eRuntimeError, "%s", e.what());
|
@@ -573,8 +693,17 @@ private:
|
|
573
693
|
return Qtrue;
|
574
694
|
};
|
575
695
|
|
576
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(
|
577
|
-
|
696
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(int argc, VALUE* argv, VALUE self) {
|
697
|
+
VALUE arr, k, filter;
|
698
|
+
VALUE kw_args = Qnil;
|
699
|
+
ID kw_table[1] = {rb_intern("filter")};
|
700
|
+
VALUE kw_values[1] = {Qundef};
|
701
|
+
|
702
|
+
rb_scan_args(argc, argv, "2:", &arr, &k, &kw_args);
|
703
|
+
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
|
704
|
+
filter = kw_values[0] != Qundef ? kw_values[0] : Qnil;
|
705
|
+
|
706
|
+
const size_t dim = NUM2SIZET(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
578
707
|
|
579
708
|
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
580
709
|
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
@@ -589,17 +718,28 @@ private:
|
|
589
718
|
return Qnil;
|
590
719
|
}
|
591
720
|
|
721
|
+
CustomFilterFunctor* filter_func = nullptr;
|
722
|
+
if (!NIL_P(filter)) {
|
723
|
+
try {
|
724
|
+
filter_func = new CustomFilterFunctor(filter);
|
725
|
+
} catch (const std::bad_alloc& e) {
|
726
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
727
|
+
return Qnil;
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
592
731
|
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
593
|
-
for (
|
732
|
+
for (size_t i = 0; i < dim; i++) {
|
594
733
|
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
595
734
|
}
|
596
735
|
|
597
736
|
std::priority_queue<std::pair<float, size_t>> result =
|
598
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (
|
737
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, NUM2SIZET(k), filter_func);
|
599
738
|
|
600
739
|
ruby_xfree(vec);
|
740
|
+
if (filter_func) delete filter_func;
|
601
741
|
|
602
|
-
if (result.size() != (
|
742
|
+
if (result.size() != NUM2SIZET(k)) {
|
603
743
|
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
604
744
|
}
|
605
745
|
|
@@ -609,7 +749,7 @@ private:
|
|
609
749
|
while (!result.empty()) {
|
610
750
|
const std::pair<float, size_t>& result_tuple = result.top();
|
611
751
|
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
612
|
-
rb_ary_unshift(neighbors_arr,
|
752
|
+
rb_ary_unshift(neighbors_arr, SIZET2NUM(result_tuple.second));
|
613
753
|
result.pop();
|
614
754
|
}
|
615
755
|
|
@@ -651,16 +791,16 @@ private:
|
|
651
791
|
};
|
652
792
|
|
653
793
|
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
654
|
-
get_hnsw_bruteforcesearch(self)->removePoint((
|
794
|
+
get_hnsw_bruteforcesearch(self)->removePoint(NUM2SIZET(idx));
|
655
795
|
return Qnil;
|
656
796
|
};
|
657
797
|
|
658
798
|
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
659
|
-
return
|
799
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->maxelements_);
|
660
800
|
};
|
661
801
|
|
662
802
|
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
663
|
-
return
|
803
|
+
return SIZET2NUM(get_hnsw_bruteforcesearch(self)->cur_element_count);
|
664
804
|
};
|
665
805
|
};
|
666
806
|
|