hnswlib 0.5.3 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,69 +29,68 @@ VALUE rb_cHnswlibHierarchicalNSW;
29
29
  VALUE rb_cHnswlibBruteforceSearch;
30
30
 
31
31
  class RbHnswlibL2Space {
32
- public:
33
- static VALUE hnsw_l2space_alloc(VALUE self) {
34
- hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
- new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
- return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
- };
38
-
39
- static void hnsw_l2space_free(void* ptr) {
40
- ((hnswlib::L2Space*)ptr)->~L2Space();
41
- ruby_xfree(ptr);
42
- };
43
-
44
- static size_t hnsw_l2space_size(const void* ptr) {
45
- return sizeof(*((hnswlib::L2Space*)ptr));
46
- };
47
-
48
- static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
49
- hnswlib::L2Space* ptr;
50
- TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
51
- return ptr;
52
- };
53
-
54
- static VALUE define_class(VALUE rb_mHnswlib) {
55
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
56
- rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
57
- rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
58
- rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
59
- rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
60
- return rb_cHnswlibL2Space;
61
- };
62
-
63
- private:
64
- static const rb_data_type_t hnsw_l2space_type;
65
-
66
- static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
67
- rb_iv_set(self, "@dim", dim);
68
- hnswlib::L2Space* ptr = get_hnsw_l2space(self);
69
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
32
+ public:
33
+ static VALUE hnsw_l2space_alloc(VALUE self) {
34
+ hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
+ new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
+ return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
+ };
38
+
39
+ static void hnsw_l2space_free(void* ptr) {
40
+ ((hnswlib::L2Space*)ptr)->~L2Space();
41
+ ruby_xfree(ptr);
42
+ };
43
+
44
+ static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
45
+
46
+ static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
47
+ hnswlib::L2Space* ptr;
48
+ TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
49
+ return ptr;
50
+ };
51
+
52
+ static VALUE define_class(VALUE rb_mHnswlib) {
53
+ rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
54
+ rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
55
+ rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
56
+ rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
57
+ rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
58
+ return rb_cHnswlibL2Space;
59
+ };
60
+
61
+ private:
62
+ static const rb_data_type_t hnsw_l2space_type;
63
+
64
+ static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
+ rb_iv_set(self, "@dim", dim);
66
+ hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
+ new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
68
+ return Qnil;
69
+ };
70
+
71
+ static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
73
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
74
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
70
75
  return Qnil;
71
- };
72
-
73
- static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
74
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
76
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
77
- return Qnil;
78
- }
79
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
80
- for (int i = 0; i < dim; i++) {
81
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
82
- }
83
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) {
85
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
86
- }
87
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
88
- const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
89
- ruby_xfree(vec_a);
90
- ruby_xfree(vec_b);
91
- return DBL2NUM((double)dist);
92
- };
76
+ }
77
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
78
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
79
+ return Qnil;
80
+ }
81
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
82
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
83
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
85
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
86
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
87
+ ruby_xfree(vec_a);
88
+ ruby_xfree(vec_b);
89
+ return DBL2NUM((double)dist);
90
+ };
93
91
  };
94
92
 
93
+ // clang-format off
95
94
  const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
96
95
  "RbHnswlibL2Space",
97
96
  {
@@ -103,71 +102,71 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
103
102
  NULL,
104
103
  RUBY_TYPED_FREE_IMMEDIATELY
105
104
  };
105
+ // clang-format on
106
106
 
