hnswlib 0.5.1 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -29,69 +29,64 @@ VALUE rb_cHnswlibHierarchicalNSW;
29
29
  VALUE rb_cHnswlibBruteforceSearch;
30
30
 
31
31
  class RbHnswlibL2Space {
32
- public:
33
- static VALUE hnsw_l2space_alloc(VALUE self) {
34
- hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
- new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
- return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
- };
38
-
39
- static void hnsw_l2space_free(void* ptr) {
40
- ((hnswlib::L2Space*)ptr)->~L2Space();
41
- ruby_xfree(ptr);
42
- };
43
-
44
- static size_t hnsw_l2space_size(const void* ptr) {
45
- return sizeof(*((hnswlib::L2Space*)ptr));
46
- };
47
-
48
- static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
49
- hnswlib::L2Space* ptr;
50
- TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
51
- return ptr;
52
- };
53
-
54
- static VALUE define_class(VALUE rb_mHnswlib) {
55
- rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
56
- rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
57
- rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
58
- rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
59
- rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
60
- return rb_cHnswlibL2Space;
61
- };
62
-
63
- private:
64
- static const rb_data_type_t hnsw_l2space_type;
65
-
66
- static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
67
- rb_iv_set(self, "@dim", dim);
68
- hnswlib::L2Space* ptr = get_hnsw_l2space(self);
69
- new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
32
+ public:
33
+ static VALUE hnsw_l2space_alloc(VALUE self) {
34
+ hnswlib::L2Space* ptr = (hnswlib::L2Space*)ruby_xmalloc(sizeof(hnswlib::L2Space));
35
+ new (ptr) hnswlib::L2Space(); // dummy call to constructor for GC.
36
+ return TypedData_Wrap_Struct(self, &hnsw_l2space_type, ptr);
37
+ };
38
+
39
+ static void hnsw_l2space_free(void* ptr) {
40
+ ((hnswlib::L2Space*)ptr)->~L2Space();
41
+ ruby_xfree(ptr);
42
+ };
43
+
44
+ static size_t hnsw_l2space_size(const void* ptr) { return sizeof(*((hnswlib::L2Space*)ptr)); };
45
+
46
+ static hnswlib::L2Space* get_hnsw_l2space(VALUE self) {
47
+ hnswlib::L2Space* ptr;
48
+ TypedData_Get_Struct(self, hnswlib::L2Space, &hnsw_l2space_type, ptr);
49
+ return ptr;
50
+ };
51
+
52
+ static VALUE define_class(VALUE rb_mHnswlib) {
53
+ rb_cHnswlibL2Space = rb_define_class_under(rb_mHnswlib, "L2Space", rb_cObject);
54
+ rb_define_alloc_func(rb_cHnswlibL2Space, hnsw_l2space_alloc);
55
+ rb_define_method(rb_cHnswlibL2Space, "initialize", RUBY_METHOD_FUNC(_hnsw_l2space_init), 1);
56
+ rb_define_method(rb_cHnswlibL2Space, "distance", RUBY_METHOD_FUNC(_hnsw_l2space_distance), 2);
57
+ rb_define_attr(rb_cHnswlibL2Space, "dim", 1, 0);
58
+ return rb_cHnswlibL2Space;
59
+ };
60
+
61
+ private:
62
+ static const rb_data_type_t hnsw_l2space_type;
63
+
64
+ static VALUE _hnsw_l2space_init(VALUE self, VALUE dim) {
65
+ rb_iv_set(self, "@dim", dim);
66
+ hnswlib::L2Space* ptr = get_hnsw_l2space(self);
67
+ new (ptr) hnswlib::L2Space(NUM2INT(rb_iv_get(self, "@dim")));
68
+ return Qnil;
69
+ };
70
+
71
+ static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
72
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
73
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
74
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
70
75
  return Qnil;
71
- };
72
-
73
- static VALUE _hnsw_l2space_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
74
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
75
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
76
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
77
- return Qnil;
78
- }
79
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
80
- for (int i = 0; i < dim; i++) {
81
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
82
- }
83
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
84
- for (int i = 0; i < dim; i++) {
85
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
86
- }
87
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
88
- const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
89
- ruby_xfree(vec_a);
90
- ruby_xfree(vec_b);
91
- return DBL2NUM((double)dist);
92
- };
76
+ }
77
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
78
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
79
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
80
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
81
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_l2space(self)->get_dist_func();
82
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_l2space(self)->get_dist_func_param());
83
+ ruby_xfree(vec_a);
84
+ ruby_xfree(vec_b);
85
+ return DBL2NUM((double)dist);
86
+ };
93
87
  };
94
88
 
89
+ // clang-format off
95
90
  const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
96
91
  "RbHnswlibL2Space",
