hnswlib 0.5.3 → 0.6.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -29,69 +29,68 @@ VALUE rb_cHnswlibHierarchicalNSW;
29
29
  VALUE rb_cHnswlibBruteforceSearch;
30
30
 
31
31
  class RbHnswlibL2Space {
32
- public:
33
- static VALUE hnsw_l2space_alloc(VALUE self) {
34
- hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
- new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
- return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
- };
38
-
39
- static void hnsw_l2space_free(void* ptr) {
40
- ((hnswlib::L2Space*)ptr)->~L2Space();
41
- ruby_xfree(ptr);
42
- };
43
-
44
- static size_t hnsw_l2space_size(const void* ptr) {
45
- return sizeof(*((hnswlib::L2Space*)ptr));
46
- };
47
-
48
- static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
49
- hnswlib::L2Space* ptr;
50
- TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
51
- return ptr;
52
- };
53
-
54
- static VALUE define_class(VALUE rb_mHnswlib) {
55
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
56
- rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
57
- rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
58
- rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
59
- rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
60
- return rb_cHnswlibL2Space;
61
- };
62
-
63
- private:
64
- static const rb_data_type_t hnsw_l2space_type;
65
-
66
- static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
67
- rb_iv_set(self, "@dim", dim);
68
- hnswlib::L2Space* ptr = get_hnsw_l2space(self);
69
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
32
+ public:
33
+ static VALUE hnsw_l2space_alloc(VALUE self) {
34
+ hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
+ new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
+ return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
+ };
38
+
39
+ static void hnsw_l2space_free(void* ptr) {
40
+ ((hnswlib::L2Space*)ptr)->~L2Space();
41
+ ruby_xfree(ptr);
42
+ };
43
+
44
+ static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
45
+
46
+ static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
47
+ hnswlib::L2Space* ptr;
48
+ TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
49
+ return ptr;
50
+ };
51
+
52
+ static VALUE define_class(VALUE rb_mHnswlib) {
53
+ rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
54
+ rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
55
+ rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
56
+ rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
57
+ rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
58
+ return rb_cHnswlibL2Space;
59
+ };
60
+
61
+ private:
62
+ static const rb_data_type_t hnsw_l2space_type;
63
+
64
+ static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
+ rb_iv_set(self, "@dim", dim);
66
+ hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
+ new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
68
+ return Qnil;
69
+ };
70
+
71
+ static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
73
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
74
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
70
75
  return Qnil;
71
- };
72
-
73
- static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
74
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
76
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
77
- return Qnil;
78
- }
79
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
80
- for (int i = 0; i < dim; i++) {
81
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
82
- }
83
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) {
85
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
86
- }
87
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
88
- const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
89
- ruby_xfree(vec_a);
90
- ruby_xfree(vec_b);
91
- return DBL2NUM((double)dist);
92
- };
76
+ }
77
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
78
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
79
+ return Qnil;
80
+ }
81
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
82
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
83
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
85
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
86
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
87
+ ruby_xfree(vec_a);
88
+ ruby_xfree(vec_b);
89
+ return DBL2NUM((double)dist);
90
+ };
93
91
  };
94
92
 
93
+ // clang-format off
95
94
  const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
96
95
  "RbHnswlibL2Space",
97
96
  {
@@ -103,71 +102,71 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
103
102
  NULL,
104
103
  RUBY_TYPED_FREE_IMMEDIATELY
105
104
  };
105
+ // clang-format on
106
106
 