107
107
  class RbHnswlibInnerProductSpace {
108
- public:
109
- static VALUE hnsw_ipspace_alloc(VALUE self) {
110
- hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
111
- new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
112
- return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
113
- };
114
-
115
- static void hnsw_ipspace_free(void* ptr) {
116
- ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
117
- ruby_xfree(ptr);
118
- };
119
-
120
- static size_t hnsw_ipspace_size(const void* ptr) {
121
- return sizeof(*((hnswlib::InnerProductSpace*)ptr));
122
- };
123
-
124
- static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
125
- hnswlib::InnerProductSpace* ptr;
126
- TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
127
- return ptr;
128
- };
129
-
130
- static VALUE define_class(VALUE rb_mHnswlib) {
131
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
132
- rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
133
- rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
134
- rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
135
- rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
136
- return rb_cHnswlibInnerProductSpace;
137
- };
138
-
139
- private:
140
- static const rb_data_type_t hnsw_ipspace_type;
141
-
142
- static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
143
- rb_iv_set(self, "@dim", dim);
144
- hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
145
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
108
+ public:
109
+ static VALUE hnsw_ipspace_alloc(VALUE self) {
110
+ hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
111
+ new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
112
+ return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
113
+ };
114
+
115
+ static void hnsw_ipspace_free(void* ptr) {
116
+ ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
117
+ ruby_xfree(ptr);
118
+ };
119
+
120
+ static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
121
+
122
+ static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
123
+ hnswlib::InnerProductSpace* ptr;
124
+ TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
125
+ return ptr;
126
+ };
127
+
128
+ static VALUE define_class(VALUE rb_mHnswlib) {
129
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
130
+ rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
131
+ rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
132
+ rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
133
+ rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
134
+ return rb_cHnswlibInnerProductSpace;
135
+ };
136
+
137
+ private:
138
+ static const rb_data_type_t hnsw_ipspace_type;
139
+
140
+ static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
141
+ rb_iv_set(self, "@dim", dim);
142
+ hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
143
+ new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
144
+ return Qnil;
145
+ };
146
+
147
+ static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
148
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
149
+ if (!RB_TYPE_P(arr_a, T_ARRAY) || !RB_TYPE_P(arr_b, T_ARRAY)) {
150
+ rb_raise(rb_eArgError, "Expect input vector to be Ruby Array.");
146
151
  return Qnil;
147
- };
148
-
149
- static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
150
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
152
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
153
- return Qnil;
154
- }
155
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
156
- for (int i = 0; i < dim; i++) {
157
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
158
- }
159
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) {
161
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
162
- }
163
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
164
- const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
165
- ruby_xfree(vec_a);
166
- ruby_xfree(vec_b);
167
- return DBL2NUM((double)dist);
168
- };
152
+ }
153
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
154
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
155
+ return Qnil;
156
+ }
157
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
158
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
159
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
161
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
162
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
163
+ ruby_xfree(vec_a);
164
+ ruby_xfree(vec_b);
165
+ return DBL2NUM((double)dist);
166
+ };
169
167
  };
170
168
 
169
+ // clang-format off
171
170
  const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
172
171
  "RbHnswlibInnerProductSpace",
173
172
  {
@@ -179,294 +178,288 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
179
178
  NULL,
180
179
  RUBY_TYPED_FREE_IMMEDIATELY
181
180
  };
181
+ // clang-format on
182
182
 