97
92
  {
@@ -103,71 +98,67 @@ const rb_data_type_t RbHnswlibL2Space::hnsw_l2space_type = {
103
98
  NULL,
104
99
  RUBY_TYPED_FREE_IMMEDIATELY
105
100
  };
101
+ // clang-format on
106
102
 
107
103
  class RbHnswlibInnerProductSpace {
108
- public:
109
- static VALUE hnsw_ipspace_alloc(VALUE self) {
110
- hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
111
- new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
112
- return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
113
- };
114
-
115
- static void hnsw_ipspace_free(void* ptr) {
116
- ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
117
- ruby_xfree(ptr);
118
- };
119
-
120
- static size_t hnsw_ipspace_size(const void* ptr) {
121
- return sizeof(*((hnswlib::InnerProductSpace*)ptr));
122
- };
123
-
124
- static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
125
- hnswlib::InnerProductSpace* ptr;
126
- TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
127
- return ptr;
128
- };
129
-
130
- static VALUE define_class(VALUE rb_mHnswlib) {
131
- rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
132
- rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
133
- rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
134
- rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
135
- rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
136
- return rb_cHnswlibInnerProductSpace;
137
- };
138
-
139
- private:
140
- static const rb_data_type_t hnsw_ipspace_type;
141
-
142
- static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
143
- rb_iv_set(self, "@dim", dim);
144
- hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
145
- new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
104
+ public:
105
+ static VALUE hnsw_ipspace_alloc(VALUE self) {
106
+ hnswlib::InnerProductSpace* ptr = (hnswlib::InnerProductSpace*)ruby_xmalloc(sizeof(hnswlib::InnerProductSpace));
107
+ new (ptr) hnswlib::InnerProductSpace(); // dummy call to constructor for GC.
108
+ return TypedData_Wrap_Struct(self, &hnsw_ipspace_type, ptr);
109
+ };
110
+
111
+ static void hnsw_ipspace_free(void* ptr) {
112
+ ((hnswlib::InnerProductSpace*)ptr)->~InnerProductSpace();
113
+ ruby_xfree(ptr);
114
+ };
115
+
116
+ static size_t hnsw_ipspace_size(const void* ptr) { return sizeof(*((hnswlib::InnerProductSpace*)ptr)); };
117
+
118
+ static hnswlib::InnerProductSpace* get_hnsw_ipspace(VALUE self) {
119
+ hnswlib::InnerProductSpace* ptr;
120
+ TypedData_Get_Struct(self, hnswlib::InnerProductSpace, &hnsw_ipspace_type, ptr);
121
+ return ptr;
122
+ };
123
+
124
+ static VALUE define_class(VALUE rb_mHnswlib) {
125
+ rb_cHnswlibInnerProductSpace = rb_define_class_under(rb_mHnswlib, "InnerProductSpace", rb_cObject);
126
+ rb_define_alloc_func(rb_cHnswlibInnerProductSpace, hnsw_ipspace_alloc);
127
+ rb_define_method(rb_cHnswlibInnerProductSpace, "initialize", RUBY_METHOD_FUNC(_hnsw_ipspace_init), 1);
128
+ rb_define_method(rb_cHnswlibInnerProductSpace, "distance", RUBY_METHOD_FUNC(_hnsw_ipspace_distance), 2);
129
+ rb_define_attr(rb_cHnswlibInnerProductSpace, "dim", 1, 0);
130
+ return rb_cHnswlibInnerProductSpace;
131
+ };
132
+
133
+ private:
134
+ static const rb_data_type_t hnsw_ipspace_type;
135
+
136
+ static VALUE _hnsw_ipspace_init(VALUE self, VALUE dim) {
137
+ rb_iv_set(self, "@dim", dim);
138
+ hnswlib::InnerProductSpace* ptr = get_hnsw_ipspace(self);
139
+ new (ptr) hnswlib::InnerProductSpace(NUM2INT(rb_iv_get(self, "@dim")));
140
+ return Qnil;
141
+ };
142
+
143
+ static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
144
+ const int dim = NUM2INT(rb_iv_get(self, "@dim"));
145
+ if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
146
+ rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
146
147
  return Qnil;
147
- };
148
-
149
- static VALUE _hnsw_ipspace_distance(VALUE self, VALUE arr_a, VALUE arr_b) {
150
- const int dim = NUM2INT(rb_iv_get(self, "@dim"));
151
- if (dim != RARRAY_LEN(arr_a) || dim != RARRAY_LEN(arr_b)) {
152
- rb_raise(rb_eArgError, "Array size does not match to space dimensionality.");
153
- return Qnil;
154
- }
155
- float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
156
- for (int i = 0; i < dim; i++) {
157
- vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
158
- }
159
- float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
160
- for (int i = 0; i < dim; i++) {
161
- vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
162
- }
163
- hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
164
- const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
165
- ruby_xfree(vec_a);
166
- ruby_xfree(vec_b);
167
- return DBL2NUM((double)dist);
168
- };
148
+ }
149
+ float* vec_a = (float*)ruby_xmalloc(dim * sizeof(float));
150
+ for (int i = 0; i < dim; i++) vec_a[i] = (float)NUM2DBL(rb_ary_entry(arr_a, i));
151
+ float* vec_b = (float*)ruby_xmalloc(dim * sizeof(float));
152
+ for (int i = 0; i < dim; i++) vec_b[i] = (float)NUM2DBL(rb_ary_entry(arr_b, i));
153
+ hnswlib::DISTFUNC<float> dist_func = get_hnsw_ipspace(self)->get_dist_func();
154
+ const float dist = dist_func(vec_a, vec_b, get_hnsw_ipspace(self)->get_dist_func_param());
155
+ ruby_xfree(vec_a);
156
+ ruby_xfree(vec_b);
157
+ return DBL2NUM((double)dist);
158
+ };
169
159
  };
170
160
 
161
+ // clang-format off
171
162
  const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
172
163
  "RbHnswlibInnerProductSpace",
