annoy-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +21 -0
- data/.rspec +3 -0
- data/.travis.yml +12 -0
- data/CHANGELOG.md +2 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +8 -0
- data/LICENSE.txt +177 -0
- data/README.md +58 -0
- data/Rakefile +14 -0
- data/annoy-rb.gemspec +27 -0
- data/ext/annoy/annoy.cpp +30 -0
- data/ext/annoy/annoy.hpp +300 -0
- data/ext/annoy/extconf.rb +9 -0
- data/ext/annoy/src/annoylib.h +1334 -0
- data/ext/annoy/src/kissrandom.h +106 -0
- data/lib/annoy.rb +174 -0
- data/lib/annoy/version.rb +7 -0
- metadata +65 -0
data/ext/annoy/annoy.hpp
ADDED
@@ -0,0 +1,300 @@
|
|
1
|
+
/**
|
2
|
+
* Annoy.rb is a Ruby binding for the Annoy (Approximate Nearest Neighbors Oh Yeah).
|
3
|
+
*
|
4
|
+
* Copyright (c) 2020 Atsushi Tatsuma
|
5
|
+
*
|
6
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
* you may not use this file except in compliance with the License.
|
8
|
+
* You may obtain a copy of the License at
|
9
|
+
*
|
10
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
*
|
12
|
+
* Unless required by applicable law or agreed to in writing, software
|
13
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
* See the License for the specific language governing permissions and
|
16
|
+
* limitations under the License.
|
17
|
+
*/
|
18
|
+
|
19
|
+
#ifndef ANNOY_HPP
|
20
|
+
#define ANNOY_HPP 1
|
21
|
+
|
22
|
+
#include <typeinfo>
|
23
|
+
|
24
|
+
#include <ruby.h>
|
25
|
+
#include <annoylib.h>
|
26
|
+
#include <kissrandom.h>
|
27
|
+
|
28
|
+
typedef AnnoyIndex<int, double, Angular, Kiss64Random> AnnoyIndexAngular;
|
29
|
+
typedef AnnoyIndex<int, double, DotProduct, Kiss64Random> AnnoyIndexDotProduct;
|
30
|
+
typedef AnnoyIndex<int, uint64_t, Hamming, Kiss64Random> AnnoyIndexHamming;
|
31
|
+
typedef AnnoyIndex<int, double, Euclidean, Kiss64Random> AnnoyIndexEuclidean;
|
32
|
+
typedef AnnoyIndex<int, double, Manhattan, Kiss64Random> AnnoyIndexManhattan;
|
33
|
+
|
34
|
+
template<class T, typename F> class RbAnnoyIndex
|
35
|
+
{
|
36
|
+
public:
|
37
|
+
static VALUE annoy_index_alloc(VALUE self) {
|
38
|
+
T* ptr = (T*)ruby_xmalloc(sizeof(T));
|
39
|
+
return Data_Wrap_Struct(self, NULL, annoy_index_free, ptr);
|
40
|
+
};
|
41
|
+
|
42
|
+
static void annoy_index_free(T* ptr) {
|
43
|
+
ptr->~AnnoyIndex();
|
44
|
+
ruby_xfree(ptr);
|
45
|
+
};
|
46
|
+
|
47
|
+
static T* get_annoy_index(VALUE self) {
|
48
|
+
T* ptr;
|
49
|
+
Data_Get_Struct(self, T, ptr);
|
50
|
+
return ptr;
|
51
|
+
};
|
52
|
+
|
53
|
+
static VALUE define_class(VALUE rb_mAnnoy, const char* class_name) {
|
54
|
+
VALUE rb_cAnnoyIndex = rb_define_class_under(rb_mAnnoy, class_name, rb_cObject);
|
55
|
+
rb_define_alloc_func(rb_cAnnoyIndex, annoy_index_alloc);
|
56
|
+
rb_define_method(rb_cAnnoyIndex, "initialize", RUBY_METHOD_FUNC(_annoy_index_init), 1);
|
57
|
+
rb_define_method(rb_cAnnoyIndex, "add_item", RUBY_METHOD_FUNC(_annoy_index_add_item), 2);
|
58
|
+
rb_define_method(rb_cAnnoyIndex, "build", RUBY_METHOD_FUNC(_annoy_index_build), 1);
|
59
|
+
rb_define_method(rb_cAnnoyIndex, "save", RUBY_METHOD_FUNC(_annoy_index_save), 2);
|
60
|
+
rb_define_method(rb_cAnnoyIndex, "load", RUBY_METHOD_FUNC(_annoy_index_load), 2);
|
61
|
+
rb_define_method(rb_cAnnoyIndex, "unload", RUBY_METHOD_FUNC(_annoy_index_unload), 0);
|
62
|
+
rb_define_method(rb_cAnnoyIndex, "get_nns_by_item", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_item), 4);
|
63
|
+
rb_define_method(rb_cAnnoyIndex, "get_nns_by_vector", RUBY_METHOD_FUNC(_annoy_index_get_nns_by_vector), 4);
|
64
|
+
rb_define_method(rb_cAnnoyIndex, "get_item", RUBY_METHOD_FUNC(_annoy_index_get_item), 1);
|
65
|
+
rb_define_method(rb_cAnnoyIndex, "get_distance", RUBY_METHOD_FUNC(_annoy_index_get_distance), 2);
|
66
|
+
rb_define_method(rb_cAnnoyIndex, "get_n_items", RUBY_METHOD_FUNC(_annoy_index_get_n_items), 0);
|
67
|
+
rb_define_method(rb_cAnnoyIndex, "get_n_trees", RUBY_METHOD_FUNC(_annoy_index_get_n_trees), 0);
|
68
|
+
rb_define_method(rb_cAnnoyIndex, "on_disk_build", RUBY_METHOD_FUNC(_annoy_index_on_disk_build), 1);
|
69
|
+
rb_define_method(rb_cAnnoyIndex, "set_seed", RUBY_METHOD_FUNC(_annoy_index_set_seed), 1);
|
70
|
+
rb_define_method(rb_cAnnoyIndex, "verbose", RUBY_METHOD_FUNC(_annoy_index_verbose), 1);
|
71
|
+
rb_define_method(rb_cAnnoyIndex, "get_f", RUBY_METHOD_FUNC(_annoy_index_get_f), 0);
|
72
|
+
return rb_cAnnoyIndex;
|
73
|
+
};
|
74
|
+
|
75
|
+
private:
|
76
|
+
|
77
|
+
static VALUE _annoy_index_init(VALUE self, VALUE _n_dims) {
|
78
|
+
const int n_dims = NUM2INT(_n_dims);
|
79
|
+
T* ptr = get_annoy_index(self);
|
80
|
+
new (ptr) T(n_dims);
|
81
|
+
return Qnil;
|
82
|
+
};
|
83
|
+
|
84
|
+
static VALUE _annoy_index_add_item(VALUE self, VALUE _idx, VALUE arr) {
|
85
|
+
const int idx = NUM2INT(_idx);
|
86
|
+
const int n_dims = get_annoy_index(self)->get_f();
|
87
|
+
|
88
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
89
|
+
rb_raise(rb_eArgError, "Expect item vector to be Array.");
|
90
|
+
return Qfalse;
|
91
|
+
}
|
92
|
+
|
93
|
+
if (n_dims != RARRAY_LEN(arr)) {
|
94
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
95
|
+
return Qfalse;
|
96
|
+
}
|
97
|
+
|
98
|
+
std::vector<F> vec(n_dims, 0);
|
99
|
+
for (int i = 0; i < n_dims; i++) {
|
100
|
+
vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(arr, i)) : NUM2UINT(rb_ary_entry(arr, i));
|
101
|
+
}
|
102
|
+
|
103
|
+
char* error;
|
104
|
+
if (!get_annoy_index(self)->add_item(idx, &vec[0], &error)) {
|
105
|
+
rb_raise(rb_eRuntimeError, "%s", error);
|
106
|
+
free(error);
|
107
|
+
return Qfalse;
|
108
|
+
}
|
109
|
+
|
110
|
+
return Qtrue;
|
111
|
+
};
|
112
|
+
|
113
|
+
static VALUE _annoy_index_build(VALUE self, VALUE _n_trees) {
|
114
|
+
const int n_trees = NUM2INT(_n_trees);
|
115
|
+
char* error;
|
116
|
+
|
117
|
+
if (!get_annoy_index(self)->build(n_trees, &error)) {
|
118
|
+
rb_raise(rb_eRuntimeError, "%s", error);
|
119
|
+
free(error);
|
120
|
+
return Qfalse;
|
121
|
+
}
|
122
|
+
|
123
|
+
return Qtrue;
|
124
|
+
};
|
125
|
+
|
126
|
+
static VALUE _annoy_index_save(VALUE self, VALUE _filename, VALUE _prefault) {
|
127
|
+
const char* filename = StringValuePtr(_filename);
|
128
|
+
const bool prefault = _prefault == Qtrue ? true : false;
|
129
|
+
char* error;
|
130
|
+
|
131
|
+
if (!get_annoy_index(self)->save(filename, prefault, &error)) {
|
132
|
+
rb_raise(rb_eRuntimeError, "%s", error);
|
133
|
+
free(error);
|
134
|
+
return Qfalse;
|
135
|
+
}
|
136
|
+
|
137
|
+
return Qtrue;
|
138
|
+
};
|
139
|
+
|
140
|
+
static VALUE _annoy_index_load(VALUE self, VALUE _filename, VALUE _prefault) {
|
141
|
+
const char* filename = StringValuePtr(_filename);
|
142
|
+
const bool prefault = _prefault == Qtrue ? true : false;
|
143
|
+
char* error;
|
144
|
+
|
145
|
+
if (!get_annoy_index(self)->load(filename, prefault, &error)) {
|
146
|
+
rb_raise(rb_eRuntimeError, "%s", error);
|
147
|
+
free(error);
|
148
|
+
return Qfalse;
|
149
|
+
}
|
150
|
+
|
151
|
+
return Qtrue;
|
152
|
+
};
|
153
|
+
|
154
|
+
static VALUE _annoy_index_unload(VALUE self) {
|
155
|
+
get_annoy_index(self)->unload();
|
156
|
+
return Qnil;
|
157
|
+
};
|
158
|
+
|
159
|
+
static VALUE _annoy_index_get_nns_by_item(VALUE self, VALUE _idx, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
|
160
|
+
const int idx = NUM2INT(_idx);
|
161
|
+
const int n_neighbors = NUM2INT(_n_neighbors);
|
162
|
+
const int search_k = NUM2INT(_search_k);
|
163
|
+
const bool include_distances = _include_distances == Qtrue ? true : false;
|
164
|
+
std::vector<int> neighbors;
|
165
|
+
std::vector<F> distances;
|
166
|
+
|
167
|
+
get_annoy_index(self)->get_nns_by_item(idx, n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
|
168
|
+
|
169
|
+
const int sz_neighbors = neighbors.size();
|
170
|
+
VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
|
171
|
+
|
172
|
+
for (int i = 0; i < sz_neighbors; i++) {
|
173
|
+
rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
|
174
|
+
}
|
175
|
+
|
176
|
+
if (include_distances) {
|
177
|
+
const int sz_distances = distances.size();
|
178
|
+
VALUE distances_arr = rb_ary_new2(sz_distances);
|
179
|
+
for (int i = 0; i < sz_distances; i++) {
|
180
|
+
rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
|
181
|
+
}
|
182
|
+
VALUE res = rb_ary_new2(2);
|
183
|
+
rb_ary_store(res, 0, neighbors_arr);
|
184
|
+
rb_ary_store(res, 1, distances_arr);
|
185
|
+
return res;
|
186
|
+
}
|
187
|
+
|
188
|
+
return neighbors_arr;
|
189
|
+
};
|
190
|
+
|
191
|
+
static VALUE _annoy_index_get_nns_by_vector(VALUE self, VALUE _vec, VALUE _n_neighbors, VALUE _search_k, VALUE _include_distances) {
|
192
|
+
const int n_dims = get_annoy_index(self)->get_f();
|
193
|
+
|
194
|
+
if (!RB_TYPE_P(_vec, T_ARRAY)) {
|
195
|
+
rb_raise(rb_eArgError, "Expect item vector to be Array.");
|
196
|
+
return Qfalse;
|
197
|
+
}
|
198
|
+
|
199
|
+
if (n_dims != RARRAY_LEN(_vec)) {
|
200
|
+
rb_raise(rb_eArgError, "Array size does not match to index dimensionality.");
|
201
|
+
return Qfalse;
|
202
|
+
}
|
203
|
+
|
204
|
+
std::vector<F> vec(n_dims, 0);
|
205
|
+
for (int i = 0; i < n_dims; i++) {
|
206
|
+
vec[i] = typeid(F) == typeid(double) ? NUM2DBL(rb_ary_entry(_vec, i)) : NUM2UINT(rb_ary_entry(_vec, i));
|
207
|
+
}
|
208
|
+
|
209
|
+
const int n_neighbors = NUM2INT(_n_neighbors);
|
210
|
+
const int search_k = NUM2INT(_search_k);
|
211
|
+
const bool include_distances = _include_distances == Qtrue ? true : false;
|
212
|
+
std::vector<int> neighbors;
|
213
|
+
std::vector<F> distances;
|
214
|
+
|
215
|
+
get_annoy_index(self)->get_nns_by_vector(&vec[0], n_neighbors, search_k, &neighbors, include_distances ? &distances : NULL);
|
216
|
+
|
217
|
+
const int sz_neighbors = neighbors.size();
|
218
|
+
VALUE neighbors_arr = rb_ary_new2(sz_neighbors);
|
219
|
+
|
220
|
+
for (int i = 0; i < sz_neighbors; i++) {
|
221
|
+
rb_ary_store(neighbors_arr, i, INT2NUM(neighbors[i]));
|
222
|
+
}
|
223
|
+
|
224
|
+
if (include_distances) {
|
225
|
+
const int sz_distances = distances.size();
|
226
|
+
VALUE distances_arr = rb_ary_new2(sz_distances);
|
227
|
+
for (int i = 0; i < sz_distances; i++) {
|
228
|
+
rb_ary_store(distances_arr, i, typeid(F) == typeid(double) ? DBL2NUM(distances[i]) : UINT2NUM(distances[i]));
|
229
|
+
}
|
230
|
+
VALUE res = rb_ary_new2(2);
|
231
|
+
rb_ary_store(res, 0, neighbors_arr);
|
232
|
+
rb_ary_store(res, 1, distances_arr);
|
233
|
+
return res;
|
234
|
+
}
|
235
|
+
|
236
|
+
return neighbors_arr;
|
237
|
+
};
|
238
|
+
|
239
|
+
static VALUE _annoy_index_get_item(VALUE self, VALUE _idx) {
|
240
|
+
const int idx = NUM2INT(_idx);
|
241
|
+
const int n_dims = get_annoy_index(self)->get_f();
|
242
|
+
std::vector<F> vec(n_dims, 0);
|
243
|
+
VALUE arr = rb_ary_new2(n_dims);
|
244
|
+
|
245
|
+
get_annoy_index(self)->get_item(idx, &vec[0]);
|
246
|
+
|
247
|
+
for (int i = 0; i < n_dims; i++) {
|
248
|
+
rb_ary_store(arr, i, typeid(F) == typeid(double) ? DBL2NUM(vec[i]) : UINT2NUM(vec[i]));
|
249
|
+
}
|
250
|
+
|
251
|
+
return arr;
|
252
|
+
};
|
253
|
+
|
254
|
+
static VALUE _annoy_index_get_distance(VALUE self, VALUE _i, VALUE _j) {
|
255
|
+
const int i = NUM2INT(_i);
|
256
|
+
const int j = NUM2INT(_j);
|
257
|
+
const double dist = get_annoy_index(self)->get_distance(i, j);
|
258
|
+
return DBL2NUM(dist);
|
259
|
+
};
|
260
|
+
|
261
|
+
static VALUE _annoy_index_get_n_items(VALUE self) {
|
262
|
+
const int32_t n_items = get_annoy_index(self)->get_n_items();
|
263
|
+
return INT2NUM(n_items);
|
264
|
+
};
|
265
|
+
|
266
|
+
static VALUE _annoy_index_get_n_trees(VALUE self) {
|
267
|
+
const int32_t n_trees = get_annoy_index(self)->get_n_trees();
|
268
|
+
return INT2NUM(n_trees);
|
269
|
+
};
|
270
|
+
|
271
|
+
static VALUE _annoy_index_on_disk_build(VALUE self, VALUE _filename) {
|
272
|
+
const char* filename = StringValuePtr(_filename);
|
273
|
+
char* error;
|
274
|
+
if (!get_annoy_index(self)->on_disk_build(filename, &error)) {
|
275
|
+
rb_raise(rb_eRuntimeError, "%s", error);
|
276
|
+
free(error);
|
277
|
+
return Qfalse;
|
278
|
+
}
|
279
|
+
return Qtrue;
|
280
|
+
};
|
281
|
+
|
282
|
+
static VALUE _annoy_index_set_seed(VALUE self, VALUE _seed) {
|
283
|
+
const int seed = NUM2INT(_seed);
|
284
|
+
get_annoy_index(self)->set_seed(seed);
|
285
|
+
return Qnil;
|
286
|
+
};
|
287
|
+
|
288
|
+
static VALUE _annoy_index_verbose(VALUE self, VALUE _flag) {
|
289
|
+
const bool flag = _flag == Qtrue ? true : false;
|
290
|
+
get_annoy_index(self)->verbose(flag);
|
291
|
+
return Qnil;
|
292
|
+
};
|
293
|
+
|
294
|
+
static VALUE _annoy_index_get_f(VALUE self) {
|
295
|
+
const int32_t f = get_annoy_index(self)->get_f();
|
296
|
+
return INT2NUM(f);
|
297
|
+
};
|
298
|
+
};
|
299
|
+
|
300
|
+
#endif /* ANNOY_HPP */
|
@@ -0,0 +1,1334 @@
|
|
1
|
+
// Copyright (c) 2013 Spotify AB
|
2
|
+
//
|
3
|
+
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
4
|
+
// use this file except in compliance with the License. You may obtain a copy of
|
5
|
+
// the License at
|
6
|
+
//
|
7
|
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
//
|
9
|
+
// Unless required by applicable law or agreed to in writing, software
|
10
|
+
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
11
|
+
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
12
|
+
// License for the specific language governing permissions and limitations under
|
13
|
+
// the License.
|
14
|
+
|
15
|
+
|
16
|
+
#ifndef ANNOYLIB_H
|
17
|
+
#define ANNOYLIB_H
|
18
|
+
|
19
|
+
#include <stdio.h>
|
20
|
+
#include <sys/stat.h>
|
21
|
+
#ifndef _MSC_VER
|
22
|
+
#include <unistd.h>
|
23
|
+
#endif
|
24
|
+
#include <stdio.h>
|
25
|
+
#include <stdlib.h>
|
26
|
+
#include <sys/types.h>
|
27
|
+
#include <fcntl.h>
|
28
|
+
#include <stddef.h>
|
29
|
+
|
30
|
+
#if defined(_MSC_VER) && _MSC_VER == 1500
|
31
|
+
typedef unsigned char uint8_t;
|
32
|
+
typedef signed __int32 int32_t;
|
33
|
+
typedef unsigned __int64 uint64_t;
|
34
|
+
typedef signed __int64 int64_t;
|
35
|
+
#else
|
36
|
+
#include <stdint.h>
|
37
|
+
#endif
|
38
|
+
|
39
|
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
40
|
+
// a bit hacky, but override some definitions to support 64 bit
|
41
|
+
#define off_t int64_t
|
42
|
+
#define lseek_getsize(fd) _lseeki64(fd, 0, SEEK_END)
|
43
|
+
#ifndef NOMINMAX
|
44
|
+
#define NOMINMAX
|
45
|
+
#endif
|
46
|
+
#include "mman.h"
|
47
|
+
#include <windows.h>
|
48
|
+
#else
|
49
|
+
#include <sys/mman.h>
|
50
|
+
#define lseek_getsize(fd) lseek(fd, 0, SEEK_END)
|
51
|
+
#endif
|
52
|
+
|
53
|
+
#include <cerrno>
|
54
|
+
#include <string.h>
|
55
|
+
#include <math.h>
|
56
|
+
#include <vector>
|
57
|
+
#include <algorithm>
|
58
|
+
#include <queue>
|
59
|
+
#include <limits>
|
60
|
+
|
61
|
+
#ifdef _MSC_VER
|
62
|
+
// Needed for Visual Studio to disable runtime checks for mempcy
|
63
|
+
#pragma runtime_checks("s", off)
|
64
|
+
#endif
|
65
|
+
|
66
|
+
// This allows others to supply their own logger / error printer without
|
67
|
+
// requiring Annoy to import their headers. See RcppAnnoy for a use case.
|
68
|
+
#ifndef __ERROR_PRINTER_OVERRIDE__
|
69
|
+
#define showUpdate(...) { fprintf(stderr, __VA_ARGS__ ); }
|
70
|
+
#else
|
71
|
+
#define showUpdate(...) { __ERROR_PRINTER_OVERRIDE__( __VA_ARGS__ ); }
|
72
|
+
#endif
|
73
|
+
|
74
|
+
// Portable alloc definition, cf Writing R Extensions, Section 1.6.4
|
75
|
+
#ifdef __GNUC__
|
76
|
+
// Includes GCC, clang and Intel compilers
|
77
|
+
# undef alloca
|
78
|
+
# define alloca(x) __builtin_alloca((x))
|
79
|
+
#elif defined(__sun) || defined(_AIX)
|
80
|
+
// this is necessary (and sufficient) for Solaris 10 and AIX 6:
|
81
|
+
# include <alloca.h>
|
82
|
+
#endif
|
83
|
+
|
84
|
+
inline void set_error_from_errno(char **error, const char* msg) {
|
85
|
+
showUpdate("%s: %s (%d)\n", msg, strerror(errno), errno);
|
86
|
+
if (error) {
|
87
|
+
*error = (char *)malloc(256); // TODO: win doesn't support snprintf
|
88
|
+
sprintf(*error, "%s: %s (%d)", msg, strerror(errno), errno);
|
89
|
+
}
|
90
|
+
}
|
91
|
+
|
92
|
+
inline void set_error_from_string(char **error, const char* msg) {
|
93
|
+
showUpdate("%s\n", msg);
|
94
|
+
if (error) {
|
95
|
+
*error = (char *)malloc(strlen(msg) + 1);
|
96
|
+
strcpy(*error, msg);
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
// We let the v array in the Node struct take whatever space is needed, so this is a mostly insignificant number.
|
101
|
+
// Compilers need *some* size defined for the v array, and some memory checking tools will flag for buffer overruns if this is set too low.
|
102
|
+
#define V_ARRAY_SIZE 65536
|
103
|
+
|
104
|
+
#ifndef _MSC_VER
|
105
|
+
#define popcount __builtin_popcountll
|
106
|
+
#else // See #293, #358
|
107
|
+
#define isnan(x) _isnan(x)
|
108
|
+
#define popcount cole_popcount
|
109
|
+
#endif
|
110
|
+
|
111
|
+
#if !defined(NO_MANUAL_VECTORIZATION) && defined(__GNUC__) && (__GNUC__ >6) && defined(__AVX512F__) // See #402
|
112
|
+
#define USE_AVX512
|
113
|
+
#elif !defined(NO_MANUAL_VECTORIZATION) && defined(__AVX__) && defined (__SSE__) && defined(__SSE2__) && defined(__SSE3__)
|
114
|
+
#define USE_AVX
|
115
|
+
#else
|
116
|
+
#endif
|
117
|
+
|
118
|
+
#if defined(USE_AVX) || defined(USE_AVX512)
|
119
|
+
#if defined(_MSC_VER)
|
120
|
+
#include <intrin.h>
|
121
|
+
#elif defined(__GNUC__)
|
122
|
+
#include <x86intrin.h>
|
123
|
+
#endif
|
124
|
+
#endif
|
125
|
+
|
126
|
+
#if !defined(__MINGW32__)
|
127
|
+
#define FTRUNCATE_SIZE(x) static_cast<int64_t>(x)
|
128
|
+
#else
|
129
|
+
#define FTRUNCATE_SIZE(x) (x)
|
130
|
+
#endif
|
131
|
+
|
132
|
+
using std::vector;
|
133
|
+
using std::pair;
|
134
|
+
using std::numeric_limits;
|
135
|
+
using std::make_pair;
|
136
|
+
|
137
|
+
inline bool remap_memory_and_truncate(void** _ptr, int _fd, size_t old_size, size_t new_size) {
|
138
|
+
#ifdef __linux__
|
139
|
+
*_ptr = mremap(*_ptr, old_size, new_size, MREMAP_MAYMOVE);
|
140
|
+
bool ok = ftruncate(_fd, new_size) != -1;
|
141
|
+
#else
|
142
|
+
munmap(*_ptr, old_size);
|
143
|
+
bool ok = ftruncate(_fd, FTRUNCATE_SIZE(new_size)) != -1;
|
144
|
+
#ifdef MAP_POPULATE
|
145
|
+
*_ptr = mmap(*_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0);
|
146
|
+
#else
|
147
|
+
*_ptr = mmap(*_ptr, new_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0);
|
148
|
+
#endif
|
149
|
+
#endif
|
150
|
+
return ok;
|
151
|
+
}
|
152
|
+
|
153
|
+
namespace {
|
154
|
+
|
155
|
+
template<typename S, typename Node>
|
156
|
+
inline Node* get_node_ptr(const void* _nodes, const size_t _s, const S i) {
|
157
|
+
return (Node*)((uint8_t *)_nodes + (_s * i));
|
158
|
+
}
|
159
|
+
|
160
|
+
template<typename T>
|
161
|
+
inline T dot(const T* x, const T* y, int f) {
|
162
|
+
T s = 0;
|
163
|
+
for (int z = 0; z < f; z++) {
|
164
|
+
s += (*x) * (*y);
|
165
|
+
x++;
|
166
|
+
y++;
|
167
|
+
}
|
168
|
+
return s;
|
169
|
+
}
|
170
|
+
|
171
|
+
template<typename T>
|
172
|
+
inline T manhattan_distance(const T* x, const T* y, int f) {
|
173
|
+
T d = 0.0;
|
174
|
+
for (int i = 0; i < f; i++)
|
175
|
+
d += fabs(x[i] - y[i]);
|
176
|
+
return d;
|
177
|
+
}
|
178
|
+
|
179
|
+
template<typename T>
|
180
|
+
inline T euclidean_distance(const T* x, const T* y, int f) {
|
181
|
+
// Don't use dot-product: avoid catastrophic cancellation in #314.
|
182
|
+
T d = 0.0;
|
183
|
+
for (int i = 0; i < f; ++i) {
|
184
|
+
const T tmp=*x - *y;
|
185
|
+
d += tmp * tmp;
|
186
|
+
++x;
|
187
|
+
++y;
|
188
|
+
}
|
189
|
+
return d;
|
190
|
+
}
|
191
|
+
|
192
|
+
#ifdef USE_AVX
|
193
|
+
// Horizontal single sum of 256bit vector.
|
194
|
+
inline float hsum256_ps_avx(__m256 v) {
|
195
|
+
const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(v, 1), _mm256_castps256_ps128(v));
|
196
|
+
const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
197
|
+
const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
198
|
+
return _mm_cvtss_f32(x32);
|
199
|
+
}
|
200
|
+
|
201
|
+
template<>
|
202
|
+
inline float dot<float>(const float* x, const float *y, int f) {
|
203
|
+
float result = 0;
|
204
|
+
if (f > 7) {
|
205
|
+
__m256 d = _mm256_setzero_ps();
|
206
|
+
for (; f > 7; f -= 8) {
|
207
|
+
d = _mm256_add_ps(d, _mm256_mul_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y)));
|
208
|
+
x += 8;
|
209
|
+
y += 8;
|
210
|
+
}
|
211
|
+
// Sum all floats in dot register.
|
212
|
+
result += hsum256_ps_avx(d);
|
213
|
+
}
|
214
|
+
// Don't forget the remaining values.
|
215
|
+
for (; f > 0; f--) {
|
216
|
+
result += *x * *y;
|
217
|
+
x++;
|
218
|
+
y++;
|
219
|
+
}
|
220
|
+
return result;
|
221
|
+
}
|
222
|
+
|
223
|
+
template<>
|
224
|
+
inline float manhattan_distance<float>(const float* x, const float* y, int f) {
|
225
|
+
float result = 0;
|
226
|
+
int i = f;
|
227
|
+
if (f > 7) {
|
228
|
+
__m256 manhattan = _mm256_setzero_ps();
|
229
|
+
__m256 minus_zero = _mm256_set1_ps(-0.0f);
|
230
|
+
for (; i > 7; i -= 8) {
|
231
|
+
const __m256 x_minus_y = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y));
|
232
|
+
const __m256 distance = _mm256_andnot_ps(minus_zero, x_minus_y); // Absolute value of x_minus_y (forces sign bit to zero)
|
233
|
+
manhattan = _mm256_add_ps(manhattan, distance);
|
234
|
+
x += 8;
|
235
|
+
y += 8;
|
236
|
+
}
|
237
|
+
// Sum all floats in manhattan register.
|
238
|
+
result = hsum256_ps_avx(manhattan);
|
239
|
+
}
|
240
|
+
// Don't forget the remaining values.
|
241
|
+
for (; i > 0; i--) {
|
242
|
+
result += fabsf(*x - *y);
|
243
|
+
x++;
|
244
|
+
y++;
|
245
|
+
}
|
246
|
+
return result;
|
247
|
+
}
|
248
|
+
|
249
|
+
template<>
|
250
|
+
inline float euclidean_distance<float>(const float* x, const float* y, int f) {
|
251
|
+
float result=0;
|
252
|
+
if (f > 7) {
|
253
|
+
__m256 d = _mm256_setzero_ps();
|
254
|
+
for (; f > 7; f -= 8) {
|
255
|
+
const __m256 diff = _mm256_sub_ps(_mm256_loadu_ps(x), _mm256_loadu_ps(y));
|
256
|
+
d = _mm256_add_ps(d, _mm256_mul_ps(diff, diff)); // no support for fmadd in AVX...
|
257
|
+
x += 8;
|
258
|
+
y += 8;
|
259
|
+
}
|
260
|
+
// Sum all floats in dot register.
|
261
|
+
result = hsum256_ps_avx(d);
|
262
|
+
}
|
263
|
+
// Don't forget the remaining values.
|
264
|
+
for (; f > 0; f--) {
|
265
|
+
float tmp = *x - *y;
|
266
|
+
result += tmp * tmp;
|
267
|
+
x++;
|
268
|
+
y++;
|
269
|
+
}
|
270
|
+
return result;
|
271
|
+
}
|
272
|
+
|
273
|
+
#endif
|
274
|
+
|
275
|
+
#ifdef USE_AVX512
|
276
|
+
template<>
|
277
|
+
inline float dot<float>(const float* x, const float *y, int f) {
|
278
|
+
float result = 0;
|
279
|
+
if (f > 15) {
|
280
|
+
__m512 d = _mm512_setzero_ps();
|
281
|
+
for (; f > 15; f -= 16) {
|
282
|
+
//AVX512F includes FMA
|
283
|
+
d = _mm512_fmadd_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y), d);
|
284
|
+
x += 16;
|
285
|
+
y += 16;
|
286
|
+
}
|
287
|
+
// Sum all floats in dot register.
|
288
|
+
result += _mm512_reduce_add_ps(d);
|
289
|
+
}
|
290
|
+
// Don't forget the remaining values.
|
291
|
+
for (; f > 0; f--) {
|
292
|
+
result += *x * *y;
|
293
|
+
x++;
|
294
|
+
y++;
|
295
|
+
}
|
296
|
+
return result;
|
297
|
+
}
|
298
|
+
|
299
|
+
template<>
|
300
|
+
inline float manhattan_distance<float>(const float* x, const float* y, int f) {
|
301
|
+
float result = 0;
|
302
|
+
int i = f;
|
303
|
+
if (f > 15) {
|
304
|
+
__m512 manhattan = _mm512_setzero_ps();
|
305
|
+
for (; i > 15; i -= 16) {
|
306
|
+
const __m512 x_minus_y = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y));
|
307
|
+
manhattan = _mm512_add_ps(manhattan, _mm512_abs_ps(x_minus_y));
|
308
|
+
x += 16;
|
309
|
+
y += 16;
|
310
|
+
}
|
311
|
+
// Sum all floats in manhattan register.
|
312
|
+
result = _mm512_reduce_add_ps(manhattan);
|
313
|
+
}
|
314
|
+
// Don't forget the remaining values.
|
315
|
+
for (; i > 0; i--) {
|
316
|
+
result += fabsf(*x - *y);
|
317
|
+
x++;
|
318
|
+
y++;
|
319
|
+
}
|
320
|
+
return result;
|
321
|
+
}
|
322
|
+
|
323
|
+
template<>
|
324
|
+
inline float euclidean_distance<float>(const float* x, const float* y, int f) {
|
325
|
+
float result=0;
|
326
|
+
if (f > 15) {
|
327
|
+
__m512 d = _mm512_setzero_ps();
|
328
|
+
for (; f > 15; f -= 16) {
|
329
|
+
const __m512 diff = _mm512_sub_ps(_mm512_loadu_ps(x), _mm512_loadu_ps(y));
|
330
|
+
d = _mm512_fmadd_ps(diff, diff, d);
|
331
|
+
x += 16;
|
332
|
+
y += 16;
|
333
|
+
}
|
334
|
+
// Sum all floats in dot register.
|
335
|
+
result = _mm512_reduce_add_ps(d);
|
336
|
+
}
|
337
|
+
// Don't forget the remaining values.
|
338
|
+
for (; f > 0; f--) {
|
339
|
+
float tmp = *x - *y;
|
340
|
+
result += tmp * tmp;
|
341
|
+
x++;
|
342
|
+
y++;
|
343
|
+
}
|
344
|
+
return result;
|
345
|
+
}
|
346
|
+
|
347
|
+
#endif
|
348
|
+
|
349
|
+
|
350
|
+
template<typename T>
|
351
|
+
inline T get_norm(T* v, int f) {
|
352
|
+
return sqrt(dot(v, v, f));
|
353
|
+
}
|
354
|
+
|
355
|
+
template<typename T, typename Random, typename Distance, typename Node>
|
356
|
+
inline void two_means(const vector<Node*>& nodes, int f, Random& random, bool cosine, Node* p, Node* q) {
|
357
|
+
/*
|
358
|
+
This algorithm is a huge heuristic. Empirically it works really well, but I
|
359
|
+
can't motivate it well. The basic idea is to keep two centroids and assign
|
360
|
+
points to either one of them. We weight each centroid by the number of points
|
361
|
+
assigned to it, so to balance it.
|
362
|
+
*/
|
363
|
+
static int iteration_steps = 200;
|
364
|
+
size_t count = nodes.size();
|
365
|
+
|
366
|
+
size_t i = random.index(count);
|
367
|
+
size_t j = random.index(count-1);
|
368
|
+
j += (j >= i); // ensure that i != j
|
369
|
+
|
370
|
+
Distance::template copy_node<T, Node>(p, nodes[i], f);
|
371
|
+
Distance::template copy_node<T, Node>(q, nodes[j], f);
|
372
|
+
|
373
|
+
if (cosine) { Distance::template normalize<T, Node>(p, f); Distance::template normalize<T, Node>(q, f); }
|
374
|
+
Distance::init_node(p, f);
|
375
|
+
Distance::init_node(q, f);
|
376
|
+
|
377
|
+
int ic = 1, jc = 1;
|
378
|
+
for (int l = 0; l < iteration_steps; l++) {
|
379
|
+
size_t k = random.index(count);
|
380
|
+
T di = ic * Distance::distance(p, nodes[k], f),
|
381
|
+
dj = jc * Distance::distance(q, nodes[k], f);
|
382
|
+
T norm = cosine ? get_norm(nodes[k]->v, f) : 1;
|
383
|
+
if (!(norm > T(0))) {
|
384
|
+
continue;
|
385
|
+
}
|
386
|
+
if (di < dj) {
|
387
|
+
for (int z = 0; z < f; z++)
|
388
|
+
p->v[z] = (p->v[z] * ic + nodes[k]->v[z] / norm) / (ic + 1);
|
389
|
+
Distance::init_node(p, f);
|
390
|
+
ic++;
|
391
|
+
} else if (dj < di) {
|
392
|
+
for (int z = 0; z < f; z++)
|
393
|
+
q->v[z] = (q->v[z] * jc + nodes[k]->v[z] / norm) / (jc + 1);
|
394
|
+
Distance::init_node(q, f);
|
395
|
+
jc++;
|
396
|
+
}
|
397
|
+
}
|
398
|
+
}
|
399
|
+
} // namespace
|
400
|
+
|
401
|
+
struct Base {
|
402
|
+
template<typename T, typename S, typename Node>
|
403
|
+
static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) {
|
404
|
+
// Override this in specific metric structs below if you need to do any pre-processing
|
405
|
+
// on the entire set of nodes passed into this index.
|
406
|
+
}
|
407
|
+
|
408
|
+
template<typename Node>
|
409
|
+
static inline void zero_value(Node* dest) {
|
410
|
+
// Initialize any fields that require sane defaults within this node.
|
411
|
+
}
|
412
|
+
|
413
|
+
template<typename T, typename Node>
|
414
|
+
static inline void copy_node(Node* dest, const Node* source, const int f) {
|
415
|
+
memcpy(dest->v, source->v, f * sizeof(T));
|
416
|
+
}
|
417
|
+
|
418
|
+
template<typename T, typename Node>
|
419
|
+
static inline void normalize(Node* node, int f) {
|
420
|
+
T norm = get_norm(node->v, f);
|
421
|
+
if (norm > 0) {
|
422
|
+
for (int z = 0; z < f; z++)
|
423
|
+
node->v[z] /= norm;
|
424
|
+
}
|
425
|
+
}
|
426
|
+
};
|
427
|
+
|
428
|
+
struct Angular : Base {
|
429
|
+
template<typename S, typename T>
|
430
|
+
struct Node {
|
431
|
+
/*
|
432
|
+
* We store a binary tree where each node has two things
|
433
|
+
* - A vector associated with it
|
434
|
+
* - Two children
|
435
|
+
* All nodes occupy the same amount of memory
|
436
|
+
* All nodes with n_descendants == 1 are leaf nodes.
|
437
|
+
* A memory optimization is that for nodes with 2 <= n_descendants <= K,
|
438
|
+
* we skip the vector. Instead we store a list of all descendants. K is
|
439
|
+
* determined by the number of items that fits in the space of the vector.
|
440
|
+
* For nodes with n_descendants == 1 the vector is a data point.
|
441
|
+
* For nodes with n_descendants > K the vector is the normal of the split plane.
|
442
|
+
* Note that we can't really do sizeof(node<T>) because we cheat and allocate
|
443
|
+
* more memory to be able to fit the vector outside
|
444
|
+
*/
|
445
|
+
S n_descendants;
|
446
|
+
union {
|
447
|
+
S children[2]; // Will possibly store more than 2
|
448
|
+
T norm;
|
449
|
+
};
|
450
|
+
T v[V_ARRAY_SIZE];
|
451
|
+
};
|
452
|
+
template<typename S, typename T>
|
453
|
+
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
|
454
|
+
// want to calculate (a/|a| - b/|b|)^2
|
455
|
+
// = a^2 / a^2 + b^2 / b^2 - 2ab/|a||b|
|
456
|
+
// = 2 - 2cos
|
457
|
+
T pp = x->norm ? x->norm : dot(x->v, x->v, f); // For backwards compatibility reasons, we need to fall back and compute the norm here
|
458
|
+
T qq = y->norm ? y->norm : dot(y->v, y->v, f);
|
459
|
+
T pq = dot(x->v, y->v, f);
|
460
|
+
T ppqq = pp * qq;
|
461
|
+
if (ppqq > 0) return 2.0 - 2.0 * pq / sqrt(ppqq);
|
462
|
+
else return 2.0; // cos is 0
|
463
|
+
}
|
464
|
+
template<typename S, typename T>
|
465
|
+
static inline T margin(const Node<S, T>* n, const T* y, int f) {
|
466
|
+
return dot(n->v, y, f);
|
467
|
+
}
|
468
|
+
template<typename S, typename T, typename Random>
|
469
|
+
static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
|
470
|
+
T dot = margin(n, y, f);
|
471
|
+
if (dot != 0)
|
472
|
+
return (dot > 0);
|
473
|
+
else
|
474
|
+
return (bool)random.flip();
|
475
|
+
}
|
476
|
+
template<typename S, typename T, typename Random>
|
477
|
+
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
|
478
|
+
Node<S, T>* p = (Node<S, T>*)alloca(s);
|
479
|
+
Node<S, T>* q = (Node<S, T>*)alloca(s);
|
480
|
+
two_means<T, Random, Angular, Node<S, T> >(nodes, f, random, true, p, q);
|
481
|
+
for (int z = 0; z < f; z++)
|
482
|
+
n->v[z] = p->v[z] - q->v[z];
|
483
|
+
Base::normalize<T, Node<S, T> >(n, f);
|
484
|
+
}
|
485
|
+
template<typename T>
|
486
|
+
static inline T normalized_distance(T distance) {
|
487
|
+
// Used when requesting distances from Python layer
|
488
|
+
// Turns out sometimes the squared distance is -0.0
|
489
|
+
// so we have to make sure it's a positive number.
|
490
|
+
return sqrt(std::max(distance, T(0)));
|
491
|
+
}
|
492
|
+
template<typename T>
|
493
|
+
static inline T pq_distance(T distance, T margin, int child_nr) {
|
494
|
+
if (child_nr == 0)
|
495
|
+
margin = -margin;
|
496
|
+
return std::min(distance, margin);
|
497
|
+
}
|
498
|
+
template<typename T>
|
499
|
+
static inline T pq_initial_value() {
|
500
|
+
return numeric_limits<T>::infinity();
|
501
|
+
}
|
502
|
+
template<typename S, typename T>
|
503
|
+
static inline void init_node(Node<S, T>* n, int f) {
|
504
|
+
n->norm = dot(n->v, n->v, f);
|
505
|
+
}
|
506
|
+
static const char* name() {
|
507
|
+
return "angular";
|
508
|
+
}
|
509
|
+
};
|
510
|
+
|
511
|
+
|
512
|
+
struct DotProduct : Angular {
|
513
|
+
template<typename S, typename T>
|
514
|
+
struct Node {
|
515
|
+
/*
|
516
|
+
* This is an extension of the Angular node with an extra attribute for the scaled norm.
|
517
|
+
*/
|
518
|
+
S n_descendants;
|
519
|
+
S children[2]; // Will possibly store more than 2
|
520
|
+
T dot_factor;
|
521
|
+
T v[V_ARRAY_SIZE];
|
522
|
+
};
|
523
|
+
|
524
|
+
static const char* name() {
|
525
|
+
return "dot";
|
526
|
+
}
|
527
|
+
template<typename S, typename T>
|
528
|
+
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
|
529
|
+
return -dot(x->v, y->v, f);
|
530
|
+
}
|
531
|
+
|
532
|
+
template<typename Node>
|
533
|
+
static inline void zero_value(Node* dest) {
|
534
|
+
dest->dot_factor = 0;
|
535
|
+
}
|
536
|
+
|
537
|
+
template<typename S, typename T>
|
538
|
+
static inline void init_node(Node<S, T>* n, int f) {
|
539
|
+
}
|
540
|
+
|
541
|
+
template<typename T, typename Node>
|
542
|
+
static inline void copy_node(Node* dest, const Node* source, const int f) {
|
543
|
+
memcpy(dest->v, source->v, f * sizeof(T));
|
544
|
+
dest->dot_factor = source->dot_factor;
|
545
|
+
}
|
546
|
+
|
547
|
+
template<typename S, typename T, typename Random>
|
548
|
+
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
|
549
|
+
Node<S, T>* p = (Node<S, T>*)alloca(s);
|
550
|
+
Node<S, T>* q = (Node<S, T>*)alloca(s);
|
551
|
+
DotProduct::zero_value(p);
|
552
|
+
DotProduct::zero_value(q);
|
553
|
+
two_means<T, Random, DotProduct, Node<S, T> >(nodes, f, random, true, p, q);
|
554
|
+
for (int z = 0; z < f; z++)
|
555
|
+
n->v[z] = p->v[z] - q->v[z];
|
556
|
+
n->dot_factor = p->dot_factor - q->dot_factor;
|
557
|
+
DotProduct::normalize<T, Node<S, T> >(n, f);
|
558
|
+
}
|
559
|
+
|
560
|
+
template<typename T, typename Node>
|
561
|
+
static inline void normalize(Node* node, int f) {
|
562
|
+
T norm = sqrt(dot(node->v, node->v, f) + pow(node->dot_factor, 2));
|
563
|
+
if (norm > 0) {
|
564
|
+
for (int z = 0; z < f; z++)
|
565
|
+
node->v[z] /= norm;
|
566
|
+
node->dot_factor /= norm;
|
567
|
+
}
|
568
|
+
}
|
569
|
+
|
570
|
+
template<typename S, typename T>
|
571
|
+
static inline T margin(const Node<S, T>* n, const T* y, int f) {
|
572
|
+
return dot(n->v, y, f) + (n->dot_factor * n->dot_factor);
|
573
|
+
}
|
574
|
+
|
575
|
+
template<typename S, typename T, typename Random>
|
576
|
+
static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
|
577
|
+
T dot = margin(n, y, f);
|
578
|
+
if (dot != 0)
|
579
|
+
return (dot > 0);
|
580
|
+
else
|
581
|
+
return (bool)random.flip();
|
582
|
+
}
|
583
|
+
|
584
|
+
template<typename T>
|
585
|
+
static inline T normalized_distance(T distance) {
|
586
|
+
return -distance;
|
587
|
+
}
|
588
|
+
|
589
|
+
template<typename T, typename S, typename Node>
|
590
|
+
static inline void preprocess(void* nodes, size_t _s, const S node_count, const int f) {
|
591
|
+
// This uses a method from Microsoft Research for transforming inner product spaces to cosine/angular-compatible spaces.
|
592
|
+
// (Bachrach et al., 2014, see https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf)
|
593
|
+
|
594
|
+
// Step one: compute the norm of each vector and store that in its extra dimension (f-1)
|
595
|
+
for (S i = 0; i < node_count; i++) {
|
596
|
+
Node* node = get_node_ptr<S, Node>(nodes, _s, i);
|
597
|
+
T norm = sqrt(dot(node->v, node->v, f));
|
598
|
+
if (isnan(norm)) norm = 0;
|
599
|
+
node->dot_factor = norm;
|
600
|
+
}
|
601
|
+
|
602
|
+
// Step two: find the maximum norm
|
603
|
+
T max_norm = 0;
|
604
|
+
for (S i = 0; i < node_count; i++) {
|
605
|
+
Node* node = get_node_ptr<S, Node>(nodes, _s, i);
|
606
|
+
if (node->dot_factor > max_norm) {
|
607
|
+
max_norm = node->dot_factor;
|
608
|
+
}
|
609
|
+
}
|
610
|
+
|
611
|
+
// Step three: set each vector's extra dimension to sqrt(max_norm^2 - norm^2)
|
612
|
+
for (S i = 0; i < node_count; i++) {
|
613
|
+
Node* node = get_node_ptr<S, Node>(nodes, _s, i);
|
614
|
+
T node_norm = node->dot_factor;
|
615
|
+
|
616
|
+
T dot_factor = sqrt(pow(max_norm, static_cast<T>(2.0)) - pow(node_norm, static_cast<T>(2.0)));
|
617
|
+
if (isnan(dot_factor)) dot_factor = 0;
|
618
|
+
|
619
|
+
node->dot_factor = dot_factor;
|
620
|
+
}
|
621
|
+
}
|
622
|
+
};
|
623
|
+
|
624
|
+
struct Hamming : Base {
|
625
|
+
template<typename S, typename T>
|
626
|
+
struct Node {
|
627
|
+
S n_descendants;
|
628
|
+
S children[2];
|
629
|
+
T v[V_ARRAY_SIZE];
|
630
|
+
};
|
631
|
+
|
632
|
+
static const size_t max_iterations = 20;
|
633
|
+
|
634
|
+
template<typename T>
|
635
|
+
static inline T pq_distance(T distance, T margin, int child_nr) {
|
636
|
+
return distance - (margin != (unsigned int) child_nr);
|
637
|
+
}
|
638
|
+
|
639
|
+
template<typename T>
|
640
|
+
static inline T pq_initial_value() {
|
641
|
+
return numeric_limits<T>::max();
|
642
|
+
}
|
643
|
+
template<typename T>
|
644
|
+
static inline int cole_popcount(T v) {
|
645
|
+
// Note: Only used with MSVC 9, which lacks intrinsics and fails to
|
646
|
+
// calculate std::bitset::count for v > 32bit. Uses the generalized
|
647
|
+
// approach by Eric Cole.
|
648
|
+
// See https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64
|
649
|
+
v = v - ((v >> 1) & (T)~(T)0/3);
|
650
|
+
v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3);
|
651
|
+
v = (v + (v >> 4)) & (T)~(T)0/255*15;
|
652
|
+
return (T)(v * ((T)~(T)0/255)) >> (sizeof(T) - 1) * 8;
|
653
|
+
}
|
654
|
+
template<typename S, typename T>
|
655
|
+
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
|
656
|
+
size_t dist = 0;
|
657
|
+
for (int i = 0; i < f; i++) {
|
658
|
+
dist += popcount(x->v[i] ^ y->v[i]);
|
659
|
+
}
|
660
|
+
return dist;
|
661
|
+
}
|
662
|
+
template<typename S, typename T>
|
663
|
+
static inline bool margin(const Node<S, T>* n, const T* y, int f) {
|
664
|
+
static const size_t n_bits = sizeof(T) * 8;
|
665
|
+
T chunk = n->v[0] / n_bits;
|
666
|
+
return (y[chunk] & (static_cast<T>(1) << (n_bits - 1 - (n->v[0] % n_bits)))) != 0;
|
667
|
+
}
|
668
|
+
template<typename S, typename T, typename Random>
|
669
|
+
static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
|
670
|
+
return margin(n, y, f);
|
671
|
+
}
|
672
|
+
template<typename S, typename T, typename Random>
|
673
|
+
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
|
674
|
+
size_t cur_size = 0;
|
675
|
+
size_t i = 0;
|
676
|
+
int dim = f * 8 * sizeof(T);
|
677
|
+
for (; i < max_iterations; i++) {
|
678
|
+
// choose random position to split at
|
679
|
+
n->v[0] = random.index(dim);
|
680
|
+
cur_size = 0;
|
681
|
+
for (typename vector<Node<S, T>*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) {
|
682
|
+
if (margin(n, (*it)->v, f)) {
|
683
|
+
cur_size++;
|
684
|
+
}
|
685
|
+
}
|
686
|
+
if (cur_size > 0 && cur_size < nodes.size()) {
|
687
|
+
break;
|
688
|
+
}
|
689
|
+
}
|
690
|
+
// brute-force search for splitting coordinate
|
691
|
+
if (i == max_iterations) {
|
692
|
+
int j = 0;
|
693
|
+
for (; j < dim; j++) {
|
694
|
+
n->v[0] = j;
|
695
|
+
cur_size = 0;
|
696
|
+
for (typename vector<Node<S, T>*>::const_iterator it = nodes.begin(); it != nodes.end(); ++it) {
|
697
|
+
if (margin(n, (*it)->v, f)) {
|
698
|
+
cur_size++;
|
699
|
+
}
|
700
|
+
}
|
701
|
+
if (cur_size > 0 && cur_size < nodes.size()) {
|
702
|
+
break;
|
703
|
+
}
|
704
|
+
}
|
705
|
+
}
|
706
|
+
}
|
707
|
+
template<typename T>
|
708
|
+
static inline T normalized_distance(T distance) {
|
709
|
+
return distance;
|
710
|
+
}
|
711
|
+
template<typename S, typename T>
|
712
|
+
static inline void init_node(Node<S, T>* n, int f) {
|
713
|
+
}
|
714
|
+
static const char* name() {
|
715
|
+
return "hamming";
|
716
|
+
}
|
717
|
+
};
|
718
|
+
|
719
|
+
|
720
|
+
struct Minkowski : Base {
|
721
|
+
template<typename S, typename T>
|
722
|
+
struct Node {
|
723
|
+
S n_descendants;
|
724
|
+
T a; // need an extra constant term to determine the offset of the plane
|
725
|
+
S children[2];
|
726
|
+
T v[V_ARRAY_SIZE];
|
727
|
+
};
|
728
|
+
template<typename S, typename T>
|
729
|
+
static inline T margin(const Node<S, T>* n, const T* y, int f) {
|
730
|
+
return n->a + dot(n->v, y, f);
|
731
|
+
}
|
732
|
+
template<typename S, typename T, typename Random>
|
733
|
+
static inline bool side(const Node<S, T>* n, const T* y, int f, Random& random) {
|
734
|
+
T dot = margin(n, y, f);
|
735
|
+
if (dot != 0)
|
736
|
+
return (dot > 0);
|
737
|
+
else
|
738
|
+
return (bool)random.flip();
|
739
|
+
}
|
740
|
+
template<typename T>
|
741
|
+
static inline T pq_distance(T distance, T margin, int child_nr) {
|
742
|
+
if (child_nr == 0)
|
743
|
+
margin = -margin;
|
744
|
+
return std::min(distance, margin);
|
745
|
+
}
|
746
|
+
template<typename T>
|
747
|
+
static inline T pq_initial_value() {
|
748
|
+
return numeric_limits<T>::infinity();
|
749
|
+
}
|
750
|
+
};
|
751
|
+
|
752
|
+
|
753
|
+
struct Euclidean : Minkowski {
|
754
|
+
template<typename S, typename T>
|
755
|
+
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
|
756
|
+
return euclidean_distance(x->v, y->v, f);
|
757
|
+
}
|
758
|
+
template<typename S, typename T, typename Random>
|
759
|
+
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
|
760
|
+
Node<S, T>* p = (Node<S, T>*)alloca(s);
|
761
|
+
Node<S, T>* q = (Node<S, T>*)alloca(s);
|
762
|
+
two_means<T, Random, Euclidean, Node<S, T> >(nodes, f, random, false, p, q);
|
763
|
+
|
764
|
+
for (int z = 0; z < f; z++)
|
765
|
+
n->v[z] = p->v[z] - q->v[z];
|
766
|
+
Base::normalize<T, Node<S, T> >(n, f);
|
767
|
+
n->a = 0.0;
|
768
|
+
for (int z = 0; z < f; z++)
|
769
|
+
n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2;
|
770
|
+
}
|
771
|
+
template<typename T>
|
772
|
+
static inline T normalized_distance(T distance) {
|
773
|
+
return sqrt(std::max(distance, T(0)));
|
774
|
+
}
|
775
|
+
template<typename S, typename T>
|
776
|
+
static inline void init_node(Node<S, T>* n, int f) {
|
777
|
+
}
|
778
|
+
static const char* name() {
|
779
|
+
return "euclidean";
|
780
|
+
}
|
781
|
+
|
782
|
+
};
|
783
|
+
|
784
|
+
struct Manhattan : Minkowski {
|
785
|
+
template<typename S, typename T>
|
786
|
+
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
|
787
|
+
return manhattan_distance(x->v, y->v, f);
|
788
|
+
}
|
789
|
+
template<typename S, typename T, typename Random>
|
790
|
+
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
|
791
|
+
Node<S, T>* p = (Node<S, T>*)alloca(s);
|
792
|
+
Node<S, T>* q = (Node<S, T>*)alloca(s);
|
793
|
+
two_means<T, Random, Manhattan, Node<S, T> >(nodes, f, random, false, p, q);
|
794
|
+
|
795
|
+
for (int z = 0; z < f; z++)
|
796
|
+
n->v[z] = p->v[z] - q->v[z];
|
797
|
+
Base::normalize<T, Node<S, T> >(n, f);
|
798
|
+
n->a = 0.0;
|
799
|
+
for (int z = 0; z < f; z++)
|
800
|
+
n->a += -n->v[z] * (p->v[z] + q->v[z]) / 2;
|
801
|
+
}
|
802
|
+
template<typename T>
|
803
|
+
static inline T normalized_distance(T distance) {
|
804
|
+
return std::max(distance, T(0));
|
805
|
+
}
|
806
|
+
template<typename S, typename T>
|
807
|
+
static inline void init_node(Node<S, T>* n, int f) {
|
808
|
+
}
|
809
|
+
static const char* name() {
|
810
|
+
return "manhattan";
|
811
|
+
}
|
812
|
+
};
|
813
|
+
|
814
|
+
template<typename S, typename T>
|
815
|
+
class AnnoyIndexInterface {
|
816
|
+
public:
|
817
|
+
// Note that the methods with an **error argument will allocate memory and write the pointer to that string if error is non-NULL
|
818
|
+
virtual ~AnnoyIndexInterface() {};
|
819
|
+
virtual bool add_item(S item, const T* w, char** error=NULL) = 0;
|
820
|
+
virtual bool build(int q, char** error=NULL) = 0;
|
821
|
+
virtual bool unbuild(char** error=NULL) = 0;
|
822
|
+
virtual bool save(const char* filename, bool prefault=false, char** error=NULL) = 0;
|
823
|
+
virtual void unload() = 0;
|
824
|
+
virtual bool load(const char* filename, bool prefault=false, char** error=NULL) = 0;
|
825
|
+
virtual T get_distance(S i, S j) const = 0;
|
826
|
+
virtual void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
|
827
|
+
virtual void get_nns_by_vector(const T* w, size_t n, int search_k, vector<S>* result, vector<T>* distances) const = 0;
|
828
|
+
virtual S get_n_items() const = 0;
|
829
|
+
virtual S get_n_trees() const = 0;
|
830
|
+
virtual void verbose(bool v) = 0;
|
831
|
+
virtual void get_item(S item, T* v) const = 0;
|
832
|
+
virtual void set_seed(int q) = 0;
|
833
|
+
virtual bool on_disk_build(const char* filename, char** error=NULL) = 0;
|
834
|
+
};
|
835
|
+
|
836
|
+
template<typename S, typename T, typename Distance, typename Random>
|
837
|
+
class AnnoyIndex : public AnnoyIndexInterface<S, T> {
|
838
|
+
/*
|
839
|
+
* We use random projection to build a forest of binary trees of all items.
|
840
|
+
* Basically just split the hyperspace into two sides by a hyperplane,
|
841
|
+
* then recursively split each of those subtrees etc.
|
842
|
+
* We create a tree like this q times. The default q is determined automatically
|
843
|
+
* in such a way that we at most use 2x as much memory as the vectors take.
|
844
|
+
*/
|
845
|
+
public:
|
846
|
+
typedef Distance D;
|
847
|
+
typedef typename D::template Node<S, T> Node;
|
848
|
+
|
849
|
+
protected:
|
850
|
+
const int _f;
|
851
|
+
size_t _s;
|
852
|
+
S _n_items;
|
853
|
+
Random _random;
|
854
|
+
void* _nodes; // Could either be mmapped, or point to a memory buffer that we reallocate
|
855
|
+
S _n_nodes;
|
856
|
+
S _nodes_size;
|
857
|
+
vector<S> _roots;
|
858
|
+
S _K;
|
859
|
+
bool _loaded;
|
860
|
+
bool _verbose;
|
861
|
+
int _fd;
|
862
|
+
bool _on_disk;
|
863
|
+
bool _built;
|
864
|
+
public:
|
865
|
+
|
866
|
+
AnnoyIndex(int f) : _f(f), _random() {
|
867
|
+
_s = offsetof(Node, v) + _f * sizeof(T); // Size of each node
|
868
|
+
_verbose = false;
|
869
|
+
_built = false;
|
870
|
+
_K = (S) (((size_t) (_s - offsetof(Node, children))) / sizeof(S)); // Max number of descendants to fit into node
|
871
|
+
reinitialize(); // Reset everything
|
872
|
+
}
|
873
|
+
~AnnoyIndex() {
|
874
|
+
unload();
|
875
|
+
}
|
876
|
+
|
877
|
+
int get_f() const {
|
878
|
+
return _f;
|
879
|
+
}
|
880
|
+
|
881
|
+
bool add_item(S item, const T* w, char** error=NULL) {
|
882
|
+
return add_item_impl(item, w, error);
|
883
|
+
}
|
884
|
+
|
885
|
+
template<typename W>
|
886
|
+
bool add_item_impl(S item, const W& w, char** error=NULL) {
|
887
|
+
if (_loaded) {
|
888
|
+
set_error_from_string(error, "You can't add an item to a loaded index");
|
889
|
+
return false;
|
890
|
+
}
|
891
|
+
_allocate_size(item + 1);
|
892
|
+
Node* n = _get(item);
|
893
|
+
|
894
|
+
D::zero_value(n);
|
895
|
+
|
896
|
+
n->children[0] = 0;
|
897
|
+
n->children[1] = 0;
|
898
|
+
n->n_descendants = 1;
|
899
|
+
|
900
|
+
for (int z = 0; z < _f; z++)
|
901
|
+
n->v[z] = w[z];
|
902
|
+
|
903
|
+
D::init_node(n, _f);
|
904
|
+
|
905
|
+
if (item >= _n_items)
|
906
|
+
_n_items = item + 1;
|
907
|
+
|
908
|
+
return true;
|
909
|
+
}
|
910
|
+
|
911
|
+
bool on_disk_build(const char* file, char** error=NULL) {
|
912
|
+
_on_disk = true;
|
913
|
+
_fd = open(file, O_RDWR | O_CREAT | O_TRUNC, (int) 0600);
|
914
|
+
if (_fd == -1) {
|
915
|
+
set_error_from_errno(error, "Unable to open");
|
916
|
+
_fd = 0;
|
917
|
+
return false;
|
918
|
+
}
|
919
|
+
_nodes_size = 1;
|
920
|
+
if (ftruncate(_fd, FTRUNCATE_SIZE(_s) * FTRUNCATE_SIZE(_nodes_size)) == -1) {
|
921
|
+
set_error_from_errno(error, "Unable to truncate");
|
922
|
+
return false;
|
923
|
+
}
|
924
|
+
#ifdef MAP_POPULATE
|
925
|
+
_nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_POPULATE, _fd, 0);
|
926
|
+
#else
|
927
|
+
_nodes = (Node*) mmap(0, _s * _nodes_size, PROT_READ | PROT_WRITE, MAP_SHARED, _fd, 0);
|
928
|
+
#endif
|
929
|
+
return true;
|
930
|
+
}
|
931
|
+
|
932
|
+
bool build(int q, char** error=NULL) {
|
933
|
+
if (_loaded) {
|
934
|
+
set_error_from_string(error, "You can't build a loaded index");
|
935
|
+
return false;
|
936
|
+
}
|
937
|
+
|
938
|
+
if (_built) {
|
939
|
+
set_error_from_string(error, "You can't build a built index");
|
940
|
+
return false;
|
941
|
+
}
|
942
|
+
|
943
|
+
D::template preprocess<T, S, Node>(_nodes, _s, _n_items, _f);
|
944
|
+
|
945
|
+
_n_nodes = _n_items;
|
946
|
+
while (1) {
|
947
|
+
if (q == -1 && _n_nodes >= _n_items * 2)
|
948
|
+
break;
|
949
|
+
if (q != -1 && _roots.size() >= (size_t)q)
|
950
|
+
break;
|
951
|
+
if (_verbose) showUpdate("pass %zd...\n", _roots.size());
|
952
|
+
|
953
|
+
vector<S> indices;
|
954
|
+
for (S i = 0; i < _n_items; i++) {
|
955
|
+
if (_get(i)->n_descendants >= 1) // Issue #223
|
956
|
+
indices.push_back(i);
|
957
|
+
}
|
958
|
+
|
959
|
+
_roots.push_back(_make_tree(indices, true));
|
960
|
+
}
|
961
|
+
|
962
|
+
// Also, copy the roots into the last segment of the array
|
963
|
+
// This way we can load them faster without reading the whole file
|
964
|
+
_allocate_size(_n_nodes + (S)_roots.size());
|
965
|
+
for (size_t i = 0; i < _roots.size(); i++)
|
966
|
+
memcpy(_get(_n_nodes + (S)i), _get(_roots[i]), _s);
|
967
|
+
_n_nodes += _roots.size();
|
968
|
+
|
969
|
+
if (_verbose) showUpdate("has %d nodes\n", _n_nodes);
|
970
|
+
|
971
|
+
if (_on_disk) {
|
972
|
+
if (!remap_memory_and_truncate(&_nodes, _fd,
|
973
|
+
static_cast<size_t>(_s) * static_cast<size_t>(_nodes_size),
|
974
|
+
static_cast<size_t>(_s) * static_cast<size_t>(_n_nodes))) {
|
975
|
+
// TODO: this probably creates an index in a corrupt state... not sure what to do
|
976
|
+
set_error_from_errno(error, "Unable to truncate");
|
977
|
+
return false;
|
978
|
+
}
|
979
|
+
_nodes_size = _n_nodes;
|
980
|
+
}
|
981
|
+
_built = true;
|
982
|
+
return true;
|
983
|
+
}
|
984
|
+
|
985
|
+
bool unbuild(char** error=NULL) {
|
986
|
+
if (_loaded) {
|
987
|
+
set_error_from_string(error, "You can't unbuild a loaded index");
|
988
|
+
return false;
|
989
|
+
}
|
990
|
+
|
991
|
+
_roots.clear();
|
992
|
+
_n_nodes = _n_items;
|
993
|
+
_built = false;
|
994
|
+
|
995
|
+
return true;
|
996
|
+
}
|
997
|
+
|
998
|
+
bool save(const char* filename, bool prefault=false, char** error=NULL) {
|
999
|
+
if (!_built) {
|
1000
|
+
set_error_from_string(error, "You can't save an index that hasn't been built");
|
1001
|
+
return false;
|
1002
|
+
}
|
1003
|
+
if (_on_disk) {
|
1004
|
+
return true;
|
1005
|
+
} else {
|
1006
|
+
// Delete file if it already exists (See issue #335)
|
1007
|
+
unlink(filename);
|
1008
|
+
|
1009
|
+
FILE *f = fopen(filename, "wb");
|
1010
|
+
if (f == NULL) {
|
1011
|
+
set_error_from_errno(error, "Unable to open");
|
1012
|
+
return false;
|
1013
|
+
}
|
1014
|
+
|
1015
|
+
if (fwrite(_nodes, _s, _n_nodes, f) != (size_t) _n_nodes) {
|
1016
|
+
set_error_from_errno(error, "Unable to write");
|
1017
|
+
return false;
|
1018
|
+
}
|
1019
|
+
|
1020
|
+
if (fclose(f) == EOF) {
|
1021
|
+
set_error_from_errno(error, "Unable to close");
|
1022
|
+
return false;
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
unload();
|
1026
|
+
return load(filename, prefault, error);
|
1027
|
+
}
|
1028
|
+
}
|
1029
|
+
|
1030
|
+
void reinitialize() {
|
1031
|
+
_fd = 0;
|
1032
|
+
_nodes = NULL;
|
1033
|
+
_loaded = false;
|
1034
|
+
_n_items = 0;
|
1035
|
+
_n_nodes = 0;
|
1036
|
+
_nodes_size = 0;
|
1037
|
+
_on_disk = false;
|
1038
|
+
_roots.clear();
|
1039
|
+
}
|
1040
|
+
|
1041
|
+
void unload() {
|
1042
|
+
if (_on_disk && _fd) {
|
1043
|
+
close(_fd);
|
1044
|
+
munmap(_nodes, _s * _nodes_size);
|
1045
|
+
} else {
|
1046
|
+
if (_fd) {
|
1047
|
+
// we have mmapped data
|
1048
|
+
close(_fd);
|
1049
|
+
munmap(_nodes, _n_nodes * _s);
|
1050
|
+
} else if (_nodes) {
|
1051
|
+
// We have heap allocated data
|
1052
|
+
free(_nodes);
|
1053
|
+
}
|
1054
|
+
}
|
1055
|
+
reinitialize();
|
1056
|
+
if (_verbose) showUpdate("unloaded\n");
|
1057
|
+
}
|
1058
|
+
|
1059
|
+
bool load(const char* filename, bool prefault=false, char** error=NULL) {
|
1060
|
+
_fd = open(filename, O_RDONLY, (int)0400);
|
1061
|
+
if (_fd == -1) {
|
1062
|
+
set_error_from_errno(error, "Unable to open");
|
1063
|
+
_fd = 0;
|
1064
|
+
return false;
|
1065
|
+
}
|
1066
|
+
off_t size = lseek_getsize(_fd);
|
1067
|
+
if (size == -1) {
|
1068
|
+
set_error_from_errno(error, "Unable to get size");
|
1069
|
+
return false;
|
1070
|
+
} else if (size == 0) {
|
1071
|
+
set_error_from_errno(error, "Size of file is zero");
|
1072
|
+
return false;
|
1073
|
+
} else if (size % _s) {
|
1074
|
+
// Something is fishy with this index!
|
1075
|
+
set_error_from_errno(error, "Index size is not a multiple of vector size. Ensure you are opening using the same metric you used to create the index.");
|
1076
|
+
return false;
|
1077
|
+
}
|
1078
|
+
|
1079
|
+
int flags = MAP_SHARED;
|
1080
|
+
if (prefault) {
|
1081
|
+
#ifdef MAP_POPULATE
|
1082
|
+
flags |= MAP_POPULATE;
|
1083
|
+
#else
|
1084
|
+
showUpdate("prefault is set to true, but MAP_POPULATE is not defined on this platform");
|
1085
|
+
#endif
|
1086
|
+
}
|
1087
|
+
_nodes = (Node*)mmap(0, size, PROT_READ, flags, _fd, 0);
|
1088
|
+
_n_nodes = (S)(size / _s);
|
1089
|
+
|
1090
|
+
// Find the roots by scanning the end of the file and taking the nodes with most descendants
|
1091
|
+
_roots.clear();
|
1092
|
+
S m = -1;
|
1093
|
+
for (S i = _n_nodes - 1; i >= 0; i--) {
|
1094
|
+
S k = _get(i)->n_descendants;
|
1095
|
+
if (m == -1 || k == m) {
|
1096
|
+
_roots.push_back(i);
|
1097
|
+
m = k;
|
1098
|
+
} else {
|
1099
|
+
break;
|
1100
|
+
}
|
1101
|
+
}
|
1102
|
+
// hacky fix: since the last root precedes the copy of all roots, delete it
|
1103
|
+
if (_roots.size() > 1 && _get(_roots.front())->children[0] == _get(_roots.back())->children[0])
|
1104
|
+
_roots.pop_back();
|
1105
|
+
_loaded = true;
|
1106
|
+
_built = true;
|
1107
|
+
_n_items = m;
|
1108
|
+
if (_verbose) showUpdate("found %lu roots with degree %d\n", _roots.size(), m);
|
1109
|
+
return true;
|
1110
|
+
}
|
1111
|
+
|
1112
|
+
T get_distance(S i, S j) const {
|
1113
|
+
return D::normalized_distance(D::distance(_get(i), _get(j), _f));
|
1114
|
+
}
|
1115
|
+
|
1116
|
+
void get_nns_by_item(S item, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
|
1117
|
+
// TODO: handle OOB
|
1118
|
+
const Node* m = _get(item);
|
1119
|
+
_get_all_nns(m->v, n, search_k, result, distances);
|
1120
|
+
}
|
1121
|
+
|
1122
|
+
void get_nns_by_vector(const T* w, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
|
1123
|
+
_get_all_nns(w, n, search_k, result, distances);
|
1124
|
+
}
|
1125
|
+
|
1126
|
+
S get_n_items() const {
|
1127
|
+
return _n_items;
|
1128
|
+
}
|
1129
|
+
|
1130
|
+
S get_n_trees() const {
|
1131
|
+
return (S)_roots.size();
|
1132
|
+
}
|
1133
|
+
|
1134
|
+
void verbose(bool v) {
|
1135
|
+
_verbose = v;
|
1136
|
+
}
|
1137
|
+
|
1138
|
+
void get_item(S item, T* v) const {
|
1139
|
+
// TODO: handle OOB
|
1140
|
+
Node* m = _get(item);
|
1141
|
+
memcpy(v, m->v, (_f) * sizeof(T));
|
1142
|
+
}
|
1143
|
+
|
1144
|
+
void set_seed(int seed) {
|
1145
|
+
_random.set_seed(seed);
|
1146
|
+
}
|
1147
|
+
|
1148
|
+
protected:
|
1149
|
+
void _allocate_size(S n) {
|
1150
|
+
if (n > _nodes_size) {
|
1151
|
+
const double reallocation_factor = 1.3;
|
1152
|
+
S new_nodes_size = std::max(n, (S) ((_nodes_size + 1) * reallocation_factor));
|
1153
|
+
void *old = _nodes;
|
1154
|
+
|
1155
|
+
if (_on_disk) {
|
1156
|
+
if (!remap_memory_and_truncate(&_nodes, _fd,
|
1157
|
+
static_cast<size_t>(_s) * static_cast<size_t>(_nodes_size),
|
1158
|
+
static_cast<size_t>(_s) * static_cast<size_t>(new_nodes_size)) &&
|
1159
|
+
_verbose)
|
1160
|
+
showUpdate("File truncation error\n");
|
1161
|
+
} else {
|
1162
|
+
_nodes = realloc(_nodes, _s * new_nodes_size);
|
1163
|
+
memset((char *) _nodes + (_nodes_size * _s) / sizeof(char), 0, (new_nodes_size - _nodes_size) * _s);
|
1164
|
+
}
|
1165
|
+
|
1166
|
+
_nodes_size = new_nodes_size;
|
1167
|
+
if (_verbose) showUpdate("Reallocating to %d nodes: old_address=%p, new_address=%p\n", new_nodes_size, old, _nodes);
|
1168
|
+
}
|
1169
|
+
}
|
1170
|
+
|
1171
|
+
Node* _get(const S i) const {
|
1172
|
+
return get_node_ptr<S, Node>(_nodes, _s, i);
|
1173
|
+
}
|
1174
|
+
|
1175
|
+
double _split_imbalance(const vector<S>& left_indices, const vector<S>& right_indices) {
|
1176
|
+
double ls = (float)left_indices.size();
|
1177
|
+
double rs = (float)right_indices.size();
|
1178
|
+
float f = ls / (ls + rs + 1e-9); // Avoid 0/0
|
1179
|
+
return std::max(f, 1-f);
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
S _make_tree(const vector<S>& indices, bool is_root) {
|
1183
|
+
// The basic rule is that if we have <= _K items, then it's a leaf node, otherwise it's a split node.
|
1184
|
+
// There's some regrettable complications caused by the problem that root nodes have to be "special":
|
1185
|
+
// 1. We identify root nodes by the arguable logic that _n_items == n->n_descendants, regardless of how many descendants they actually have
|
1186
|
+
// 2. Root nodes with only 1 child need to be a "dummy" parent
|
1187
|
+
// 3. Due to the _n_items "hack", we need to be careful with the cases where _n_items <= _K or _n_items > _K
|
1188
|
+
if (indices.size() == 1 && !is_root)
|
1189
|
+
return indices[0];
|
1190
|
+
|
1191
|
+
if (indices.size() <= (size_t)_K && (!is_root || (size_t)_n_items <= (size_t)_K || indices.size() == 1)) {
|
1192
|
+
_allocate_size(_n_nodes + 1);
|
1193
|
+
S item = _n_nodes++;
|
1194
|
+
Node* m = _get(item);
|
1195
|
+
m->n_descendants = is_root ? _n_items : (S)indices.size();
|
1196
|
+
|
1197
|
+
// Using std::copy instead of a loop seems to resolve issues #3 and #13,
|
1198
|
+
// probably because gcc 4.8 goes overboard with optimizations.
|
1199
|
+
// Using memcpy instead of std::copy for MSVC compatibility. #235
|
1200
|
+
// Only copy when necessary to avoid crash in MSVC 9. #293
|
1201
|
+
if (!indices.empty())
|
1202
|
+
memcpy(m->children, &indices[0], indices.size() * sizeof(S));
|
1203
|
+
return item;
|
1204
|
+
}
|
1205
|
+
|
1206
|
+
vector<Node*> children;
|
1207
|
+
for (size_t i = 0; i < indices.size(); i++) {
|
1208
|
+
S j = indices[i];
|
1209
|
+
Node* n = _get(j);
|
1210
|
+
if (n)
|
1211
|
+
children.push_back(n);
|
1212
|
+
}
|
1213
|
+
|
1214
|
+
vector<S> children_indices[2];
|
1215
|
+
Node* m = (Node*)alloca(_s);
|
1216
|
+
|
1217
|
+
for (int attempt = 0; attempt < 3; attempt++) {
|
1218
|
+
children_indices[0].clear();
|
1219
|
+
children_indices[1].clear();
|
1220
|
+
D::create_split(children, _f, _s, _random, m);
|
1221
|
+
|
1222
|
+
for (size_t i = 0; i < indices.size(); i++) {
|
1223
|
+
S j = indices[i];
|
1224
|
+
Node* n = _get(j);
|
1225
|
+
if (n) {
|
1226
|
+
bool side = D::side(m, n->v, _f, _random);
|
1227
|
+
children_indices[side].push_back(j);
|
1228
|
+
} else {
|
1229
|
+
showUpdate("No node for index %d?\n", j);
|
1230
|
+
}
|
1231
|
+
}
|
1232
|
+
|
1233
|
+
if (_split_imbalance(children_indices[0], children_indices[1]) < 0.95)
|
1234
|
+
break;
|
1235
|
+
}
|
1236
|
+
|
1237
|
+
// If we didn't find a hyperplane, just randomize sides as a last option
|
1238
|
+
while (_split_imbalance(children_indices[0], children_indices[1]) > 0.99) {
|
1239
|
+
if (_verbose)
|
1240
|
+
showUpdate("\tNo hyperplane found (left has %ld children, right has %ld children)\n",
|
1241
|
+
children_indices[0].size(), children_indices[1].size());
|
1242
|
+
|
1243
|
+
children_indices[0].clear();
|
1244
|
+
children_indices[1].clear();
|
1245
|
+
|
1246
|
+
// Set the vector to 0.0
|
1247
|
+
for (int z = 0; z < _f; z++)
|
1248
|
+
m->v[z] = 0;
|
1249
|
+
|
1250
|
+
for (size_t i = 0; i < indices.size(); i++) {
|
1251
|
+
S j = indices[i];
|
1252
|
+
// Just randomize...
|
1253
|
+
children_indices[_random.flip()].push_back(j);
|
1254
|
+
}
|
1255
|
+
}
|
1256
|
+
|
1257
|
+
int flip = (children_indices[0].size() > children_indices[1].size());
|
1258
|
+
|
1259
|
+
m->n_descendants = is_root ? _n_items : (S)indices.size();
|
1260
|
+
for (int side = 0; side < 2; side++) {
|
1261
|
+
// run _make_tree for the smallest child first (for cache locality)
|
1262
|
+
m->children[side^flip] = _make_tree(children_indices[side^flip], false);
|
1263
|
+
}
|
1264
|
+
|
1265
|
+
|
1266
|
+
_allocate_size(_n_nodes + 1);
|
1267
|
+
S item = _n_nodes++;
|
1268
|
+
memcpy(_get(item), m, _s);
|
1269
|
+
|
1270
|
+
return item;
|
1271
|
+
}
|
1272
|
+
|
1273
|
+
void _get_all_nns(const T* v, size_t n, int search_k, vector<S>* result, vector<T>* distances) const {
|
1274
|
+
Node* v_node = (Node *)alloca(_s);
|
1275
|
+
D::template zero_value<Node>(v_node);
|
1276
|
+
memcpy(v_node->v, v, sizeof(T) * _f);
|
1277
|
+
D::init_node(v_node, _f);
|
1278
|
+
|
1279
|
+
std::priority_queue<pair<T, S> > q;
|
1280
|
+
|
1281
|
+
if (search_k == -1) {
|
1282
|
+
search_k = n * _roots.size();
|
1283
|
+
}
|
1284
|
+
|
1285
|
+
for (size_t i = 0; i < _roots.size(); i++) {
|
1286
|
+
q.push(make_pair(Distance::template pq_initial_value<T>(), _roots[i]));
|
1287
|
+
}
|
1288
|
+
|
1289
|
+
std::vector<S> nns;
|
1290
|
+
while (nns.size() < (size_t)search_k && !q.empty()) {
|
1291
|
+
const pair<T, S>& top = q.top();
|
1292
|
+
T d = top.first;
|
1293
|
+
S i = top.second;
|
1294
|
+
Node* nd = _get(i);
|
1295
|
+
q.pop();
|
1296
|
+
if (nd->n_descendants == 1 && i < _n_items) {
|
1297
|
+
nns.push_back(i);
|
1298
|
+
} else if (nd->n_descendants <= _K) {
|
1299
|
+
const S* dst = nd->children;
|
1300
|
+
nns.insert(nns.end(), dst, &dst[nd->n_descendants]);
|
1301
|
+
} else {
|
1302
|
+
T margin = D::margin(nd, v, _f);
|
1303
|
+
q.push(make_pair(D::pq_distance(d, margin, 1), static_cast<S>(nd->children[1])));
|
1304
|
+
q.push(make_pair(D::pq_distance(d, margin, 0), static_cast<S>(nd->children[0])));
|
1305
|
+
}
|
1306
|
+
}
|
1307
|
+
|
1308
|
+
// Get distances for all items
|
1309
|
+
// To avoid calculating distance multiple times for any items, sort by id
|
1310
|
+
std::sort(nns.begin(), nns.end());
|
1311
|
+
vector<pair<T, S> > nns_dist;
|
1312
|
+
S last = -1;
|
1313
|
+
for (size_t i = 0; i < nns.size(); i++) {
|
1314
|
+
S j = nns[i];
|
1315
|
+
if (j == last)
|
1316
|
+
continue;
|
1317
|
+
last = j;
|
1318
|
+
if (_get(j)->n_descendants == 1) // This is only to guard a really obscure case, #284
|
1319
|
+
nns_dist.push_back(make_pair(D::distance(v_node, _get(j), _f), j));
|
1320
|
+
}
|
1321
|
+
|
1322
|
+
size_t m = nns_dist.size();
|
1323
|
+
size_t p = n < m ? n : m; // Return this many items
|
1324
|
+
std::partial_sort(nns_dist.begin(), nns_dist.begin() + p, nns_dist.end());
|
1325
|
+
for (size_t i = 0; i < p; i++) {
|
1326
|
+
if (distances)
|
1327
|
+
distances->push_back(D::normalized_distance(nns_dist[i].first));
|
1328
|
+
result->push_back(nns_dist[i].second);
|
1329
|
+
}
|
1330
|
+
}
|
1331
|
+
};
|
1332
|
+
|
1333
|
+
#endif
|
1334
|
+
// vim: tabstop=2 shiftwidth=2
|