107
107
  class RbHnswlibInnerProductSpace {
108
- public:
109
- static VALUE hnsw_ipspace_alloc(VALUE self) {
110
- hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
111
- new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
112
- return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
113
- };
114
-
115
- static void hnsw_ipspace_free(void* ptr) {
116
- ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
117
- ruby_xfree(ptr);
118
- };
119
-
120
- static size_t hnsw_ipspace_size(const void* ptr) {
121
- return sizeof(*((hnswlib::InnerProductSpace*)ptr));
122
- };
123
-
124
- static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
125
- hnswlib::InnerProductSpace* ptr;
126
- TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
127
- return ptr;
128
- };
129
-
130
- static VALUE define_class(VALUE rb_mHnswlib) {
131
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
132
- rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
133
- rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
134
- rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
135
- rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
136
- return rb_cHnswlibInnerProductSpace;
137
- };
138
-
139
- private:
140
- static const rb_data_type_t hnsw_ipspace_type;
141
-
142
- static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
143
- rb_iv_set(self, "@dim", dim);
144
- hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
145
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
108
+ public:
109
+ static VALUE hnsw_ipspace_alloc(VALUE self) {
110
+ hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
111
+ new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
112
+ return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
113
+ };
114
+
115
+ static void hnsw_ipspace_free(void* ptr) {
116
+ ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
117
+ ruby_xfree(ptr);
118
+ };
119
+
120
+ static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
121
+
122
+ static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
123
+ hnswlib::InnerProductSpace* ptr;
124
+ TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
125
+ return ptr;
126
+ };
127
+
128
+ static VALUE define_class(VALUE rb_mHnswlib) {
129
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
130
+ rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
131
+ rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
132
+ rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
133
+ rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
134
+ return rb_cHnswlibInnerProductSpace;
135
+ };
136
+
137
+ private:
138
+ static const rb_data_type_t hnsw_ipspace_type;
139
+
140
+ static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
141
+ rb_iv_set(self, "@dim", dim);
142
+ hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
143
+ new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
144
+ return Qnil;
145
+ };
146
+
147
+ static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
148
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
149
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
150
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
146
151
  return Qnil;
147
- };
148
-
149
- static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
150
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
152
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
153
- return Qnil;
154
- }
155
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
156
- for (int i = 0; i < dim; i++) {
157
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
158
- }
159
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) {
161
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
162
- }
163
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
164
- const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
165
- ruby_xfree(vec_a);
166
- ruby_xfree(vec_b);
167
- return DBL2NUM((double)dist);
168
- };
152
+ }
153
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
154
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
155
+ return Qnil;
156
+ }
157
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
158
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
159
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
161
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
162
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
163
+ ruby_xfree(vec_a);
164
+ ruby_xfree(vec_b);
165
+ return DBL2NUM((double)dist);
166
+ };
169
167
  };
170
168
 
169
+ // clang-format off
171
170
  const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
172
171
  "RbHnswlibInnerProductSpace",
173
172
  {
@@ -179,294 +178,288 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
179
178
  NULL,
180
179
  RUBY_TYPED_FREE_IMMEDIATELY
181
180
  };
181
+ // clang-format on
182
182
 