173
164
  {
@@ -179,258 +170,278 @@ const rb_data_type_t RbHnswlibInnerProductSpace::hnsw_ipspace_type = {
179
170
  NULL,
180
171
  RUBY_TYPED_FREE_IMMEDIATELY
181
172
  };
173
+ // clang-format on
182
174
 
183
175
  class RbHnswlibHierarchicalNSW {
184
- public:
185
- static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
186
- hnswlib::HierarchicalNSW<float>* ptr = (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
187
- new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
188
- return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
189
- };
190
-
191
- static void hnsw_hierarchicalnsw_free(void* ptr) {
192
- ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
193
- ruby_xfree(ptr);
194
- };
195
-
196
- static size_t hnsw_hierarchicalnsw_size(const void* ptr) {
197
- return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr));
198
- };
199
-
200
- static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
201
- hnswlib::HierarchicalNSW<float>* ptr;
202
- TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
203
- return ptr;
204
- };
205
-
206
- static VALUE define_class(VALUE rb_mHnswlib) {
207
- rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
208
- rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
209
- rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
210
- rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
211
- rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
212
- rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
213
- rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
214
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
215
- rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
216
- rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
217
- rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
218
- rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
219
- rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
220
- rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
221
- rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
222
- return rb_cHnswlibHierarchicalNSW;
223
- };
224
-
225
- private:
226
- static const rb_data_type_t hnsw_hierarchicalnsw_type;
227
-
228
- static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
229
- VALUE kw_args = Qnil;
230
- ID kw_table[5] = {
231
- rb_intern("space"),
232
- rb_intern("max_elements"),
233
- rb_intern("m"),
234
- rb_intern("ef_construction"),
235
- rb_intern("random_seed")
236
- };
237
- VALUE kw_values[5] = {
238
- Qundef, Qundef, Qundef, Qundef, Qundef
239
- };
240
- rb_scan_args(argc, argv, ":", &kw_args);
241
- rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
242
- if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
243
- if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
244
- if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
245
-
246
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
247
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
248
- return Qnil;
249
- }
250
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
251
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
252
- return Qnil;
253
- }
254
- if (!RB_INTEGER_TYPE_P(kw_values[2])) {
255
- rb_raise(rb_eTypeError, "expected m, Integer");
256
- return Qnil;
257
- }
258
- if (!RB_INTEGER_TYPE_P(kw_values[3])) {
259
- rb_raise(rb_eTypeError, "expected ef_construction, Integer");
260
- return Qnil;
261
- }
262
- if (!RB_INTEGER_TYPE_P(kw_values[4])) {
263
- rb_raise(rb_eTypeError, "expected random_seed, Integer");
264
- return Qnil;
265
- }
266
-
267
- rb_iv_set(self, "@space", kw_values[0]);
268
- hnswlib::SpaceInterface<float>* space;
269
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
270
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
271
- } else {
272
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
273
- }
274
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
275
- const size_t m = (size_t)NUM2INT(kw_values[2]);
276
- const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
277
- const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
278
-
279
- hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
176
+ public:
177
+ static VALUE hnsw_hierarchicalnsw_alloc(VALUE self) {
178
+ hnswlib::HierarchicalNSW<float>* ptr =
179
+ (hnswlib::HierarchicalNSW<float>*)ruby_xmalloc(sizeof(hnswlib::HierarchicalNSW<float>));
180
+ new (ptr) hnswlib::HierarchicalNSW<float>(); // dummy call to constructor for GC.
181
+ return TypedData_Wrap_Struct(self, &hnsw_hierarchicalnsw_type, ptr);
182
+ };
183
+
184
+ static void hnsw_hierarchicalnsw_free(void* ptr) {
185
+ ((hnswlib::HierarchicalNSW<float>*)ptr)->~HierarchicalNSW();
186
+ ruby_xfree(ptr);
187
+ };
188
+
189
+ static size_t hnsw_hierarchicalnsw_size(const void* ptr) { return sizeof(*((hnswlib::HierarchicalNSW<float>*)ptr)); };
190
+
191
+ static hnswlib::HierarchicalNSW<float>* get_hnsw_hierarchicalnsw(VALUE self) {
192
+ hnswlib::HierarchicalNSW<float>* ptr;
193
+ TypedData_Get_Struct(self, hnswlib::HierarchicalNSW<float>, &hnsw_hierarchicalnsw_type, ptr);
194
+ return ptr;
195
+ };
196
+
197
+ static VALUE define_class(VALUE rb_mHnswlib) {
198
+ rb_cHnswlibHierarchicalNSW = rb_define_class_under(rb_mHnswlib, "HierarchicalNSW", rb_cObject);
199
+ rb_define_alloc_func(rb_cHnswlibHierarchicalNSW, hnsw_hierarchicalnsw_alloc);
200
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "initialize", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_init), -1);
201
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "add_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_add_point), 2);
202
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "search_knn", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_search_knn), 2);
203
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "save_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_save_index), 1);
204
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "load_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_load_index), 1);
205
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_point", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_point), 1);
206
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "get_ids", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_get_ids), 0);
207
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "mark_deleted", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_mark_deleted), 1);
208
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "resize_index", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_resize_index), 1);
209
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "set_ef", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_set_ef), 1);
210
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "max_elements", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_max_elements), 0);
211
+ rb_define_method(rb_cHnswlibHierarchicalNSW, "current_count", RUBY_METHOD_FUNC(_hnsw_hierarchicalnsw_current_count), 0);
212
+ rb_define_attr(rb_cHnswlibHierarchicalNSW, "space", 1, 0);
213
+ return rb_cHnswlibHierarchicalNSW;
214
+ };
215
+
216
+ private:
217
+ static const rb_data_type_t hnsw_hierarchicalnsw_type;
218
+
219
+ static VALUE _hnsw_hierarchicalnsw_init(int argc, VALUE* argv, VALUE self) {
220
+ VALUE kw_args = Qnil;
221
+ ID kw_table[5] = {rb_intern("space"), rb_intern("max_elements"), rb_intern("m"), rb_intern("ef_construction"),
222
+ rb_intern("random_seed")};
223
+ VALUE kw_values[5] = {Qundef, Qundef, Qundef, Qundef, Qundef};
224
+ rb_scan_args(argc, argv, ":", &kw_args);
225
+ rb_get_kwargs(kw_args, kw_table, 2, 3, kw_values);
226
+ if (kw_values[2] == Qundef) kw_values[2] = INT2NUM(16);
227
+ if (kw_values[3] == Qundef) kw_values[3] = INT2NUM(200);
228
+ if (kw_values[4] == Qundef) kw_values[4] = INT2NUM(100);
229
+
230
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
231
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
232
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
233
+ return Qnil;
234
+ }
235
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
236
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
237
+ return Qnil;
238
+ }
239
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
240
+ rb_raise(rb_eTypeError, "expected m, Integer");
241
+ return Qnil;
242
+ }
243
+ if (!RB_INTEGER_TYPE_P(kw_values[3])) {
244
+ rb_raise(rb_eTypeError, "expected ef_construction, Integer");
245
+ return Qnil;
246
+ }
247
+ if (!RB_INTEGER_TYPE_P(kw_values[4])) {
248
+ rb_raise(rb_eTypeError, "expected random_seed, Integer");
249
+ return Qnil;
250
+ }
251
+
252
+ rb_iv_set(self, "@space", kw_values[0]);
253
+ hnswlib::SpaceInterface<float>* space;
254
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
255
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
256
+ } else {
257
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
258
+ }
259
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
260
+ const size_t m = (size_t)NUM2INT(kw_values[2]);
261
+ const size_t ef_construction = (size_t)NUM2INT(kw_values[3]);
262
+ const size_t random_seed = (size_t)NUM2INT(kw_values[4]);
263
+
264
+ hnswlib::HierarchicalNSW<float>* ptr = get_hnsw_hierarchicalnsw(self);
265
+ try {
280
266
  new (ptr) hnswlib::HierarchicalNSW<float>(space, max_elements, m, ef_construction, random_seed);
281
-
267
+ } catch (const std::runtime_error& e) {
268
+ rb_raise(rb_eRuntimeError, "%s", e.what());
282
269
  return Qnil;
283
- };
284
-
285
- static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
286
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
287
-
288
- if (!RB_TYPE_P(arr, T_ARRAY)) {
289
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
290
- return Qfalse;
291
- }
292
-
293
- if (!RB_INTEGER_TYPE_P(idx)) {
294
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
295
- return Qfalse;
296
- }
270
+ }
297
271
 
