hnswlib 0.5.1 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/README.md +2 -2
- data/ext/hnswlib/hnswlibext.hpp +544 -518
- data/ext/hnswlib/src/space_ip.h +61 -23
- data/lib/hnswlib/version.rb +2 -2
- metadata +3 -3
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -29,69 +29,64 @@ VALUE rb_cHnswlibHierarchicalNSW;
|
|
29
29
|
VALUE rb_cHnswlibBruteforceSearch;
|
30
30
|
|
31
31
|
class RbHnswlibL2Space {
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
32
|
+
public:
|
33
|
+
static VALUE hnsw_l2space_alloc(VALUE self) {
|
34
|
+
hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
|
35
|
+
new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
|
36
|
+
return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
|
37
|
+
};
|
38
|
+
|
39
|
+
static void hnsw_l2space_free(void* ptr) {
|
40
|
+
((hnswlib::L2Space*)ptr)->~L2Space();
|
41
|
+
ruby_xfree(ptr);
|
42
|
+
};
|
43
|
+
|
44
|
+
static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
|
45
|
+
|
46
|
+
static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
|
47
|
+
hnswlib::L2Space* ptr;
|
48
|
+
TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
|
49
|
+
return ptr;
|
50
|
+
};
|
51
|
+
|
52
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
53
|
+
rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
|
54
|
+
rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
|
55
|
+
rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
|
56
|
+
rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
|
57
|
+
rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
|
58
|
+
return rb_cHnswlibL2Space;
|
59
|
+
};
|
60
|
+
|
61
|
+
private:
|
62
|
+
static const rb_data_type_t hnsw_l2space_type;
|
63
|
+
|
64
|
+
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
|
+
rb_iv_set(self, "@dim", dim);
|
66
|
+
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
+
new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
|
68
|
+
return Qnil;
|
69
|
+
};
|
70
|
+
|
71
|
+
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
73
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
74
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
70
75
|
return Qnil;
|
71
|
-
}
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
}
|
83
|
-
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (int i = 0; i < dim; i++) {
|
85
|
-
vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
86
|
-
}
|
87
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
88
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
89
|
-
ruby_xfree(vec_a);
|
90
|
-
ruby_xfree(vec_b);
|
91
|
-
return DBL2NUM((double)dist);
|
92
|
-
};
|
76
|
+
}
|
77
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
78
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
79
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
80
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
81
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
82
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
83
|
+
ruby_xfree(vec_a);
|
84
|
+
ruby_xfree(vec_b);
|
85
|
+
return DBL2NUM((double)dist);
|
86
|
+
};
|
93
87
|
};
|
94
88
|
|
89
|
+
// clang-format off
|
95
90
|
const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
96
91
|
"RbHnswlibL2Space",
|
97
92
|
{
|
@@ -103,71 +98,67 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
|
103
98
|
NULL,
|
104
99
|
RUBY_TYPED_FREE_IMMEDIATELY
|
105
100
|
};
|
101
|
+
// clang-format on
|
106
102
|
|
107
103
|
class RbHnswlibInnerProductSpace {
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
104
|
+
public:
|
105
|
+
static VALUE hnsw_ipspace_alloc(VALUE self) {
|
106
|
+
hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
|
107
|
+
new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
|
108
|
+
return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
|
109
|
+
};
|
110
|
+
|
111
|
+
static void hnsw_ipspace_free(void* ptr) {
|
112
|
+
((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
|
113
|
+
ruby_xfree(ptr);
|
114
|
+
};
|
115
|
+
|
116
|
+
static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
|
117
|
+
|
118
|
+
static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
|
119
|
+
hnswlib::InnerProductSpace* ptr;
|
120
|
+
TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
|
121
|
+
return ptr;
|
122
|
+
};
|
123
|
+
|
124
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
125
|
+
rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
|
126
|
+
rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
|
127
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
|
128
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
|
129
|
+
rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
|
130
|
+
return rb_cHnswlibInnerProductSpace;
|
131
|
+
};
|
132
|
+
|
133
|
+
private:
|
134
|
+
static const rb_data_type_t hnsw_ipspace_type;
|
135
|
+
|
136
|
+
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
137
|
+
rb_iv_set(self, "@dim", dim);
|
138
|
+
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
139
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
|
140
|
+
return Qnil;
|
141
|
+
};
|
142
|
+
|
143
|
+
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
144
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
145
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
146
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
146
147
|
return Qnil;
|
147
|
-
}
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
}
|
159
|
-
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (int i = 0; i < dim; i++) {
|
161
|
-
vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
162
|
-
}
|
163
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
164
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
165
|
-
ruby_xfree(vec_a);
|
166
|
-
ruby_xfree(vec_b);
|
167
|
-
return DBL2NUM((double)dist);
|
168
|
-
};
|
148
|
+
}
|
149
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
150
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
151
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
152
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
153
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
154
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
155
|
+
ruby_xfree(vec_a);
|
156
|
+
ruby_xfree(vec_b);
|
157
|
+
return DBL2NUM((double)dist);
|
158
|
+
};
|
169
159
|
};
|
170
160
|
|
161
|
+
// clang-format off
|
171
162
|
const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
172
163
|
"RbHnswlibInnerProductSpace",
|
173
164
|
{
|
@@ -179,258 +170,278 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
179
170
|
NULL,
|
180
171
|
RUBY_TYPED_FREE_IMMEDIATELY
|
181
172
|
};
|
173
|
+
// clang-format on
|
182
174
|
|
183
175
|
class RbHnswlibHierarchicalNSW {
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
275
|
-
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
276
|
-
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
277
|
-
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
278
|
-
|
279
|
-
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
176
|
+
public:
|
177
|
+
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
178
|
+
hnswlib::HierarchicalNSW<float>* ptr =
|
179
|
+
(hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
|
180
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
|
181
|
+
return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
|
182
|
+
};
|
183
|
+
|
184
|
+
static void hnsw_hierarchicalnsw_free(void* ptr) {
|
185
|
+
((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
|
186
|
+
ruby_xfree(ptr);
|
187
|
+
};
|
188
|
+
|
189
|
+
static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
|
190
|
+
|
191
|
+
static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
|
192
|
+
hnswlib::HierarchicalNSW<float>* ptr;
|
193
|
+
TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
|
194
|
+
return ptr;
|
195
|
+
};
|
196
|
+
|
197
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
198
|
+
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
199
|
+
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
200
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
201
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
|
202
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
|
203
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
204
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
205
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
206
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
207
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
208
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
209
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
210
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
211
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
212
|
+
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
213
|
+
return rb_cHnswlibHierarchicalNSW;
|
214
|
+
};
|
215
|
+
|
216
|
+
private:
|
217
|
+
static const rb_data_type_t hnsw_hierarchicalnsw_type;
|
218
|
+
|
219
|
+
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
220
|
+
VALUE kw_args = Qnil;
|
221
|
+
ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
|
222
|
+
rb_intern("random_seed")};
|
223
|
+
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
224
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
225
|
+
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
|
226
|
+
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
227
|
+
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
228
|
+
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
229
|
+
|
230
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
231
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
232
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
233
|
+
return Qnil;
|
234
|
+
}
|
235
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
236
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
237
|
+
return Qnil;
|
238
|
+
}
|
239
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
240
|
+
rb_raise(rb_eTypeError, "expected m, Integer");
|
241
|
+
return Qnil;
|
242
|
+
}
|
243
|
+
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
244
|
+
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
245
|
+
return Qnil;
|
246
|
+
}
|
247
|
+
if (!RB_INTEGER_TYPE_P(kw_values[4])) {
|
248
|
+
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
249
|
+
return Qnil;
|
250
|
+
}
|
251
|
+
|
252
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
253
|
+
hnswlib::SpaceInterface<float>* space;
|
254
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
255
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
256
|
+
} else {
|
257
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
258
|
+
}
|
259
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
260
|
+
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
261
|
+
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
262
|
+
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
263
|
+
|
264
|
+
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
265
|
+
try {
|
280
266
|
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
281
|
-
|
267
|
+
} catch (const std::runtime_error& e) {
|
268
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
282
269
|
return Qnil;
|
283
|
-
}
|
284
|
-
|
285
|
-
static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
|
286
|
-
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
287
|
-
|
288
|
-
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
289
|
-
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
290
|
-
return Qfalse;
|
291
|
-
}
|
292
|
-
|
293
|
-
if (!RB_INTEGER_TYPE_P(idx)) {
|
294
|
-
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
295
|
-
return Qfalse;
|
296
|
-
}
|
270
|
+
}
|
297
271
|
|
298
|
-
|
299
|
-
|
300
|
-
return Qfalse;
|
301
|
-
}
|
302
|
-
|
303
|
-
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
304
|
-
for (int i = 0; i < dim; i++) {
|
305
|
-
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
306
|
-
}
|
272
|
+
return Qnil;
|
273
|
+
};
|
307
274
|
|
308
|
-
|
275
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
|
276
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
309
277
|
|
310
|
-
|
311
|
-
|
312
|
-
|
278
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
279
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
280
|
+
return Qfalse;
|
281
|
+
}
|
282
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
283
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
284
|
+
return Qfalse;
|
285
|
+
}
|
286
|
+
if (dim != RARRAY_LEN(arr)) {
|
287
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
288
|
+
return Qfalse;
|
289
|
+
}
|
313
290
|
|
314
|
-
|
315
|
-
|
291
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
292
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
316
293
|
|
317
|
-
|
318
|
-
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
319
|
-
return Qnil;
|
320
|
-
}
|
294
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
321
295
|
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
}
|
296
|
+
ruby_xfree(vec);
|
297
|
+
return Qtrue;
|
298
|
+
};
|
326
299
|
|
327
|
-
|
328
|
-
|
329
|
-
return Qnil;
|
330
|
-
}
|
300
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
|
301
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
331
302
|
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
303
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
304
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
305
|
+
return Qnil;
|
306
|
+
}
|
307
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
308
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
309
|
+
return Qnil;
|
310
|
+
}
|
311
|
+
if (dim != RARRAY_LEN(arr)) {
|
312
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
313
|
+
return Qnil;
|
314
|
+
}
|
336
315
|
|
337
|
-
|
338
|
-
|
316
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
317
|
+
for (int i = 0; i < dim; i++) {
|
318
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
319
|
+
}
|
339
320
|
|
321
|
+
std::priority_queue<std::pair<float, size_t>> result;
|
322
|
+
try {
|
323
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
324
|
+
} catch (const std::runtime_error& e) {
|
340
325
|
ruby_xfree(vec);
|
341
|
-
|
342
|
-
if (result.size() != (size_t)NUM2INT(k)) {
|
343
|
-
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
344
|
-
return Qnil;
|
345
|
-
}
|
346
|
-
|
347
|
-
VALUE distances_arr = rb_ary_new2(result.size());
|
348
|
-
VALUE neighbors_arr = rb_ary_new2(result.size());
|
349
|
-
|
350
|
-
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
351
|
-
const std::pair<float, size_t>& result_tuple = result.top();
|
352
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
353
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
354
|
-
result.pop();
|
355
|
-
}
|
356
|
-
|
357
|
-
VALUE ret = rb_ary_new2(2);
|
358
|
-
rb_ary_store(ret, 0, neighbors_arr);
|
359
|
-
rb_ary_store(ret, 1, distances_arr);
|
360
|
-
return ret;
|
361
|
-
};
|
362
|
-
|
363
|
-
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
364
|
-
std::string filename(StringValuePtr(_filename));
|
365
|
-
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
366
|
-
RB_GC_GUARD(_filename);
|
326
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
367
327
|
return Qnil;
|
368
|
-
}
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
328
|
+
}
|
329
|
+
|
330
|
+
ruby_xfree(vec);
|
331
|
+
|
332
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
333
|
+
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
334
|
+
}
|
335
|
+
|
336
|
+
VALUE distances_arr = rb_ary_new();
|
337
|
+
VALUE neighbors_arr = rb_ary_new();
|
338
|
+
|
339
|
+
while (!result.empty()) {
|
340
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
341
|
+
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
342
|
+
rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
|
343
|
+
result.pop();
|
344
|
+
}
|
345
|
+
|
346
|
+
VALUE ret = rb_ary_new2(2);
|
347
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
348
|
+
rb_ary_store(ret, 1, distances_arr);
|
349
|
+
return ret;
|
350
|
+
};
|
351
|
+
|
352
|
+
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
353
|
+
std::string filename(StringValuePtr(_filename));
|
354
|
+
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
355
|
+
RB_GC_GUARD(_filename);
|
356
|
+
return Qnil;
|
357
|
+
};
|
358
|
+
|
359
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
|
360
|
+
std::string filename(StringValuePtr(_filename));
|
361
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
362
|
+
hnswlib::SpaceInterface<float>* space;
|
363
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
364
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
365
|
+
} else {
|
366
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
367
|
+
}
|
368
|
+
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
369
|
+
if (index->data_level0_memory_) free(index->data_level0_memory_);
|
370
|
+
if (index->linkLists_) {
|
371
|
+
for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
|
372
|
+
if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
|
373
|
+
}
|
374
|
+
free(index->linkLists_);
|
375
|
+
}
|
376
|
+
if (index->visited_list_pool_) delete index->visited_list_pool_;
|
377
|
+
try {
|
378
|
+
index->loadIndex(filename, space);
|
379
|
+
} catch (const std::runtime_error& e) {
|
380
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
381
381
|
return Qnil;
|
382
|
-
}
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
return
|
396
|
-
}
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
382
|
+
}
|
383
|
+
RB_GC_GUARD(_filename);
|
384
|
+
return Qnil;
|
385
|
+
};
|
386
|
+
|
387
|
+
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
388
|
+
VALUE ret = Qnil;
|
389
|
+
try {
|
390
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
|
391
|
+
ret = rb_ary_new2(vec.size());
|
392
|
+
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
393
|
+
} catch (const std::runtime_error& e) {
|
394
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
395
|
+
return Qnil;
|
396
|
+
}
|
397
|
+
return ret;
|
398
|
+
};
|
399
|
+
|
400
|
+
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
401
|
+
VALUE ret = rb_ary_new();
|
402
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
|
403
|
+
return ret;
|
404
|
+
};
|
405
|
+
|
406
|
+
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
407
|
+
try {
|
407
408
|
get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
|
409
|
+
} catch (const std::runtime_error& e) {
|
410
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
408
411
|
return Qnil;
|
409
|
-
}
|
412
|
+
}
|
413
|
+
return Qnil;
|
414
|
+
};
|
410
415
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
return Qnil;
|
415
|
-
}
|
416
|
-
get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
|
416
|
+
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
417
|
+
if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
418
|
+
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
417
419
|
return Qnil;
|
418
|
-
}
|
419
|
-
|
420
|
-
|
421
|
-
|
420
|
+
}
|
421
|
+
try {
|
422
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
|
423
|
+
} catch (const std::runtime_error& e) {
|
424
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
422
425
|
return Qnil;
|
423
|
-
}
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
426
|
+
}
|
427
|
+
return Qnil;
|
428
|
+
};
|
429
|
+
|
430
|
+
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
431
|
+
get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
|
432
|
+
return Qnil;
|
433
|
+
};
|
434
|
+
|
435
|
+
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
436
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
|
437
|
+
};
|
438
|
+
|
439
|
+
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
440
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
|
441
|
+
};
|
432
442
|
};
|
433
443
|
|
444
|
+
// clang-format off
|
434
445
|
const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
435
446
|
"RbHnswlibHierarchicalNSW",
|
436
447
|
{
|
@@ -442,192 +453,206 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
|
442
453
|
NULL,
|
443
454
|
RUBY_TYPED_FREE_IMMEDIATELY
|
444
455
|
};
|
456
|
+
// clang-format on
|
445
457
|
|
446
458
|
class RbHnswlibBruteforceSearch {
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
459
|
+
public:
|
460
|
+
static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
|
461
|
+
hnswlib::BruteforceSearch<float>* ptr =
|
462
|
+
(hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
|
463
|
+
new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
|
464
|
+
return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
|
465
|
+
};
|
466
|
+
|
467
|
+
static void hnsw_bruteforcesearch_free(void* ptr) {
|
468
|
+
((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
|
469
|
+
ruby_xfree(ptr);
|
470
|
+
};
|
471
|
+
|
472
|
+
static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
|
473
|
+
|
474
|
+
static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
|
475
|
+
hnswlib::BruteforceSearch<float>* ptr;
|
476
|
+
TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
|
477
|
+
return ptr;
|
478
|
+
};
|
479
|
+
|
480
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
481
|
+
rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
|
482
|
+
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
483
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
484
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
485
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
|
486
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
487
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
488
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
489
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
|
490
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
|
491
|
+
rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
|
492
|
+
return rb_cHnswlibBruteforceSearch;
|
493
|
+
};
|
494
|
+
|
495
|
+
private:
|
496
|
+
static const rb_data_type_t hnsw_bruteforcesearch_type;
|
497
|
+
|
498
|
+
static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
|
499
|
+
VALUE kw_args = Qnil;
|
500
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
|
501
|
+
VALUE kw_values[2] = {Qundef, Qundef};
|
502
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
503
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
504
|
+
|
505
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
506
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
507
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
508
|
+
return Qnil;
|
509
|
+
}
|
510
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
511
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
512
|
+
return Qnil;
|
513
|
+
}
|
514
|
+
|
515
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
516
|
+
hnswlib::SpaceInterface<float>* space;
|
517
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
518
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
519
|
+
} else {
|
520
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
521
|
+
}
|
522
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
523
|
+
|
524
|
+
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
525
|
+
try {
|
513
526
|
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
514
|
-
|
527
|
+
} catch (const std::runtime_error& e) {
|
528
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
515
529
|
return Qnil;
|
516
|
-
}
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
get_hnsw_bruteforcesearch(self)->addPoint((void
|
542
|
-
|
530
|
+
}
|
531
|
+
|
532
|
+
return Qnil;
|
533
|
+
};
|
534
|
+
|
535
|
+
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
536
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
537
|
+
|
538
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
539
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
540
|
+
return Qfalse;
|
541
|
+
}
|
542
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
543
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
544
|
+
return Qfalse;
|
545
|
+
}
|
546
|
+
if (dim != RARRAY_LEN(arr)) {
|
547
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
548
|
+
return Qfalse;
|
549
|
+
}
|
550
|
+
|
551
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
552
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
553
|
+
|
554
|
+
try {
|
555
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
556
|
+
} catch (const std::runtime_error& e) {
|
543
557
|
ruby_xfree(vec);
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
|
548
|
-
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
549
|
-
|
550
|
-
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
551
|
-
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
552
|
-
return Qnil;
|
553
|
-
}
|
554
|
-
|
555
|
-
if (!RB_INTEGER_TYPE_P(k)) {
|
556
|
-
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
557
|
-
return Qnil;
|
558
|
-
}
|
559
|
-
|
560
|
-
if (dim != RARRAY_LEN(arr)) {
|
561
|
-
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
562
|
-
return Qnil;
|
563
|
-
}
|
564
|
-
|
565
|
-
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
566
|
-
for (int i = 0; i < dim; i++) {
|
567
|
-
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
568
|
-
}
|
558
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
559
|
+
return Qfalse;
|
560
|
+
}
|
569
561
|
|
570
|
-
|
571
|
-
|
562
|
+
ruby_xfree(vec);
|
563
|
+
return Qtrue;
|
564
|
+
};
|
572
565
|
|
573
|
-
|
566
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
|
567
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
574
568
|
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
569
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
570
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
571
|
+
return Qnil;
|
572
|
+
}
|
573
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
574
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
575
|
+
return Qnil;
|
576
|
+
}
|
577
|
+
if (dim != RARRAY_LEN(arr)) {
|
578
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
579
|
+
return Qnil;
|
580
|
+
}
|
579
581
|
|
580
|
-
|
581
|
-
|
582
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
583
|
+
for (int i = 0; i < dim; i++) {
|
584
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
585
|
+
}
|
582
586
|
|
583
|
-
|
584
|
-
|
585
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
586
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
587
|
-
result.pop();
|
588
|
-
}
|
587
|
+
std::priority_queue<std::pair<float, size_t>> result =
|
588
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
589
589
|
|
590
|
-
|
591
|
-
rb_ary_store(ret, 0, neighbors_arr);
|
592
|
-
rb_ary_store(ret, 1, distances_arr);
|
593
|
-
return ret;
|
594
|
-
};
|
590
|
+
ruby_xfree(vec);
|
595
591
|
|
596
|
-
|
597
|
-
|
598
|
-
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
599
|
-
RB_GC_GUARD(_filename);
|
600
|
-
return Qnil;
|
601
|
-
};
|
602
|
-
|
603
|
-
static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
|
604
|
-
std::string filename(StringValuePtr(_filename));
|
605
|
-
VALUE ivspace = rb_iv_get(self, "@space");
|
606
|
-
hnswlib::SpaceInterface<float>* space;
|
607
|
-
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
608
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
609
|
-
} else {
|
610
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
611
|
-
}
|
612
|
-
get_hnsw_bruteforcesearch(self)->loadIndex(filename, space);
|
613
|
-
RB_GC_GUARD(_filename);
|
592
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
593
|
+
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
614
594
|
return Qnil;
|
615
|
-
}
|
616
|
-
|
617
|
-
|
618
|
-
|
595
|
+
}
|
596
|
+
|
597
|
+
VALUE distances_arr = rb_ary_new2(result.size());
|
598
|
+
VALUE neighbors_arr = rb_ary_new2(result.size());
|
599
|
+
|
600
|
+
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
601
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
602
|
+
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
603
|
+
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
604
|
+
result.pop();
|
605
|
+
}
|
606
|
+
|
607
|
+
VALUE ret = rb_ary_new2(2);
|
608
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
609
|
+
rb_ary_store(ret, 1, distances_arr);
|
610
|
+
return ret;
|
611
|
+
};
|
612
|
+
|
613
|
+
static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
|
614
|
+
std::string filename(StringValuePtr(_filename));
|
615
|
+
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
616
|
+
RB_GC_GUARD(_filename);
|
617
|
+
return Qnil;
|
618
|
+
};
|
619
|
+
|
620
|
+
static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
|
621
|
+
std::string filename(StringValuePtr(_filename));
|
622
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
623
|
+
hnswlib::SpaceInterface<float>* space;
|
624
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
625
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
626
|
+
} else {
|
627
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
628
|
+
}
|
629
|
+
hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
|
630
|
+
if (index->data_) free(index->data_);
|
631
|
+
try {
|
632
|
+
index->loadIndex(filename, space);
|
633
|
+
} catch (const std::runtime_error& e) {
|
634
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
619
635
|
return Qnil;
|
620
|
-
}
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
636
|
+
}
|
637
|
+
RB_GC_GUARD(_filename);
|
638
|
+
return Qnil;
|
639
|
+
};
|
640
|
+
|
641
|
+
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
642
|
+
get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
|
643
|
+
return Qnil;
|
644
|
+
};
|
645
|
+
|
646
|
+
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
647
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
|
648
|
+
};
|
649
|
+
|
650
|
+
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
651
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
|
652
|
+
};
|
629
653
|
};
|
630
654
|
|
655
|
+
// clang-format off
|
631
656
|
const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
632
657
|
"RbHnswlibBruteforceSearch",
|
633
658
|
{
|
@@ -639,5 +664,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
|
639
664
|
NULL,
|
640
665
|
RUBY_TYPED_FREE_IMMEDIATELY
|
641
666
|
};
|
667
|
+
// clang-format on
|
642
668
|
|
643
669
|
#endif /* HNSWLIBEXT_HPP */
|