183
183
  class RbHnswlibHierarchicalNSW {
184
- public:
185
- static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
186
- hnswlib::HierarchicalNSW<float>* ptr = (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
187
- new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
188
- return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
189
- };
190
-
191
- static void hnsw_hierarchicalnsw_free(void* ptr) {
192
- ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
193
- ruby_xfree(ptr);
194
- };
195
-
196
- static size_t hnsw_hierarchicalnsw_size(const void* ptr) {
197
- return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr));
198
- };
199
-
200
- static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
201
- hnswlib::HierarchicalNSW<float>* ptr;
202
- TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
203
- return ptr;
204
- };
205
-
206
- static VALUE define_class(VALUE rb_mHnswlib) {
207
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
208
- rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
211
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
213
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
214
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
215
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
216
- rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
217
- rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
218
- rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
219
- rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
220
- rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
221
- rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
222
- return rb_cHnswlibHierarchicalNSW;
223
- };
224
-
225
- private:
226
- static const rb_data_type_t hnsw_hierarchicalnsw_type;
227
-
228
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
229
- VALUE kw_args = Qnil;
230
- ID kw_table[5] = {
231
- rb_intern("space"),
232
- rb_intern("max_elements"),
233
- rb_intern("m"),
234
- rb_intern("ef_construction"),
235
- rb_intern("random_seed")
236
- };
237
- VALUE kw_values[5] = {
238
- Qundef, Qundef, Qundef, Qundef, Qundef
239
- };
240
- rb_scan_args(argc, argv, ":", &kw_args);
241
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
242
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
243
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
244
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
245
-
246
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
247
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
248
- return Qnil;
249
- }
250
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
251
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
252
- return Qnil;
253
- }
254
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
255
- rb_raise(rb_eTypeError, "expected m, Integer");
256
- return Qnil;
257
- }
258
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
259
- rb_raise(rb_eTypeError, "expected ef_construction, Integer");
260
- return Qnil;
261
- }
262
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
263
- rb_raise(rb_eTypeError, "expected random_seed, Integer");
264
- return Qnil;
265
- }
266
-
267
- rb_iv_set(self, "@space", kw_values[0]);
268
- hnswlib::SpaceInterface<float>* space;
269
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
270
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
271
- } else {
272
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
273
- }
274
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
275
- const size_t m = (size_t)NUM2INT(kw_values[2]);
276
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
277
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
278
-
279
- hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
280
- try {
281
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
282
- } catch(const std::runtime_error& e) {
283
- rb_raise(rb_eRuntimeError, "%s", e.what());
284
- return Qnil;
285
- }
286
-
184
+ public:
185
+ static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
186
+ hnswlib::HierarchicalNSW<float>* ptr =
187
+ (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
188
+ new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
189
+ return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
190
+ };
191
+
192
+ static void hnsw_hierarchicalnsw_free(void* ptr) {
193
+ ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
194
+ ruby_xfree(ptr);
195
+ };
196
+
197
+ static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
198
+
199
+ static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
200
+ hnswlib::HierarchicalNSW<float>* ptr;
201
+ TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
202
+ return ptr;
203
+ };
204
+
205
+ static VALUE define_class(VALUE rb_mHnswlib) {
206
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
207
+ rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
208
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
209
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
210
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
211
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
212
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
213
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
214
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
215
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
216
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
217
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
218
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
219
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
220
+ rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
221
+ return rb_cHnswlibHierarchicalNSW;
222
+ };
223
+
224
+ private:
225
+ static const rb_data_type_t hnsw_hierarchicalnsw_type;
226
+
227
+ static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
228
+ VALUE kw_args = Qnil;
229
+ ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
230
+ rb_intern("random_seed")};
231
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
232
+ rb_scan_args(argc, argv, ":", &kw_args);
233
+ rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
234
+ if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
235
+ if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
236
+ if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
237
+
238
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
239
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
240
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
287
241
  return Qnil;
288
- };
289
-
290
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
291
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
292
-
293
- if (!RB_TYPE_P(arr, T_ARRAY)) {
294
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
295
- return Qfalse;
296
- }
297
-
298
- if (!RB_INTEGER_TYPE_P(idx)) {
299
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
300
- return Qfalse;
301
- }
302
-
303
- if (dim != RARRAY_LEN(arr)) {
304
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
305
- return Qfalse;
306
- }
242
+ }
243
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
244
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
245
+ return Qnil;
246
+ }
247
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
248
+ rb_raise(rb_eTypeError, "expected m, Integer");
249
+ return Qnil;
250
+ }
251
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
252
+ rb_raise(rb_eTypeError, "expected ef_construction, Integer");
253
+ return Qnil;
254
+ }
255
+ if (!RB_INTEGER_TYPE_P(kw_values[4])) {
256
+ rb_raise(rb_eTypeError, "expected random_seed, Integer");
257
+ return Qnil;
258
+ }
259
+
260
+ rb_iv_set(self, "@space", kw_values[0]);
261
+ hnswlib::SpaceInterface<float>* space;
262
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
263
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
264
+ } else {
265
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
266
+ }
267
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
268
+ const size_t m = (size_t)NUM2INT(kw_values[2]);
269
+ const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
270
+ const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
271
+
272
+ hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
273
+ try {
274
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
275
+ } catch (const std::runtime_error& e) {
276
+ rb_raise(rb_eRuntimeError, "%s", e.what());
277
+ return Qnil;
278
+ }
307
279
 