298
- if (dim != RARRAY_LEN(arr)) {
299
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
300
- return Qfalse;
301
- }
302
-
303
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
304
- for (int i = 0; i < dim; i++) {
305
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
306
- }
272
+ return Qnil;
273
+ };
307
274
 
308
- get_hnsw_hierarchicalnsw(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
275
+ static VALUE _hnsw_hierarchicalnsw_add_point(VALUE self, VALUE arr, VALUE idx) {
276
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
309
277
 
310
- ruby_xfree(vec);
311
- return Qtrue;
312
- };
278
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
279
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
280
+ return Qfalse;
281
+ }
282
+ if (!RB_INTEGER_TYPE_P(idx)) {
283
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
284
+ return Qfalse;
285
+ }
286
+ if (dim != RARRAY_LEN(arr)) {
287
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
288
+ return Qfalse;
289
+ }
313
290
 
314
- static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
315
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
291
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
292
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
316
293
 
317
- if (!RB_TYPE_P(arr, T_ARRAY)) {
318
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
319
- return Qnil;
320
- }
294
+ get_hnsw_hierarchicalnsw(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
321
295
 
322
- if (!RB_INTEGER_TYPE_P(k)) {
323
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
324
- return Qnil;
325
- }
296
+ ruby_xfree(vec);
297
+ return Qtrue;
298
+ };
326
299
 
327
- if (dim != RARRAY_LEN(arr)) {
328
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
329
- return Qnil;
330
- }
300
+ static VALUE _hnsw_hierarchicalnsw_search_knn(VALUE self, VALUE arr, VALUE k) {
301
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
331
302
 
332
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
333
- for (int i = 0; i < dim; i++) {
334
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
335
- }
303
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
304
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
305
+ return Qnil;
306
+ }
307
+ if (!RB_INTEGER_TYPE_P(k)) {
308
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
309
+ return Qnil;
310
+ }
311
+ if (dim != RARRAY_LEN(arr)) {
312
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
313
+ return Qnil;
314
+ }
336
315
 