183
183
  class RbHnswlibHierarchicalNSW {
184
- public:
185
- static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
186
- hnswlib::HierarchicalNSW<float>* ptr = (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
187
- new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
188
- return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
189
- };
190
-
191
- static void hnsw_hierarchicalnsw_free(void* ptr) {
192
- ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
193
- ruby_xfree(ptr);
194
- };
195
-
196
- static size_t hnsw_hierarchicalnsw_size(const void* ptr) {
197
- return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr));
198
- };
199
-
200
- static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
201
- hnswlib::HierarchicalNSW<float>* ptr;
202
- TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
203
- return ptr;
204
- };
205
-
206
- static VALUE define_class(VALUE rb_mHnswlib) {
207
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
208
- rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
211
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
213
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
214
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
215
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
216
- rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
217
- rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
218
- rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
219
- rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
220
- rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
221
- rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
222
- return rb_cHnswlibHierarchicalNSW;
223
- };
224
-
225
- private:
226
- static const rb_data_type_t hnsw_hierarchicalnsw_type;
227
-
228
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
229
- VALUE kw_args = Qnil;
230
- ID kw_table[5] = {
231
- rb_intern("space"),
232
- rb_intern("max_elements"),
233
- rb_intern("m"),
234
- rb_intern("ef_construction"),
235
- rb_intern("random_seed")
236
- };
237
- VALUE kw_values[5] = {
238
- Qundef, Qundef, Qundef, Qundef, Qundef
239
- };
240
- rb_scan_args(argc, argv, ":", &kw_args);
241
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
242
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
243
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
244
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
245
-
246
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
247
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
248
- return Qnil;
249
- }
250
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
251
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
252
- return Qnil;
253
- }
254
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
255
- rb_raise(rb_eTypeError, "expected m, Integer");
256
- return Qnil;
257
- }
258
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
259
- rb_raise(rb_eTypeError, "expected ef_construction, Integer");
260
- return Qnil;
261
- }
262
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
263
- rb_raise(rb_eTypeError, "expected random_seed, Integer");
264
- return Qnil;
265
- }
266
-
267
- rb_iv_set(self, "@space", kw_values[0]);
268
- hnswlib::SpaceInterface<float>* space;
269
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
270
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
271
- } else {
272
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
273
- }
274
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
275
- const size_t m = (size_t)NUM2INT(kw_values[2]);
276
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
277
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
278
-
279
- hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
280
- try {
281
- new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
282
- } catch(const std::runtime_error& e) {
283
- rb_raise(rb_eRuntimeError, "%s", e.what());
284
- return Qnil;
285
- }
286
-
184
+ public:
185
+ static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
186
+ hnswlib::HierarchicalNSW<float>* ptr =
187
+ (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
188
+ new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
189
+ return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
190
+ };
191
+
192
+ static void hnsw_hierarchicalnsw_free(void* ptr) {
193
+ ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
194
+ ruby_xfree(ptr);
195
+ };
196
+
197
+ static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
198
+
199
+ static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
200
+ hnswlib::HierarchicalNSW<float>* ptr;
201
+ TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
202
+ return ptr;
203
+ };
204
+
205
+ static VALUE define_class(VALUE rb_mHnswlib) {
206
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
207
+ rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
208
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
209
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
210
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
211
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
212
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
213
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
214
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
215
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
216
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
217
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
218
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
219
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
220
+ rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
221
+ return rb_cHnswlibHierarchicalNSW;
222
+ };
223
+
224
+ private:
225
+ static const rb_data_type_t hnsw_hierarchicalnsw_type;
226
+
227
+ static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
228
+ VALUE kw_args = Qnil;
229
+ ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
230
+ rb_intern("random_seed")};
231
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
232
+ rb_scan_args(argc, argv, ":", &kw_args);
233
+ rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
234
+ if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
235
+ if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
236
+ if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
237
+
238
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
239
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
240
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
287
241
  return Qnil;
288
- };
289
-
290
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
291
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
292
-
293
- if (!RB_TYPE_P(arr, T_ARRAY)) {
294
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
295
- return Qfalse;
296
- }
297
-
298
- if (!RB_INTEGER_TYPE_P(idx)) {
299
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
300
- return Qfalse;
301
- }
302
-
303
- if (dim != RARRAY_LEN(arr)) {
304
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
305
- return Qfalse;
306
- }
242
+ }
243
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
244
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
245
+ return Qnil;
246
+ }
247
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
248
+ rb_raise(rb_eTypeError, "expected m, Integer");
249
+ return Qnil;
250
+ }
251
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
252
+ rb_raise(rb_eTypeError, "expected ef_construction, Integer");
253
+ return Qnil;
254
+ }
255
+ if (!RB_INTEGER_TYPE_P(kw_values[4])) {
256
+ rb_raise(rb_eTypeError, "expected random_seed, Integer");
257
+ return Qnil;
258
+ }
259
+
260
+ rb_iv_set(self, "@space", kw_values[0]);
261
+ hnswlib::SpaceInterface<float>* space;
262
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
263
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
264
+ } else {
265
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
266
+ }
267
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
268
+ const size_t m = (size_t)NUM2INT(kw_values[2]);
269
+ const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
270
+ const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
271
+
272
+ hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
273
+ try {
274
+ new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
275
+ } catch (const std::runtime_error& e) {
276
+ rb_raise(rb_eRuntimeError, "%s", e.what());
277
+ return Qnil;
278
+ }
307
279
 