308
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
309
- for (int i = 0; i < dim; i++) {
310
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
311
- }
280
+ return Qnil;
281
+ };
312
282
 
313
- get_hnsw_hierarchicalnsw(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
283
+ static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
284
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
314
285
 
315
- ruby_xfree(vec);
316
- return Qtrue;
317
- };
286
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
287
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
288
+ return Qfalse;
289
+ }
290
+ if (!RB_INTEGER_TYPE_P(idx)) {
291
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
292
+ return Qfalse;
293
+ }
294
+ if (dim != RARRAY_LEN(arr)) {
295
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
296
+ return Qfalse;
297
+ }
318
298
 
319
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
320
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
299
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
300
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
321
301
 
322
- if (!RB_TYPE_P(arr, T_ARRAY)) {
323
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
324
- return Qnil;
325
- }
302
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
326
303
 
327
- if (!RB_INTEGER_TYPE_P(k)) {
328
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
329
- return Qnil;
330
- }
304
+ ruby_xfree(vec);
305
+ return Qtrue;
306
+ };
331
307
 
332
- if (dim != RARRAY_LEN(arr)) {
333
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
334
- return Qnil;
335
- }
308
+ static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
309
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
336
310
 
337
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
338
- for (int i = 0; i < dim; i++) {
339
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
340
- }
311
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
312
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
313
+ return Qnil;
314
+ }
315
+ if (!RB_INTEGER_TYPE_P(k)) {
316
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
317
+ return Qnil;
318
+ }
319
+ if (dim != RARRAY_LEN(arr)) {
320
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
321
+ return Qnil;
322
+ }
341
323
 
342
- std::priority_queue<std::pair<float, size_t>> result;
343
- try {
344
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
345
- } catch(const std::runtime_error& e) {
346
- ruby_xfree(vec);
347
- rb_raise(rb_eRuntimeError, "%s", e.what());
348
- return Qnil;
349
- }
324
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
325
+ for (int i = 0; i < dim; i++) {
326
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
327
+ }
350
328
 
329
+ std::priority_queue<std::pair<float, size_t>> result;
330
+ try {
331
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
332
+ } catch (const std::runtime_error& e) {
351
333
  ruby_xfree(vec);
352
-
353
- if (result.size() != (size_t)NUM2INT(k)) {
354
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
355
- return Qnil;
356
- }
357
-
358
- VALUE distances_arr = rb_ary_new2(result.size());
359
- VALUE neighbors_arr = rb_ary_new2(result.size());
360
-
361
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
362
- const std::pair<float, size_t>& result_tuple = result.top();
363
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
364
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
365
- result.pop();
366
- }
367
-
368
- VALUE ret = rb_ary_new2(2);
369
- rb_ary_store(ret, 0, neighbors_arr);
370
- rb_ary_store(ret, 1, distances_arr);
371
- return ret;
372
- };
373
-
374
- static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
375
- std::string filename(StringValuePtr(_filename));
376
- get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
377
- RB_GC_GUARD(_filename);
334
+ rb_raise(rb_eRuntimeError, "%s", e.what());
378
335
  return Qnil;