337
- std::priority_queue<std::pair<float, size_t>> result =
338
- get_hnsw_hierarchicalnsw(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
316
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
317
+ for (int i = 0; i < dim; i++) {
318
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
319
+ }
339
320
 
321
+ std::priority_queue<std::pair<float, size_t>> result;
322
+ try {
323
+ result = get_hnsw_hierarchicalnsw(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
324
+ } catch (const std::runtime_error& e) {
340
325
  ruby_xfree(vec);
341
-
342
- if (result.size() != (size_t)NUM2INT(k)) {
343
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
344
- return Qnil;
345
- }
346
-
347
- VALUE distances_arr = rb_ary_new2(result.size());
348
- VALUE neighbors_arr = rb_ary_new2(result.size());
349
-
350
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
351
- const std::pair<float, size_t>& result_tuple = result.top();
352
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
353
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
354
- result.pop();
355
- }
356
-
357
- VALUE ret = rb_ary_new2(2);
358
- rb_ary_store(ret, 0, neighbors_arr);
359
- rb_ary_store(ret, 1, distances_arr);
360
- return ret;
361
- };
362
-
363
- static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
364
- std::string filename(StringValuePtr(_filename));
365
- get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
366
- RB_GC_GUARD(_filename);
326
+ rb_raise(rb_eRuntimeError, "%s", e.what());
367
327
  return Qnil;
368
- };
369
-
370
- static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
371
- std::string filename(StringValuePtr(_filename));
372
- VALUE ivspace = rb_iv_get(self, "@space");
373
- hnswlib::SpaceInterface<float>* space;
374
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
375
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
376
- } else {
377
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
378
- }
379
- get_hnsw_hierarchicalnsw(self)->loadIndex(filename, space);
380
- RB_GC_GUARD(_filename);
328
+ }
329
+
330
+ ruby_xfree(vec);
331
+
332
+ if (result.size() != (size_t)NUM2INT(k)) {
333
+ rb_warning("Cannot return as many search results as the requested number of neighbors. Probably ef or M is too small.");
334
+ }
335
+
336
+ VALUE distances_arr = rb_ary_new();
337
+ VALUE neighbors_arr = rb_ary_new();
338
+
339
+ while (!result.empty()) {
340
+ const std::pair<float, size_t>& result_tuple = result.top();
341
+ rb_ary_unshift(distances_arr, DBL2NUM((double)result_tuple.first));
342
+ rb_ary_unshift(neighbors_arr, INT2NUM((int)result_tuple.second));
343
+ result.pop();
344
+ }
345
+
346
+ VALUE ret = rb_ary_new2(2);
347
+ rb_ary_store(ret, 0, neighbors_arr);
348
+ rb_ary_store(ret, 1, distances_arr);
349
+ return ret;
350
+ };
351
+
352
+ static VALUE _hnsw_hierarchicalnsw_save_index(VALUE self, VALUE _filename) {
353
+ std::string filename(StringValuePtr(_filename));
354
+ get_hnsw_hierarchicalnsw(self)->saveIndex(filename);
355
+ RB_GC_GUARD(_filename);
356
+ return Qnil;
357
+ };
358
+
359
+ static VALUE _hnsw_hierarchicalnsw_load_index(VALUE self, VALUE _filename) {
360
+ std::string filename(StringValuePtr(_filename));
361
+ VALUE ivspace = rb_iv_get(self, "@space");
362
+ hnswlib::SpaceInterface<float>* space;
363
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
364
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
365
+ } else {
366
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
367
+ }
368
+ hnswlib::HierarchicalNSW<float>* index = get_hnsw_hierarchicalnsw(self);
369
+ if (index->data_level0_memory_) free(index->data_level0_memory_);
370
+ if (index->linkLists_) {
371
+ for (hnswlib::tableint i = 0; i < index->cur_element_count; i++) {
372
+ if (index->element_levels_[i] > 0 && index->linkLists_[i]) free(index->linkLists_[i]);
373
+ }
374
+ free(index->linkLists_);
375
+ }
376
+ if (index->visited_list_pool_) delete index->visited_list_pool_;
377
+ try {
378
+ index->loadIndex(filename, space);
379
+ } catch (const std::runtime_error& e) {
380
+ rb_raise(rb_eRuntimeError, "%s", e.what());
381
381
  return Qnil;
382
- };
383
-
384
- static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
385
- VALUE ret = Qnil;
386
- try {
387
- std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
388
- ret = rb_ary_new2(vec.size());
389
- for (size_t i = 0; i < vec.size(); i++) {
390
- rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
391
- }
392
- } catch(std::runtime_error const& e) {
393
- rb_raise(rb_eRuntimeError, "%s", e.what());
394
- }
395
- return ret;
396
- };
397
-
398
- static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
399
- VALUE ret = rb_ary_new();
400
- for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) {
401
- rb_ary_push(ret, INT2NUM((int)kv.first));
402
- }
403
- return ret;
404
- };
405
-
406
- static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
382
+ }
383
+ RB_GC_GUARD(_filename);
384
+ return Qnil;
385
+ };
386
+
387
+ static VALUE _hnsw_hierarchicalnsw_get_point(VALUE self, VALUE idx) {
388
+ VALUE ret = Qnil;
389
+ try {
390
+ std::vector<float> vec = get_hnsw_hierarchicalnsw(self)->template getDataByLabel<float>((size_t)NUM2INT(idx));
391
+ ret = rb_ary_new2(vec.size());
392
+ for (size_t i = 0; i < vec.size(); i++) rb_ary_store(ret, i, DBL2NUM((double)vec[i]));
393
+ } catch (const std::runtime_error& e) {
394
+ rb_raise(rb_eRuntimeError, "%s", e.what());
395
+ return Qnil;
396
+ }
397
+ return ret;
398
+ };
399
+
400
+ static VALUE _hnsw_hierarchicalnsw_get_ids(VALUE self) {
401
+ VALUE ret = rb_ary_new();
402
+ for (auto kv : get_hnsw_hierarchicalnsw(self)->label_lookup_) rb_ary_push(ret, INT2NUM((int)kv.first));
403
+ return ret;
404
+ };
405
+
406
+ static VALUE _hnsw_hierarchicalnsw_mark_deleted(VALUE self, VALUE idx) {
407
+ try {
407
408
  get_hnsw_hierarchicalnsw(self)->markDelete((size_t)NUM2INT(idx));
409
+ } catch (const std::runtime_error& e) {
410
+ rb_raise(rb_eRuntimeError, "%s", e.what());
408
411
  return Qnil;
409
- };
412
+ }
413
+ return Qnil;
414
+ };
410
415
 
