hnswlib 0.5.3 → 0.6.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/README.md +2 -2
- data/ext/hnswlib/hnswlibext.cpp +1 -2
- data/ext/hnswlib/hnswlibext.hpp +568 -576
- data/ext/hnswlib/src/bruteforce.h +5 -1
- data/ext/hnswlib/src/hnswalg.h +8 -7
- data/lib/hnswlib/version.rb +1 -1
- metadata +4 -3
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -29,69 +29,68 @@ VALUE rb_cHnswlibHierarchicalNSW;
|
|
29
29
|
VALUE rb_cHnswlibBruteforceSearch;
|
30
30
|
|
31
31
|
class RbHnswlibL2Space {
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
32
|
+
public:
|
33
|
+
static VALUE hnsw_l2space_alloc(VALUE self) {
|
34
|
+
hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
|
35
|
+
new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
|
36
|
+
return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
|
37
|
+
};
|
38
|
+
|
39
|
+
static void hnsw_l2space_free(void* ptr) {
|
40
|
+
((hnswlib::L2Space*)ptr)->~L2Space();
|
41
|
+
ruby_xfree(ptr);
|
42
|
+
};
|
43
|
+
|
44
|
+
static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
|
45
|
+
|
46
|
+
static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
|
47
|
+
hnswlib::L2Space* ptr;
|
48
|
+
TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
|
49
|
+
return ptr;
|
50
|
+
};
|
51
|
+
|
52
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
53
|
+
rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
|
54
|
+
rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
|
55
|
+
rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
|
56
|
+
rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
|
57
|
+
rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
|
58
|
+
return rb_cHnswlibL2Space;
|
59
|
+
};
|
60
|
+
|
61
|
+
private:
|
62
|
+
static const rb_data_type_t hnsw_l2space_type;
|
63
|
+
|
64
|
+
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
|
+
rb_iv_set(self, "@dim", dim);
|
66
|
+
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
+
new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
|
68
|
+
return Qnil;
|
69
|
+
};
|
70
|
+
|
71
|
+
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
73
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
74
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
70
75
|
return Qnil;
|
71
|
-
}
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
}
|
87
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
88
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
89
|
-
ruby_xfree(vec_a);
|
90
|
-
ruby_xfree(vec_b);
|
91
|
-
return DBL2NUM((double)dist);
|
92
|
-
};
|
76
|
+
}
|
77
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
78
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
79
|
+
return Qnil;
|
80
|
+
}
|
81
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
82
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
83
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
85
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
86
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
87
|
+
ruby_xfree(vec_a);
|
88
|
+
ruby_xfree(vec_b);
|
89
|
+
return DBL2NUM((double)dist);
|
90
|
+
};
|
93
91
|
};
|
94
92
|
|
93
|
+
// clang-format off
|
95
94
|
const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
96
95
|
"RbHnswlibL2Space",
|
97
96
|
{
|
@@ -103,71 +102,71 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
|
103
102
|
NULL,
|
104
103
|
RUBY_TYPED_FREE_IMMEDIATELY
|
105
104
|
};
|
105
|
+
// clang-format on
|
106
106
|
|
107
107
|
class RbHnswlibInnerProductSpace {
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
108
|
+
public:
|
109
|
+
static VALUE hnsw_ipspace_alloc(VALUE self) {
|
110
|
+
hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
|
111
|
+
new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
|
112
|
+
return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
|
113
|
+
};
|
114
|
+
|
115
|
+
static void hnsw_ipspace_free(void* ptr) {
|
116
|
+
((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
|
117
|
+
ruby_xfree(ptr);
|
118
|
+
};
|
119
|
+
|
120
|
+
static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
|
121
|
+
|
122
|
+
static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
|
123
|
+
hnswlib::InnerProductSpace* ptr;
|
124
|
+
TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
|
125
|
+
return ptr;
|
126
|
+
};
|
127
|
+
|
128
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
129
|
+
rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
|
130
|
+
rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
|
131
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
|
132
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
|
133
|
+
rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
|
134
|
+
return rb_cHnswlibInnerProductSpace;
|
135
|
+
};
|
136
|
+
|
137
|
+
private:
|
138
|
+
static const rb_data_type_t hnsw_ipspace_type;
|
139
|
+
|
140
|
+
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
141
|
+
rb_iv_set(self, "@dim", dim);
|
142
|
+
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
143
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
|
144
|
+
return Qnil;
|
145
|
+
};
|
146
|
+
|
147
|
+
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
148
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
149
|
+
if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
|
150
|
+
rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
|
146
151
|
return Qnil;
|
147
|
-
}
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
}
|
163
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
164
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
165
|
-
ruby_xfree(vec_a);
|
166
|
-
ruby_xfree(vec_b);
|
167
|
-
return DBL2NUM((double)dist);
|
168
|
-
};
|
152
|
+
}
|
153
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
154
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
155
|
+
return Qnil;
|
156
|
+
}
|
157
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
158
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
159
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
161
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
162
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
163
|
+
ruby_xfree(vec_a);
|
164
|
+
ruby_xfree(vec_b);
|
165
|
+
return DBL2NUM((double)dist);
|
166
|
+
};
|
169
167
|
};
|
170
168
|
|
169
|
+
// clang-format off
|
171
170
|
const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
172
171
|
"RbHnswlibInnerProductSpace",
|
173
172
|
{
|
@@ -179,294 +178,288 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
179
178
|
NULL,
|
180
179
|
RUBY_TYPED_FREE_IMMEDIATELY
|
181
180
|
};
|
181
|
+
// clang-format on
|
182
182
|
|
183
183
|
class RbHnswlibHierarchicalNSW {
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
|
242
|
-
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
243
|
-
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
244
|
-
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
245
|
-
|
246
|
-
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
247
|
-
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
248
|
-
return Qnil;
|
249
|
-
}
|
250
|
-
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
251
|
-
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
252
|
-
return Qnil;
|
253
|
-
}
|
254
|
-
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
255
|
-
rb_raise(rb_eTypeError, "expected m, Integer");
|
256
|
-
return Qnil;
|
257
|
-
}
|
258
|
-
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
259
|
-
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
260
|
-
return Qnil;
|
261
|
-
}
|
262
|
-
if (!RB_INTEGER_TYPE_P(kw_values[4])) {
|
263
|
-
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
264
|
-
return Qnil;
|
265
|
-
}
|
266
|
-
|
267
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
268
|
-
hnswlib::SpaceInterface<float>* space;
|
269
|
-
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
270
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
271
|
-
} else {
|
272
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
273
|
-
}
|
274
|
-
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
275
|
-
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
276
|
-
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
277
|
-
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
278
|
-
|
279
|
-
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
280
|
-
try {
|
281
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
282
|
-
} catch(const std::runtime_error& e) {
|
283
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
284
|
-
return Qnil;
|
285
|
-
}
|
286
|
-
|
184
|
+
public:
|
185
|
+
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
186
|
+
hnswlib::HierarchicalNSW<float>* ptr =
|
187
|
+
(hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
|
188
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
|
189
|
+
return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
|
190
|
+
};
|
191
|
+
|
192
|
+
static void hnsw_hierarchicalnsw_free(void* ptr) {
|
193
|
+
((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
|
194
|
+
ruby_xfree(ptr);
|
195
|
+
};
|
196
|
+
|
197
|
+
static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
|
198
|
+
|
199
|
+
static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
|
200
|
+
hnswlib::HierarchicalNSW<float>* ptr;
|
201
|
+
TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
|
202
|
+
return ptr;
|
203
|
+
};
|
204
|
+
|
205
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
206
|
+
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
207
|
+
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
208
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
209
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
|
210
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
|
211
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
212
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
213
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
214
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
215
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
216
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
217
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
218
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
219
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
220
|
+
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
221
|
+
return rb_cHnswlibHierarchicalNSW;
|
222
|
+
};
|
223
|
+
|
224
|
+
private:
|
225
|
+
static const rb_data_type_t hnsw_hierarchicalnsw_type;
|
226
|
+
|
227
|
+
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
228
|
+
VALUE kw_args = Qnil;
|
229
|
+
ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
|
230
|
+
rb_intern("random_seed")};
|
231
|
+
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
232
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
233
|
+
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
|
234
|
+
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
235
|
+
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
236
|
+
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
237
|
+
|
238
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
239
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
240
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
287
241
|
return Qnil;
|
288
|
-
}
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
242
|
+
}
|
243
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
244
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
245
|
+
return Qnil;
|
246
|
+
}
|
247
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
248
|
+
rb_raise(rb_eTypeError, "expected m, Integer");
|
249
|
+
return Qnil;
|
250
|
+
}
|
251
|
+
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
252
|
+
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
253
|
+
return Qnil;
|
254
|
+
}
|
255
|
+
if (!RB_INTEGER_TYPE_P(kw_values[4])) {
|
256
|
+
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
257
|
+
return Qnil;
|
258
|
+
}
|
259
|
+
|
260
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
261
|
+
hnswlib::SpaceInterface<float>* space;
|
262
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
263
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
264
|
+
} else {
|
265
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
266
|
+
}
|
267
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
268
|
+
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
269
|
+
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
270
|
+
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
271
|
+
|
272
|
+
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
273
|
+
try {
|
274
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
275
|
+
} catch (const std::runtime_error& e) {
|
276
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
277
|
+
return Qnil;
|
278
|
+
}
|
307
279
|
|
308
|
-
|
309
|
-
|
310
|
-
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
311
|
-
}
|
280
|
+
return Qnil;
|
281
|
+
};
|
312
282
|
|
313
|
-
|
283
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
|
284
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
314
285
|
|
315
|
-
|
316
|
-
|
317
|
-
|
286
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
287
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
288
|
+
return Qfalse;
|
289
|
+
}
|
290
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
291
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
292
|
+
return Qfalse;
|
293
|
+
}
|
294
|
+
if (dim != RARRAY_LEN(arr)) {
|
295
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
296
|
+
return Qfalse;
|
297
|
+
}
|
318
298
|
|
319
|
-
|
320
|
-
|
299
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
300
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
321
301
|
|
322
|
-
|
323
|
-
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
324
|
-
return Qnil;
|
325
|
-
}
|
302
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
326
303
|
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
}
|
304
|
+
ruby_xfree(vec);
|
305
|
+
return Qtrue;
|
306
|
+
};
|
331
307
|
|
332
|
-
|
333
|
-
|
334
|
-
return Qnil;
|
335
|
-
}
|
308
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
|
309
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
336
310
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
311
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
312
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
313
|
+
return Qnil;
|
314
|
+
}
|
315
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
316
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
317
|
+
return Qnil;
|
318
|
+
}
|
319
|
+
if (dim != RARRAY_LEN(arr)) {
|
320
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
321
|
+
return Qnil;
|
322
|
+
}
|
341
323
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
ruby_xfree(vec);
|
347
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
-
return Qnil;
|
349
|
-
}
|
324
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
325
|
+
for (int i = 0; i < dim; i++) {
|
326
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
327
|
+
}
|
350
328
|
|
329
|
+
std::priority_queue<std::pair<float, size_t>> result;
|
330
|
+
try {
|
331
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
332
|
+
} catch (const std::runtime_error& e) {
|
351
333
|
ruby_xfree(vec);
|
352
|
-
|
353
|
-
if (result.size() != (size_t)NUM2INT(k)) {
|
354
|
-
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
355
|
-
return Qnil;
|
356
|
-
}
|
357
|
-
|
358
|
-
VALUE distances_arr = rb_ary_new2(result.size());
|
359
|
-
VALUE neighbors_arr = rb_ary_new2(result.size());
|
360
|
-
|
361
|
-
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
362
|
-
const std::pair<float, size_t>& result_tuple = result.top();
|
363
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
364
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
365
|
-
result.pop();
|
366
|
-
}
|
367
|
-
|
368
|
-
VALUE ret = rb_ary_new2(2);
|
369
|
-
rb_ary_store(ret, 0, neighbors_arr);
|
370
|
-
rb_ary_store(ret, 1, distances_arr);
|
371
|
-
return ret;
|
372
|
-
};
|
373
|
-
|
374
|
-
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
375
|
-
std::string filename(StringValuePtr(_filename));
|
376
|
-
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
377
|
-
RB_GC_GUARD(_filename);
|
334
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
378
335
|
return Qnil;
|
379
|
-
}
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
336
|
+
}
|
337
|
+
|
338
|
+
ruby_xfree(vec);
|
339
|
+
|
340
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
341
|
+
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
342
|
+
}
|
343
|
+
|
344
|
+
VALUE distances_arr = rb_ary_new();
|
345
|
+
VALUE neighbors_arr = rb_ary_new();
|
346
|
+
|
347
|
+
while (!result.empty()) {
|
348
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
349
|
+
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
350
|
+
rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
|
351
|
+
result.pop();
|
352
|
+
}
|
353
|
+
|
354
|
+
VALUE ret = rb_ary_new2(2);
|
355
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
356
|
+
rb_ary_store(ret, 1, distances_arr);
|
357
|
+
return ret;
|
358
|
+
};
|
359
|
+
|
360
|
+
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
361
|
+
std::string filename(StringValuePtr(_filename));
|
362
|
+
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
363
|
+
RB_GC_GUARD(_filename);
|
364
|
+
return Qnil;
|
365
|
+
};
|
366
|
+
|
367
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
|
368
|
+
std::string filename(StringValuePtr(_filename));
|
369
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
370
|
+
hnswlib::SpaceInterface<float>* space;
|
371
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
372
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
373
|
+
} else {
|
374
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
375
|
+
}
|
376
|
+
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
377
|
+
if (index->data_level0_memory_) {
|
378
|
+
free(index->data_level0_memory_);
|
379
|
+
index->data_level0_memory_ = nullptr;
|
380
|
+
}
|
381
|
+
if (index->linkLists_) {
|
382
|
+
for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
|
383
|
+
if (index->element_levels_[i] > 0 && index->linkLists_[i]) {
|
384
|
+
free(index->linkLists_[i]);
|
385
|
+
index->linkLists_[i] = nullptr;
|
395
386
|
}
|
396
|
-
free(index->linkLists_);
|
397
387
|
}
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
388
|
+
free(index->linkLists_);
|
389
|
+
index->linkLists_ = nullptr;
|
390
|
+
}
|
391
|
+
if (index->visited_list_pool_) {
|
392
|
+
delete index->visited_list_pool_;
|
393
|
+
index->visited_list_pool_ = nullptr;
|
394
|
+
}
|
395
|
+
try {
|
396
|
+
index->loadIndex(filename, space);
|
397
|
+
} catch (const std::runtime_error& e) {
|
398
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
406
399
|
return Qnil;
|
407
|
-
}
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
}
|
421
|
-
return ret;
|
422
|
-
};
|
423
|
-
|
424
|
-
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
425
|
-
VALUE ret = rb_ary_new();
|
426
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) {
|
427
|
-
rb_ary_push(ret, INT2NUM((int)kv.first));
|
428
|
-
}
|
429
|
-
return ret;
|
430
|
-
};
|
431
|
-
|
432
|
-
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
433
|
-
try {
|
434
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
|
435
|
-
} catch(const std::runtime_error& e) {
|
436
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
437
|
-
return Qnil;
|
438
|
-
}
|
400
|
+
}
|
401
|
+
RB_GC_GUARD(_filename);
|
402
|
+
return Qnil;
|
403
|
+
};
|
404
|
+
|
405
|
+
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
406
|
+
VALUE ret = Qnil;
|
407
|
+
try {
|
408
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
|
409
|
+
ret = rb_ary_new2(vec.size());
|
410
|
+
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
411
|
+
} catch (const std::runtime_error& e) {
|
412
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
439
413
|
return Qnil;
|
440
|
-
}
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
414
|
+
}
|
415
|
+
return ret;
|
416
|
+
};
|
417
|
+
|
418
|
+
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
419
|
+
VALUE ret = rb_ary_new();
|
420
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
|
421
|
+
return ret;
|
422
|
+
};
|
423
|
+
|
424
|
+
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
425
|
+
try {
|
426
|
+
get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
|
427
|
+
} catch (const std::runtime_error& e) {
|
428
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
453
429
|
return Qnil;
|
454
|
-
}
|
430
|
+
}
|
431
|
+
return Qnil;
|
432
|
+
};
|
455
433
|
|
456
|
-
|
457
|
-
|
434
|
+
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
435
|
+
if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
436
|
+
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
458
437
|
return Qnil;
|
459
|
-
}
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
438
|
+
}
|
439
|
+
try {
|
440
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
|
441
|
+
} catch (const std::runtime_error& e) {
|
442
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
443
|
+
return Qnil;
|
444
|
+
}
|
445
|
+
return Qnil;
|
446
|
+
};
|
447
|
+
|
448
|
+
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
449
|
+
get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
|
450
|
+
return Qnil;
|
451
|
+
};
|
452
|
+
|
453
|
+
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
454
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
|
455
|
+
};
|
456
|
+
|
457
|
+
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
458
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
|
459
|
+
};
|
468
460
|
};
|
469
461
|
|
462
|
+
// clang-format off
|
470
463
|
const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
471
464
|
"RbHnswlibHierarchicalNSW",
|
472
465
|
{
|
@@ -478,210 +471,208 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
|
478
471
|
NULL,
|
479
472
|
RUBY_TYPED_FREE_IMMEDIATELY
|
480
473
|
};
|
474
|
+
// clang-format on
|
481
475
|
|
482
476
|
class RbHnswlibBruteforceSearch {
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
return Qnil;
|
533
|
-
}
|
534
|
-
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
535
|
-
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
536
|
-
return Qnil;
|
537
|
-
}
|
538
|
-
|
539
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
540
|
-
hnswlib::SpaceInterface<float>* space;
|
541
|
-
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
542
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
543
|
-
} else {
|
544
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
545
|
-
}
|
546
|
-
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
547
|
-
|
548
|
-
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
549
|
-
try {
|
550
|
-
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
551
|
-
} catch(const std::runtime_error& e) {
|
552
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
553
|
-
return Qnil;
|
554
|
-
}
|
555
|
-
|
477
|
+
public:
|
478
|
+
static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
|
479
|
+
hnswlib::BruteforceSearch<float>* ptr =
|
480
|
+
(hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
|
481
|
+
new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
|
482
|
+
return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
|
483
|
+
};
|
484
|
+
|
485
|
+
static void hnsw_bruteforcesearch_free(void* ptr) {
|
486
|
+
((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
|
487
|
+
ruby_xfree(ptr);
|
488
|
+
};
|
489
|
+
|
490
|
+
static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
|
491
|
+
|
492
|
+
static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
|
493
|
+
hnswlib::BruteforceSearch<float>* ptr;
|
494
|
+
TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
|
495
|
+
return ptr;
|
496
|
+
};
|
497
|
+
|
498
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
499
|
+
rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
|
500
|
+
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
501
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
502
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
503
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
|
504
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
505
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
506
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
507
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
|
508
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
|
509
|
+
rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
|
510
|
+
return rb_cHnswlibBruteforceSearch;
|
511
|
+
};
|
512
|
+
|
513
|
+
private:
|
514
|
+
static const rb_data_type_t hnsw_bruteforcesearch_type;
|
515
|
+
|
516
|
+
static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
|
517
|
+
VALUE kw_args = Qnil;
|
518
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
|
519
|
+
VALUE kw_values[2] = {Qundef, Qundef};
|
520
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
521
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
522
|
+
|
523
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
524
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
525
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
556
526
|
return Qnil;
|
557
|
-
}
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
}
|
606
|
-
|
607
|
-
if (dim != RARRAY_LEN(arr)) {
|
608
|
-
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
609
|
-
return Qnil;
|
610
|
-
}
|
611
|
-
|
612
|
-
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
613
|
-
for (int i = 0; i < dim; i++) {
|
614
|
-
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
615
|
-
}
|
616
|
-
|
617
|
-
std::priority_queue<std::pair<float, size_t>> result =
|
618
|
-
get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
|
619
|
-
|
527
|
+
}
|
528
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
529
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
530
|
+
return Qnil;
|
531
|
+
}
|
532
|
+
|
533
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
534
|
+
hnswlib::SpaceInterface<float>* space;
|
535
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
536
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
537
|
+
} else {
|
538
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
539
|
+
}
|
540
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
541
|
+
|
542
|
+
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
543
|
+
try {
|
544
|
+
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
545
|
+
} catch (const std::runtime_error& e) {
|
546
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
547
|
+
return Qnil;
|
548
|
+
}
|
549
|
+
|
550
|
+
return Qnil;
|
551
|
+
};
|
552
|
+
|
553
|
+
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
554
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
555
|
+
|
556
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
557
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
558
|
+
return Qfalse;
|
559
|
+
}
|
560
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
561
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
562
|
+
return Qfalse;
|
563
|
+
}
|
564
|
+
if (dim != RARRAY_LEN(arr)) {
|
565
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
566
|
+
return Qfalse;
|
567
|
+
}
|
568
|
+
|
569
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
570
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
571
|
+
|
572
|
+
try {
|
573
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
574
|
+
} catch (const std::runtime_error& e) {
|
620
575
|
ruby_xfree(vec);
|
576
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
577
|
+
return Qfalse;
|
578
|
+
}
|
621
579
|
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
}
|
626
|
-
|
627
|
-
VALUE distances_arr = rb_ary_new2(result.size());
|
628
|
-
VALUE neighbors_arr = rb_ary_new2(result.size());
|
629
|
-
|
630
|
-
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
631
|
-
const std::pair<float, size_t>& result_tuple = result.top();
|
632
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
633
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
634
|
-
result.pop();
|
635
|
-
}
|
580
|
+
ruby_xfree(vec);
|
581
|
+
return Qtrue;
|
582
|
+
};
|
636
583
|
|
637
|
-
|
638
|
-
|
639
|
-
rb_ary_store(ret, 1, distances_arr);
|
640
|
-
return ret;
|
641
|
-
};
|
584
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
|
585
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
642
586
|
|
643
|
-
|
644
|
-
|
645
|
-
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
646
|
-
RB_GC_GUARD(_filename);
|
587
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
588
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
647
589
|
return Qnil;
|
648
|
-
}
|
649
|
-
|
650
|
-
|
651
|
-
std::string filename(StringValuePtr(_filename));
|
652
|
-
VALUE ivspace = rb_iv_get(self, "@space");
|
653
|
-
hnswlib::SpaceInterface<float>* space;
|
654
|
-
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
655
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
656
|
-
} else {
|
657
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
658
|
-
}
|
659
|
-
hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
|
660
|
-
if (index->data_) free(index->data_);
|
661
|
-
try {
|
662
|
-
index->loadIndex(filename, space);
|
663
|
-
} catch(const std::runtime_error& e) {
|
664
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
665
|
-
return Qnil;
|
666
|
-
}
|
667
|
-
RB_GC_GUARD(_filename);
|
590
|
+
}
|
591
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
592
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
668
593
|
return Qnil;
|
669
|
-
}
|
670
|
-
|
671
|
-
|
672
|
-
get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
|
594
|
+
}
|
595
|
+
if (dim != RARRAY_LEN(arr)) {
|
596
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
673
597
|
return Qnil;
|
674
|
-
}
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
598
|
+
}
|
599
|
+
|
600
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
601
|
+
for (int i = 0; i < dim; i++) {
|
602
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
603
|
+
}
|
604
|
+
|
605
|
+
std::priority_queue<std::pair<float, size_t>> result =
|
606
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
607
|
+
|
608
|
+
ruby_xfree(vec);
|
609
|
+
|
610
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
611
|
+
rb_warning("Cannot return as many search results as the requested number of neighbors.");
|
612
|
+
}
|
613
|
+
|
614
|
+
VALUE distances_arr = rb_ary_new2(result.size());
|
615
|
+
VALUE neighbors_arr = rb_ary_new2(result.size());
|
616
|
+
|
617
|
+
while (!result.empty()) {
|
618
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
619
|
+
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
620
|
+
rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
|
621
|
+
result.pop();
|
622
|
+
}
|
623
|
+
|
624
|
+
VALUE ret = rb_ary_new2(2);
|
625
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
626
|
+
rb_ary_store(ret, 1, distances_arr);
|
627
|
+
return ret;
|
628
|
+
};
|
629
|
+
|
630
|
+
static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
|
631
|
+
std::string filename(StringValuePtr(_filename));
|
632
|
+
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
633
|
+
RB_GC_GUARD(_filename);
|
634
|
+
return Qnil;
|
635
|
+
};
|
636
|
+
|
637
|
+
static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
|
638
|
+
std::string filename(StringValuePtr(_filename));
|
639
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
640
|
+
hnswlib::SpaceInterface<float>* space;
|
641
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
642
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
643
|
+
} else {
|
644
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
645
|
+
}
|
646
|
+
hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
|
647
|
+
if (index->data_) {
|
648
|
+
free(index->data_);
|
649
|
+
index->data_ = nullptr;
|
650
|
+
}
|
651
|
+
try {
|
652
|
+
index->loadIndex(filename, space);
|
653
|
+
} catch (const std::runtime_error& e) {
|
654
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
655
|
+
return Qnil;
|
656
|
+
}
|
657
|
+
RB_GC_GUARD(_filename);
|
658
|
+
return Qnil;
|
659
|
+
};
|
660
|
+
|
661
|
+
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
662
|
+
get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
|
663
|
+
return Qnil;
|
664
|
+
};
|
665
|
+
|
666
|
+
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
667
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
|
668
|
+
};
|
669
|
+
|
670
|
+
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
671
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
|
672
|
+
};
|
683
673
|
};
|
684
674
|
|
675
|
+
// clang-format off
|
685
676
|
const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
686
677
|
"RbHnswlibBruteforceSearch",
|
687
678
|
{
|
@@ -693,5 +684,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
|
693
684
|
NULL,
|
694
685
|
RUBY_TYPED_FREE_IMMEDIATELY
|
695
686
|
};
|
687
|
+
// clang-format on
|
696
688
|
|
697
689
|
#endif /* HNSWLIBEXT_HPP */
|