379
- };
380
-
381
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
382
- std::string filename(StringValuePtr(_filename));
383
- VALUE ivspace = rb_iv_get(self, "@space");
384
- hnswlib::SpaceInterface<float>* space;
385
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
386
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
387
- } else {
388
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
389
- }
390
- hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
391
- if (index->data_level0_memory_) free(index->data_level0_memory_);
392
- if (index->linkLists_) {
393
- for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
394
- if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
336
+ }
337
+
338
+ ruby_xfree(vec);
339
+
340
+ if (result.size() != (size_t)NUM2INT(k)) {
341
+ rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
342
+ }
343
+
344
+ VALUE distances_arr = rb_ary_new();
345
+ VALUE neighbors_arr = rb_ary_new();
346
+
347
+ while (!result.empty()) {
348
+ const std::pair<float, size_t>& result_tuple = result.top();
349
+ rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
350
+ rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
351
+ result.pop();
352
+ }
353
+
354
+ VALUE ret = rb_ary_new2(2);
355
+ rb_ary_store(ret, 0, neighbors_arr);
356
+ rb_ary_store(ret, 1, distances_arr);
357
+ return ret;
358
+ };
359
+
360
+ static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
361
+ std::string filename(StringValuePtr(_filename));
362
+ get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
363
+ RB_GC_GUARD(_filename);
364
+ return Qnil;
365
+ };
366
+
367
+ static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
368
+ std::string filename(StringValuePtr(_filename));
369
+ VALUE ivspace = rb_iv_get(self, "@space");
370
+ hnswlib::SpaceInterface<float>* space;
371
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
372
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
373
+ } else {
374
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
375
+ }
376
+ hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
377
+ if (index->data_level0_memory_) {
378
+ free(index->data_level0_memory_);
379
+ index->data_level0_memory_ = nullptr;
380
+ }
381
+ if (index->linkLists_) {
382
+ for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
383
+ if (index->element_levels_[i] > 0 && index->linkLists_[i]) {
384
+ free(index->linkLists_[i]);
385
+ index->linkLists_[i] = nullptr;
395
386
  }
396
- free(index->linkLists_);
397
387
  }
398
- if (index->visited_list_pool_) delete index->visited_list_pool_;
399
- try {
400
- index->loadIndex(filename, space);
401
- } catch(const std::runtime_error& e) {
402
- rb_raise(rb_eRuntimeError, "%s", e.what());
403
- return Qnil;
404
- }
405
- RB_GC_GUARD(_filename);
388
+ free(index->linkLists_);
389
+ index->linkLists_ = nullptr;
390
+ }
391
+ if (index->visited_list_pool_) {
392
+ delete index->visited_list_pool_;
393
+ index->visited_list_pool_ = nullptr;
394
+ }
395
+ try {
396
+ index->loadIndex(filename, space);
397
+ } catch (const std::runtime_error& e) {
398
+ rb_raise(rb_eRuntimeError, "%s", e.what());
406
399
  return Qnil;
407
- };
408
-
409
- static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
410
- VALUE ret = Qnil;
411
- try {
412
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
413
- ret = rb_ary_new2(vec.size());
414
- for (size_t i = 0; i < vec.size(); i++) {
415
- rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
416
- }
417
- } catch(const std::runtime_error& e) {
418
- rb_raise(rb_eRuntimeError, "%s", e.what());
419
- return Qnil;
420
- }
421
- return ret;
422
- };
423
-
424
- static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
425
- VALUE ret = rb_ary_new();
426
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) {
427
- rb_ary_push(ret, INT2NUM((int)kv.first));
428
- }
429
- return ret;
430
- };
431
-
432
- static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
433
- try {
434
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
435
- } catch(const std::runtime_error& e) {
436
- rb_raise(rb_eRuntimeError, "%s", e.what());
437
- return Qnil;
438
- }
400
+ }
401
+ RB_GC_GUARD(_filename);
402
+ return Qnil;
403
+ };
404
+
405
+ static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
406
+ VALUE ret = Qnil;
407
+ try {
408
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
409
+ ret = rb_ary_new2(vec.size());
410
+ for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
411
+ } catch (const std::runtime_error& e) {
412
+ rb_raise(rb_eRuntimeError, "%s", e.what());
439
413
  return Qnil;
440
- };
441
-
442
- static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
443
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
444
- rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
445
- return Qnil;
446
- }
447
- try {
448
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
449
- } catch(const std::runtime_error& e) {
450
- rb_raise(rb_eRuntimeError, "%s", e.what());
451
- return Qnil;
452
- }
414
+ }
415
+ return ret;
416
+ };
417
+
418
+ static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
419
+ VALUE ret = rb_ary_new();
420
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
421
+ return ret;
422
+ };
423
+
424
+ static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
425
+ try {
426
+ get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
427
+ } catch (const std::runtime_error& e) {
428
+ rb_raise(rb_eRuntimeError, "%s", e.what());
453
429
  return Qnil;
454
- };
430
+ }
431
+ return Qnil;
432
+ };
455
433
 