411
- static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
412
- if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
413
- rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
414
- return Qnil;
415
- }
416
- get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
416
+ static VALUE _hnsw_hierarchicalnsw_resize_index(VALUE self, VALUE new_max_elements) {
417
+ if ((size_t)NUM2INT(new_max_elements) < get_hnsw_hierarchicalnsw(self)->cur_element_count) {
418
+ rb_raise(rb_eArgError, "Cannot resize, max element is less than the current number of elements.");
417
419
  return Qnil;
418
- };
419
-
420
- static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
421
- get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
420
+ }
421
+ try {
422
+ get_hnsw_hierarchicalnsw(self)->resizeIndex((size_t)NUM2INT(new_max_elements));
423
+ } catch (const std::runtime_error& e) {
424
+ rb_raise(rb_eRuntimeError, "%s", e.what());
422
425
  return Qnil;
423
- };
424
-
425
- static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
426
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
427
- };
428
-
429
- static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
430
- return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
431
- };
426
+ }
427
+ return Qnil;
428
+ };
429
+
430
+ static VALUE _hnsw_hierarchicalnsw_set_ef(VALUE self, VALUE ef) {
431
+ get_hnsw_hierarchicalnsw(self)->ef_ = (size_t)NUM2INT(ef);
432
+ return Qnil;
433
+ };
434
+
435
+ static VALUE _hnsw_hierarchicalnsw_max_elements(VALUE self) {
436
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->max_elements_));
437
+ };
438
+
439
+ static VALUE _hnsw_hierarchicalnsw_current_count(VALUE self) {
440
+ return INT2NUM((int)(get_hnsw_hierarchicalnsw(self)->cur_element_count));
441
+ };
432
442
  };
433
443
 
444
+ // clang-format off
434
445
  const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
435
446
  "RbHnswlibHierarchicalNSW",