308
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
309
- for (int i = 0; i < dim; i++) {
310
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
311
- }
280
+ return Qnil;
281
+ };
312
282
 
313
- get_hnsw_hierarchicalnsw(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
283
+ static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
284
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
314
285
 
315
- ruby_xfree(vec);
316
- return Qtrue;
317
- };
286
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
287
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
288
+ return Qfalse;
289
+ }
290
+ if (!RB_INTEGER_TYPE_P(idx)) {
291
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
292
+ return Qfalse;
293
+ }
294
+ if (dim != RARRAY_LEN(arr)) {
295
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
296
+ return Qfalse;
297
+ }
318
298
 
319
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
320
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
299
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
300
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
321
301
 
322
- if (!RB_TYPE_P(arr, T_ARRAY)) {
323
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
324
- return Qnil;
325
- }
302
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
326
303
 
327
- if (!RB_INTEGER_TYPE_P(k)) {
328
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
329
- return Qnil;
330
- }
304
+ ruby_xfree(vec);
305
+ return Qtrue;
306
+ };
331
307
 
332
- if (dim != RARRAY_LEN(arr)) {
333
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
334
- return Qnil;
335
- }
308
+ static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
309
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
336
310
 
337
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
338
- for (int i = 0; i < dim; i++) {
339
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
340
- }
311
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
312
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
313
+ return Qnil;
314
+ }
315
+ if (!RB_INTEGER_TYPE_P(k)) {
316
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
317
+ return Qnil;
318
+ }
319
+ if (dim != RARRAY_LEN(arr)) {
320
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
321
+ return Qnil;
322
+ }
341
323
 
342
- std::priority_queue<std::pair<float, size_t>> result;
343
- try {
344
- result = get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
345
- } catch(const std::runtime_error& e) {
346
- ruby_xfree(vec);
347
- rb_raise(rb_eRuntimeError, "%s", e.what());
348
- return Qnil;
349
- }
324
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
325
+ for (int i = 0; i < dim; i++) {
326
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
327
+ }
350
328
 
329
+ std::priority_queue<std::pair<float, size_t>> result;
330
+ try {
331
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
332
+ } catch (const std::runtime_error& e) {
351
333
  ruby_xfree(vec);
352
-
353
- if (result.size() != (size_t)NUM2INT(k)) {
354
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
355
- return Qnil;
356
- }
357
-
358
- VALUE distances_arr = rb_ary_new2(result.size());
359
- VALUE neighbors_arr = rb_ary_new2(result.size());
360
-
361
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
362
- const std::pair<float, size_t>& result_tuple = result.top();
363
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
364
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
365
- result.pop();
366
- }
367
-
368
- VALUE ret = rb_ary_new2(2);
369
- rb_ary_store(ret, 0, neighbors_arr);
370
- rb_ary_store(ret, 1, distances_arr);
371
- return ret;
372
- };
373
-
374
- static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
375
- std::string filename(StringValuePtr(_filename));
376
- get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
377
- RB_GC_GUARD(_filename);
334
+ rb_raise(rb_eRuntimeError, "%s", e.what());
378
335
  return Qnil;
