annoy-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,300 @@
1
+ /**
2
+ * Annoy.rb is a Ruby binding for the Annoy (Approximate Nearest Neighbors Oh Yeah).
3
+ *
4
+ * Copyright (c) 2020 Atsushi Tatsuma
5
+ *
6
+ * Licensed under the Apache License, Version 2.0 (the "License");
7
+ * you may not use this file except in compliance with the License.
8
+ * You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing, software
13
+ * distributed under the License is distributed on an "AS IS" BASIS,
14
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ * See the License for the specific language governing permissions and
16
+ * limitations under the License.
17
+ */
18
+
19
+ #ifndef ANNOY_HPP
20
+ #define ANNOY_HPP 1
21
+
22
+ #include <typeinfo>
23
+
24
+ #include <ruby.h>
25
+ #include <annoylib.h>
26
+ #include <kissrandom.h>
27
+
28
+ typedef AnnoyIndex<int, double, Angular, Kiss64Random> AnnoyIndexAngular;
29
+ typedef AnnoyIndex<int, double, DotProduct, Kiss64Random> AnnoyIndexDotProduct;
30
+ typedef AnnoyIndex<int, uint64_t, Hamming, Kiss64Random> AnnoyIndexHamming;
31
+ typedef AnnoyIndex<int, double, Euclidean, Kiss64Random> AnnoyIndexEuclidean;
32
+ typedef AnnoyIndex<int, double, Manhattan, Kiss64Random> AnnoyIndexManhattan;
33
+
34
+ template<class T, typename F> class RbAnnoyIndex
35
+ {
36
+ public:
37
+ static VALUE annoy_index_alloc(VALUE self) {
38
+ T* ptr = (T*)ruby_xmalloc(sizeof(T));
39
+ return Data_Wrap_Struct(self, NULL, annoy_index_free, ptr);
40
+ };
41
+
42
+ static void annoy_index_free(T* ptr) {
43
+ ptr->~AnnoyIndex();
44
+ ruby_xfree(ptr);
45
+ };
46
+
47
+ static T* get_annoy_index(VALUE self) {
48
+ T* ptr;
49
+ Data_Get_Struct(self, T, ptr);
50
+ return ptr;
51
+ };
52
+
53
+ static VALUE define_class(VALUE rb_mAnnoy, const char* class_name) {
54
+ VALUE rb_cAnnoyIndex = rb_define_class_under(rb_mAnnoy, class_name, rb_cObject);
55
+ rb_define_alloc_func(rb_cAnnoyIndex, annoy_index_alloc);
56
+ rb_define_method(rb_cAnnoyIndex, "initialize", RUBY_METHOD_FUNC(_annoy_index_init), 1);
57
+ rb_define_method(rb_cAnnoyIndex, "add_item", RUBY_METHOD_FUNC(_annoy_index_add_item), 2);
58
+ rb_define_method(rb_cAnnoyIndex, "build", RUBY_METHOD_FUNC(_annoy_index_build), 1);
59
+ rb_define_method(rb_cAnnoyIndex, "save", RUBY_METHOD_FUNC(_annoy_index_save), 2);
60
+ rb_define_method(rb_cAnnoyIndex, "load", RUBY_METHOD_FUNC(_annoy_index_load), 2);
61
+ rb_define_method(rb_cAnnoyIndex, "unload", RUBY_METHOD_FUNC(_annoy_index_unload), 0);
62
+ rb_define_method(rb_cAnnoyIndex, "get_nns_by_item", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_item), 4);
63
+ rb_define_method(rb_cAnnoyIndex, "get_nns_by_vector", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_vector), 4);
64
+ rb_define_method(rb_cAnnoyIndex, "get_item", RUBY_METHOD_FUNC(_annoy_index_get_item), 1);
65
+ rb_define_method(rb_cAnnoyIndex, "get_distance", RUBY_METHOD_FUNC(_annoy_index_get_distance), 2);
66
+ rb_define_method(rb_cAnnoyIndex, "get_n_items", RUBY_METHOD_FUNC(_annoy_index_get_n_items), 0);
67
+ rb_define_method(rb_cAnnoyIndex, "get_n_trees", RUBY_METHOD_FUNC(_annoy_index_get_n_trees), 0);
68
+ rb_define_method(rb_cAnnoyIndex, "on_disk_build", RUBY_METHOD_FUNC(_annoy_index_on_disk_build), 1);
69
+ rb_define_method(rb_cAnnoyIndex, "set_seed", RUBY_METHOD_FUNC(_annoy_index_set_seed), 1);
70
+ rb_define_method(rb_cAnnoyIndex, "verbose", RUBY_METHOD_FUNC(_annoy_index_verbose), 1);
71
+ rb_define_method(rb_cAnnoyIndex, "get_f", RUBY_METHOD_FUNC(_annoy_index_get_f), 0);
72
+ return rb_cAnnoyIndex;
73
+ };
74
+
75
+ private:
76
+
77
+ static VALUE _annoy_index_init(VALUE self, VALUE _n_dims) {
78
+ const int n_dims = NUM2INT(_n_dims);
79
+ T* ptr = get_annoy_index(self);
80
+ new (ptr) T(n_dims);
81
+ return Qnil;
82
+ };
83
+
84
+ static VALUE _annoy_index_add_item(VALUE self, VALUE _idx, VALUE arr) {
85
+ const int idx = NUM2INT(_idx);
86
+ const int n_dims = get_annoy_index(self)->get_f();
87
+
88
+ if (!RB_TYPE_P(arr, T_ARRAY)) {
89
+ rb_raise(rb_eArgError, "Expect item vector to be Array.");
90
+ return Qfalse;
91
+ }
92
+
93
+ if (n_dims != RARRAY_LEN(arr)) {
94
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
95
+ return Qfalse;
96
+ }
97
+
98
+ std::vector<F> vec(n_dims, 0);
99
+ for (int i = 0; i < n_dims; i++) {
100
+ vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(arr, i)) : NUM2UINT(rb_ary_entry(arr, i));
101
+ }
102
+
103
+ char* error;
104
+ if (!get_annoy_index(self)->add_item(idx, &vec[0], &error)) {
105
+ rb_raise(rb_eRuntimeError, "%s", error);
106
+ free(error);
107
+ return Qfalse;
108
+ }
109
+
110
+ return Qtrue;
111
+ };
112
+
113
+ static VALUE _annoy_index_build(VALUE self, VALUE _n_trees) {
114
+ const int n_trees = NUM2INT(_n_trees);
115
+ char* error;
116
+
117
+ if (!get_annoy_index(self)->build(n_trees, &error)) {
118
+ rb_raise(rb_eRuntimeError, "%s", error);
119
+ free(error);
120
+ return Qfalse;
121
+ }
122
+
123
+ return Qtrue;
124
+ };
125
+
126
+ static VALUE _annoy_index_save(VALUE self, VALUE _filename, VALUE _prefault) {
127
+ const char* filename = StringValuePtr(_filename);
128
+ const bool prefault = _prefault == Qtrue ? true : false;
129
+ char* error;
130
+
131
+ if (!get_annoy_index(self)->save(filename, prefault, &error)) {
132
+ rb_raise(rb_eRuntimeError, "%s", error);
133
+ free(error);
134
+ return Qfalse;
135
+ }
136
+
137
+ return Qtrue;
138
+ };
139
+
140
+ static VALUE _annoy_index_load(VALUE self, VALUE _filename, VALUE _prefault) {
141
+ const char* filename = StringValuePtr(_filename);
142
+ const bool prefault = _prefault == Qtrue ? true : false;
143
+ char* error;
144
+
145
+ if (!get_annoy_index(self)->load(filename, prefault, &error)) {
146
+ rb_raise(rb_eRuntimeError, "%s", error);
147
+ free(error);
148
+ return Qfalse;
149
+ }
150
+
151
+ return Qtrue;
152
+ };
153
+
154
+ static VALUE _annoy_index_unload(VALUE self) {
155
+ get_annoy_index(self)->unload();
156
+ return Qnil;
157
+ };
158
+
159
+ static VALUE _annoy_index_get_nns_by_item(VALUE self, VALUE _idx, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
160
+ const int idx = NUM2INT(_idx);
161
+ const int n_neighbors = NUM2INT(_n_neighbors);
162
+ const int search_k = NUM2INT(_search_k);
163
+ const bool include_distances = _include_distances == Qtrue ? true : false;
164
+ std::vector<int> neighbors;
165
+ std::vector<F> distances;
166
+
167
+ get_annoy_index(self)->get_nns_by_item(idx, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
168
+
169
+ const int sz_neighbors = neighbors.size();
170
+ VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
171
+
172
+ for (int i = 0; i < sz_neighbors; i++) {
173
+ rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
174
+ }
175
+
176
+ if (include_distances) {
177
+ const int sz_distances = distances.size();
178
+ VALUE distances_arr = rb_ary_new2(sz_distances);
179
+ for (int i = 0; i < sz_distances; i++) {
180
+ rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
181
+ }
182
+ VALUE res = rb_ary_new2(2);
183
+ rb_ary_store(res, 0, neighbors_arr);
184
+ rb_ary_store(res, 1, distances_arr);
185
+ return res;
186
+ }
187
+
188
+ return neighbors_arr;
189
+ };
190
+
191
+ static VALUE _annoy_index_get_nns_by_vector(VALUE self, VALUE _vec, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
192
+ const int n_dims = get_annoy_index(self)->get_f();
193
+
194
+ if (!RB_TYPE_P(_vec, T_ARRAY)) {
195
+ rb_raise(rb_eArgError, "Expect item vector to be Array.");
196
+ return Qfalse;
197
+ }
198
+
199
+ if (n_dims != RARRAY_LEN(_vec)) {
200
+ rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
201
+ return Qfalse;
202
+ }
203
+
204
+ std::vector<F> vec(n_dims, 0);
205
+ for (int i = 0; i < n_dims; i++) {
206
+ vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(_vec, i)) : NUM2UINT(rb_ary_entry(_vec, i));
207
+ }
208
+
209
+ const int n_neighbors = NUM2INT(_n_neighbors);
210
+ const int search_k = NUM2INT(_search_k);
211
+ const bool include_distances = _include_distances == Qtrue ? true : false;
212
+ std::vector<int> neighbors;
213
+ std::vector<F> distances;
214
+
215
+ get_annoy_index(self)->get_nns_by_vector(&vec[0], n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
216
+
217
+ const int sz_neighbors = neighbors.size();
218
+ VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
219
+
220
+ for (int i = 0; i < sz_neighbors; i++) {
221
+ rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
222
+ }
223
+
224
+ if (include_distances) {
225
+ const int sz_distances = distances.size();
226
+ VALUE distances_arr = rb_ary_new2(sz_distances);
227
+ for (int i = 0; i < sz_distances; i++) {
228
+ rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
229
+ }
230
+ VALUE res = rb_ary_new2(2);
231
+ rb_ary_store(res, 0, neighbors_arr);
232
+ rb_ary_store(res, 1, distances_arr);
233
+ return res;
234
+ }
235
+
236
+ return neighbors_arr;
237
+ };
238
+
239
+ static VALUE _annoy_index_get_item(VALUE self, VALUE _idx) {
240
+ const int idx = NUM2INT(_idx);
241
+ const int n_dims = get_annoy_index(self)->get_f();
242
+ std::vector<F> vec(n_dims, 0);
243
+ VALUE arr = rb_ary_new2(n_dims);
244
+
245
+ get_annoy_index(self)->get_item(idx, &vec[0]);
246
+
247
+ for (int i = 0; i < n_dims; i++) {
248
+ rb_ary_store(arr, i, typeid(F) == typeid(double) ? DBL2NUM(vec[i]) : UINT2NUM(vec[i]));
249
+ }
250
+
251
+ return arr;
252
+ };
253
+
254
+ static VALUE _annoy_index_get_distance(VALUE self, VALUE _i, VALUE _j) {
255
+ const int i = NUM2INT(_i);
256
+ const int j = NUM2INT(_j);
257
+ const double dist = get_annoy_index(self)->get_distance(i, j);
258
+ return DBL2NUM(dist);
259
+ };
260
+
261
+ static VALUE _annoy_index_get_n_items(VALUE self) {
262
+ const int32_t n_items = get_annoy_index(self)->get_n_items();
263
+ return INT2NUM(n_items);
264
+ };
265
+
266
+ static VALUE _annoy_index_get_n_trees(VALUE self) {
267
+ const int32_t n_trees = get_annoy_index(self)->get_n_trees();
268
+ return INT2NUM(n_trees);
269
+ };
270
+
271
+ static VALUE _annoy_index_on_disk_build(VALUE self, VALUE _filename) {
272
+ const char* filename = StringValuePtr(_filename);
273
+ char* error;
274
+ if (!get_annoy_index(self)->on_disk_build(filename, &error)) {
275
+ rb_raise(rb_eRuntimeError, "%s", error);
276
+ free(error);
277
+ return Qfalse;
278
+ }
279
+ return Qtrue;
280
+ };
281
+
282
+ static VALUE _annoy_index_set_seed(VALUE self, VALUE _seed) {
283
+ const int seed = NUM2INT(_seed);
284
+ get_annoy_index(self)->set_seed(seed);
285
+ return Qnil;
286
+ };
287
+
288
+ static VALUE _annoy_index_verbose(VALUE self, VALUE _flag) {
289
+ const bool flag = _flag == Qtrue ? true : false;
290
+ get_annoy_index(self)->verbose(flag);
291
+ return Qnil;
292
+ };
293
+
294
+ static VALUE _annoy_index_get_f(VALUE self) {
295
+ const int32_t f = get_annoy_index(self)->get_f();
296
+ return INT2NUM(f);
297
+ };
298
+ };
299
+
300
+ #endif /* ANNOY_HPP */
@@ -0,0 +1,9 @@
1
+ require 'mkmf'
2
+
3
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
4
+
5
+ $CXXFLAGS << " -march=native"
6
+ $INCFLAGS << " -I$(srcdir)/src"
7
+ $VPATH << "$(srcdir)/src"
8
+
9
+ create_makefile('annoy/annoy')
@@ -0,0 +1,1334 @@
1
+ // Copyright (c) 2013 Spotify AB
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4
+ // use this file except in compliance with the License. You may obtain a copy of
5
+ // the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11
+ // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12
+ // License for the specific language governing permissions and limitations under
13
+ // the License.
14
+
15
+
16
+ #ifndef ANNOYLIB_H
17
+ #define ANNOYLIB_H
18
+
19
+ #include <stdio.h>
20
+ #include <sys/stat.h>
21
+ #ifndef _MSC_VER
22
+ #include <unistd.h>
23
+ #endif
24
+ #include <stdio.h>
25
+ #include <stdlib.h>
26
+ #include <sys/types.h>
27
+ #include <fcntl.h>
28
+ #include <stddef.h>
29
+
30
+ #if defined(_MSC_VER) && _MSC_VER == 1500
31
+ typedef unsigned char uint8_t;
32
+ typedef signed __int32 int32_t;
33
+ typedef unsigned __int64 uint64_t;
34
+ typedef signed __int64 int64_t;
35
+ #else
36
+ #include <stdint.h>
37
+ #endif
38
+
39
+ #if defined(_MSC_VER) || defined(__MINGW32__)
40
+ // a bit hacky, but override some definitions to support 64 bit
41
+ #define off_t int64_t
42
+ #define lseek_getsize(fd) _lseeki64(fd, 0, SEEK_END)
43
+ #ifndef NOMINMAX
44
+ #define NOMINMAX
45
+ #endif
46
+ #include "mman.h"
47
+ #include <windows.h>
48
+ #else
49
+ #include <sys/mman.h>
50
+ #define lseek_getsize(fd) lseek(fd, 0, SEEK_END)
51
+ #endif
52
+
53
+ #include <cerrno>
54
+ #include <string.h>
55
+ #include <math.h>
56
+ #include <vector>
57
+ #include <algorithm>
58
+ #include <queue>
59
+ #include <limits>
60
+
61
+ #ifdef _MSC_VER
62
+ // Needed for Visual Studio to disable runtime checks for mempcy
63
+ #pragma runtime_checks("s", off)
64
+ #endif
65
+
66
+ // This allows others to supply their own logger / error printer without
67
+ // requiring Annoy to import their headers. See RcppAnnoy for a use case.
68
+ #ifndef __ERROR_PRINTER_OVERRIDE__
69
+ #define showUpdate(...) { fprintf(stderr, __VA_ARGS__ ); }
70
+ #else
71
+ #define showUpdate(...) { __ERROR_PRINTER_OVERRIDE__( __VA_ARGS__ ); }
72
+ #endif
73
+
74
+ // Portable alloc definition, cf Writing R Extensions, Section 1.6.4
75
+ #ifdef __GNUC__
76
+ // Includes GCC, clang and Intel compilers
77
+ # undef alloca
78
+ # define alloca(x) __builtin_alloca((x))
79
+ #elif defined(__sun) || defined(_AIX)
80
+ // this is necessary (and sufficient) for Solaris 10 and AIX 6:
81
+ # include <alloca.h>
82
+ #endif
83
+
84
+ inline void set_error_from_errno(char **error, const char* msg) {
85
+ showUpdate("%s: %s (%d)\n", msg, strerror(errno), errno);
86
+ if (error) {
87
+ *error = (char *)malloc(256); // TODO: win doesn't support snprintf
88
+ sprintf(*error, "%s: %s (%d)", msg, strerror(errno), errno);
89
+ }
90
+ }
91
+
92
+ inline void set_error_from_string(char **error, const char* msg) {
93
+ showUpdate("%s\n", msg);
94
+ if (error) {
95
+ *error = (char *)malloc(strlen(msg) + 1);
96
+ strcpy(*error, msg);
97
+ }
98
+ }
99
+
100
+ // We let the v array in the Node struct take whatever space is needed, so this is a mostly insignificant number.
101
+ // Compilers need *some* size defined for the v array, and some memory checking tools will flag for buffer overruns if this is set too low.
102
+ #define V_ARRAY_SIZE 65536
103
+
104
+ #ifndef _MSC_VER
105
+ #define popcount __builtin_popcountll
106
+ #else // See #293, #358
107
+ #define isnan(x) _isnan(x)
108
+ #define popcount cole_popcount
109
+ #endif
110
+
111
+ #if !defined(NO_MANUAL_VECTORIZATION) && defined(__GNUC__) && (__GNUC__ >6) && defined(__AVX512F__) // See #402
112
+ #define USE_AVX512
113
+ #elif !defined(NO_MANUAL_VECTORIZATION) && defined(__AVX__) && defined (__SSE__) && defined(__SSE2__) && defined(__SSE3__)
114
+ #define USE_AVX
115
+ #else
116
+ #endif
117
+
118
+ #if defined(USE_AVX) || defined(USE_AVX512)
119
+ #if defined(_MSC_VER)
120
+ #include <intrin.h>
121
+ #elif defined(__GNUC__)
122
+ #include <x86intrin.h>
123
+ #endif
124
+ #endif
125
+
126
+ #if !defined(__MINGW32__)
127
+ #define FTRUNCATE_SIZE(x) static_cast<int64_t>(x)
128
+ #else
129
+ #define FTRUNCATE_SIZE(x) (x)
130
+ #endif
131
+
132
+ using std::vector;
133
+ using std::pair;
134
+ using std::numeric_limits;
135
+ using std::make_pair;
136
+
137
+ inline bool remap_memory_and_truncate(void** _ptr, int _fd, size_t old_size, size_t new_size) {
138
+ #ifdef __linux__
139
+ *_ptr = mremap(*_ptr, old_size, new_size, MREMAP_MAYMOVE);
140
+ bool ok = ftruncate(_fd, new_size) != -1;
141
+ #else
142
+ munmap(*_ptr, old_size);
143
+ bool ok = ftruncate(_fd, FTRUNCATE_SIZE(new_size)) != -1;
144
+ #ifdef MAP_POPULATE
145
+ *_ptr = mmap(*_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0);
146
+ #else
147
+ *_ptr = mmap(*_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0);
148
+ #endif
149
+ #endif
150
+ return ok;
151
+ }
152
+
153
+ namespace {
154
+
155
+ template<typename S, typename Node>
156
+ inline Node* get_node_ptr(const void* _nodes, const size_t _s, const S i) {
157
+ return (Node*)((uint8_t *)_nodes + (_s * i));
158
+ }
159
+
160
+ template<typename T>
161
+ inline T dot(const T* x, const T* y, int f) {
162
+ T s = 0;
163
+ for (int z = 0; z < f; z++) {
164
+ s += (*x) * (*y);
165
+ x++;
166
+ y++;
167
+ }
168
+ return s;
169
+ }
170
+
171
+ template<typename T>
172
+ inline T manhattan_distance(const T* x, const T* y, int f) {
173
+ T d = 0.0;
174
+ for (int i = 0; i < f; i++)
175
+ d += fabs(x[i] - y[i]);
176
+ return d;
177
+ }
178
+
179
+ template<typename T>
180
+ inline T euclidean_distance(const T* x, const T* y, int f) {
181
+ // Don't use dot-product: avoid catastrophic cancellation in #314.
182
+ T d = 0.0;
183
+ for (int i = 0; i < f; ++i) {
184
+ const T tmp=*x - *y;
185
+ d += tmp * tmp;
186
+ ++x;
187
+ ++y;
188
+ }
189
+ return d;
190
+ }
191
+
192
+ #ifdef USE_AVX
193
+ // Horizontal single sum of 256bit vector.
194
+ inline float hsum256_ps_avx(__m256 v) {
195
+ const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
196
+ const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
197
+ const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
198
+ return _mm_cvtss_f32(x32);
199
+ }
200
+
201
+ template<>
202
+ inline float dot<float>(const float* x, const float *y, int f) {
203
+ float result = 0;
204
+ if (f > 7) {
205
+ __m256 d = _mm256_setzero_ps();
206
+ for (; f > 7; f -= 8) {
207
+ d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)));
208
+ x += 8;
209
+ y += 8;
210
+ }
211
+ // Sum all floats in dot register.
212
+ result += hsum256_ps_avx(d);
213
+ }
214
+ // Don't forget the remaining values.
215
+ for (; f > 0; f--) {
216
+ result += *x * *y;
217
+ x++;
218
+ y++;
219
+ }
220
+ return result;
221
+ }
222
+
223
+ template<>
224
+ inline float manhattan_distance<float>(const float* x, const float* y, int f) {
225
+ float result = 0;
226
+ int i = f;
227
+ if (f > 7) {
228
+ __m256 manhattan = _mm256_setzero_ps();
229
+ __m256 minus_zero = _mm256_set1_ps(-0.0f);
230
+ for (; i > 7; i -= 8) {
231
+ const __m256 x_minus_y = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y));
232
+ const __m256 distance = _mm256_andnot_ps(minus_zero, x_minus_y); // Absolute value of x_minus_y (forces sign bit to zero)
233
+ manhattan = _mm256_add_ps(manhattan, distance);
234
+ x += 8;
235
+ y += 8;
236
+ }
237
+ // Sum all floats in manhattan register.
238
+ result = hsum256_ps_avx(manhattan);
239
+ }
240
+ // Don't forget the remaining values.
241
+ for (; i > 0; i--) {
242
+ result += fabsf(*x - *y);
243
+ x++;
244
+ y++;
245
+ }
246
+ return result;
247
+ }
248
+
249
+ template<>
250
+ inline float euclidean_distance<float>(const float* x, const float* y, int f) {
251
+ float result=0;
252
+ if (f > 7) {
253
+ __m256 d = _mm256_setzero_ps();
254
+ for (; f > 7; f -= 8) {
255
+ const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y));
256
+ d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX...
257
+ x += 8;
258
+ y += 8;
259
+ }
260
+ // Sum all floats in dot register.
261
+ result = hsum256_ps_avx(d);
262
+ }
263
+ // Don't forget the remaining values.
264
+ for (; f > 0; f--) {
265
+ float tmp = *x - *y;
266
+ result += tmp * tmp;
267
+ x++;
268
+ y++;
269
+ }
270
+ return result;
271
+ }
272
+
273
+ #endif
274
+
275
+ #ifdef USE_AVX512
276
+ template<>
277
+ inline float dot<float>(const float* x, const float *y, int f) {
278
+ float result = 0;
279
+ if (f > 15) {
280
+ __m512 d = _mm512_setzero_ps();
281
+ for (; f > 15; f -= 16) {
282
+ //AVX512F includes FMA
283
+ d = _mm512_fmadd_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y), d);
284
+ x += 16;
285
+ y += 16;
286
+ }
287
+ // Sum all floats in dot register.
288
+ result += _mm512_reduce_add_ps(d);
289
+ }
290
+ // Don't forget the remaining values.
291
+ for (; f > 0; f--) {
292
+ result += *x * *y;
293
+ x++;
294
+ y++;
295
+ }
296
+ return result;
297
+ }
298
+
299
+ template<>
300
+ inline float manhattan_distance<float>(const float* x, const float* y, int f) {
301
+ float result = 0;
302
+ int i = f;
303
+ if (f > 15) {
304
+ __m512 manhattan = _mm512_setzero_ps();
305
+ for (; i > 15; i -= 16) {
306
+ const __m512 x_minus_y = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y));
307
+ manhattan = _mm512_add_ps(manhattan, _mm512_abs_ps(x_minus_y));
308
+ x += 16;
309
+ y += 16;
310
+ }
311
+ // Sum all floats in manhattan register.
312
+ result = _mm512_reduce_add_ps(manhattan);
313
+ }
314
+ // Don't forget the remaining values.
315
+ for (; i > 0; i--) {
316
+ result += fabsf(*x - *y);
317
+ x++;
318
+ y++;
319
+ }
320
+ return result;
321
+ }
322
+
323
+ template<>
324
+ inline float euclidean_distance<float>(const float* x, const float* y, int f) {
325
+ float result=0;
326
+ if (f > 15) {
327
+ __m512 d = _mm512_setzero_ps();
328
+ for (; f > 15; f -= 16) {
329
+ const __m512 diff = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y));
330
+ d = _mm512_fmadd_ps(diff, diff, d);
331
+ x += 16;
332
+ y += 16;
333
+ }
334
+ // Sum all floats in dot register.
335
+ result = _mm512_reduce_add_ps(d);
336
+ }
337
+ // Don't forget the remaining values.
338
+ for (; f > 0; f--) {
339
+ float tmp = *x - *y;
340
+ result += tmp * tmp;
341
+ x++;
342
+ y++;
343
+ }
344
+ return result;
345
+ }
346
+
347
+ #endif
348
+
349
+
350
+ template<typename T>
351
+ inline T get_norm(T* v, int f) {
352
+ return sqrt(dot(v, v, f));
353
+ }
354
+
355
+ template<typename T, typename Random, typename Distance, typename Node>
356
+ inline void two_means(const vector<Node*>& nodes, int f, Random& random, bool cosine, Node* p, Node* q) {
357
+ /*
358
+ This algorithm is a huge heuristic. Empirically it works really well, but I
359
+ can't motivate it well. The basic idea is to keep two centroids and assign
360
+ points to either one of them. We weight each centroid by the number of points
361
+ assigned to it, so to balance it.
362
+ */
363
+ static int iteration_steps = 200;
364
+ size_t count = nodes.size();
365
+
366
+ size_t i = random.index(count);
367
+ size_t j = random.index(count-1);
368
+ j += (j >= i); // ensure that i != j
369
+
370
+ Distance::template copy_node<T, Node>(p, nodes[i], f);
371
+ Distance::template copy_node<T, Node>(q, nodes[j], f);
372
+
373
+ if (cosine) { Distance::template normalize<T, Node>(p, f); Distance::template normalize<T, Node>(q, f); }
374
+ Distance::init_node(p, f);
375
+ Distance::init_node(q, f);
376
+
377
+ int ic = 1, jc = 1;
378
+ for (int l = 0; l < iteration_steps; l++) {
379
+ size_t k = random.index(count);
380
+ T di = ic * Distance::distance(p, nodes[k], f),
381
+ dj = jc * Distance::distance(q, nodes[k], f);
382
+ T norm = cosine ? get_norm(nodes[k]->v, f) : 1;
383
+ if (!(norm > T(0))) {
384
+ continue;
385
+ }
386
+ if (di < dj) {
387
+ for (int z = 0; z < f; z++)
388
+ p->v[z] = (p->v[z] * ic + nodes[k]->v[z] / norm) / (ic + 1);
389
+ Distance::init_node(p, f);
390
+ ic++;
391
+ } else if (dj < di) {
392
+ for (int z = 0; z < f; z++)
393
+ q->v[z] = (q->v[z] * jc + nodes[k]->v[z] / norm) / (jc + 1);
394
+ Distance::init_node(q, f);
395
+ jc++;
396
+ }
397
+ }
398
+ }
399
+ } // namespace
400
+
401
+ struct Base {
402
+ template<typename T, typename S, typename Node>
403
+ static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) {
404
+ // Override this in specific metric structs below if you need to do any pre-processing
405
+ // on the entire set of nodes passed into this index.
406
+ }
407
+
408
+ template<typename Node>
409
+ static inline void zero_value(Node* dest) {
410
+ // Initialize any fields that require sane defaults within this node.
411
+ }
412
+
413
+ template<typename T, typename Node>
414
+ static inline void copy_node(Node* dest, const Node* source, const int f) {
415
+ memcpy(dest->v, source->v, f * sizeof(T));
416
+ }
417
+
418
+ template<typename T, typename Node>
419
+ static inline void normalize(Node* node, int f) {
420
+ T norm = get_norm(node->v, f);
421
+ if (norm > 0) {
422
+ for (int z = 0; z < f; z++)
423
+ node->v[z] /= norm;
424
+ }
425
+ }
426
+ };
427
+
428
+ struct Angular : Base {
429
+ template<typename S, typename T>
430
+ struct Node {
431
+ /*
432
+ * We store a binary tree where each node has two things
433
+ * - A vector associated with it
434
+ * - Two children
435
+ * All nodes occupy the same amount of memory
436
+ * All nodes with n_descendants == 1 are leaf nodes.
437
+ * A memory optimization is that for nodes with 2 <= n_descendants <= K,
438
+ * we skip the vector. Instead we store a list of all descendants. K is
439
+ * determined by the number of items that fits in the space of the vector.
440
+ * For nodes with n_descendants == 1 the vector is a data point.
441
+ * For nodes with n_descendants > K the vector is the normal of the split plane.
442
+ * Note that we can't really do sizeof(node<T>) because we cheat and allocate
443
+ * more memory to be able to fit the vector outside
444
+ */
445
+ S n_descendants;
446
+ union {
447
+ S children[2]; // Will possibly store more than 2
448
+ T norm;
449
+ };
450
+ T v[V_ARRAY_SIZE];
451
+ };
452
+ template<typename S, typename T>
453
+ static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
454
+ // want to calculate (a/|a| - b/|b|)^2
455
+ // = a^2 / a^2 + b^2 / b^2 - 2ab/|a||b|
456
+ // = 2 - 2cos
457
+ T pp = x->norm ? x->norm : dot(x->v, x->v, f); // For backwards compatibility reasons, we need to fall back and compute the norm here
458
+ T qq = y->norm ? y->norm : dot(y->v, y->v, f);
459
+ T pq = dot(x->v, y->v, f);
460
+ T ppqq = pp * qq;
461
+ if (ppqq > 0) return 2.0 - 2.0 * pq / sqrt(ppqq);
462
+ else return 2.0; // cos is 0
463
+ }
464
+ template<typename S, typename T>
465
+ static inline T margin(const Node<S, T>* n, const T* y, int f) {
466
+ return dot(n->v, y, f);
467
+ }
468
+ template<typename S, typename T, typename Random>
469
+ static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
470
+ T dot = margin(n, y, f);
471
+ if (dot != 0)
472
+ return (dot > 0);
473
+ else
474
+ return (bool)random.flip();
475
+ }
476
+ template<typename S, typename T, typename Random>
477
+ static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
478
+ Node<S, T>* p = (Node<S, T>*)alloca(s);
479
+ Node<S, T>* q = (Node<S, T>*)alloca(s);
480
+ two_means<T, Random, Angular, Node<S, T> >(nodes, f, random, true, p, q);
481
+ for (int z = 0; z < f; z++)
482
+ n->v[z] = p->v[z] - q->v[z];
483
+ Base::normalize<T, Node<S, T> >(n, f);
484
+ }
485
+ template<typename T>
486
+ static inline T normalized_distance(T distance) {
487
+ // Used when requesting distances from Python layer
488
+ // Turns out sometimes the squared distance is -0.0
489
+ // so we have to make sure it's a positive number.
490
+ return sqrt(std::max(distance, T(0)));
491
+ }
492
+ template<typename T>
493
+ static inline T pq_distance(T distance, T margin, int child_nr) {
494
+ if (child_nr == 0)
495
+ margin = -margin;
496
+ return std::min(distance, margin);
497
+ }
498
+ template<typename T>
499
+ static inline T pq_initial_value() {
500
+ return numeric_limits<T>::infinity();
501
+ }
502
+ template<typename S, typename T>
503
+ static inline void init_node(Node<S, T>* n, int f) {
504
+ n->norm = dot(n->v, n->v, f);
505
+ }
506
+ static const char* name() {
507
+ return "angular";
508
+ }
509
+ };
510
+
511
+
512
+ struct DotProduct : Angular {
513
+ template<typename S, typename T>
514
+ struct Node {
515
+ /*
516
+ * This is an extension of the Angular node with an extra attribute for the scaled norm.
517
+ */
518
+ S n_descendants;
519
+ S children[2]; // Will possibly store more than 2
520
+ T dot_factor;
521
+ T v[V_ARRAY_SIZE];
522
+ };
523
+
524
+ static const char* name() {
525
+ return "dot";
526
+ }
527
+ template<typename S, typename T>
528
+ static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
529
+ return -dot(x->v, y->v, f);
530
+ }
531
+
532
+ template<typename Node>
533
+ static inline void zero_value(Node* dest) {
534
+ dest->dot_factor = 0;
535
+ }
536
+
537
+ template<typename S, typename T>
538
+ static inline void init_node(Node<S, T>* n, int f) {
539
+ }
540
+
541
+ template<typename T, typename Node>
542
+ static inline void copy_node(Node* dest, const Node* source, const int f) {
543
+ memcpy(dest->v, source->v, f * sizeof(T));
544
+ dest->dot_factor = source->dot_factor;
545
+ }
546
+
547
+ template<typename S, typename T, typename Random>
548
+ static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
549
+ Node<S, T>* p = (Node<S, T>*)alloca(s);
550
+ Node<S, T>* q = (Node<S, T>*)alloca(s);
551
+ DotProduct::zero_value(p);
552
+ DotProduct::zero_value(q);
553
+ two_means<T, Random, DotProduct, Node<S, T> >(nodes, f, random, true, p, q);
554
+ for (int z = 0; z < f; z++)
555
+ n->v[z] = p->v[z] - q->v[z];
556
+ n->dot_factor = p->dot_factor - q->dot_factor;
557
+ DotProduct::normalize<T, Node<S, T> >(n, f);
558
+ }
559
+
560
+ template<typename T, typename Node>
561
+ static inline void normalize(Node* node, int f) {
562
+ T norm = sqrt(dot(node->v, node->v, f) + pow(node->dot_factor, 2));
563
+ if (norm > 0) {
564
+ for (int z = 0; z < f; z++)
565
+ node->v[z] /= norm;
566
+ node->dot_factor /= norm;
567
+ }
568
+ }
569
+
570
+ template<typename S, typename T>
571
+ static inline T margin(const Node<S, T>* n, const T* y, int f) {
572
+ return dot(n->v, y, f) + (n->dot_factor * n->dot_factor);
573
+ }
574
+
575
+ template<typename S, typename T, typename Random>
576
+ static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
577
+ T dot = margin(n, y, f);
578
+ if (dot != 0)
579
+ return (dot > 0);
580
+ else
581
+ return (bool)random.flip();
582
+ }
583
+
584
+ template<typename T>
585
+ static inline T normalized_distance(T distance) {
586
+ return -distance;
587
+ }
588
+
589
+ template<typename T, typename S, typename Node>
590
+ static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) {
591
+ // This uses a method from Microsoft Research for transforming inner product spaces to cosine/angular-compatible spaces.
592
+ // (Bachrach et al., 2014, see https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf)
593
+
594
+ // Step one: compute the norm of each vector and store that in its extra dimension (f-1)
595
+ for (S i = 0; i < node_count; i++) {
596
+ Node* node = get_node_ptr<S, Node>(nodes, _s, i);
597
+ T norm = sqrt(dot(node->v, node->v, f));
598
+ if (isnan(norm)) norm = 0;
599
+ node->dot_factor = norm;
600
+ }
601
+
602
+ // Step two: find the maximum norm
603
+ T max_norm = 0;
604
+ for (S i = 0; i < node_count; i++) {
605
+ Node* node = get_node_ptr<S, Node>(nodes, _s, i);
606
+ if (node->dot_factor > max_norm) {
607
+ max_norm = node->dot_factor;
608
+ }
609
+ }
610
+
611
+ // Step three: set each vector's extra dimension to sqrt(max_norm^2 - norm^2)
612
+ for (S i = 0; i < node_count; i++) {
613
+ Node* node = get_node_ptr<S, Node>(nodes, _s, i);
614
+ T node_norm = node->dot_factor;
615
+
616
+ T dot_factor = sqrt(pow(max_norm, static_cast<T>(2.0)) - pow(node_norm, static_cast<T>(2.0)));
617
+ if (isnan(dot_factor)) dot_factor = 0;
618
+
619
+ node->dot_factor = dot_factor;
620
+ }
621
+ }
622
+ };
623
+
624
+ struct Hamming : Base {
625
+ template<typename S, typename T>
626
+ struct Node {
627
+ S n_descendants;
628
+ S children[2];
629
+ T v[V_ARRAY_SIZE];
630
+ };
631
+
632
+ static const size_t max_iterations = 20;
633
+
634
+ template<typename T>
635
+ static inline T pq_distance(T distance, T margin, int child_nr) {
636
+ return distance - (margin != (unsigned int) child_nr);
637
+ }
638
+
639
+ template<typename T>
640
+ static inline T pq_initial_value() {
641
+ return numeric_limits<T>::max();
642
+ }
643
+ template<typename T>
644
+ static inline int cole_popcount(T v) {
645
+ // Note: Only used with MSVC 9, which lacks intrinsics and fails to
646
+ // calculate std::bitset::count for v > 32bit. Uses the generalized
647
+ // approach by Eric Cole.
648
+ // See https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64
649
+ v = v - ((v >> 1) & (T)~(T)0/3);
650
+ v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3);
651
+ v = (v + (v >> 4)) & (T)~(T)0/255*15;
652
+ return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * 8;
653
+ }
654
+ template<typename S, typename T>
655
+ static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
656
+ size_t dist = 0;
657
+ for (int i = 0; i < f; i++) {
658
+ dist += popcount(x->v[i] ^ y->v[i]);
659
+ }
660
+ return dist;
661
+ }
662
+ template<typename S, typename T>
663
+ static inline bool margin(const Node<S, T>* n, const T* y, int f) {
664
+ static const size_t n_bits = sizeof(T) * 8;
665
+ T chunk = n->v[0] / n_bits;
666
+ return (y[chunk] & (static_cast<T>(1) << (n_bits - 1 - (n->v[0] % n_bits)))) != 0;
667
+ }
668
+ template<typename S, typename T, typename Random>
669
+ static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
670
+ return margin(n, y, f);
671
+ }
672
+ template<typename S, typename T, typename Random>
673
+ static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
674
+ size_t cur_size = 0;
675
+ size_t i = 0;
676
+ int dim = f * 8 * sizeof(T);
677
+ for (; i < max_iterations; i++) {
678
+ // choose random position to split at
679
+ n->v[0] = random.index(dim);
680
+ cur_size = 0;
681
+ for (typename vector<Node<S, T>*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) {
682
+ if (margin(n, (*it)->v, f)) {
683
+ cur_size++;
684
+ }
685
+ }
686
+ if (cur_size > 0 && cur_size < nodes.size()) {
687
+ break;
688
+ }
689
+ }
690
+ // brute-force search for splitting coordinate
691
+ if (i == max_iterations) {
692
+ int j = 0;
693
+ for (; j < dim; j++) {
694
+ n->v[0] = j;
695
+ cur_size = 0;
696
+ for (typename vector<Node<S, T>*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) {
697
+ if (margin(n, (*it)->v, f)) {
698
+ cur_size++;
699
+ }
700
+ }
701
+ if (cur_size > 0 && cur_size < nodes.size()) {
702
+ break;
703
+ }
704
+ }
705
+ }
706
+ }
707
+ template<typename T>
708
+ static inline T normalized_distance(T distance) {
709
+ return distance;
710
+ }
711
+ template<typename S, typename T>
712
+ static inline void init_node(Node<S, T>* n, int f) {
713
+ }
714
+ static const char* name() {
715
+ return "hamming";
716
+ }
717
+ };
718
+
719
+
720
+ struct Minkowski : Base {
721
+ template<typename S, typename T>
722
+ struct Node {
723
+ S n_descendants;
724
+ T a; // need an extra constant term to determine the offset of the plane
725
+ S children[2];
726
+ T v[V_ARRAY_SIZE];
727
+ };
728
+ template<typename S, typename T>
729
+ static inline T margin(const Node<S, T>* n, const T* y, int f) {
730
+ return n->a + dot(n->v, y, f);
731
+ }
732
+ template<typename S, typename T, typename Random>
733
+ static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
734
+ T dot = margin(n, y, f);
735
+ if (dot != 0)
736
+ return (dot > 0);
737
+ else
738
+ return (bool)random.flip();
739
+ }
740
+ template<typename T>
741
+ static inline T pq_distance(T distance, T margin, int child_nr) {
742
+ if (child_nr == 0)
743
+ margin = -margin;
744
+ return std::min(distance, margin);
745
+ }
746
+ template<typename T>
747
+ static inline T pq_initial_value() {
748
+ return numeric_limits<T>::infinity();
749
+ }
750
+ };
751
+
752
+
753
+ struct Euclidean : Minkowski {
754
+ template<typename S, typename T>
755
+ static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
756
+ return euclidean_distance(x->v, y->v, f);
757
+ }
758
+ template<typename S, typename T, typename Random>
759
+ static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
760
+ Node<S, T>* p = (Node<S, T>*)alloca(s);
761
+ Node<S, T>* q = (Node<S, T>*)alloca(s);
762
+ two_means<T, Random, Euclidean, Node<S, T> >(nodes, f, random, false, p, q);
763
+
764
+ for (int z = 0; z < f; z++)
765
+ n->v[z] = p->v[z] - q->v[z];
766
+ Base::normalize<T, Node<S, T> >(n, f);
767
+ n->a = 0.0;
768
+ for (int z = 0; z < f; z++)
769
+ n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2;
770
+ }
771
+ template<typename T>
772
+ static inline T normalized_distance(T distance) {
773
+ return sqrt(std::max(distance, T(0)));
774
+ }
775
+ template<typename S, typename T>
776
+ static inline void init_node(Node<S, T>* n, int f) {
777
+ }
778
+ static const char* name() {
779
+ return "euclidean";
780
+ }
781
+
782
+ };
783
+
784
+ struct Manhattan : Minkowski {
785
+ template<typename S, typename T>
786
+ static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
787
+ return manhattan_distance(x->v, y->v, f);
788
+ }
789
+ template<typename S, typename T, typename Random>
790
+ static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
791
+ Node<S, T>* p = (Node<S, T>*)alloca(s);
792
+ Node<S, T>* q = (Node<S, T>*)alloca(s);
793
+ two_means<T, Random, Manhattan, Node<S, T> >(nodes, f, random, false, p, q);
794
+
795
+ for (int z = 0; z < f; z++)
796
+ n->v[z] = p->v[z] - q->v[z];
797
+ Base::normalize<T, Node<S, T> >(n, f);
798
+ n->a = 0.0;
799
+ for (int z = 0; z < f; z++)
800
+ n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2;
801
+ }
802
+ template<typename T>
803
+ static inline T normalized_distance(T distance) {
804
+ return std::max(distance, T(0));
805
+ }
806
+ template<typename S, typename T>
807
+ static inline void init_node(Node<S, T>* n, int f) {
808
+ }
809
+ static const char* name() {
810
+ return "manhattan";
811
+ }
812
+ };
813
+
814
+ template<typename S, typename T>
815
+ class AnnoyIndexInterface {
816
+ public:
817
+ // Note that the methods with an **error argument will allocate memory and write the pointer to that string if error is non-NULL
818
+ virtual ~AnnoyIndexInterface() {};
819
+ virtual bool add_item(S item, const T* w, char** error=NULL) = 0;
820
+ virtual bool build(int q, char** error=NULL) = 0;
821
+ virtual bool unbuild(char** error=NULL) = 0;
822
+ virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0;
823
+ virtual void unload() = 0;
824
+ virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0;
825
+ virtual T get_distance(S i, S j) const = 0;
826
+ virtual void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
827
+ virtual void get_nns_by_vector(const T* w, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
828
+ virtual S get_n_items() const = 0;
829
+ virtual S get_n_trees() const = 0;
830
+ virtual void verbose(bool v) = 0;
831
+ virtual void get_item(S item, T* v) const = 0;
832
+ virtual void set_seed(int q) = 0;
833
+ virtual bool on_disk_build(const char* filename, char** error=NULL) = 0;
834
+ };
835
+
836
+ template<typename S, typename T, typename Distance, typename Random>
837
+ class AnnoyIndex : public AnnoyIndexInterface<S, T> {
838
+ /*
839
+ * We use random projection to build a forest of binary trees of all items.
840
+ * Basically just split the hyperspace into two sides by a hyperplane,
841
+ * then recursively split each of those subtrees etc.
842
+ * We create a tree like this q times. The default q is determined automatically
843
+ * in such a way that we at most use 2x as much memory as the vectors take.
844
+ */
845
+ public:
846
+ typedef Distance D;
847
+ typedef typename D::template Node<S, T> Node;
848
+
849
+ protected:
850
+ const int _f;
851
+ size_t _s;
852
+ S _n_items;
853
+ Random _random;
854
+ void* _nodes; // Could either be mmapped, or point to a memory buffer that we reallocate
855
+ S _n_nodes;
856
+ S _nodes_size;
857
+ vector<S> _roots;
858
+ S _K;
859
+ bool _loaded;
860
+ bool _verbose;
861
+ int _fd;
862
+ bool _on_disk;
863
+ bool _built;
864
+ public:
865
+
866
+ AnnoyIndex(int f) : _f(f), _random() {
867
+ _s = offsetof(Node, v) + _f * sizeof(T); // Size of each node
868
+ _verbose = false;
869
+ _built = false;
870
+ _K = (S) (((size_t) (_s - offsetof(Node, children))) / sizeof(S)); // Max number of descendants to fit into node
871
+ reinitialize(); // Reset everything
872
+ }
873
+ ~AnnoyIndex() {
874
+ unload();
875
+ }
876
+
877
+ int get_f() const {
878
+ return _f;
879
+ }
880
+
881
+ bool add_item(S item, const T* w, char** error=NULL) {
882
+ return add_item_impl(item, w, error);
883
+ }
884
+
885
+ template<typename W>
886
+ bool add_item_impl(S item, const W& w, char** error=NULL) {
887
+ if (_loaded) {
888
+ set_error_from_string(error, "You can't add an item to a loaded index");
889
+ return false;
890
+ }
891
+ _allocate_size(item + 1);
892
+ Node* n = _get(item);
893
+
894
+ D::zero_value(n);
895
+
896
+ n->children[0] = 0;
897
+ n->children[1] = 0;
898
+ n->n_descendants = 1;
899
+
900
+ for (int z = 0; z < _f; z++)
901
+ n->v[z] = w[z];
902
+
903
+ D::init_node(n, _f);
904
+
905
+ if (item >= _n_items)
906
+ _n_items = item + 1;
907
+
908
+ return true;
909
+ }
910
+
911
+ bool on_disk_build(const char* file, char** error=NULL) {
912
+ _on_disk = true;
913
+ _fd = open(file, O_RDWR | O_CREAT | O_TRUNC, (int) 0600);
914
+ if (_fd == -1) {
915
+ set_error_from_errno(error, "Unable to open");
916
+ _fd = 0;
917
+ return false;
918
+ }
919
+ _nodes_size = 1;
920
+ if (ftruncate(_fd, FTRUNCATE_SIZE(_s) * FTRUNCATE_SIZE(_nodes_size)) == -1) {
921
+ set_error_from_errno(error, "Unable to truncate");
922
+ return false;
923
+ }
924
+ #ifdef MAP_POPULATE
925
+ _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0);
926
+ #else
927
+ _nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0);
928
+ #endif
929
+ return true;
930
+ }
931
+
932
+ bool build(int q, char** error=NULL) {
933
+ if (_loaded) {
934
+ set_error_from_string(error, "You can't build a loaded index");
935
+ return false;
936
+ }
937
+
938
+ if (_built) {
939
+ set_error_from_string(error, "You can't build a built index");
940
+ return false;
941
+ }
942
+
943
+ D::template preprocess<T, S, Node>(_nodes, _s, _n_items, _f);
944
+
945
+ _n_nodes = _n_items;
946
+ while (1) {
947
+ if (q == -1 && _n_nodes >= _n_items * 2)
948
+ break;
949
+ if (q != -1 && _roots.size() >= (size_t)q)
950
+ break;
951
+ if (_verbose) showUpdate("pass %zd...\n", _roots.size());
952
+
953
+ vector<S> indices;
954
+ for (S i = 0; i < _n_items; i++) {
955
+ if (_get(i)->n_descendants >= 1) // Issue #223
956
+ indices.push_back(i);
957
+ }
958
+
959
+ _roots.push_back(_make_tree(indices, true));
960
+ }
961
+
962
+ // Also, copy the roots into the last segment of the array
963
+ // This way we can load them faster without reading the whole file
964
+ _allocate_size(_n_nodes + (S)_roots.size());
965
+ for (size_t i = 0; i < _roots.size(); i++)
966
+ memcpy(_get(_n_nodes + (S)i), _get(_roots[i]), _s);
967
+ _n_nodes += _roots.size();
968
+
969
+ if (_verbose) showUpdate("has %d nodes\n", _n_nodes);
970
+
971
+ if (_on_disk) {
972
+ if (!remap_memory_and_truncate(&_nodes, _fd,
973
+ static_cast<size_t>(_s) * static_cast<size_t>(_nodes_size),
974
+ static_cast<size_t>(_s) * static_cast<size_t>(_n_nodes))) {
975
+ // TODO: this probably creates an index in a corrupt state... not sure what to do
976
+ set_error_from_errno(error, "Unable to truncate");
977
+ return false;
978
+ }
979
+ _nodes_size = _n_nodes;
980
+ }
981
+ _built = true;
982
+ return true;
983
+ }
984
+
985
+ bool unbuild(char** error=NULL) {
986
+ if (_loaded) {
987
+ set_error_from_string(error, "You can't unbuild a loaded index");
988
+ return false;
989
+ }
990
+
991
+ _roots.clear();
992
+ _n_nodes = _n_items;
993
+ _built = false;
994
+
995
+ return true;
996
+ }
997
+
998
+ bool save(const char* filename, bool prefault=false, char** error=NULL) {
999
+ if (!_built) {
1000
+ set_error_from_string(error, "You can't save an index that hasn't been built");
1001
+ return false;
1002
+ }
1003
+ if (_on_disk) {
1004
+ return true;
1005
+ } else {
1006
+ // Delete file if it already exists (See issue #335)
1007
+ unlink(filename);
1008
+
1009
+ FILE *f = fopen(filename, "wb");
1010
+ if (f == NULL) {
1011
+ set_error_from_errno(error, "Unable to open");
1012
+ return false;
1013
+ }
1014
+
1015
+ if (fwrite(_nodes, _s, _n_nodes, f) != (size_t) _n_nodes) {
1016
+ set_error_from_errno(error, "Unable to write");
1017
+ return false;
1018
+ }
1019
+
1020
+ if (fclose(f) == EOF) {
1021
+ set_error_from_errno(error, "Unable to close");
1022
+ return false;
1023
+ }
1024
+
1025
+ unload();
1026
+ return load(filename, prefault, error);
1027
+ }
1028
+ }
1029
+
1030
+ void reinitialize() {
1031
+ _fd = 0;
1032
+ _nodes = NULL;
1033
+ _loaded = false;
1034
+ _n_items = 0;
1035
+ _n_nodes = 0;
1036
+ _nodes_size = 0;
1037
+ _on_disk = false;
1038
+ _roots.clear();
1039
+ }
1040
+
1041
+ void unload() {
1042
+ if (_on_disk && _fd) {
1043
+ close(_fd);
1044
+ munmap(_nodes, _s * _nodes_size);
1045
+ } else {
1046
+ if (_fd) {
1047
+ // we have mmapped data
1048
+ close(_fd);
1049
+ munmap(_nodes, _n_nodes * _s);
1050
+ } else if (_nodes) {
1051
+ // We have heap allocated data
1052
+ free(_nodes);
1053
+ }
1054
+ }
1055
+ reinitialize();
1056
+ if (_verbose) showUpdate("unloaded\n");
1057
+ }
1058
+
1059
+ bool load(const char* filename, bool prefault=false, char** error=NULL) {
1060
+ _fd = open(filename, O_RDONLY, (int)0400);
1061
+ if (_fd == -1) {
1062
+ set_error_from_errno(error, "Unable to open");
1063
+ _fd = 0;
1064
+ return false;
1065
+ }
1066
+ off_t size = lseek_getsize(_fd);
1067
+ if (size == -1) {
1068
+ set_error_from_errno(error, "Unable to get size");
1069
+ return false;
1070
+ } else if (size == 0) {
1071
+ set_error_from_errno(error, "Size of file is zero");
1072
+ return false;
1073
+ } else if (size % _s) {
1074
+ // Something is fishy with this index!
1075
+ set_error_from_errno(error, "Index size is not a multiple of vector size. Ensure you are opening using the same metric you used to create the index.");
1076
+ return false;
1077
+ }
1078
+
1079
+ int flags = MAP_SHARED;
1080
+ if (prefault) {
1081
+ #ifdef MAP_POPULATE
1082
+ flags |= MAP_POPULATE;
1083
+ #else
1084
+ showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform");
1085
+ #endif
1086
+ }
1087
+ _nodes = (Node*)mmap(0, size, PROT_READ, flags, _fd, 0);
1088
+ _n_nodes = (S)(size / _s);
1089
+
1090
+ // Find the roots by scanning the end of the file and taking the nodes with most descendants
1091
+ _roots.clear();
1092
+ S m = -1;
1093
+ for (S i = _n_nodes - 1; i >= 0; i--) {
1094
+ S k = _get(i)->n_descendants;
1095
+ if (m == -1 || k == m) {
1096
+ _roots.push_back(i);
1097
+ m = k;
1098
+ } else {
1099
+ break;
1100
+ }
1101
+ }
1102
+ // hacky fix: since the last root precedes the copy of all roots, delete it
1103
+ if (_roots.size() > 1 && _get(_roots.front())->children[0] == _get(_roots.back())->children[0])
1104
+ _roots.pop_back();
1105
+ _loaded = true;
1106
+ _built = true;
1107
+ _n_items = m;
1108
+ if (_verbose) showUpdate("found %lu roots with degree %d\n", _roots.size(), m);
1109
+ return true;
1110
+ }
1111
+
1112
+ T get_distance(S i, S j) const {
1113
+ return D::normalized_distance(D::distance(_get(i), _get(j), _f));
1114
+ }
1115
+
1116
+ void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
1117
+ // TODO: handle OOB
1118
+ const Node* m = _get(item);
1119
+ _get_all_nns(m->v, n, search_k, result, distances);
1120
+ }
1121
+
1122
+ void get_nns_by_vector(const T* w, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
1123
+ _get_all_nns(w, n, search_k, result, distances);
1124
+ }
1125
+
1126
+ S get_n_items() const {
1127
+ return _n_items;
1128
+ }
1129
+
1130
+ S get_n_trees() const {
1131
+ return (S)_roots.size();
1132
+ }
1133
+
1134
+ void verbose(bool v) {
1135
+ _verbose = v;
1136
+ }
1137
+
1138
+ void get_item(S item, T* v) const {
1139
+ // TODO: handle OOB
1140
+ Node* m = _get(item);
1141
+ memcpy(v, m->v, (_f) * sizeof(T));
1142
+ }
1143
+
1144
+ void set_seed(int seed) {
1145
+ _random.set_seed(seed);
1146
+ }
1147
+
1148
+ protected:
1149
+ void _allocate_size(S n) {
1150
+ if (n > _nodes_size) {
1151
+ const double reallocation_factor = 1.3;
1152
+ S new_nodes_size = std::max(n, (S) ((_nodes_size + 1) * reallocation_factor));
1153
+ void *old = _nodes;
1154
+
1155
+ if (_on_disk) {
1156
+ if (!remap_memory_and_truncate(&_nodes, _fd,
1157
+ static_cast<size_t>(_s) * static_cast<size_t>(_nodes_size),
1158
+ static_cast<size_t>(_s) * static_cast<size_t>(new_nodes_size)) &&
1159
+ _verbose)
1160
+ showUpdate("File truncation error\n");
1161
+ } else {
1162
+ _nodes = realloc(_nodes, _s * new_nodes_size);
1163
+ memset((char *) _nodes + (_nodes_size * _s) / sizeof(char), 0, (new_nodes_size - _nodes_size) * _s);
1164
+ }
1165
+
1166
+ _nodes_size = new_nodes_size;
1167
+ if (_verbose) showUpdate("Reallocating to %d nodes: old_address=%p, new_address=%p\n", new_nodes_size, old, _nodes);
1168
+ }
1169
+ }
1170
+
1171
+ Node* _get(const S i) const {
1172
+ return get_node_ptr<S, Node>(_nodes, _s, i);
1173
+ }
1174
+
1175
+ double _split_imbalance(const vector<S>& left_indices, const vector<S>& right_indices) {
1176
+ double ls = (float)left_indices.size();
1177
+ double rs = (float)right_indices.size();
1178
+ float f = ls / (ls + rs + 1e-9); // Avoid 0/0
1179
+ return std::max(f, 1-f);
1180
+ }
1181
+
1182
+ S _make_tree(const vector<S>& indices, bool is_root) {
1183
+ // The basic rule is that if we have <= _K items, then it's a leaf node, otherwise it's a split node.
1184
+ // There's some regrettable complications caused by the problem that root nodes have to be "special":
1185
+ // 1. We identify root nodes by the arguable logic that _n_items == n->n_descendants, regardless of how many descendants they actually have
1186
+ // 2. Root nodes with only 1 child need to be a "dummy" parent
1187
+ // 3. Due to the _n_items "hack", we need to be careful with the cases where _n_items <= _K or _n_items > _K
1188
+ if (indices.size() == 1 && !is_root)
1189
+ return indices[0];
1190
+
1191
+ if (indices.size() <= (size_t)_K && (!is_root || (size_t)_n_items <= (size_t)_K || indices.size() == 1)) {
1192
+ _allocate_size(_n_nodes + 1);
1193
+ S item = _n_nodes++;
1194
+ Node* m = _get(item);
1195
+ m->n_descendants = is_root ? _n_items : (S)indices.size();
1196
+
1197
+ // Using std::copy instead of a loop seems to resolve issues #3 and #13,
1198
+ // probably because gcc 4.8 goes overboard with optimizations.
1199
+ // Using memcpy instead of std::copy for MSVC compatibility. #235
1200
+ // Only copy when necessary to avoid crash in MSVC 9. #293
1201
+ if (!indices.empty())
1202
+ memcpy(m->children, &indices[0], indices.size() * sizeof(S));
1203
+ return item;
1204
+ }
1205
+
1206
+ vector<Node*> children;
1207
+ for (size_t i = 0; i < indices.size(); i++) {
1208
+ S j = indices[i];
1209
+ Node* n = _get(j);
1210
+ if (n)
1211
+ children.push_back(n);
1212
+ }
1213
+
1214
+ vector<S> children_indices[2];
1215
+ Node* m = (Node*)alloca(_s);
1216
+
1217
+ for (int attempt = 0; attempt < 3; attempt++) {
1218
+ children_indices[0].clear();
1219
+ children_indices[1].clear();
1220
+ D::create_split(children, _f, _s, _random, m);
1221
+
1222
+ for (size_t i = 0; i < indices.size(); i++) {
1223
+ S j = indices[i];
1224
+ Node* n = _get(j);
1225
+ if (n) {
1226
+ bool side = D::side(m, n->v, _f, _random);
1227
+ children_indices[side].push_back(j);
1228
+ } else {
1229
+ showUpdate("No node for index %d?\n", j);
1230
+ }
1231
+ }
1232
+
1233
+ if (_split_imbalance(children_indices[0], children_indices[1]) < 0.95)
1234
+ break;
1235
+ }
1236
+
1237
+ // If we didn't find a hyperplane, just randomize sides as a last option
1238
+ while (_split_imbalance(children_indices[0], children_indices[1]) > 0.99) {
1239
+ if (_verbose)
1240
+ showUpdate("\tNo hyperplane found (left has %ld children, right has %ld children)\n",
1241
+ children_indices[0].size(), children_indices[1].size());
1242
+
1243
+ children_indices[0].clear();
1244
+ children_indices[1].clear();
1245
+
1246
+ // Set the vector to 0.0
1247
+ for (int z = 0; z < _f; z++)
1248
+ m->v[z] = 0;
1249
+
1250
+ for (size_t i = 0; i < indices.size(); i++) {
1251
+ S j = indices[i];
1252
+ // Just randomize...
1253
+ children_indices[_random.flip()].push_back(j);
1254
+ }
1255
+ }
1256
+
1257
+ int flip = (children_indices[0].size() > children_indices[1].size());
1258
+
1259
+ m->n_descendants = is_root ? _n_items : (S)indices.size();
1260
+ for (int side = 0; side < 2; side++) {
1261
+ // run _make_tree for the smallest child first (for cache locality)
1262
+ m->children[side^flip] = _make_tree(children_indices[side^flip], false);
1263
+ }
1264
+
1265
+
1266
+ _allocate_size(_n_nodes + 1);
1267
+ S item = _n_nodes++;
1268
+ memcpy(_get(item), m, _s);
1269
+
1270
+ return item;
1271
+ }
1272
+
1273
+ void _get_all_nns(const T* v, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
1274
+ Node* v_node = (Node *)alloca(_s);
1275
+ D::template zero_value<Node>(v_node);
1276
+ memcpy(v_node->v, v, sizeof(T) * _f);
1277
+ D::init_node(v_node, _f);
1278
+
1279
+ std::priority_queue<pair<T, S> > q;
1280
+
1281
+ if (search_k == -1) {
1282
+ search_k = n * _roots.size();
1283
+ }
1284
+
1285
+ for (size_t i = 0; i < _roots.size(); i++) {
1286
+ q.push(make_pair(Distance::template pq_initial_value<T>(), _roots[i]));
1287
+ }
1288
+
1289
+ std::vector<S> nns;
1290
+ while (nns.size() < (size_t)search_k && !q.empty()) {
1291
+ const pair<T, S>& top = q.top();
1292
+ T d = top.first;
1293
+ S i = top.second;
1294
+ Node* nd = _get(i);
1295
+ q.pop();
1296
+ if (nd->n_descendants == 1 && i < _n_items) {
1297
+ nns.push_back(i);
1298
+ } else if (nd->n_descendants <= _K) {
1299
+ const S* dst = nd->children;
1300
+ nns.insert(nns.end(), dst, &dst[nd->n_descendants]);
1301
+ } else {
1302
+ T margin = D::margin(nd, v, _f);
1303
+ q.push(make_pair(D::pq_distance(d, margin, 1), static_cast<S>(nd->children[1])));
1304
+ q.push(make_pair(D::pq_distance(d, margin, 0), static_cast<S>(nd->children[0])));
1305
+ }
1306
+ }
1307
+
1308
+ // Get distances for all items
1309
+ // To avoid calculating distance multiple times for any items, sort by id
1310
+ std::sort(nns.begin(), nns.end());
1311
+ vector<pair<T, S> > nns_dist;
1312
+ S last = -1;
1313
+ for (size_t i = 0; i < nns.size(); i++) {
1314
+ S j = nns[i];
1315
+ if (j == last)
1316
+ continue;
1317
+ last = j;
1318
+ if (_get(j)->n_descendants == 1) // This is only to guard a really obscure case, #284
1319
+ nns_dist.push_back(make_pair(D::distance(v_node, _get(j), _f), j));
1320
+ }
1321
+
1322
+ size_t m = nns_dist.size();
1323
+ size_t p = n < m ? n : m; // Return this many items
1324
+ std::partial_sort(nns_dist.begin(), nns_dist.begin() + p, nns_dist.end());
1325
+ for (size_t i = 0; i < p; i++) {
1326
+ if (distances)
1327
+ distances->push_back(D::normalized_distance(nns_dist[i].first));
1328
+ result->push_back(nns_dist[i].second);
1329
+ }
1330
+ }
1331
+ };
1332
+
1333
+ #endif
1334
+ // vim: tabstop=2 shiftwidth=2