456
- static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
457
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
434
+ static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
435
+ if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
436
+ rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
458
437
  return Qnil;
459
- };
460
-
461
- static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
462
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
463
- };
464
-
465
- static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
466
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
467
- };
438
+ }
439
+ try {
440
+ get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
441
+ } catch (const std::runtime_error& e) {
442
+ rb_raise(rb_eRuntimeError, "%s", e.what());
443
+ return Qnil;
444
+ }
445
+ return Qnil;
446
+ };
447
+
448
+ static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
449
+ get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
450
+ return Qnil;
451
+ };
452
+
453
+ static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
454
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
455
+ };
456
+
457
+ static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
458
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
459
+ };
468
460
  };
469
461
 
462
+ // clang-format off
470
463
  const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
471
464
  "RbHnswlibHierarchicalNSW",
472
465
  {
@@ -478,210 +471,208 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
478
471
  NULL,
479
472
  RUBY_TYPED_FREE_IMMEDIATELY
480
473
  };
474
+ // clang-format on
481
475
 
482
476
  class RbHnswlibBruteforceSearch {
483
- public:
484
- static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
485
- hnswlib::BruteforceSearch<float>* ptr = (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
486
- new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
487
- return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
488
- };
489
-
490
- static void hnsw_bruteforcesearch_free(void* ptr) {
491
- ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
492
- ruby_xfree(ptr);
493
- };
494
-
495
- static size_t hnsw_bruteforcesearch_size(const void* ptr) {
496
- return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr));
497
- };
498
-
499
- static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
500
- hnswlib::BruteforceSearch<float>* ptr;
501
- TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
502
- return ptr;
503
- };
504
-
505
- static VALUE define_class(VALUE rb_mHnswlib) {
506
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
507
- rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
508
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
509
- rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
510
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
511
- rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
512
- rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
513
- rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
514
- rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
515
- rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
516
- rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
517
- return rb_cHnswlibBruteforceSearch;
518
- };
519
-
520
- private:
521
- static const rb_data_type_t hnsw_bruteforcesearch_type;
522
-
523
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
524
- VALUE kw_args = Qnil;
525
- ID kw_table[2] = { rb_intern("space"), rb_intern("max_elements") };
526
- VALUE kw_values[2] = { Qundef, Qundef };
527
- rb_scan_args(argc, argv, ":", &kw_args);
528
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
529
-
530
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
531
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
532
- return Qnil;
533
- }
534
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
535
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
536
- return Qnil;
537
- }
538
-
539
- rb_iv_set(self, "@space", kw_values[0]);
540
- hnswlib::SpaceInterface<float>* space;
541
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
542
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
543
- } else {
544
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
545
- }
546
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
547
-
548
- hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
549
- try {
550
- new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
551
- } catch(const std::runtime_error& e) {
552
- rb_raise(rb_eRuntimeError, "%s", e.what());
553
- return Qnil;
554
- }
555
-
477
+ public:
478
+ static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
479
+ hnswlib::BruteforceSearch<float>* ptr =
480
+ (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
481
+ new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
482
+ return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
483
+ };
484
+
485
+ static void hnsw_bruteforcesearch_free(void* ptr) {
486
+ ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
487
+ ruby_xfree(ptr);
488
+ };
489
+
490
+ static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
491
+
492
+ static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
493
+ hnswlib::BruteforceSearch<float>* ptr;
494
+ TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
495
+ return ptr;
496
+ };
497
+
498
+ static VALUE define_class(VALUE rb_mHnswlib) {
499
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
500
+ rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
501
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
502
+ rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
503
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
504
+ rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
505
+ rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
506
+ rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
507
+ rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
508
+ rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
509
+ rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
510
+ return rb_cHnswlibBruteforceSearch;
511
+ };
512
+
513
+ private:
514
+ static const rb_data_type_t hnsw_bruteforcesearch_type;
515
+
516
+ static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
517
+ VALUE kw_args = Qnil;
518
+ ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
519
+ VALUE kw_values[2] = {Qundef, Qundef};
520
+ rb_scan_args(argc, argv, ":", &kw_args);
521
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
522
+
523
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
524
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
525
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
556
526
  return Qnil;