379
- };
380
-
381
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
382
- std::string filename(StringValuePtr(_filename));
383
- VALUE ivspace = rb_iv_get(self, "@space");
384
- hnswlib::SpaceInterface<float>* space;
385
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
386
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
387
- } else {
388
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
389
- }
390
- hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
391
- if (index->data_level0_memory_) free(index->data_level0_memory_);
392
- if (index->linkLists_) {
393
- for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
394
- if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
336
+ }
337
+
338
+ ruby_xfree(vec);
339
+
340
+ if (result.size() != (size_t)NUM2INT(k)) {
341
+ rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
342
+ }
343
+
344
+ VALUE distances_arr = rb_ary_new();
345
+ VALUE neighbors_arr = rb_ary_new();
346
+
347
+ while (!result.empty()) {
348
+ const std::pair<float, size_t>& result_tuple = result.top();
349
+ rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
350
+ rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
351
+ result.pop();
352
+ }
353
+
354
+ VALUE ret = rb_ary_new2(2);
355
+ rb_ary_store(ret, 0, neighbors_arr);
356
+ rb_ary_store(ret, 1, distances_arr);
357
+ return ret;
358
+ };
359
+
360
+ static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
361
+ std::string filename(StringValuePtr(_filename));
362
+ get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
363
+ RB_GC_GUARD(_filename);
364
+ return Qnil;
365
+ };
366
+
367
+ static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
368
+ std::string filename(StringValuePtr(_filename));
369
+ VALUE ivspace = rb_iv_get(self, "@space");
370
+ hnswlib::SpaceInterface<float>* space;
371
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
372
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
373
+ } else {
374
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
375
+ }
376
+ hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
377
+ if (index->data_level0_memory_) {
378
+ free(index->data_level0_memory_);
379
+ index->data_level0_memory_ = nullptr;
380
+ }
381
+ if (index->linkLists_) {
382
+ for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
383
+ if (index->element_levels_[i] > 0 && index->linkLists_[i]) {
384
+ free(index->linkLists_[i]);
385
+ index->linkLists_[i] = nullptr;
395
386
  }
396
- free(index->linkLists_);
397
387
  }
398
- if (index->visited_list_pool_) delete index->visited_list_pool_;
399
- try {
400
- index->loadIndex(filename, space);
401
- } catch(const std::runtime_error& e) {
402
- rb_raise(rb_eRuntimeError, "%s", e.what());
403
- return Qnil;
404
- }
405
- RB_GC_GUARD(_filename);
388
+ free(index->linkLists_);
389
+ index->linkLists_ = nullptr;
390
+ }
391
+ if (index->visited_list_pool_) {
392
+ delete index->visited_list_pool_;
393
+ index->visited_list_pool_ = nullptr;
394
+ }
395
+ try {
396
+ index->loadIndex(filename, space);
397
+ } catch (const std::runtime_error& e) {
398
+ rb_raise(rb_eRuntimeError, "%s", e.what());
406
399
  return Qnil;
407
- };
408
-
409
- static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
410
- VALUE ret = Qnil;
411
- try {
412
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
413
- ret = rb_ary_new2(vec.size());
414
- for (size_t i = 0; i < vec.size(); i++) {
415
- rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
416
- }
417
- } catch(const std::runtime_error& e) {
418
- rb_raise(rb_eRuntimeError, "%s", e.what());
419
- return Qnil;
420
- }
421
- return ret;
422
- };
423
-
424
- static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
425
- VALUE ret = rb_ary_new();
426
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) {
427
- rb_ary_push(ret, INT2NUM((int)kv.first));
428
- }
429
- return ret;
430
- };
431
-
432
- static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
433
- try {
434
- get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
435
- } catch(const std::runtime_error& e) {
436
- rb_raise(rb_eRuntimeError, "%s", e.what());
437
- return Qnil;
438
- }
400
+ }
401
+ RB_GC_GUARD(_filename);
402
+ return Qnil;
403
+ };
404
+
405
+ static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
406
+ VALUE ret = Qnil;
407
+ try {
408
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
409
+ ret = rb_ary_new2(vec.size());
410
+ for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
411
+ } catch (const std::runtime_error& e) {
412
+ rb_raise(rb_eRuntimeError, "%s", e.what());
439
413
  return Qnil;
440
- };
441
-
442
- static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
443
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
444
- rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
445
- return Qnil;
446
- }
447
- try {
448
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
449
- } catch(const std::runtime_error& e) {
450
- rb_raise(rb_eRuntimeError, "%s", e.what());
451
- return Qnil;
452
- }
414
+ }
415
+ return ret;
416
+ };
417
+
418
+ static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
419
+ VALUE ret = rb_ary_new();
420
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
421
+ return ret;
422
+ };
423
+
424
+ static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
425
+ try {
426
+ get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
427
+ } catch (const std::runtime_error& e) {
428
+ rb_raise(rb_eRuntimeError, "%s", e.what());
453
429
  return Qnil;
454
- };
430
+ }
431
+ return Qnil;
432
+ };
455
433
 