436
447
  {
@@ -442,192 +453,206 @@ const rb_data_type_t RbHnswlibHierarchicalNSW::hnsw_hierarchicalnsw_type = {
442
453
  NULL,
443
454
  RUBY_TYPED_FREE_IMMEDIATELY
444
455
  };
456
+ // clang-format on
445
457
 
446
458
  class RbHnswlibBruteforceSearch {
447
- public:
448
- static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
449
- hnswlib::BruteforceSearch<float>* ptr = (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
450
- new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
451
- return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
452
- };
453
-
454
- static void hnsw_bruteforcesearch_free(void* ptr) {
455
- ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
456
- ruby_xfree(ptr);
457
- };
458
-
459
- static size_t hnsw_bruteforcesearch_size(const void* ptr) {
460
- return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr));
461
- };
462
-
463
- static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
464
- hnswlib::BruteforceSearch<float>* ptr;
465
- TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
466
- return ptr;
467
- };
468
-
469
- static VALUE define_class(VALUE rb_mHnswlib) {
470
- rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
471
- rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
472
- rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
473
- rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
474
- rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
475
- rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
476
- rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
477
- rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
478
- rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
479
- rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
480
- rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
481
- return rb_cHnswlibBruteforceSearch;
482
- };
483
-
484
- private:
485
- static const rb_data_type_t hnsw_bruteforcesearch_type;
486
-
487
- static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
488
- VALUE kw_args = Qnil;
489
- ID kw_table[2] = { rb_intern("space"), rb_intern("max_elements") };
490
- VALUE kw_values[2] = { Qundef, Qundef };
491
- rb_scan_args(argc, argv, ":", &kw_args);
492
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
493
-
494
- if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) || rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
495
- rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
496
- return Qnil;
497
- }
498
- if (!RB_INTEGER_TYPE_P(kw_values[1])) {
499
- rb_raise(rb_eTypeError, "expected max_elements, Integer");
500
- return Qnil;
501
- }
502
-
503
- rb_iv_set(self, "@space", kw_values[0]);
504
- hnswlib::SpaceInterface<float>* space;
505
- if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
506
- space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
507
- } else {
508
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
509
- }
510
- const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
511
-
512
- hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
459
+ public:
460
+ static VALUE hnsw_bruteforcesearch_alloc(VALUE self) {
461
+ hnswlib::BruteforceSearch<float>* ptr =
462
+ (hnswlib::BruteforceSearch<float>*)ruby_xmalloc(sizeof(hnswlib::BruteforceSearch<float>));
463
+ new (ptr) hnswlib::BruteforceSearch<float>(); // dummy call to constructor for GC.
464
+ return TypedData_Wrap_Struct(self, &hnsw_bruteforcesearch_type, ptr);
465
+ };
466
+
467
+ static void hnsw_bruteforcesearch_free(void* ptr) {
468
+ ((hnswlib::BruteforceSearch<float>*)ptr)->~BruteforceSearch();
469
+ ruby_xfree(ptr);
470
+ };
471
+
472
+ static size_t hnsw_bruteforcesearch_size(const void* ptr) { return sizeof(*((hnswlib::BruteforceSearch<float>*)ptr)); };
473
+
474
+ static hnswlib::BruteforceSearch<float>* get_hnsw_bruteforcesearch(VALUE self) {
475
+ hnswlib::BruteforceSearch<float>* ptr;
476
+ TypedData_Get_Struct(self, hnswlib::BruteforceSearch<float>, &hnsw_bruteforcesearch_type, ptr);
477
+ return ptr;
478
+ };
479
+
480
+ static VALUE define_class(VALUE rb_mHnswlib) {
481
+ rb_cHnswlibBruteforceSearch = rb_define_class_under(rb_mHnswlib, "BruteforceSearch", rb_cObject);
482
+ rb_define_alloc_func(rb_cHnswlibBruteforceSearch, hnsw_bruteforcesearch_alloc);
483
+ rb_define_method(rb_cHnswlibBruteforceSearch, "initialize", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_init), -1);
484
+ rb_define_method(rb_cHnswlibBruteforceSearch, "add_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_add_point), 2);
485
+ rb_define_method(rb_cHnswlibBruteforceSearch, "search_knn", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_search_knn), 2);
486
+ rb_define_method(rb_cHnswlibBruteforceSearch, "save_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_save_index), 1);
487
+ rb_define_method(rb_cHnswlibBruteforceSearch, "load_index", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_load_index), 1);
488
+ rb_define_method(rb_cHnswlibBruteforceSearch, "remove_point", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_remove_point), 1);
489
+ rb_define_method(rb_cHnswlibBruteforceSearch, "max_elements", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_max_elements), 0);
490
+ rb_define_method(rb_cHnswlibBruteforceSearch, "current_count", RUBY_METHOD_FUNC(_hnsw_bruteforcesearch_current_count), 0);
491
+ rb_define_attr(rb_cHnswlibBruteforceSearch, "space", 1, 0);
492
+ return rb_cHnswlibBruteforceSearch;
493
+ };
494
+
495
+ private:
496
+ static const rb_data_type_t hnsw_bruteforcesearch_type;
497
+
498
+ static VALUE _hnsw_bruteforcesearch_init(int argc, VALUE* argv, VALUE self) {
499
+ VALUE kw_args = Qnil;
500
+ ID kw_table[2] = {rb_intern("space"), rb_intern("max_elements")};
501
+ VALUE kw_values[2] = {Qundef, Qundef};
502
+ rb_scan_args(argc, argv, ":", &kw_args);
503
+ rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
504
+
505
+ if (!(rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space) ||
506
+ rb_obj_is_instance_of(kw_values[0], rb_cHnswlibInnerProductSpace))) {
507
+ rb_raise(rb_eTypeError, "expected space, Hnswlib::L2Space or Hnswlib::InnerProductSpace");
508
+ return Qnil;
509
+ }
510
+ if (!RB_INTEGER_TYPE_P(kw_values[1])) {
511
+ rb_raise(rb_eTypeError, "expected max_elements, Integer");
512
+ return Qnil;
513
+ }
514
+
515
+ rb_iv_set(self, "@space", kw_values[0]);
516
+ hnswlib::SpaceInterface<float>* space;
517
+ if (rb_obj_is_instance_of(kw_values[0], rb_cHnswlibL2Space)) {
518
+ space = RbHnswlibL2Space::get_hnsw_l2space(kw_values[0]);
519
+ } else {
520
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(kw_values[0]);
521
+ }
522
+ const size_t max_elements = (size_t)NUM2INT(kw_values[1]);
523
+
524
+ hnswlib::BruteforceSearch<float>* ptr = get_hnsw_bruteforcesearch(self);
525
+ try {
513
526
  new (ptr) hnswlib::BruteforceSearch<float>(space, max_elements);
514
-
527
+ } catch (const std::runtime_error& e) {
528
+ rb_raise(rb_eRuntimeError, "%s", e.what());
515
529
  return Qnil;
516
- };
517
-
518
- static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
519
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
520
-
521
- if (!RB_TYPE_P(arr, T_ARRAY)) {
522
- rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
523
- return Qfalse;
524
- }
525
-
526
- if (!RB_INTEGER_TYPE_P(idx)) {
527
- rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
528
- return Qfalse;
529
- }
530
-
531
- if (dim != RARRAY_LEN(arr)) {
532
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
533
- return Qfalse;
534
- }
535
-
536
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
537
- for (int i = 0; i < dim; i++) {
538
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
539
- }
540
-
541
- get_hnsw_bruteforcesearch(self)->addPoint((void *)vec, (size_t)NUM2INT(idx));
542
-
530
+ }
531
+
532
+ return Qnil;
533
+ };
534
+
535
+ static VALUE _hnsw_bruteforcesearch_add_point(VALUE self, VALUE arr, VALUE idx) {
536
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
537
+
538
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
539
+ rb_raise(rb_eArgError, "Expect point vector to be Ruby Array.");
540
+ return Qfalse;
541
+ }
542
+ if (!RB_INTEGER_TYPE_P(idx)) {
543
+ rb_raise(rb_eArgError, "Expect index to be Ruby Integer.");
544
+ return Qfalse;
545
+ }
546
+ if (dim != RARRAY_LEN(arr)) {
547
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
548
+ return Qfalse;
549
+ }
550
+
551
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
552
+ for (int i = 0; i < dim; i++) vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
553
+
554
+ try {
555
+ get_hnsw_bruteforcesearch(self)->addPoint((void*)vec, (size_t)NUM2INT(idx));
556
+ } catch (const std::runtime_error& e) {
543
557
  ruby_xfree(vec);
544
- return Qtrue;
545
- };
546
-
547
- static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
548
- const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
549
-
550
- if (!RB_TYPE_P(arr, T_ARRAY)) {
551
- rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
552
- return Qnil;
553
- }
554
-
555
- if (!RB_INTEGER_TYPE_P(k)) {
556
- rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
557
- return Qnil;
558
- }
559
-
560
- if (dim != RARRAY_LEN(arr)) {
561
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
562
- return Qnil;
563
- }
564
-
565
- float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
566
- for (int i = 0; i < dim; i++) {
567
- vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
568
- }
558
+ rb_raise(rb_eRuntimeError, "%s", e.what());
559
+ return Qfalse;
560
+ }
569
561
 
570
- std::priority_queue<std::pair<float, size_t>> result =
571
- get_hnsw_bruteforcesearch(self)->searchKnn((void *)vec, (size_t)NUM2INT(k));
562
+ ruby_xfree(vec);
563
+ return Qtrue;
564
+ };
572
565
 
573
- ruby_xfree(vec);
566
+ static VALUE _hnsw_bruteforcesearch_search_knn(VALUE self, VALUE arr, VALUE k) {
567
+ const int dim = NUM2INT(rb_iv_get(rb_iv_get(self, "@space"), "@dim"));
574
568
 
575
- if (result.size() != (size_t)NUM2INT(k)) {
576
- rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
577
- return Qnil;
578
- }
569
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
570
+ rb_raise(rb_eArgError, "Expect query vector to be Ruby Array.");
571
+ return Qnil;
572
+ }
573
+ if (!RB_INTEGER_TYPE_P(k)) {
574
+ rb_raise(rb_eArgError, "Expect the number of nearest neighbors to be Ruby Integer.");
575
+ return Qnil;
576
+ }
577
+ if (dim != RARRAY_LEN(arr)) {
578
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
579
+ return Qnil;
580
+ }
579
581
 