557
- };
558
-
559
- static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
560
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
561
-
562
- if (!RB_TYPE_P(arr, T_ARRAY)) {
563
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
564
- return Qfalse;
565
- }
566
-
567
- if (!RB_INTEGER_TYPE_P(idx)) {
568
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
569
- return Qfalse;
570
- }
571
-
572
- if (dim != RARRAY_LEN(arr)) {
573
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
574
- return Qfalse;
575
- }
576
-
577
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
578
- for (int i = 0; i < dim; i++) {
579
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
580
- }
581
-
582
- try {
583
- get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
584
- } catch(const std::runtime_error& e) {
585
- ruby_xfree(vec);
586
- rb_raise(rb_eRuntimeError, "%s", e.what());
587
- return Qfalse;
588
- }
589
-
590
- ruby_xfree(vec);
591
- return Qtrue;
592
- };
593
-
594
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
595
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
596
-
597
- if (!RB_TYPE_P(arr, T_ARRAY)) {
598
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
599
- return Qnil;
600
- }
601
-
602
- if (!RB_INTEGER_TYPE_P(k)) {
603
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
604
- return Qnil;
605
- }
606
-
607
- if (dim != RARRAY_LEN(arr)) {
608
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
609
- return Qnil;
610
- }
611
-
612
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
613
- for (int i = 0; i < dim; i++) {
614
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
615
- }
616
-
617
- std::priority_queue<std::pair<float, size_t>> result =
618
- get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
619
-
527
+ }
528
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
529
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
530
+ return Qnil;
531
+ }
532
+
533
+ rb_iv_set(self, "@space", kw_values[0]);
534
+ hnswlib::SpaceInterface<float>* space;
535
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
536
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
537
+ } else {
538
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
539
+ }
540
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
541
+
542
+ hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
543
+ try {
544
+ new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
545
+ } catch (const std::runtime_error& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
548
+ }
549
+
550
+ return Qnil;
551
+ };
552
+
553
+ static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
554
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
555
+
556
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
557
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
558
+ return Qfalse;
559
+ }
560
+ if (!RB_INTEGER_TYPE_P(idx)) {
561
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
562
+ return Qfalse;
563
+ }
564
+ if (dim != RARRAY_LEN(arr)) {
565
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
566
+ return Qfalse;
567
+ }
568
+
569
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
570
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
571
+
572
+ try {
573
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
574
+ } catch (const std::runtime_error& e) {
620
575
  ruby_xfree(vec);
576
+ rb_raise(rb_eRuntimeError, "%s", e.what());
577
+ return Qfalse;
578
+ }
621
579
 
622
- if (result.size() != (size_t)NUM2INT(k)) {
623
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
624
- return Qnil;
625
- }
626
-
627
- VALUE distances_arr = rb_ary_new2(result.size());
628
- VALUE neighbors_arr = rb_ary_new2(result.size());
629
-
630
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
631
- const std::pair<float, size_t>& result_tuple = result.top();
632
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
633
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
634
- result.pop();
635
- }
580
+ ruby_xfree(vec);
581
+ return Qtrue;
582
+ };
636
583
 
637
- VALUE ret = rb_ary_new2(2);
638
- rb_ary_store(ret, 0, neighbors_arr);
639
- rb_ary_store(ret, 1, distances_arr);
640
- return ret;
641
- };
584
+ static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
585
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
642
586
 