456
- static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
457
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
434
+ static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
435
+ if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
436
+ rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
458
437
  return Qnil;
459
- };
460
-
461
- static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
462
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
463
- };
464
-
465
- static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
466
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
467
- };
438
+ }
439
+ try {
440
+ get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
441
+ } catch (const std::runtime_error& e) {
442
+ rb_raise(rb_eRuntimeError, "%s", e.what());
443
+ return Qnil;
444
+ }
445
+ return Qnil;
446
+ };
447
+
448
+ static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
449
+ get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
450
+ return Qnil;
451
+ };
452
+
453
+ static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
454
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
455
+ };
456
+
457
+ static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
458
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
459
+ };
468
460
  };
469
461
 
462
+ // clang-format off
470
463
  const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
471
464
  "RbHnswlibHierarchicalNSW",
472
465
  {
@@ -478,210 +471,208 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
478
471
  NULL,
479
472
  RUBY_TYPED_FREE_IMMEDIATELY
480
473
  };
474
+ // clang-format on
481
475
 
482
476
  class RbHnswlibBruteforceSearch {
483
- public:
484
- static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
485
- hnswlib::BruteforceSearch<float>* ptr = (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
486
- new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
487
- return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
488
- };
489
-
490
- static void hnsw_bruteforcesearch_free(void* ptr) {
491
- ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
492
- ruby_xfree(ptr);
493
- };
494
-
495
- static size_t hnsw_bruteforcesearch_size(const void* ptr) {
496
- return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr));
497
- };
498
-
499
- static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
500
- hnswlib::BruteforceSearch<float>* ptr;
501
- TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
502
- return ptr;
503
- };
504
-
505
- static VALUE define_class(VALUE rb_mHnswlib) {
506
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
507
- rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
508
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
509
- rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
510
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
511
- rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
512
- rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
513
- rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
514
- rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
515
- rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
516
- rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
517
- return rb_cHnswlibBruteforceSearch;
518
- };
519
-
520
- private:
521
- static const rb_data_type_t hnsw_bruteforcesearch_type;
522
-
523
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
524
- VALUE kw_args = Qnil;
525
- ID kw_table[2] = { rb_intern("space"), rb_intern("max_elements") };
526
- VALUE kw_values[2] = { Qundef, Qundef };
527
- rb_scan_args(argc, argv, ":", &kw_args);
528
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
529
-
530
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
531
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
532
- return Qnil;
533
- }
534
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
535
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
536
- return Qnil;
537
- }
538
-
539
- rb_iv_set(self, "@space", kw_values[0]);
540
- hnswlib::SpaceInterface<float>* space;
541
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
542
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
543
- } else {
544
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
545
- }
546
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
547
-
548
- hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
549
- try {
550
- new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
551
- } catch(const std::runtime_error& e) {
552
- rb_raise(rb_eRuntimeError, "%s", e.what());
553
- return Qnil;
554
- }
555
-
477
+ public:
478
+ static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
479
+ hnswlib::BruteforceSearch<float>* ptr =
480
+ (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
481
+ new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
482
+ return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
483
+ };
484
+
485
+ static void hnsw_bruteforcesearch_free(void* ptr) {
486
+ ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
487
+ ruby_xfree(ptr);
488
+ };
489
+
490
+ static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
491
+
492
+ static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
493
+ hnswlib::BruteforceSearch<float>* ptr;
494
+ TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
495
+ return ptr;
496
+ };
497
+
498
+ static VALUE define_class(VALUE rb_mHnswlib) {
499
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
500
+ rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
501
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
502
+ rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
503
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
504
+ rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
505
+ rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
506
+ rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
507
+ rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
508
+ rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
509
+ rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
510
+ return rb_cHnswlibBruteforceSearch;
511
+ };
512
+
513
+ private:
514
+ static const rb_data_type_t hnsw_bruteforcesearch_type;
515
+
516
+ static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
517
+ VALUE kw_args = Qnil;
518
+ ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
519
+ VALUE kw_values[2] = {Qundef, Qundef};
520
+ rb_scan_args(argc, argv, ":", &kw_args);
521
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
522
+
523
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
524
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
525
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
556
526
  return Qnil;
