annoy-rb 0.5.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  /**
2
2
  * Annoy.rb is a Ruby binding for the Annoy (Approximate Nearest Neighbors Oh Yeah).
3
3
  *
4
- * Copyright (c) 2020 Atsushi Tatsuma
4
+ * Copyright (c) 2020-2022 Atsushi Tatsuma
5
5
  *
6
6
  * Licensed under the Apache License, Version 2.0 (the "License");
7
7
  * you may not use this file except in compliance with the License.
@@ -22,301 +22,306 @@
22
22
  #include <typeinfo>
23
23
 
24
24
  #include <ruby.h>
25
+
25
26
  #include <annoylib.h>
26
27
  #include <kissrandom.h>
27
28
 
29
+ using namespace Annoy;
30
+
28
31
  #ifdef ANNOYLIB_MULTITHREADED_BUILD
29
- typedef AnnoyIndexMultiThreadedBuildPolicy AnnoyIndexThreadedBuildPolicy;
32
+ typedef AnnoyIndexMultiThreadedBuildPolicy AnnoyIndexThreadedBuildPolicy;
30
33
  #else
31
- typedef AnnoyIndexSingleThreadedBuildPolicy AnnoyIndexThreadedBuildPolicy;
34
+ typedef AnnoyIndexSingleThreadedBuildPolicy AnnoyIndexThreadedBuildPolicy;
32
35
  #endif
33
36
 
34
- typedef AnnoyIndex<int, double, Angular, Kiss64Random, AnnoyIndexThreadedBuildPolicy> AnnoyIndexAngular;
35
- typedef AnnoyIndex<int, double, DotProduct, Kiss64Random, AnnoyIndexThreadedBuildPolicy> AnnoyIndexDotProduct;
36
- typedef AnnoyIndex<int, uint64_t, Hamming, Kiss64Random, AnnoyIndexThreadedBuildPolicy> AnnoyIndexHamming;
37
- typedef AnnoyIndex<int, double, Euclidean, Kiss64Random, AnnoyIndexThreadedBuildPolicy> AnnoyIndexEuclidean;
38
- typedef AnnoyIndex<int, double, Manhattan, Kiss64Random, AnnoyIndexThreadedBuildPolicy> AnnoyIndexManhattan;
39
-
40
- template<class T, typename F> class RbAnnoyIndex
41
- {
42
- public:
43
- static VALUE annoy_index_alloc(VALUE self) {
44
- T* ptr = (T*)ruby_xmalloc(sizeof(T));
45
- new (ptr) T();
46
- return TypedData_Wrap_Struct(self, &annoy_index_type, ptr);
47
- };
48
-
49
- static void annoy_index_free(void* ptr) {
50
- ((T*)ptr)->~AnnoyIndex();
51
- ruby_xfree(ptr);
52
- };
53
-
54
- static size_t annoy_index_size(const void* ptr) {
55
- return sizeof(*((T*)ptr));
56
- };
57
-
58
- static T* get_annoy_index(VALUE self) {
59
- T* ptr;
60
- TypedData_Get_Struct(self, T, &annoy_index_type, ptr);
61
- return ptr;
62
- };
63
-
64
- static VALUE define_class(VALUE rb_mAnnoy, const char* class_name) {
65
- VALUE rb_cAnnoyIndex = rb_define_class_under(rb_mAnnoy, class_name, rb_cObject);
66
- rb_define_alloc_func(rb_cAnnoyIndex, annoy_index_alloc);
67
- rb_define_method(rb_cAnnoyIndex, "initialize", RUBY_METHOD_FUNC(_annoy_index_init), 1);
68
- rb_define_method(rb_cAnnoyIndex, "add_item", RUBY_METHOD_FUNC(_annoy_index_add_item), 2);
69
- rb_define_method(rb_cAnnoyIndex, "build", RUBY_METHOD_FUNC(_annoy_index_build), 2);
70
- rb_define_method(rb_cAnnoyIndex, "save", RUBY_METHOD_FUNC(_annoy_index_save), 2);
71
- rb_define_method(rb_cAnnoyIndex, "load", RUBY_METHOD_FUNC(_annoy_index_load), 2);
72
- rb_define_method(rb_cAnnoyIndex, "unload", RUBY_METHOD_FUNC(_annoy_index_unload), 0);
73
- rb_define_method(rb_cAnnoyIndex, "get_nns_by_item", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_item), 4);
74
- rb_define_method(rb_cAnnoyIndex, "get_nns_by_vector", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_vector), 4);
75
- rb_define_method(rb_cAnnoyIndex, "get_item", RUBY_METHOD_FUNC(_annoy_index_get_item), 1);
76
- rb_define_method(rb_cAnnoyIndex, "get_distance", RUBY_METHOD_FUNC(_annoy_index_get_distance), 2);
77
- rb_define_method(rb_cAnnoyIndex, "get_n_items", RUBY_METHOD_FUNC(_annoy_index_get_n_items), 0);
78
- rb_define_method(rb_cAnnoyIndex, "get_n_trees", RUBY_METHOD_FUNC(_annoy_index_get_n_trees), 0);
79
- rb_define_method(rb_cAnnoyIndex, "on_disk_build", RUBY_METHOD_FUNC(_annoy_index_on_disk_build), 1);
80
- rb_define_method(rb_cAnnoyIndex, "set_seed", RUBY_METHOD_FUNC(_annoy_index_set_seed), 1);
81
- rb_define_method(rb_cAnnoyIndex, "verbose", RUBY_METHOD_FUNC(_annoy_index_verbose), 1);
82
- rb_define_method(rb_cAnnoyIndex, "get_f", RUBY_METHOD_FUNC(_annoy_index_get_f), 0);
83
- return rb_cAnnoyIndex;
84
- };
85
-
86
- private:
87
- static const rb_data_type_t annoy_index_type;
88
-
89
- static VALUE _annoy_index_init(VALUE self, VALUE _n_dims) {
90
- const int n_dims = NUM2INT(_n_dims);
91
- T* ptr = get_annoy_index(self);
92
- new (ptr) T(n_dims);
93
- return Qnil;
94
- };
95
-
96
- static VALUE _annoy_index_add_item(VALUE self, VALUE _idx, VALUE arr) {
97
- const int idx = NUM2INT(_idx);
98
- const int n_dims = get_annoy_index(self)->get_f();
99
-
100
- if (!RB_TYPE_P(arr, T_ARRAY)) {
101
- rb_raise(rb_eArgError, "Expect item vector to be Array.");
102
- return Qfalse;
103
- }
104
-
105
- if (n_dims != RARRAY_LEN(arr)) {
106
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
107
- return Qfalse;
108
- }
109
-
110
- F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
111
- for (int i = 0; i < n_dims; i++) {
112
- vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(arr, i)) : NUM2UINT(rb_ary_entry(arr, i));
113
- }
114
-
115
- char* error;
116
- if (!get_annoy_index(self)->add_item(idx, vec, &error)) {
117
- VALUE error_str = rb_str_new_cstr(error);
118
- free(error);
119
- ruby_xfree(vec);
120
- rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
121
- return Qfalse;
122
- }
123
-
37
+ // clang-format off
38
+ template<typename F> using AnnoyIndexAngular = AnnoyIndex<int32_t, F, Angular, Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
39
+ template<typename F> using AnnoyIndexDotProduct = AnnoyIndex<int32_t, F, DotProduct, Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
40
+ template<typename F> using AnnoyIndexHamming = AnnoyIndex<int32_t, F, Hamming, Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
41
+ template<typename F> using AnnoyIndexEuclidean = AnnoyIndex<int32_t, F, Euclidean, Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
42
+ template<typename F> using AnnoyIndexManhattan = AnnoyIndex<int32_t, F, Manhattan, Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
43
+ // clang-format on
44
+
45
+ template <class T, typename F> class RbAnnoyIndex {
46
+ public:
47
+ static VALUE annoy_index_alloc(VALUE self) {
48
+ T* ptr = (T*)ruby_xmalloc(sizeof(T));
49
+ new (ptr) T();
50
+ return TypedData_Wrap_Struct(self, &annoy_index_type, ptr);
51
+ };
52
+
53
+ static void annoy_index_free(void* ptr) {
54
+ ((T*)ptr)->~AnnoyIndex();
55
+ ruby_xfree(ptr);
56
+ };
57
+
58
+ static size_t annoy_index_size(const void* ptr) { return sizeof(*((T*)ptr)); };
59
+
60
+ static T* get_annoy_index(VALUE self) {
61
+ T* ptr;
62
+ TypedData_Get_Struct(self, T, &annoy_index_type, ptr);
63
+ return ptr;
64
+ };
65
+
66
+ static VALUE define_class(VALUE rb_mAnnoy, const char* class_name) {
67
+ VALUE rb_cAnnoyIndex = rb_define_class_under(rb_mAnnoy, class_name, rb_cObject);
68
+ rb_define_alloc_func(rb_cAnnoyIndex, annoy_index_alloc);
69
+ rb_define_method(rb_cAnnoyIndex, "initialize", RUBY_METHOD_FUNC(_annoy_index_init), 1);
70
+ rb_define_method(rb_cAnnoyIndex, "add_item", RUBY_METHOD_FUNC(_annoy_index_add_item), 2);
71
+ rb_define_method(rb_cAnnoyIndex, "build", RUBY_METHOD_FUNC(_annoy_index_build), 2);
72
+ rb_define_method(rb_cAnnoyIndex, "save", RUBY_METHOD_FUNC(_annoy_index_save), 2);
73
+ rb_define_method(rb_cAnnoyIndex, "load", RUBY_METHOD_FUNC(_annoy_index_load), 2);
74
+ rb_define_method(rb_cAnnoyIndex, "unload", RUBY_METHOD_FUNC(_annoy_index_unload), 0);
75
+ rb_define_method(rb_cAnnoyIndex, "get_nns_by_item", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_item), 4);
76
+ rb_define_method(rb_cAnnoyIndex, "get_nns_by_vector", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_vector), 4);
77
+ rb_define_method(rb_cAnnoyIndex, "get_item", RUBY_METHOD_FUNC(_annoy_index_get_item), 1);
78
+ rb_define_method(rb_cAnnoyIndex, "get_distance", RUBY_METHOD_FUNC(_annoy_index_get_distance), 2);
79
+ rb_define_method(rb_cAnnoyIndex, "get_n_items", RUBY_METHOD_FUNC(_annoy_index_get_n_items), 0);
80
+ rb_define_method(rb_cAnnoyIndex, "get_n_trees", RUBY_METHOD_FUNC(_annoy_index_get_n_trees), 0);
81
+ rb_define_method(rb_cAnnoyIndex, "on_disk_build", RUBY_METHOD_FUNC(_annoy_index_on_disk_build), 1);
82
+ rb_define_method(rb_cAnnoyIndex, "set_seed", RUBY_METHOD_FUNC(_annoy_index_set_seed), 1);
83
+ rb_define_method(rb_cAnnoyIndex, "verbose", RUBY_METHOD_FUNC(_annoy_index_verbose), 1);
84
+ rb_define_method(rb_cAnnoyIndex, "get_f", RUBY_METHOD_FUNC(_annoy_index_get_f), 0);
85
+ return rb_cAnnoyIndex;
86
+ };
87
+
88
+ private:
89
+ static const rb_data_type_t annoy_index_type;
90
+
91
+ static VALUE _annoy_index_init(VALUE self, VALUE _n_dims) {
92
+ const int n_dims = NUM2INT(_n_dims);
93
+ T* ptr = get_annoy_index(self);
94
+ new (ptr) T(n_dims);
95
+ return Qnil;
96
+ };
97
+
98
+ static VALUE _annoy_index_add_item(VALUE self, VALUE _idx, VALUE arr) {
99
+ const int32_t idx = (int32_t)NUM2INT(_idx);
100
+ const int n_dims = get_annoy_index(self)->get_f();
101
+
102
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
103
+ rb_raise(rb_eArgError, "Expect item vector to be Array.");
104
+ return Qfalse;
105
+ }
106
+
107
+ if (n_dims != RARRAY_LEN(arr)) {
108
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
109
+ return Qfalse;
110
+ }
111
+
112
+ F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
113
+ for (int i = 0; i < n_dims; i++) {
114
+ vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(arr, i)) : NUM2UINT(rb_ary_entry(arr, i));
115
+ }
116
+
117
+ char* error;
118
+ if (!get_annoy_index(self)->add_item(idx, vec, &error)) {
119
+ VALUE error_str = rb_str_new_cstr(error);
120
+ free(error);
124
121
  ruby_xfree(vec);
125
- return Qtrue;
126
- };
127
-
128
- static VALUE _annoy_index_build(VALUE self, VALUE _n_trees, VALUE _n_jobs) {
129
- const int n_trees = NUM2INT(_n_trees);
130
- const int n_jobs = NUM2INT(_n_jobs);
131
- char* error;
132
- if (!get_annoy_index(self)->build(n_trees, n_jobs, &error)) {
133
- VALUE error_str = rb_str_new_cstr(error);
134
- free(error);
135
- rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
136
- return Qfalse;
122
+ rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
123
+ return Qfalse;
124
+ }
125
+
126
+ ruby_xfree(vec);
127
+ return Qtrue;
128
+ };
129
+
130
+ static VALUE _annoy_index_build(VALUE self, VALUE _n_trees, VALUE _n_jobs) {
131
+ const int n_trees = NUM2INT(_n_trees);
132
+ const int n_jobs = NUM2INT(_n_jobs);
133
+ char* error;
134
+ if (!get_annoy_index(self)->build(n_trees, n_jobs, &error)) {
135
+ VALUE error_str = rb_str_new_cstr(error);
136
+ free(error);
137
+ rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
138
+ return Qfalse;
139
+ }
140
+ return Qtrue;
141
+ };
142
+
143
+ static VALUE _annoy_index_save(VALUE self, VALUE _filename, VALUE _prefault) {
144
+ const char* filename = StringValuePtr(_filename);
145
+ const bool prefault = _prefault == Qtrue ? true : false;
146
+ char* error;
147
+ if (!get_annoy_index(self)->save(filename, prefault, &error)) {
148
+ VALUE error_str = rb_str_new_cstr(error);
149
+ free(error);
150
+ rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
151
+ return Qfalse;
152
+ }
153
+ RB_GC_GUARD(_filename);
154
+ return Qtrue;
155
+ };
156
+
157
+ static VALUE _annoy_index_load(VALUE self, VALUE _filename, VALUE _prefault) {
158
+ const char* filename = StringValuePtr(_filename);
159
+ const bool prefault = _prefault == Qtrue ? true : false;
160
+ char* error;
161
+ if (!get_annoy_index(self)->load(filename, prefault, &error)) {
162
+ VALUE error_str = rb_str_new_cstr(error);
163
+ free(error);
164
+ rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
165
+ return Qfalse;
166
+ }
167
+ RB_GC_GUARD(_filename);
168
+ return Qtrue;
169
+ };
170
+
171
+ static VALUE _annoy_index_unload(VALUE self) {
172
+ get_annoy_index(self)->unload();
173
+ return Qnil;
174
+ };
175
+
176
+ static VALUE _annoy_index_get_nns_by_item(VALUE self, VALUE _idx, VALUE _n_neighbors, VALUE _search_k,
177
+ VALUE _include_distances) {
178
+ const int32_t idx = (int32_t)NUM2INT(_idx);
179
+ const int n_neighbors = NUM2INT(_n_neighbors);
180
+ const int search_k = NUM2INT(_search_k);
181
+ const bool include_distances = _include_distances == Qtrue ? true : false;
182
+ std::vector<int32_t> neighbors;
183
+ std::vector<F> distances;
184
+
185
+ get_annoy_index(self)->get_nns_by_item(idx, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
186
+
187
+ const int sz_neighbors = neighbors.size();
188
+ VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
189
+
190
+ for (int i = 0; i < sz_neighbors; i++) {
191
+ rb_ary_store(neighbors_arr, i, INT2NUM((int)(neighbors[i])));
192
+ }
193
+
194
+ if (include_distances) {
195
+ const int sz_distances = distances.size();
196
+ VALUE distances_arr = rb_ary_new2(sz_distances);
197
+ for (int i = 0; i < sz_distances; i++) {
198
+ rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
137
199
  }
138
- return Qtrue;
139
- };
140
-
141
- static VALUE _annoy_index_save(VALUE self, VALUE _filename, VALUE _prefault) {
142
- const char* filename = StringValuePtr(_filename);
143
- const bool prefault = _prefault == Qtrue ? true : false;
144
- char* error;
145
- if (!get_annoy_index(self)->save(filename, prefault, &error)) {
146
- VALUE error_str = rb_str_new_cstr(error);
147
- free(error);
148
- rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
149
- return Qfalse;
150
- }
151
- RB_GC_GUARD(_filename);
152
- return Qtrue;
153
- };
154
-
155
- static VALUE _annoy_index_load(VALUE self, VALUE _filename, VALUE _prefault) {
156
- const char* filename = StringValuePtr(_filename);
157
- const bool prefault = _prefault == Qtrue ? true : false;
158
- char* error;
159
- if (!get_annoy_index(self)->load(filename, prefault, &error)) {
160
- VALUE error_str = rb_str_new_cstr(error);
161
- free(error);
162
- rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
163
- return Qfalse;
164
- }
165
- RB_GC_GUARD(_filename);
166
- return Qtrue;
167
- };
168
-
169
- static VALUE _annoy_index_unload(VALUE self) {
170
- get_annoy_index(self)->unload();
171
- return Qnil;
172
- };
173
-
174
- static VALUE _annoy_index_get_nns_by_item(VALUE self, VALUE _idx, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
175
- const int idx = NUM2INT(_idx);
176
- const int n_neighbors = NUM2INT(_n_neighbors);
177
- const int search_k = NUM2INT(_search_k);
178
- const bool include_distances = _include_distances == Qtrue ? true : false;
179
- std::vector<int> neighbors;
180
- std::vector<F> distances;
181
-
182
- get_annoy_index(self)->get_nns_by_item(idx, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
183
-
184
- const int sz_neighbors = neighbors.size();
185
- VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
186
-
187
- for (int i = 0; i < sz_neighbors; i++) {
188
- rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
189
- }
190
-
191
- if (include_distances) {
192
- const int sz_distances = distances.size();
193
- VALUE distances_arr = rb_ary_new2(sz_distances);
194
- for (int i = 0; i < sz_distances; i++) {
195
- rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
196
- }
197
- VALUE res = rb_ary_new2(2);
198
- rb_ary_store(res, 0, neighbors_arr);
199
- rb_ary_store(res, 1, distances_arr);
200
- return res;
201
- }
202
-
203
- return neighbors_arr;
204
- };
205
-
206
- static VALUE _annoy_index_get_nns_by_vector(VALUE self, VALUE _vec, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
207
- const int n_dims = get_annoy_index(self)->get_f();
208
-
209
- if (!RB_TYPE_P(_vec, T_ARRAY)) {
210
- rb_raise(rb_eArgError, "Expect item vector to be Array.");
211
- return Qfalse;
212
- }
213
-
214
- if (n_dims != RARRAY_LEN(_vec)) {
215
- rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
216
- return Qfalse;
217
- }
218
-
219
- F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
220
- for (int i = 0; i < n_dims; i++) {
221
- vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(_vec, i)) : NUM2UINT(rb_ary_entry(_vec, i));
222
- }
223
-
224
- const int n_neighbors = NUM2INT(_n_neighbors);
225
- const int search_k = NUM2INT(_search_k);
226
- const bool include_distances = _include_distances == Qtrue ? true : false;
227
- std::vector<int> neighbors;
228
- std::vector<F> distances;
229
-
230
- get_annoy_index(self)->get_nns_by_vector(vec, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
231
-
232
- ruby_xfree(vec);
233
-
234
- const int sz_neighbors = neighbors.size();
235
- VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
236
-
237
- for (int i = 0; i < sz_neighbors; i++) {
238
- rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
239
- }
240
-
241
- if (include_distances) {
242
- const int sz_distances = distances.size();
243
- VALUE distances_arr = rb_ary_new2(sz_distances);
244
- for (int i = 0; i < sz_distances; i++) {
245
- rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
246
- }
247
- VALUE res = rb_ary_new2(2);
248
- rb_ary_store(res, 0, neighbors_arr);
249
- rb_ary_store(res, 1, distances_arr);
250
- return res;
251
- }
252
-
253
- return neighbors_arr;
254
- };
255
-
256
- static VALUE _annoy_index_get_item(VALUE self, VALUE _idx) {
257
- const int idx = NUM2INT(_idx);
258
- const int n_dims = get_annoy_index(self)->get_f();
259
- F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
260
- VALUE arr = rb_ary_new2(n_dims);
261
-
262
- get_annoy_index(self)->get_item(idx, vec);
263
-
264
- for (int i = 0; i < n_dims; i++) {
265
- rb_ary_store(arr, i, typeid(F) == typeid(double) ? DBL2NUM(vec[i]) : UINT2NUM(vec[i]));
266
- }
267
-
268
- ruby_xfree(vec);
269
- return arr;
270
- };
271
-
272
- static VALUE _annoy_index_get_distance(VALUE self, VALUE _i, VALUE _j) {
273
- const int i = NUM2INT(_i);
274
- const int j = NUM2INT(_j);
275
- const F dist = get_annoy_index(self)->get_distance(i, j);
276
- return typeid(F) == typeid(double) ? DBL2NUM(dist) : UINT2NUM(dist);
277
- };
278
-
279
- static VALUE _annoy_index_get_n_items(VALUE self) {
280
- const int32_t n_items = get_annoy_index(self)->get_n_items();
281
- return INT2NUM(n_items);
282
- };
283
-
284
- static VALUE _annoy_index_get_n_trees(VALUE self) {
285
- const int32_t n_trees = get_annoy_index(self)->get_n_trees();
286
- return INT2NUM(n_trees);
287
- };
288
-
289
- static VALUE _annoy_index_on_disk_build(VALUE self, VALUE _filename) {
290
- const char* filename = StringValuePtr(_filename);
291
- char* error;
292
- if (!get_annoy_index(self)->on_disk_build(filename, &error)) {
293
- VALUE error_str = rb_str_new_cstr(error);
294
- free(error);
295
- rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
296
- return Qfalse;
200
+ VALUE res = rb_ary_new2(2);
201
+ rb_ary_store(res, 0, neighbors_arr);
202
+ rb_ary_store(res, 1, distances_arr);
203
+ return res;
204
+ }
205
+
206
+ return neighbors_arr;
207
+ };
208
+
209
+ static VALUE _annoy_index_get_nns_by_vector(VALUE self, VALUE _vec, VALUE _n_neighbors, VALUE _search_k,
210
+ VALUE _include_distances) {
211
+ const int n_dims = get_annoy_index(self)->get_f();
212
+
213
+ if (!RB_TYPE_P(_vec, T_ARRAY)) {
214
+ rb_raise(rb_eArgError, "Expect item vector to be Array.");
215
+ return Qfalse;
216
+ }
217
+
218
+ if (n_dims != RARRAY_LEN(_vec)) {
219
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
220
+ return Qfalse;
221
+ }
222
+
223
+ F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
224
+ for (int i = 0; i < n_dims; i++) {
225
+ vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(_vec, i)) : NUM2UINT(rb_ary_entry(_vec, i));
226
+ }
227
+
228
+ const int n_neighbors = NUM2INT(_n_neighbors);
229
+ const int search_k = NUM2INT(_search_k);
230
+ const bool include_distances = _include_distances == Qtrue ? true : false;
231
+ std::vector<int32_t> neighbors;
232
+ std::vector<F> distances;
233
+
234
+ get_annoy_index(self)->get_nns_by_vector(vec, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
235
+
236
+ ruby_xfree(vec);
237
+
238
+ const int sz_neighbors = neighbors.size();
239
+ VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
240
+
241
+ for (int i = 0; i < sz_neighbors; i++) {
242
+ rb_ary_store(neighbors_arr, i, INT2NUM((int)(neighbors[i])));
243
+ }
244
+
245
+ if (include_distances) {
246
+ const int sz_distances = distances.size();
247
+ VALUE distances_arr = rb_ary_new2(sz_distances);
248
+ for (int i = 0; i < sz_distances; i++) {
249
+ rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
297
250
  }
298
- RB_GC_GUARD(_filename);
299
- return Qtrue;
300
- };
301
-
302
- static VALUE _annoy_index_set_seed(VALUE self, VALUE _seed) {
303
- const int seed = NUM2INT(_seed);
304
- get_annoy_index(self)->set_seed(seed);
305
- return Qnil;
306
- };
307
-
308
- static VALUE _annoy_index_verbose(VALUE self, VALUE _flag) {
309
- const bool flag = _flag == Qtrue ? true : false;
310
- get_annoy_index(self)->verbose(flag);
311
- return Qnil;
312
- };
313
-
314
- static VALUE _annoy_index_get_f(VALUE self) {
315
- const int32_t f = get_annoy_index(self)->get_f();
316
- return INT2NUM(f);
317
- };
251
+ VALUE res = rb_ary_new2(2);
252
+ rb_ary_store(res, 0, neighbors_arr);
253
+ rb_ary_store(res, 1, distances_arr);
254
+ return res;
255
+ }
256
+
257
+ return neighbors_arr;
258
+ };
259
+
260
+ static VALUE _annoy_index_get_item(VALUE self, VALUE _idx) {
261
+ const int32_t idx = (int32_t)NUM2INT(_idx);
262
+ const int n_dims = get_annoy_index(self)->get_f();
263
+ F* vec = (F*)ruby_xmalloc(n_dims * sizeof(F));
264
+ VALUE arr = rb_ary_new2(n_dims);
265
+
266
+ get_annoy_index(self)->get_item(idx, vec);
267
+
268
+ for (int i = 0; i < n_dims; i++) {
269
+ rb_ary_store(arr, i, typeid(F) == typeid(double) ? DBL2NUM(vec[i]) : UINT2NUM(vec[i]));
270
+ }
271
+
272
+ ruby_xfree(vec);
273
+ return arr;
274
+ };
275
+
276
+ static VALUE _annoy_index_get_distance(VALUE self, VALUE _i, VALUE _j) {
277
+ const int32_t i = (int32_t)NUM2INT(_i);
278
+ const int32_t j = (int32_t)NUM2INT(_j);
279
+ const F dist = get_annoy_index(self)->get_distance(i, j);
280
+ return typeid(F) == typeid(double) ? DBL2NUM(dist) : UINT2NUM(dist);
281
+ };
282
+
283
+ static VALUE _annoy_index_get_n_items(VALUE self) {
284
+ const int32_t n_items = get_annoy_index(self)->get_n_items();
285
+ return INT2NUM(n_items);
286
+ };
287
+
288
+ static VALUE _annoy_index_get_n_trees(VALUE self) {
289
+ const int32_t n_trees = get_annoy_index(self)->get_n_trees();
290
+ return INT2NUM(n_trees);
291
+ };
292
+
293
+ static VALUE _annoy_index_on_disk_build(VALUE self, VALUE _filename) {
294
+ const char* filename = StringValuePtr(_filename);
295
+ char* error;
296
+ if (!get_annoy_index(self)->on_disk_build(filename, &error)) {
297
+ VALUE error_str = rb_str_new_cstr(error);
298
+ free(error);
299
+ rb_raise(rb_eRuntimeError, "%s", StringValuePtr(error_str));
300
+ return Qfalse;
301
+ }
302
+ RB_GC_GUARD(_filename);
303
+ return Qtrue;
304
+ };
305
+
306
+ static VALUE _annoy_index_set_seed(VALUE self, VALUE _seed) {
307
+ const int seed = NUM2INT(_seed);
308
+ get_annoy_index(self)->set_seed(seed);
309
+ return Qnil;
310
+ };
311
+
312
+ static VALUE _annoy_index_verbose(VALUE self, VALUE _flag) {
313
+ const bool flag = _flag == Qtrue ? true : false;
314
+ get_annoy_index(self)->verbose(flag);
315
+ return Qnil;
316
+ };
317
+
318
+ static VALUE _annoy_index_get_f(VALUE self) {
319
+ const int32_t f = get_annoy_index(self)->get_f();
320
+ return INT2NUM(f);
321
+ };
318
322
  };
319
323
 
324
+ // clang-format off
320
325
  template<class T, typename F>
321
326
  const rb_data_type_t RbAnnoyIndex<T, F>::annoy_index_type = {
322
327
  "RbAnnoyIndex",
@@ -329,5 +334,6 @@ const rb_data_type_t RbAnnoyIndex<T, F>::annoy_index_type = {
329
334
  NULL,
330
335
  RUBY_TYPED_FREE_IMMEDIATELY
331
336
  };
337
+ // clang-format on
332
338
 
333
339
  #endif /* ANNOYEXT_HPP */
@@ -187,7 +187,7 @@
187
187
  same "printed page" as the copyright notice for easier
188
188
  identification within third-party archives.
189
189
 
190
- Copyright [yyyy] [name of copyright owner]
190
+ Copyright 2021 (c) Spotify and its affiliates.
191
191
 
192
192
  Licensed under the Apache License, Version 2.0 (the "License");
193
193
  you may not use this file except in compliance with the License.