643
- static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
644
- std::string filename(StringValuePtr(_filename));
645
- get_hnsw_bruteforcesearch(self)->saveIndex(filename);
646
- RB_GC_GUARD(_filename);
587
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
588
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
647
589
  return Qnil;
648
- };
649
-
650
- static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
651
- std::string filename(StringValuePtr(_filename));
652
- VALUE ivspace = rb_iv_get(self, "@space");
653
- hnswlib::SpaceInterface<float>* space;
654
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
655
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
656
- } else {
657
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
658
- }
659
- hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
660
- if (index->data_) free(index->data_);
661
- try {
662
- index->loadIndex(filename, space);
663
- } catch(const std::runtime_error& e) {
664
- rb_raise(rb_eRuntimeError, "%s", e.what());
665
- return Qnil;
666
- }
667
- RB_GC_GUARD(_filename);
590
+ }
591
+ if (!RB_INTEGER_TYPE_P(k)) {
592
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
668
593
  return Qnil;
669
- };
670
-
671
- static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
672
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
594
+ }
595
+ if (dim != RARRAY_LEN(arr)) {
596
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
673
597
  return Qnil;
674
- };
675
-
676
- static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
677
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
678
- };
679
-
680
- static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
681
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
682
- };
598
+ }
599
+
600
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
601
+ for (int i = 0; i < dim; i++) {
602
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
603
+ }
604
+
605
+ std::priority_queue<std::pair<float, size_t>> result =
606
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
607
+
608
+ ruby_xfree(vec);
609
+
610
+ if (result.size() != (size_t)NUM2INT(k)) {
611
+ rb_warning("Cannot return as many search results as the requested number of neighbors.");
612
+ }
613
+
614
+ VALUE distances_arr = rb_ary_new2(result.size());
615
+ VALUE neighbors_arr = rb_ary_new2(result.size());
616
+
617
+ while (!result.empty()) {
618
+ const std::pair<float, size_t>& result_tuple = result.top();
619
+ rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
620
+ rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
621
+ result.pop();
622
+ }
623
+
624
+ VALUE ret = rb_ary_new2(2);
625
+ rb_ary_store(ret, 0, neighbors_arr);
626
+ rb_ary_store(ret, 1, distances_arr);
627
+ return ret;
628
+ };
629
+
630
+ static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
631
+ std::string filename(StringValuePtr(_filename));
632
+ get_hnsw_bruteforcesearch(self)->saveIndex(filename);
633
+ RB_GC_GUARD(_filename);
634
+ return Qnil;
635
+ };
636
+
637
+ static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
638
+ std::string filename(StringValuePtr(_filename));
639
+ VALUE ivspace = rb_iv_get(self, "@space");
640
+ hnswlib::SpaceInterface<float>* space;
641
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
642
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
643
+ } else {
644
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
645
+ }
646
+ hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
647
+ if (index->data_) {
648
+ free(index->data_);
649
+ index->data_ = nullptr;
650
+ }
651
+ try {
652
+ index->loadIndex(filename, space);
653
+ } catch (const std::runtime_error& e) {
654
+ rb_raise(rb_eRuntimeError, "%s", e.what());
655
+ return Qnil;
656
+ }
657
+ RB_GC_GUARD(_filename);
658
+ return Qnil;
659
+ };
660
+
661
+ static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
662
+ get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
663
+ return Qnil;
664
+ };
665
+
666
+ static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
667
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
668
+ };
669
+
670
+ static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
671
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
672
+ };
683
673
  };
684
674
 
675
+ // clang-format off
685
676
  const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
686
677
  "RbHnswlibBruteforceSearch",
687
678
  {
@@ -693,5 +684,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
693
684
  NULL,
694
685
  RUBY_TYPED_FREE_IMMEDIATELY
695
686
  };
687
+ // clang-format on
696
688
 
697
689
  #endif /* HNSWLIBEXT_HPP */