557
- };
558
-
559
- static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
560
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
561
-
562
- if (!RB_TYPE_P(arr, T_ARRAY)) {
563
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
564
- return Qfalse;
565
- }
566
-
567
- if (!RB_INTEGER_TYPE_P(idx)) {
568
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
569
- return Qfalse;
570
- }
571
-
572
- if (dim != RARRAY_LEN(arr)) {
573
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
574
- return Qfalse;
575
- }
576
-
577
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
578
- for (int i = 0; i < dim; i++) {
579
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
580
- }
581
-
582
- try {
583
- get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
584
- } catch(const std::runtime_error& e) {
585
- ruby_xfree(vec);
586
- rb_raise(rb_eRuntimeError, "%s", e.what());
587
- return Qfalse;
588
- }
589
-
590
- ruby_xfree(vec);
591
- return Qtrue;
592
- };
593
-
594
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
595
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
596
-
597
- if (!RB_TYPE_P(arr, T_ARRAY)) {
598
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
599
- return Qnil;
600
- }
601
-
602
- if (!RB_INTEGER_TYPE_P(k)) {
603
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
604
- return Qnil;
605
- }
606
-
607
- if (dim != RARRAY_LEN(arr)) {
608
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
609
- return Qnil;
610
- }
611
-
612
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
613
- for (int i = 0; i < dim; i++) {
614
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
615
- }
616
-
617
- std::priority_queue<std::pair<float, size_t>> result =
618
- get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
619
-
527
+ }
528
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
529
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
530
+ return Qnil;
531
+ }
532
+
533
+ rb_iv_set(self, "@space", kw_values[0]);
534
+ hnswlib::SpaceInterface<float>* space;
535
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
536
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
537
+ } else {
538
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
539
+ }
540
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
541
+
542
+ hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
543
+ try {
544
+ new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
545
+ } catch (const std::runtime_error& e) {
546
+ rb_raise(rb_eRuntimeError, "%s", e.what());
547
+ return Qnil;
548
+ }
549
+
550
+ return Qnil;
551
+ };
552
+
553
+ static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
554
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
555
+
556
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
557
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
558
+ return Qfalse;
559
+ }
560
+ if (!RB_INTEGER_TYPE_P(idx)) {
561
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
562
+ return Qfalse;
563
+ }
564
+ if (dim != RARRAY_LEN(arr)) {
565
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
566
+ return Qfalse;
567
+ }
568
+
569
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
570
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
571
+
572
+ try {
573
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
574
+ } catch (const std::runtime_error& e) {
620
575
  ruby_xfree(vec);
576
+ rb_raise(rb_eRuntimeError, "%s", e.what());
577
+ return Qfalse;
578
+ }
621
579
 
622
- if (result.size() != (size_t)NUM2INT(k)) {
623
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
624
- return Qnil;
625
- }
626
-
627
- VALUE distances_arr = rb_ary_new2(result.size());
628
- VALUE neighbors_arr = rb_ary_new2(result.size());
629
-
630
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
631
- const std::pair<float, size_t>& result_tuple = result.top();
632
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
633
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
634
- result.pop();
635
- }
580
+ ruby_xfree(vec);
581
+ return Qtrue;
582
+ };
636
583
 
637
- VALUE ret = rb_ary_new2(2);
638
- rb_ary_store(ret, 0, neighbors_arr);
639
- rb_ary_store(ret, 1, distances_arr);
640
- return ret;
641
- };
584
+ static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
585
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
642
586
 