580
- VALUE distances_arr = rb_ary_new2(result.size());
581
- VALUE neighbors_arr = rb_ary_new2(result.size());
582
+ float* vec = (float*)ruby_xmalloc(dim * sizeof(float));
583
+ for (int i = 0; i < dim; i++) {
584
+ vec[i] = (float)NUM2DBL(rb_ary_entry(arr, i));
585
+ }
582
586
 
583
- for (int i = NUM2INT(k) - 1; i >= 0; i--) {
584
- const std::pair<float, size_t>& result_tuple = result.top();
585
- rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
586
- rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
587
- result.pop();
588
- }
587
+ std::priority_queue<std::pair<float, size_t>> result =
588
+ get_hnsw_bruteforcesearch(self)->searchKnn((void*)vec, (size_t)NUM2INT(k));
589
589
 
590
- VALUE ret = rb_ary_new2(2);
591
- rb_ary_store(ret, 0, neighbors_arr);
592
- rb_ary_store(ret, 1, distances_arr);
593
- return ret;
594
- };
590
+ ruby_xfree(vec);
595
591
 
596
- static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
597
- std::string filename(StringValuePtr(_filename));
598
- get_hnsw_bruteforcesearch(self)->saveIndex(filename);
599
- RB_GC_GUARD(_filename);
600
- return Qnil;
601
- };
602
-
603
- static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
604
- std::string filename(StringValuePtr(_filename));
605
- VALUE ivspace = rb_iv_get(self, "@space");
606
- hnswlib::SpaceInterface<float>* space;
607
- if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
608
- space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
609
- } else {
610
- space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
611
- }
612
- get_hnsw_bruteforcesearch(self)->loadIndex(filename, space);
613
- RB_GC_GUARD(_filename);
592
+ if (result.size() != (size_t)NUM2INT(k)) {
593
+ rb_raise(rb_eRuntimeError, "Cannot return the results in a contigious 2D array. Probably ef or M is too small.");
614
594
  return Qnil;
615
- };
616
-
617
- static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
618
- get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
595
+ }
596
+
597
+ VALUE distances_arr = rb_ary_new2(result.size());
598
+ VALUE neighbors_arr = rb_ary_new2(result.size());
599
+
600
+ for (int i = NUM2INT(k) - 1; i >= 0; i--) {
601
+ const std::pair<float, size_t>& result_tuple = result.top();
602
+ rb_ary_store(distances_arr, i, DBL2NUM((double)result_tuple.first));
603
+ rb_ary_store(neighbors_arr, i, INT2NUM((int)result_tuple.second));
604
+ result.pop();
605
+ }
606
+
607
+ VALUE ret = rb_ary_new2(2);
608
+ rb_ary_store(ret, 0, neighbors_arr);
609
+ rb_ary_store(ret, 1, distances_arr);
610
+ return ret;
611
+ };
612
+
613
+ static VALUE _hnsw_bruteforcesearch_save_index(VALUE self, VALUE _filename) {
614
+ std::string filename(StringValuePtr(_filename));
615
+ get_hnsw_bruteforcesearch(self)->saveIndex(filename);
616
+ RB_GC_GUARD(_filename);
617
+ return Qnil;
618
+ };
619
+
620
+ static VALUE _hnsw_bruteforcesearch_load_index(VALUE self, VALUE _filename) {
621
+ std::string filename(StringValuePtr(_filename));
622
+ VALUE ivspace = rb_iv_get(self, "@space");
623
+ hnswlib::SpaceInterface<float>* space;
624
+ if (rb_obj_is_instance_of(ivspace, rb_cHnswlibL2Space)) {
625
+ space = RbHnswlibL2Space::get_hnsw_l2space(ivspace);
626
+ } else {
627
+ space = RbHnswlibInnerProductSpace::get_hnsw_ipspace(ivspace);
628
+ }
629
+ hnswlib::BruteforceSearch<float>* index = get_hnsw_bruteforcesearch(self);
630
+ if (index->data_) free(index->data_);
631
+ try {
632
+ index->loadIndex(filename, space);
633
+ } catch (const std::runtime_error& e) {
634
+ rb_raise(rb_eRuntimeError, "%s", e.what());
619
635
  return Qnil;
620
- };
621
-
622
- static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
623
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
624
- };
625
-
626
- static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
627
- return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
628
- };
636
+ }
637
+ RB_GC_GUARD(_filename);
638
+ return Qnil;
639
+ };
640
+
641
+ static VALUE _hnsw_bruteforcesearch_remove_point(VALUE self, VALUE idx) {
642
+ get_hnsw_bruteforcesearch(self)->removePoint((size_t)NUM2INT(idx));
643
+ return Qnil;
644
+ };
645
+
646
+ static VALUE _hnsw_bruteforcesearch_max_elements(VALUE self) {
647
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->maxelements_));
648
+ };
649
+
650
+ static VALUE _hnsw_bruteforcesearch_current_count(VALUE self) {
651
+ return INT2NUM((int)(get_hnsw_bruteforcesearch(self)->cur_element_count));
652
+ };
629
653
  };
630
654
 
655
+ // clang-format off
631
656
  const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
632
657
  "RbHnswlibBruteforceSearch",
633
658
  {
@@ -639,5 +664,6 @@ const rb_data_type_t RbHnswlibBruteforceSearch::hnsw_bruteforcesearch_type = {
639
664
  NULL,
640
665
  RUBY_TYPED_FREE_IMMEDIATELY
641
666
  };
667
+ // clang-format on
642
668
 
643
669
  #endif /* HNSWLIBEXT_HPP */