hnswlib 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +2 -2
- data/ext/hnswlib/hnswlibext.hpp +547 -575
- data/lib/hnswlib/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a6aed34b1cfcf85d2df88ccb2f1a3b417780e0f614fe1a1e3870a060479207ff
|
4
|
+
data.tar.gz: ef5272e25876da853ae49cb26d39ecb7e7d708a25e3442d51e4d998f227334a9
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 0f9c77270306ea4363d94d45025f7e5e14da66ea5bd9c2ec2b0f62ec578afd2ee75c73a5cce2dac668a70adb9fdc653bfe9eb72e1608928fc648bb73571efd48
|
7
|
+
data.tar.gz: 3d5ab21ea7b9ae8d79f182fb406d98f147d7c534e4be992be238c2c4f450989cb353aa159d62e342ce3e58a7ed8756c59528400554f6229c5298ebbb23414ec3
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,10 @@
|
|
1
|
+
## [0.6.0] - 2022-04-16
|
2
|
+
|
3
|
+
**Breaking change:**
|
4
|
+
|
5
|
+
- Change the `search_knn` method of `HierarchicalNSW` to output warning message instead of rasing RuntimeError
|
6
|
+
when the number of search results is less than the requested number of neighbors.
|
7
|
+
|
1
8
|
## [0.5.3] - 2022-03-05
|
2
9
|
|
3
10
|
- Add error handling for std::runtime_error throwed from hnswlib.
|
data/README.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
[](https://github.com/yoshoku/hnswlib.rb/actions/workflows/build.yml)
|
4
4
|
[](https://badge.fury.io/rb/hnswlib)
|
5
5
|
[](https://github.com/yoshoku/hnswlib.rb/blob/main/LICENSE.txt)
|
6
|
-
[](https://yoshoku.github.io/hnswlib.rb/doc/)
|
7
7
|
|
8
8
|
Hnswlib.rb provides Ruby bindings for the [Hnswlib](https://github.com/nmslib/hnswlib)
|
9
9
|
that implements approximate nearest-neghbor search based on
|
@@ -71,4 +71,4 @@ The gem is available as open source under the terms of the [Apache-2.0 License](
|
|
71
71
|
|
72
72
|
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/hnswlib.rb.
|
73
73
|
This project is intended to be a safe, welcoming space for collaboration,
|
74
|
-
and contributors are expected to adhere to the [Contributor Covenant](
|
74
|
+
and contributors are expected to adhere to the [Contributor Covenant](https://contributor-covenant.org) code of conduct.
|
data/ext/hnswlib/hnswlibext.hpp
CHANGED
@@ -29,69 +29,64 @@ VALUE rb_cHnswlibHierarchicalNSW;
|
|
29
29
|
VALUE rb_cHnswlibBruteforceSearch;
|
30
30
|
|
31
31
|
class RbHnswlibL2Space {
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
32
|
+
public:
|
33
|
+
static VALUE hnsw_l2space_alloc(VALUE self) {
|
34
|
+
hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
|
35
|
+
new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
|
36
|
+
return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
|
37
|
+
};
|
38
|
+
|
39
|
+
static void hnsw_l2space_free(void* ptr) {
|
40
|
+
((hnswlib::L2Space*)ptr)->~L2Space();
|
41
|
+
ruby_xfree(ptr);
|
42
|
+
};
|
43
|
+
|
44
|
+
static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
|
45
|
+
|
46
|
+
static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
|
47
|
+
hnswlib::L2Space* ptr;
|
48
|
+
TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
|
49
|
+
return ptr;
|
50
|
+
};
|
51
|
+
|
52
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
53
|
+
rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
|
54
|
+
rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
|
55
|
+
rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
|
56
|
+
rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
|
57
|
+
rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
|
58
|
+
return rb_cHnswlibL2Space;
|
59
|
+
};
|
60
|
+
|
61
|
+
private:
|
62
|
+
static const rb_data_type_t hnsw_l2space_type;
|
63
|
+
|
64
|
+
static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
|
65
|
+
rb_iv_set(self, "@dim", dim);
|
66
|
+
hnswlib::L2Space* ptr = get_hnsw_l2space(self);
|
67
|
+
new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
|
68
|
+
return Qnil;
|
69
|
+
};
|
70
|
+
|
71
|
+
static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
72
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
73
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
74
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
70
75
|
return Qnil;
|
71
|
-
}
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
}
|
83
|
-
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
84
|
-
for (int i = 0; i < dim; i++) {
|
85
|
-
vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
86
|
-
}
|
87
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
88
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
89
|
-
ruby_xfree(vec_a);
|
90
|
-
ruby_xfree(vec_b);
|
91
|
-
return DBL2NUM((double)dist);
|
92
|
-
};
|
76
|
+
}
|
77
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
78
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
79
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
80
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
81
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
|
82
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
|
83
|
+
ruby_xfree(vec_a);
|
84
|
+
ruby_xfree(vec_b);
|
85
|
+
return DBL2NUM((double)dist);
|
86
|
+
};
|
93
87
|
};
|
94
88
|
|
89
|
+
// clang-format off
|
95
90
|
const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
96
91
|
"RbHnswlibL2Space",
|
97
92
|
{
|
@@ -103,71 +98,67 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
|
|
103
98
|
NULL,
|
104
99
|
RUBY_TYPED_FREE_IMMEDIATELY
|
105
100
|
};
|
101
|
+
// clang-format on
|
106
102
|
|
107
103
|
class RbHnswlibInnerProductSpace {
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
104
|
+
public:
|
105
|
+
static VALUE hnsw_ipspace_alloc(VALUE self) {
|
106
|
+
hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
|
107
|
+
new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
|
108
|
+
return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
|
109
|
+
};
|
110
|
+
|
111
|
+
static void hnsw_ipspace_free(void* ptr) {
|
112
|
+
((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
|
113
|
+
ruby_xfree(ptr);
|
114
|
+
};
|
115
|
+
|
116
|
+
static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
|
117
|
+
|
118
|
+
static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
|
119
|
+
hnswlib::InnerProductSpace* ptr;
|
120
|
+
TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
|
121
|
+
return ptr;
|
122
|
+
};
|
123
|
+
|
124
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
125
|
+
rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
|
126
|
+
rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
|
127
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
|
128
|
+
rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
|
129
|
+
rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
|
130
|
+
return rb_cHnswlibInnerProductSpace;
|
131
|
+
};
|
132
|
+
|
133
|
+
private:
|
134
|
+
static const rb_data_type_t hnsw_ipspace_type;
|
135
|
+
|
136
|
+
static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
|
137
|
+
rb_iv_set(self, "@dim", dim);
|
138
|
+
hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
|
139
|
+
new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
|
140
|
+
return Qnil;
|
141
|
+
};
|
142
|
+
|
143
|
+
static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
|
144
|
+
const int dim = NUM2INT(rb_iv_get(self, "@dim"));
|
145
|
+
if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
|
146
|
+
rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
|
146
147
|
return Qnil;
|
147
|
-
}
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
}
|
159
|
-
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
160
|
-
for (int i = 0; i < dim; i++) {
|
161
|
-
vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
162
|
-
}
|
163
|
-
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
164
|
-
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
165
|
-
ruby_xfree(vec_a);
|
166
|
-
ruby_xfree(vec_b);
|
167
|
-
return DBL2NUM((double)dist);
|
168
|
-
};
|
148
|
+
}
|
149
|
+
float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
|
150
|
+
for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
|
151
|
+
float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
|
152
|
+
for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
|
153
|
+
hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
|
154
|
+
const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
|
155
|
+
ruby_xfree(vec_a);
|
156
|
+
ruby_xfree(vec_b);
|
157
|
+
return DBL2NUM((double)dist);
|
158
|
+
};
|
169
159
|
};
|
170
160
|
|
161
|
+
// clang-format off
|
171
162
|
const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
172
163
|
"RbHnswlibInnerProductSpace",
|
173
164
|
{
|
@@ -179,294 +170,278 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
|
|
179
170
|
NULL,
|
180
171
|
RUBY_TYPED_FREE_IMMEDIATELY
|
181
172
|
};
|
173
|
+
// clang-format on
|
182
174
|
|
183
175
|
class RbHnswlibHierarchicalNSW {
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
|
242
|
-
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
243
|
-
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
244
|
-
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
245
|
-
|
246
|
-
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
247
|
-
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
248
|
-
return Qnil;
|
249
|
-
}
|
250
|
-
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
251
|
-
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
252
|
-
return Qnil;
|
253
|
-
}
|
254
|
-
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
255
|
-
rb_raise(rb_eTypeError, "expected m, Integer");
|
256
|
-
return Qnil;
|
257
|
-
}
|
258
|
-
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
259
|
-
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
260
|
-
return Qnil;
|
261
|
-
}
|
262
|
-
if (!RB_INTEGER_TYPE_P(kw_values[4])) {
|
263
|
-
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
264
|
-
return Qnil;
|
265
|
-
}
|
266
|
-
|
267
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
268
|
-
hnswlib::SpaceInterface<float>* space;
|
269
|
-
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
270
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
271
|
-
} else {
|
272
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
273
|
-
}
|
274
|
-
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
275
|
-
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
276
|
-
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
277
|
-
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
278
|
-
|
279
|
-
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
280
|
-
try {
|
281
|
-
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
282
|
-
} catch(const std::runtime_error& e) {
|
283
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
284
|
-
return Qnil;
|
285
|
-
}
|
286
|
-
|
176
|
+
public:
|
177
|
+
static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
|
178
|
+
hnswlib::HierarchicalNSW<float>* ptr =
|
179
|
+
(hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
|
180
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
|
181
|
+
return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
|
182
|
+
};
|
183
|
+
|
184
|
+
static void hnsw_hierarchicalnsw_free(void* ptr) {
|
185
|
+
((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
|
186
|
+
ruby_xfree(ptr);
|
187
|
+
};
|
188
|
+
|
189
|
+
static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
|
190
|
+
|
191
|
+
static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
|
192
|
+
hnswlib::HierarchicalNSW<float>* ptr;
|
193
|
+
TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
|
194
|
+
return ptr;
|
195
|
+
};
|
196
|
+
|
197
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
198
|
+
rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
|
199
|
+
rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
|
200
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
|
201
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
|
202
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
|
203
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
|
204
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
|
205
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
|
206
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
|
207
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
|
208
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
|
209
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
|
210
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
|
211
|
+
rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
|
212
|
+
rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
|
213
|
+
return rb_cHnswlibHierarchicalNSW;
|
214
|
+
};
|
215
|
+
|
216
|
+
private:
|
217
|
+
static const rb_data_type_t hnsw_hierarchicalnsw_type;
|
218
|
+
|
219
|
+
static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
|
220
|
+
VALUE kw_args = Qnil;
|
221
|
+
ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
|
222
|
+
rb_intern("random_seed")};
|
223
|
+
VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
|
224
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
225
|
+
rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
|
226
|
+
if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
|
227
|
+
if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
|
228
|
+
if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
|
229
|
+
|
230
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
231
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
232
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
287
233
|
return Qnil;
|
288
|
-
}
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
234
|
+
}
|
235
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
236
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
237
|
+
return Qnil;
|
238
|
+
}
|
239
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
240
|
+
rb_raise(rb_eTypeError, "expected m, Integer");
|
241
|
+
return Qnil;
|
242
|
+
}
|
243
|
+
if (!RB_INTEGER_TYPE_P(kw_values[3])) {
|
244
|
+
rb_raise(rb_eTypeError, "expected ef_construction, Integer");
|
245
|
+
return Qnil;
|
246
|
+
}
|
247
|
+
if (!RB_INTEGER_TYPE_P(kw_values[4])) {
|
248
|
+
rb_raise(rb_eTypeError, "expected random_seed, Integer");
|
249
|
+
return Qnil;
|
250
|
+
}
|
251
|
+
|
252
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
253
|
+
hnswlib::SpaceInterface<float>* space;
|
254
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
255
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
256
|
+
} else {
|
257
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
258
|
+
}
|
259
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
260
|
+
const size_t m = (size_t)NUM2INT(kw_values[2]);
|
261
|
+
const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
|
262
|
+
const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
|
263
|
+
|
264
|
+
hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
|
265
|
+
try {
|
266
|
+
new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
|
267
|
+
} catch (const std::runtime_error& e) {
|
268
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
269
|
+
return Qnil;
|
270
|
+
}
|
307
271
|
|
308
|
-
|
309
|
-
|
310
|
-
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
311
|
-
}
|
272
|
+
return Qnil;
|
273
|
+
};
|
312
274
|
|
313
|
-
|
275
|
+
static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
|
276
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
314
277
|
|
315
|
-
|
316
|
-
|
317
|
-
|
278
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
279
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
280
|
+
return Qfalse;
|
281
|
+
}
|
282
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
283
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
284
|
+
return Qfalse;
|
285
|
+
}
|
286
|
+
if (dim != RARRAY_LEN(arr)) {
|
287
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
288
|
+
return Qfalse;
|
289
|
+
}
|
318
290
|
|
319
|
-
|
320
|
-
|
291
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
292
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
321
293
|
|
322
|
-
|
323
|
-
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
324
|
-
return Qnil;
|
325
|
-
}
|
294
|
+
get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
326
295
|
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
}
|
296
|
+
ruby_xfree(vec);
|
297
|
+
return Qtrue;
|
298
|
+
};
|
331
299
|
|
332
|
-
|
333
|
-
|
334
|
-
return Qnil;
|
335
|
-
}
|
300
|
+
static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
|
301
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
336
302
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
303
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
304
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
305
|
+
return Qnil;
|
306
|
+
}
|
307
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
308
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
309
|
+
return Qnil;
|
310
|
+
}
|
311
|
+
if (dim != RARRAY_LEN(arr)) {
|
312
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
313
|
+
return Qnil;
|
314
|
+
}
|
341
315
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
ruby_xfree(vec);
|
347
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
348
|
-
return Qnil;
|
349
|
-
}
|
316
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
317
|
+
for (int i = 0; i < dim; i++) {
|
318
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
319
|
+
}
|
350
320
|
|
321
|
+
std::priority_queue<std::pair<float, size_t>> result;
|
322
|
+
try {
|
323
|
+
result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
324
|
+
} catch (const std::runtime_error& e) {
|
351
325
|
ruby_xfree(vec);
|
352
|
-
|
353
|
-
if (result.size() != (size_t)NUM2INT(k)) {
|
354
|
-
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
355
|
-
return Qnil;
|
356
|
-
}
|
357
|
-
|
358
|
-
VALUE distances_arr = rb_ary_new2(result.size());
|
359
|
-
VALUE neighbors_arr = rb_ary_new2(result.size());
|
360
|
-
|
361
|
-
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
362
|
-
const std::pair<float, size_t>& result_tuple = result.top();
|
363
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
364
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
365
|
-
result.pop();
|
366
|
-
}
|
367
|
-
|
368
|
-
VALUE ret = rb_ary_new2(2);
|
369
|
-
rb_ary_store(ret, 0, neighbors_arr);
|
370
|
-
rb_ary_store(ret, 1, distances_arr);
|
371
|
-
return ret;
|
372
|
-
};
|
373
|
-
|
374
|
-
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
375
|
-
std::string filename(StringValuePtr(_filename));
|
376
|
-
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
377
|
-
RB_GC_GUARD(_filename);
|
326
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
378
327
|
return Qnil;
|
379
|
-
}
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
328
|
+
}
|
329
|
+
|
330
|
+
ruby_xfree(vec);
|
331
|
+
|
332
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
333
|
+
rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
|
334
|
+
}
|
335
|
+
|
336
|
+
VALUE distances_arr = rb_ary_new();
|
337
|
+
VALUE neighbors_arr = rb_ary_new();
|
338
|
+
|
339
|
+
while (!result.empty()) {
|
340
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
341
|
+
rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
|
342
|
+
rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
|
343
|
+
result.pop();
|
344
|
+
}
|
345
|
+
|
346
|
+
VALUE ret = rb_ary_new2(2);
|
347
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
348
|
+
rb_ary_store(ret, 1, distances_arr);
|
349
|
+
return ret;
|
350
|
+
};
|
351
|
+
|
352
|
+
static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
|
353
|
+
std::string filename(StringValuePtr(_filename));
|
354
|
+
get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
|
355
|
+
RB_GC_GUARD(_filename);
|
356
|
+
return Qnil;
|
357
|
+
};
|
358
|
+
|
359
|
+
static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
|
360
|
+
std::string filename(StringValuePtr(_filename));
|
361
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
362
|
+
hnswlib::SpaceInterface<float>* space;
|
363
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
364
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
365
|
+
} else {
|
366
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
367
|
+
}
|
368
|
+
hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
|
369
|
+
if (index->data_level0_memory_) free(index->data_level0_memory_);
|
370
|
+
if (index->linkLists_) {
|
371
|
+
for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
|
372
|
+
if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
|
373
|
+
}
|
374
|
+
free(index->linkLists_);
|
375
|
+
}
|
376
|
+
if (index->visited_list_pool_) delete index->visited_list_pool_;
|
377
|
+
try {
|
378
|
+
index->loadIndex(filename, space);
|
379
|
+
} catch (const std::runtime_error& e) {
|
380
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
406
381
|
return Qnil;
|
407
|
-
}
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
}
|
421
|
-
return ret;
|
422
|
-
};
|
423
|
-
|
424
|
-
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
425
|
-
VALUE ret = rb_ary_new();
|
426
|
-
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) {
|
427
|
-
rb_ary_push(ret, INT2NUM((int)kv.first));
|
428
|
-
}
|
429
|
-
return ret;
|
430
|
-
};
|
431
|
-
|
432
|
-
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
433
|
-
try {
|
434
|
-
get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
|
435
|
-
} catch(const std::runtime_error& e) {
|
436
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
437
|
-
return Qnil;
|
438
|
-
}
|
382
|
+
}
|
383
|
+
RB_GC_GUARD(_filename);
|
384
|
+
return Qnil;
|
385
|
+
};
|
386
|
+
|
387
|
+
static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
|
388
|
+
VALUE ret = Qnil;
|
389
|
+
try {
|
390
|
+
std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
|
391
|
+
ret = rb_ary_new2(vec.size());
|
392
|
+
for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
|
393
|
+
} catch (const std::runtime_error& e) {
|
394
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
439
395
|
return Qnil;
|
440
|
-
}
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
396
|
+
}
|
397
|
+
return ret;
|
398
|
+
};
|
399
|
+
|
400
|
+
static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
|
401
|
+
VALUE ret = rb_ary_new();
|
402
|
+
for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
|
403
|
+
return ret;
|
404
|
+
};
|
405
|
+
|
406
|
+
static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
|
407
|
+
try {
|
408
|
+
get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
|
409
|
+
} catch (const std::runtime_error& e) {
|
410
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
453
411
|
return Qnil;
|
454
|
-
}
|
412
|
+
}
|
413
|
+
return Qnil;
|
414
|
+
};
|
455
415
|
|
456
|
-
|
457
|
-
|
416
|
+
static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
|
417
|
+
if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
|
418
|
+
rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
|
458
419
|
return Qnil;
|
459
|
-
}
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
420
|
+
}
|
421
|
+
try {
|
422
|
+
get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
|
423
|
+
} catch (const std::runtime_error& e) {
|
424
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
425
|
+
return Qnil;
|
426
|
+
}
|
427
|
+
return Qnil;
|
428
|
+
};
|
429
|
+
|
430
|
+
static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
|
431
|
+
get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
|
432
|
+
return Qnil;
|
433
|
+
};
|
434
|
+
|
435
|
+
static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
|
436
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
|
437
|
+
};
|
438
|
+
|
439
|
+
static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
|
440
|
+
return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
|
441
|
+
};
|
468
442
|
};
|
469
443
|
|
444
|
+
// clang-format off
|
470
445
|
const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
471
446
|
"RbHnswlibHierarchicalNSW",
|
472
447
|
{
|
@@ -478,210 +453,206 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
|
|
478
453
|
NULL,
|
479
454
|
RUBY_TYPED_FREE_IMMEDIATELY
|
480
455
|
};
|
456
|
+
// clang-format on
|
481
457
|
|
482
458
|
class RbHnswlibBruteforceSearch {
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
return Qnil;
|
533
|
-
}
|
534
|
-
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
535
|
-
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
536
|
-
return Qnil;
|
537
|
-
}
|
538
|
-
|
539
|
-
rb_iv_set(self, "@space", kw_values[0]);
|
540
|
-
hnswlib::SpaceInterface<float>* space;
|
541
|
-
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
542
|
-
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
543
|
-
} else {
|
544
|
-
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
545
|
-
}
|
546
|
-
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
547
|
-
|
548
|
-
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
549
|
-
try {
|
550
|
-
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
551
|
-
} catch(const std::runtime_error& e) {
|
552
|
-
rb_raise(rb_eRuntimeError, "%s", e.what());
|
553
|
-
return Qnil;
|
554
|
-
}
|
555
|
-
|
459
|
+
public:
|
460
|
+
static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
|
461
|
+
hnswlib::BruteforceSearch<float>* ptr =
|
462
|
+
(hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
|
463
|
+
new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
|
464
|
+
return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
|
465
|
+
};
|
466
|
+
|
467
|
+
static void hnsw_bruteforcesearch_free(void* ptr) {
|
468
|
+
((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
|
469
|
+
ruby_xfree(ptr);
|
470
|
+
};
|
471
|
+
|
472
|
+
static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
|
473
|
+
|
474
|
+
static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
|
475
|
+
hnswlib::BruteforceSearch<float>* ptr;
|
476
|
+
TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
|
477
|
+
return ptr;
|
478
|
+
};
|
479
|
+
|
480
|
+
static VALUE define_class(VALUE rb_mHnswlib) {
|
481
|
+
rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
|
482
|
+
rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
|
483
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
|
484
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
|
485
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
|
486
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
|
487
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
|
488
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
|
489
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
|
490
|
+
rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
|
491
|
+
rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
|
492
|
+
return rb_cHnswlibBruteforceSearch;
|
493
|
+
};
|
494
|
+
|
495
|
+
private:
|
496
|
+
static const rb_data_type_t hnsw_bruteforcesearch_type;
|
497
|
+
|
498
|
+
static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
|
499
|
+
VALUE kw_args = Qnil;
|
500
|
+
ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
|
501
|
+
VALUE kw_values[2] = {Qundef, Qundef};
|
502
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
503
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
504
|
+
|
505
|
+
if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
|
506
|
+
rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
|
507
|
+
rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
|
556
508
|
return Qnil;
|
557
|
-
}
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
509
|
+
}
|
510
|
+
if (!RB_INTEGER_TYPE_P(kw_values[1])) {
|
511
|
+
rb_raise(rb_eTypeError, "expected max_elements, Integer");
|
512
|
+
return Qnil;
|
513
|
+
}
|
514
|
+
|
515
|
+
rb_iv_set(self, "@space", kw_values[0]);
|
516
|
+
hnswlib::SpaceInterface<float>* space;
|
517
|
+
if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
|
518
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
|
519
|
+
} else {
|
520
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
|
521
|
+
}
|
522
|
+
const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
|
523
|
+
|
524
|
+
hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
|
525
|
+
try {
|
526
|
+
new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
|
527
|
+
} catch (const std::runtime_error& e) {
|
528
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
529
|
+
return Qnil;
|
530
|
+
}
|
531
|
+
|
532
|
+
return Qnil;
|
533
|
+
};
|
534
|
+
|
535
|
+
static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
|
536
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
537
|
+
|
538
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
539
|
+
rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
|
540
|
+
return Qfalse;
|
541
|
+
}
|
542
|
+
if (!RB_INTEGER_TYPE_P(idx)) {
|
543
|
+
rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
|
544
|
+
return Qfalse;
|
545
|
+
}
|
546
|
+
if (dim != RARRAY_LEN(arr)) {
|
547
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
548
|
+
return Qfalse;
|
549
|
+
}
|
550
|
+
|
551
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
552
|
+
for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
553
|
+
|
554
|
+
try {
|
555
|
+
get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
|
556
|
+
} catch (const std::runtime_error& e) {
|
590
557
|
ruby_xfree(vec);
|
591
|
-
|
592
|
-
|
558
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
559
|
+
return Qfalse;
|
560
|
+
}
|
593
561
|
|
594
|
-
|
595
|
-
|
562
|
+
ruby_xfree(vec);
|
563
|
+
return Qtrue;
|
564
|
+
};
|
596
565
|
|
597
|
-
|
598
|
-
|
599
|
-
return Qnil;
|
600
|
-
}
|
601
|
-
|
602
|
-
if (!RB_INTEGER_TYPE_P(k)) {
|
603
|
-
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
604
|
-
return Qnil;
|
605
|
-
}
|
566
|
+
static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
|
567
|
+
const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
|
606
568
|
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
ruby_xfree(vec);
|
621
|
-
|
622
|
-
if (result.size() != (size_t)NUM2INT(k)) {
|
623
|
-
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
624
|
-
return Qnil;
|
625
|
-
}
|
569
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
570
|
+
rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
|
571
|
+
return Qnil;
|
572
|
+
}
|
573
|
+
if (!RB_INTEGER_TYPE_P(k)) {
|
574
|
+
rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
|
575
|
+
return Qnil;
|
576
|
+
}
|
577
|
+
if (dim != RARRAY_LEN(arr)) {
|
578
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
579
|
+
return Qnil;
|
580
|
+
}
|
626
581
|
|
627
|
-
|
628
|
-
|
582
|
+
float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
|
583
|
+
for (int i = 0; i < dim; i++) {
|
584
|
+
vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
|
585
|
+
}
|
629
586
|
|
630
|
-
|
631
|
-
|
632
|
-
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
633
|
-
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
634
|
-
result.pop();
|
635
|
-
}
|
587
|
+
std::priority_queue<std::pair<float, size_t>> result =
|
588
|
+
get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
|
636
589
|
|
637
|
-
|
638
|
-
rb_ary_store(ret, 0, neighbors_arr);
|
639
|
-
rb_ary_store(ret, 1, distances_arr);
|
640
|
-
return ret;
|
641
|
-
};
|
590
|
+
ruby_xfree(vec);
|
642
591
|
|
643
|
-
|
644
|
-
|
645
|
-
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
646
|
-
RB_GC_GUARD(_filename);
|
592
|
+
if (result.size() != (size_t)NUM2INT(k)) {
|
593
|
+
rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
|
647
594
|
return Qnil;
|
648
|
-
}
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
595
|
+
}
|
596
|
+
|
597
|
+
VALUE distances_arr = rb_ary_new2(result.size());
|
598
|
+
VALUE neighbors_arr = rb_ary_new2(result.size());
|
599
|
+
|
600
|
+
for (int i = NUM2INT(k) - 1; i >= 0; i--) {
|
601
|
+
const std::pair<float, size_t>& result_tuple = result.top();
|
602
|
+
rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
|
603
|
+
rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
|
604
|
+
result.pop();
|
605
|
+
}
|
606
|
+
|
607
|
+
VALUE ret = rb_ary_new2(2);
|
608
|
+
rb_ary_store(ret, 0, neighbors_arr);
|
609
|
+
rb_ary_store(ret, 1, distances_arr);
|
610
|
+
return ret;
|
611
|
+
};
|
612
|
+
|
613
|
+
static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
|
614
|
+
std::string filename(StringValuePtr(_filename));
|
615
|
+
get_hnsw_bruteforcesearch(self)->saveIndex(filename);
|
616
|
+
RB_GC_GUARD(_filename);
|
617
|
+
return Qnil;
|
618
|
+
};
|
619
|
+
|
620
|
+
static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
|
621
|
+
std::string filename(StringValuePtr(_filename));
|
622
|
+
VALUE ivspace = rb_iv_get(self, "@space");
|
623
|
+
hnswlib::SpaceInterface<float>* space;
|
624
|
+
if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
|
625
|
+
space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
|
626
|
+
} else {
|
627
|
+
space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
|
628
|
+
}
|
629
|
+
hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
|
630
|
+
if (index->data_) free(index->data_);
|
631
|
+
try {
|
632
|
+
index->loadIndex(filename, space);
|
633
|
+
} catch (const std::runtime_error& e) {
|
634
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
673
635
|
return Qnil;
|
674
|
-
}
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
636
|
+
}
|
637
|
+
RB_GC_GUARD(_filename);
|
638
|
+
return Qnil;
|
639
|
+
};
|
640
|
+
|
641
|
+
static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
|
642
|
+
get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
|
643
|
+
return Qnil;
|
644
|
+
};
|
645
|
+
|
646
|
+
static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
|
647
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
|
648
|
+
};
|
649
|
+
|
650
|
+
static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
|
651
|
+
return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
|
652
|
+
};
|
683
653
|
};
|
684
654
|
|
655
|
+
// clang-format off
|
685
656
|
const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
686
657
|
"RbHnswlibBruteforceSearch",
|
687
658
|
{
|
@@ -693,5 +664,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
|
|
693
664
|
NULL,
|
694
665
|
RUBY_TYPED_FREE_IMMEDIATELY
|
695
666
|
};
|
667
|
+
// clang-format on
|
696
668
|
|
697
669
|
#endif /* HNSWLIBEXT_HPP */
|
data/lib/hnswlib/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: hnswlib
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.6.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-04-16 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: Hnswlib.rb provides Ruby bindings for the Hnswlib.
|
14
14
|
email:
|