643
- static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
644
- std::string filename(StringValuePtr(_filename));
645
- get_hnsw_bruteforcesearch(self)->saveIndex(filename);
646
- RB_GC_GUARD(_filename);
587
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
588
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
647
589
  return Qnil;
648
- };
649
-
650
- static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
651
- std::string filename(StringValuePtr(_filename));
652
- VALUE ivspace = rb_iv_get(self, "@space");
653
- hnswlib::SpaceInterface<float>* space;
654
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
655
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
656
- } else {
657
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
658
- }
659
- hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
660
- if (index->data_) free(index->data_);
661
- try {
662
- index->loadIndex(filename, space);
663
- } catch(const std::runtime_error& e) {
664
- rb_raise(rb_eRuntimeError, "%s", e.what());
665
- return Qnil;
666
- }
667
- RB_GC_GUARD(_filename);
590
+ }
591
+ if (!RB_INTEGER_TYPE_P(k)) {
592
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
668
593
  return Qnil;
669
- };
670
-
671
- static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
672
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
594
+ }
595
+ if (dim != RARRAY_LEN(arr)) {
596
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
673
597
  return Qnil;
674
- };
675
-
676
- static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
677
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
678
- };
679
-
680
- static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
681
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
682
- };
598
+ }
599
+
600
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
601
+ for (int i = 0; i < dim; i++) {
602
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
603
+ }
604
+
605
+ std::priority_queue<std::pair<float, size_t>> result =
606
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
607
+
608
+ ruby_xfree(vec);
609
+
610
+ if (result.size() != (size_t)NUM2INT(k)) {
611
+ rb_warning("Cannot return as many search results as the requested number of neighbors.");
612
+ }
613
+
614
+ VALUE distances_arr = rb_ary_new2(result.size());
615
+ VALUE neighbors_arr = rb_ary_new2(result.size());
616
+
617
+ while (!result.empty()) {
618
+ const std::pair<float, size_t>& result_tuple = result.top();
619
+ rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
620
+ rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
621
+ result.pop();
622
+ }
623
+
624
+ VALUE ret = rb_ary_new2(2);
625
+ rb_ary_store(ret, 0, neighbors_arr);
626
+ rb_ary_store(ret, 1, distances_arr);
627
+ return ret;
628
+ };
629
+
630
+ static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
631
+ std::string filename(StringValuePtr(_filename));
632
+ get_hnsw_bruteforcesearch(self)->saveIndex(filename);
633
+ RB_GC_GUARD(_filename);
634
+ return Qnil;
635
+ };
636
+
637
+ static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
638
+ std::string filename(StringValuePtr(_filename));
639
+ VALUE ivspace = rb_iv_get(self, "@space");
640
+ hnswlib::SpaceInterface<float>* space;
641
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
642
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
643
+ } else {
644
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
645
+ }
646
+ hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
647
+ if (index->data_) {
648
+ free(index->data_);
649
+ index->data_ = nullptr;
650
+ }
651
+ try {
652
+ index->loadIndex(filename, space);
653
+ } catch (const std::runtime_error& e) {
654
+ rb_raise(rb_eRuntimeError, "%s", e.what());
655
+ return Qnil;
656
+ }
657
+ RB_GC_GUARD(_filename);
658
+ return Qnil;
659
+ };
660
+
661
+ static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
662
+ get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
663
+ return Qnil;
664
+ };
665
+
666
+ static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
667
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
668
+ };
669
+
670
+ static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
671
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
672
+ };
683
673
  };
684
674
 
675
+ // clang-format off
685
676
  const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
686
677
  "RbHnswlibBruteforceSearch",
687
678
  {
@@ -693,5 +684,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
693
684
  NULL,
694
685
  RUBY_TYPED_FREE_IMMEDIATELY
695
686
  };
687
+ // clang-format on
696
688
 
697
689
  #endif /* HNSWLIBEXT_HPP */