faiss 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +18 -18
  4. data/README.md +1 -1
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/Clustering.cpp +318 -53
  7. data/vendor/faiss/Clustering.h +39 -11
  8. data/vendor/faiss/DirectMap.cpp +267 -0
  9. data/vendor/faiss/DirectMap.h +120 -0
  10. data/vendor/faiss/IVFlib.cpp +24 -4
  11. data/vendor/faiss/IVFlib.h +4 -0
  12. data/vendor/faiss/Index.h +5 -24
  13. data/vendor/faiss/Index2Layer.cpp +0 -1
  14. data/vendor/faiss/IndexBinary.h +7 -3
  15. data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
  16. data/vendor/faiss/IndexBinaryFlat.h +3 -0
  17. data/vendor/faiss/IndexBinaryHash.cpp +492 -0
  18. data/vendor/faiss/IndexBinaryHash.h +116 -0
  19. data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
  20. data/vendor/faiss/IndexBinaryIVF.h +14 -4
  21. data/vendor/faiss/IndexFlat.h +2 -1
  22. data/vendor/faiss/IndexHNSW.cpp +68 -16
  23. data/vendor/faiss/IndexHNSW.h +3 -3
  24. data/vendor/faiss/IndexIVF.cpp +72 -76
  25. data/vendor/faiss/IndexIVF.h +24 -5
  26. data/vendor/faiss/IndexIVFFlat.cpp +19 -54
  27. data/vendor/faiss/IndexIVFFlat.h +1 -11
  28. data/vendor/faiss/IndexIVFPQ.cpp +49 -26
  29. data/vendor/faiss/IndexIVFPQ.h +9 -10
  30. data/vendor/faiss/IndexIVFPQR.cpp +2 -2
  31. data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
  32. data/vendor/faiss/IndexLSH.h +4 -1
  33. data/vendor/faiss/IndexPreTransform.cpp +0 -1
  34. data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
  35. data/vendor/faiss/InvertedLists.cpp +0 -2
  36. data/vendor/faiss/MetaIndexes.cpp +0 -1
  37. data/vendor/faiss/MetricType.h +36 -0
  38. data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
  39. data/vendor/faiss/c_api/Clustering_c.h +11 -5
  40. data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
  41. data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
  42. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
  43. data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
  44. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
  45. data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
  46. data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
  47. data/vendor/faiss/gpu/GpuDistance.h +93 -0
  48. data/vendor/faiss/gpu/GpuIndex.h +7 -0
  49. data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
  50. data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
  51. data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
  52. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
  53. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
  54. data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
  55. data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
  56. data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
  57. data/vendor/faiss/impl/HNSW.cpp +0 -1
  58. data/vendor/faiss/impl/PolysemousTraining.h +5 -5
  59. data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
  60. data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
  61. data/vendor/faiss/impl/ProductQuantizer.h +42 -47
  62. data/vendor/faiss/impl/index_read.cpp +103 -7
  63. data/vendor/faiss/impl/index_write.cpp +101 -5
  64. data/vendor/faiss/impl/io.cpp +111 -1
  65. data/vendor/faiss/impl/io.h +38 -0
  66. data/vendor/faiss/index_factory.cpp +0 -1
  67. data/vendor/faiss/tests/test_merge.cpp +0 -1
  68. data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
  69. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
  70. data/vendor/faiss/utils/distances.cpp +4 -5
  71. data/vendor/faiss/utils/distances_simd.cpp +0 -1
  72. data/vendor/faiss/utils/hamming.cpp +85 -3
  73. data/vendor/faiss/utils/hamming.h +20 -0
  74. data/vendor/faiss/utils/utils.cpp +0 -96
  75. data/vendor/faiss/utils/utils.h +0 -15
  76. metadata +11 -3
  77. data/lib/faiss/ext.bundle +0 -0
@@ -13,6 +13,7 @@
13
13
 
14
14
  #include <faiss/IndexPreTransform.h>
15
15
  #include <faiss/impl/FaissAssert.h>
16
+ #include <faiss/MetaIndexes.h>
16
17
 
17
18
 
18
19
 
@@ -56,17 +57,35 @@ void check_compatible_for_merge (const Index * index0,
56
57
 
57
58
  }
58
59
 
59
- const IndexIVF * extract_index_ivf (const Index * index)
60
+ const IndexIVF * try_extract_index_ivf (const Index * index)
60
61
  {
61
62
  if (auto *pt =
62
63
  dynamic_cast<const IndexPreTransform *>(index)) {
63
64
  index = pt->index;
64
65
  }
65
66
 
67
+ if (auto *idmap =
68
+ dynamic_cast<const IndexIDMap *>(index)) {
69
+ index = idmap->index;
70
+ }
71
+ if (auto *idmap =
72
+ dynamic_cast<const IndexIDMap2 *>(index)) {
73
+ index = idmap->index;
74
+ }
75
+
66
76
  auto *ivf = dynamic_cast<const IndexIVF *>(index);
67
77
 
68
- FAISS_THROW_IF_NOT (ivf);
78
+ return ivf;
79
+ }
69
80
 
81
+ IndexIVF * try_extract_index_ivf (Index * index) {
82
+ return const_cast<IndexIVF*> (try_extract_index_ivf ((const Index*)(index)));
83
+ }
84
+
85
+ const IndexIVF * extract_index_ivf (const Index * index)
86
+ {
87
+ const IndexIVF *ivf = try_extract_index_ivf (index);
88
+ FAISS_THROW_IF_NOT (ivf);
70
89
  return ivf;
71
90
  }
72
91
 
@@ -74,6 +93,7 @@ IndexIVF * extract_index_ivf (Index * index) {
74
93
  return const_cast<IndexIVF*> (extract_index_ivf ((const Index*)(index)));
75
94
  }
76
95
 
96
+
77
97
  void merge_into(faiss::Index *index0, faiss::Index *index1, bool shift_ids) {
78
98
 
79
99
  check_compatible_for_merge (index0, index1);
@@ -146,8 +166,8 @@ void search_and_return_centroids(faiss::Index *index,
146
166
  if (result_centroid_ids)
147
167
  result_centroid_ids[i] = -1;
148
168
  } else {
149
- long list_no = label >> 32;
150
- long list_index = label & 0xffffffff;
169
+ long list_no = lo_listno (label);
170
+ long list_index = lo_offset (label);
151
171
  if (result_centroid_ids)
152
172
  result_centroid_ids[i] = list_no;
153
173
  labels[i] = index_ivf->invlists->get_single_id(list_no, list_index);
@@ -35,6 +35,10 @@ void check_compatible_for_merge (const Index * index1,
35
35
  const IndexIVF * extract_index_ivf (const Index * index);
36
36
  IndexIVF * extract_index_ivf (Index * index);
37
37
 
38
+ /// same as above but returns nullptr instead of throwing on failure
39
+ const IndexIVF * try_extract_index_ivf (const Index * index);
40
+ IndexIVF * try_extract_index_ivf (Index * index);
41
+
38
42
  /** Merge index1 into index0. Works on IndexIVF's and IndexIVF's
39
43
  * embedded in a IndexPreTransform. On output, the index1 is empty.
40
44
  *
@@ -10,7 +10,7 @@
10
10
  #ifndef FAISS_INDEX_H
11
11
  #define FAISS_INDEX_H
12
12
 
13
-
13
+ #include <faiss/MetricType.h>
14
14
  #include <cstdio>
15
15
  #include <typeinfo>
16
16
  #include <string>
@@ -18,7 +18,7 @@
18
18
 
19
19
  #define FAISS_VERSION_MAJOR 1
20
20
  #define FAISS_VERSION_MINOR 6
21
- #define FAISS_VERSION_PATCH 1
21
+ #define FAISS_VERSION_PATCH 3
22
22
 
23
23
  /**
24
24
  * @namespace faiss
@@ -39,34 +39,15 @@
39
39
 
40
40
  namespace faiss {
41
41
 
42
-
43
- /// Some algorithms support both an inner product version and a L2 search version.
44
- enum MetricType {
45
- METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
46
- METRIC_L2 = 1, ///< squared L2 search
47
- METRIC_L1, ///< L1 (aka cityblock)
48
- METRIC_Linf, ///< infinity distance
49
- METRIC_Lp, ///< L_p distance, p is given by metric_arg
50
-
51
- /// some additional metrics defined in scipy.spatial.distance
52
- METRIC_Canberra = 20,
53
- METRIC_BrayCurtis,
54
- METRIC_JensenShannon,
55
-
56
- };
57
-
58
-
59
42
  /// Forward declarations see AuxIndexStructures.h
60
43
  struct IDSelector;
61
44
  struct RangeSearchResult;
62
45
  struct DistanceComputer;
63
46
 
64
- /** Abstract structure for an index
65
- *
66
- * Supports adding vertices and searching them.
47
+ /** Abstract structure for an index, supports adding vectors and searching them.
67
48
  *
68
- * Currently only asymmetric queries are supported:
69
- * database-to-database queries are not implemented.
49
+ * All vectors provided at add or search time are 32-bit float arrays,
50
+ * although the internal representation may vary.
70
51
  */
71
52
  struct Index {
72
53
  using idx_t = int64_t; ///< all indices are this type
@@ -42,7 +42,6 @@
42
42
 
43
43
  namespace faiss {
44
44
 
45
- using idx_t = Index::idx_t;
46
45
 
47
46
  /*************************************
48
47
  * Index2Layer implementation
@@ -99,9 +99,13 @@ struct IndexBinary {
99
99
 
100
100
  /** Query n vectors of dimension d to the index.
101
101
  *
102
- * return all vectors with distance < radius. Note that many
103
- * indexes do not implement the range_search (only the k-NN search
104
- * is mandatory).
102
+ * return all vectors with distance < radius. Note that many indexes
103
+ * do not implement the range_search (only the k-NN search is
104
+ * mandatory). The distances are converted to float to reuse the
105
+ * RangeSearchResult structure, but they are integer. By convention,
106
+ * only distances < radius (strict comparison) are returned,
107
+ * ie. radius = 0 does not return any result and 1 returns only
108
+ * exact same vectors.
105
109
  *
106
110
  * @param x input vectors to search, size n * d / 8
107
111
  * @param radius search radius
@@ -79,5 +79,10 @@ void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const {
79
79
  memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
80
80
  }
81
81
 
82
+ void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
83
+ RangeSearchResult *result) const
84
+ {
85
+ hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
86
+ }
82
87
 
83
88
  } // namespace faiss
@@ -38,6 +38,9 @@ struct IndexBinaryFlat : IndexBinary {
38
38
  void search(idx_t n, const uint8_t *x, idx_t k,
39
39
  int32_t *distances, idx_t *labels) const override;
40
40
 
41
+ void range_search(idx_t n, const uint8_t *x, int radius,
42
+ RangeSearchResult *result) const override;
43
+
41
44
  void reconstruct(idx_t key, uint8_t *recons) const override;
42
45
 
43
46
  /** Remove some ids. Note that because of the indexing structure,
@@ -0,0 +1,492 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // Copyright 2004-present Facebook. All Rights Reserved
9
+ // -*- c++ -*-
10
+
11
+ #include <faiss/IndexBinaryHash.h>
12
+
13
+ #include <cstdio>
14
+ #include <memory>
15
+
16
+ #include <faiss/utils/hamming.h>
17
+ #include <faiss/utils/utils.h>
18
+
19
+ #include <faiss/impl/AuxIndexStructures.h>
20
+ #include <faiss/impl/FaissAssert.h>
21
+
22
+
23
+ namespace faiss {
24
+
25
+ void IndexBinaryHash::InvertedList::add (
26
+ idx_t id, size_t code_size, const uint8_t *code)
27
+ {
28
+ ids.push_back(id);
29
+ vecs.insert(vecs.end(), code, code + code_size);
30
+ }
31
+
32
+ IndexBinaryHash::IndexBinaryHash(int d, int b):
33
+ IndexBinary(d), b(b), nflip(0)
34
+ {
35
+ is_trained = true;
36
+ }
37
+
38
+ IndexBinaryHash::IndexBinaryHash(): b(0), nflip(0)
39
+ {
40
+ is_trained = true;
41
+ }
42
+
43
+ void IndexBinaryHash::reset()
44
+ {
45
+ invlists.clear();
46
+ ntotal = 0;
47
+ }
48
+
49
+
50
+ void IndexBinaryHash::add(idx_t n, const uint8_t *x)
51
+ {
52
+ add_with_ids(n, x, nullptr);
53
+ }
54
+
55
+ void IndexBinaryHash::add_with_ids(idx_t n, const uint8_t *x, const idx_t *xids)
56
+ {
57
+ uint64_t mask = ((uint64_t)1 << b) - 1;
58
+ // simplistic add function. Cannot really be parallelized.
59
+
60
+ for (idx_t i = 0; i < n; i++) {
61
+ idx_t id = xids ? xids[i] : ntotal + i;
62
+ const uint8_t * xi = x + i * code_size;
63
+ idx_t hash = *((uint64_t*)xi) & mask;
64
+ invlists[hash].add(id, code_size, xi);
65
+ }
66
+ ntotal += n;
67
+ }
68
+
69
+ namespace {
70
+
71
+
72
+ /** Enumerate all bit vectors of size nbit with up to maxflip 1s
73
+ * test in P127257851 P127258235
74
+ */
75
+ struct FlipEnumerator {
76
+ int nbit, nflip, maxflip;
77
+ uint64_t mask, x;
78
+
79
+ FlipEnumerator (int nbit, int maxflip): nbit(nbit), maxflip(maxflip) {
80
+ nflip = 0;
81
+ mask = 0;
82
+ x = 0;
83
+ }
84
+
85
+ bool next() {
86
+ if (x == mask) {
87
+ if (nflip == maxflip) {
88
+ return false;
89
+ }
90
+ // increase Hamming radius
91
+ nflip++;
92
+ mask = (((uint64_t)1 << nflip) - 1);
93
+ x = mask << (nbit - nflip);
94
+ return true;
95
+ }
96
+
97
+ int i = __builtin_ctzll(x);
98
+
99
+ if (i > 0) {
100
+ x ^= (uint64_t)3 << (i - 1);
101
+ } else {
102
+ // nb of LSB 1s
103
+ int n1 = __builtin_ctzll(~x);
104
+ // clear them
105
+ x &= ((uint64_t)(-1) << n1);
106
+ int n2 = __builtin_ctzll(x);
107
+ x ^= (((uint64_t)1 << (n1 + 2)) - 1) << (n2 - n1 - 1);
108
+ }
109
+ return true;
110
+ }
111
+
112
+ };
113
+
114
+ using idx_t = Index::idx_t;
115
+
116
+
117
+ struct RangeSearchResults {
118
+ int radius;
119
+ RangeQueryResult &qres;
120
+
121
+ inline void add (float dis, idx_t id) {
122
+ if (dis < radius) {
123
+ qres.add (dis, id);
124
+ }
125
+ }
126
+
127
+ };
128
+
129
+ struct KnnSearchResults {
130
+ // heap params
131
+ idx_t k;
132
+ int32_t * heap_sim;
133
+ idx_t * heap_ids;
134
+
135
+ using C = CMax<int, idx_t>;
136
+
137
+ inline void add (float dis, idx_t id) {
138
+ if (dis < heap_sim[0]) {
139
+ heap_pop<C> (k, heap_sim, heap_ids);
140
+ heap_push<C> (k, heap_sim, heap_ids, dis, id);
141
+ }
142
+ }
143
+
144
+ };
145
+
146
+ template<class HammingComputer, class SearchResults>
147
+ void
148
+ search_single_query_template(const IndexBinaryHash & index, const uint8_t *q,
149
+ SearchResults &res,
150
+ size_t &n0, size_t &nlist, size_t &ndis)
151
+ {
152
+ size_t code_size = index.code_size;
153
+ uint64_t mask = ((uint64_t)1 << index.b) - 1;
154
+ uint64_t qhash = *((uint64_t*)q) & mask;
155
+ HammingComputer hc (q, code_size);
156
+ FlipEnumerator fe(index.b, index.nflip);
157
+
158
+ // loop over neighbors that are at most at nflip bits
159
+ do {
160
+ uint64_t hash = qhash ^ fe.x;
161
+ auto it = index.invlists.find (hash);
162
+
163
+ if (it == index.invlists.end()) {
164
+ continue;
165
+ }
166
+
167
+ const IndexBinaryHash::InvertedList &il = it->second;
168
+
169
+ size_t nv = il.ids.size();
170
+
171
+ if (nv == 0) {
172
+ n0++;
173
+ } else {
174
+ const uint8_t *codes = il.vecs.data();
175
+ for (size_t i = 0; i < nv; i++) {
176
+ int dis = hc.hamming (codes);
177
+ res.add(dis, il.ids[i]);
178
+ codes += code_size;
179
+ }
180
+ ndis += nv;
181
+ nlist++;
182
+ }
183
+ } while(fe.next());
184
+ }
185
+
186
+ template<class SearchResults>
187
+ void
188
+ search_single_query(const IndexBinaryHash & index, const uint8_t *q,
189
+ SearchResults &res,
190
+ size_t &n0, size_t &nlist, size_t &ndis)
191
+ {
192
+ #define HC(name) search_single_query_template<name>(index, q, res, n0, nlist, ndis);
193
+ switch(index.code_size) {
194
+ case 4: HC(HammingComputer4); break;
195
+ case 8: HC(HammingComputer8); break;
196
+ case 16: HC(HammingComputer16); break;
197
+ case 20: HC(HammingComputer20); break;
198
+ case 32: HC(HammingComputer32); break;
199
+ default:
200
+ if (index.code_size % 8 == 0) {
201
+ HC(HammingComputerM8);
202
+ } else {
203
+ HC(HammingComputerDefault);
204
+ }
205
+ }
206
+ #undef HC
207
+ }
208
+
209
+
210
+ } // anonymous namespace
211
+
212
+
213
+
214
+ void IndexBinaryHash::range_search(idx_t n, const uint8_t *x, int radius,
215
+ RangeSearchResult *result) const
216
+ {
217
+
218
+ size_t nlist = 0, ndis = 0, n0 = 0;
219
+
220
+ #pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist)
221
+ {
222
+ RangeSearchPartialResult pres (result);
223
+
224
+ #pragma omp for
225
+ for (size_t i = 0; i < n; i++) { // loop queries
226
+ RangeQueryResult & qres = pres.new_result (i);
227
+ RangeSearchResults res = {radius, qres};
228
+ const uint8_t *q = x + i * code_size;
229
+
230
+ search_single_query (*this, q, res, n0, nlist, ndis);
231
+
232
+ }
233
+ pres.finalize ();
234
+ }
235
+ indexBinaryHash_stats.nq += n;
236
+ indexBinaryHash_stats.n0 += n0;
237
+ indexBinaryHash_stats.nlist += nlist;
238
+ indexBinaryHash_stats.ndis += ndis;
239
+ }
240
+
241
+ void IndexBinaryHash::search(idx_t n, const uint8_t *x, idx_t k,
242
+ int32_t *distances, idx_t *labels) const
243
+ {
244
+
245
+ using HeapForL2 = CMax<int32_t, idx_t>;
246
+ size_t nlist = 0, ndis = 0, n0 = 0;
247
+
248
+ #pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0)
249
+ for (size_t i = 0; i < n; i++) {
250
+ int32_t * simi = distances + k * i;
251
+ idx_t * idxi = labels + k * i;
252
+
253
+ heap_heapify<HeapForL2> (k, simi, idxi);
254
+ KnnSearchResults res = {k, simi, idxi};
255
+ const uint8_t *q = x + i * code_size;
256
+
257
+ search_single_query (*this, q, res, n0, nlist, ndis);
258
+
259
+ }
260
+ indexBinaryHash_stats.nq += n;
261
+ indexBinaryHash_stats.n0 += n0;
262
+ indexBinaryHash_stats.nlist += nlist;
263
+ indexBinaryHash_stats.ndis += ndis;
264
+ }
265
+
266
+ size_t IndexBinaryHash::hashtable_size() const
267
+ {
268
+ return invlists.size();
269
+ }
270
+
271
+
272
+ void IndexBinaryHash::display() const
273
+ {
274
+ for (auto it = invlists.begin(); it != invlists.end(); ++it) {
275
+ printf("%ld: [", it->first);
276
+ const std::vector<idx_t> & v = it->second.ids;
277
+ for (auto x: v) {
278
+ printf("%ld ", 0 + x);
279
+ }
280
+ printf("]\n");
281
+
282
+ }
283
+ }
284
+
285
+
286
+ void IndexBinaryHashStats::reset()
287
+ {
288
+ memset ((void*)this, 0, sizeof (*this));
289
+ }
290
+
291
+ IndexBinaryHashStats indexBinaryHash_stats;
292
+
293
+ /*******************************************************
294
+ * IndexBinaryMultiHash implementation
295
+ ******************************************************/
296
+
297
+
298
+ IndexBinaryMultiHash::IndexBinaryMultiHash(int d, int nhash, int b):
299
+ IndexBinary(d),
300
+ storage(new IndexBinaryFlat(d)), own_fields(true),
301
+ maps(nhash), nhash(nhash), b(b), nflip(0)
302
+ {
303
+ FAISS_THROW_IF_NOT(nhash * b <= d);
304
+ }
305
+
306
+ IndexBinaryMultiHash::IndexBinaryMultiHash():
307
+ storage(nullptr), own_fields(true),
308
+ nhash(0), b(0), nflip(0)
309
+ {}
310
+
311
+ IndexBinaryMultiHash::~IndexBinaryMultiHash()
312
+ {
313
+ if (own_fields) {
314
+ delete storage;
315
+ }
316
+ }
317
+
318
+
319
+ void IndexBinaryMultiHash::reset()
320
+ {
321
+ storage->reset();
322
+ ntotal = 0;
323
+ for(auto map: maps) {
324
+ map.clear();
325
+ }
326
+ }
327
+
328
+ void IndexBinaryMultiHash::add(idx_t n, const uint8_t *x)
329
+ {
330
+ storage->add(n, x);
331
+ // populate maps
332
+ uint64_t mask = ((uint64_t)1 << b) - 1;
333
+
334
+ for(idx_t i = 0; i < n; i++) {
335
+ const uint8_t *xi = x + i * code_size;
336
+ int ho = 0;
337
+ for(int h = 0; h < nhash; h++) {
338
+ uint64_t hash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7);
339
+ hash &= mask;
340
+ maps[h][hash].push_back(i + ntotal);
341
+ ho += b;
342
+ }
343
+ }
344
+ ntotal += n;
345
+ }
346
+
347
+
348
+ namespace {
349
+
350
+ template <class HammingComputer, class SearchResults>
351
+ static
352
+ void verify_shortlist(
353
+ const IndexBinaryFlat & index,
354
+ const uint8_t * q,
355
+ const std::unordered_set<Index::idx_t> & shortlist,
356
+ SearchResults &res)
357
+ {
358
+ size_t code_size = index.code_size;
359
+ size_t nlist = 0, ndis = 0, n0 = 0;
360
+
361
+ HammingComputer hc (q, code_size);
362
+ const uint8_t *codes = index.xb.data();
363
+
364
+ for (auto i: shortlist) {
365
+ int dis = hc.hamming (codes + i * code_size);
366
+ res.add(dis, i);
367
+ }
368
+ }
369
+
370
+ template<class SearchResults>
371
+ void
372
+ search_1_query_multihash(const IndexBinaryMultiHash & index, const uint8_t *xi,
373
+ SearchResults &res,
374
+ size_t &n0, size_t &nlist, size_t &ndis)
375
+ {
376
+
377
+ std::unordered_set<idx_t> shortlist;
378
+ int b = index.b;
379
+ uint64_t mask = ((uint64_t)1 << b) - 1;
380
+
381
+ int ho = 0;
382
+ for(int h = 0; h < index.nhash; h++) {
383
+ uint64_t qhash = *(uint64_t*)(xi + (ho >> 3)) >> (ho & 7);
384
+ qhash &= mask;
385
+ const IndexBinaryMultiHash::Map & map = index.maps[h];
386
+
387
+ FlipEnumerator fe(index.b, index.nflip);
388
+ // loop over neighbors that are at most at nflip bits
389
+ do {
390
+ uint64_t hash = qhash ^ fe.x;
391
+ auto it = map.find (hash);
392
+
393
+ if (it != map.end()) {
394
+ const std::vector<idx_t> & v = it->second;
395
+ for (auto i: v) {
396
+ shortlist.insert(i);
397
+ }
398
+ nlist++;
399
+ } else {
400
+ n0++;
401
+ }
402
+ } while(fe.next());
403
+
404
+ ho += b;
405
+ }
406
+ ndis += shortlist.size();
407
+
408
+ // verify shortlist
409
+
410
+ #define HC(name) verify_shortlist<name> (*index.storage, xi, shortlist, res)
411
+ switch(index.code_size) {
412
+ case 4: HC(HammingComputer4); break;
413
+ case 8: HC(HammingComputer8); break;
414
+ case 16: HC(HammingComputer16); break;
415
+ case 20: HC(HammingComputer20); break;
416
+ case 32: HC(HammingComputer32); break;
417
+ default:
418
+ if (index.code_size % 8 == 0) {
419
+ HC(HammingComputerM8);
420
+ } else {
421
+ HC(HammingComputerDefault);
422
+ }
423
+ }
424
+ #undef HC
425
+ }
426
+
427
+ } // anonymous namespace
428
+
429
+ void IndexBinaryMultiHash::range_search(idx_t n, const uint8_t *x, int radius,
430
+ RangeSearchResult *result) const
431
+ {
432
+
433
+ size_t nlist = 0, ndis = 0, n0 = 0;
434
+
435
+ #pragma omp parallel if(n > 100) reduction(+: ndis, n0, nlist)
436
+ {
437
+ RangeSearchPartialResult pres (result);
438
+
439
+ #pragma omp for
440
+ for (size_t i = 0; i < n; i++) { // loop queries
441
+ RangeQueryResult & qres = pres.new_result (i);
442
+ RangeSearchResults res = {radius, qres};
443
+ const uint8_t *q = x + i * code_size;
444
+
445
+ search_1_query_multihash (*this, q, res, n0, nlist, ndis);
446
+
447
+ }
448
+ pres.finalize ();
449
+ }
450
+ indexBinaryHash_stats.nq += n;
451
+ indexBinaryHash_stats.n0 += n0;
452
+ indexBinaryHash_stats.nlist += nlist;
453
+ indexBinaryHash_stats.ndis += ndis;
454
+ }
455
+
456
+ void IndexBinaryMultiHash::search(idx_t n, const uint8_t *x, idx_t k,
457
+ int32_t *distances, idx_t *labels) const
458
+ {
459
+
460
+ using HeapForL2 = CMax<int32_t, idx_t>;
461
+ size_t nlist = 0, ndis = 0, n0 = 0;
462
+
463
+ #pragma omp parallel for if(n > 100) reduction(+: nlist, ndis, n0)
464
+ for (size_t i = 0; i < n; i++) {
465
+ int32_t * simi = distances + k * i;
466
+ idx_t * idxi = labels + k * i;
467
+
468
+ heap_heapify<HeapForL2> (k, simi, idxi);
469
+ KnnSearchResults res = {k, simi, idxi};
470
+ const uint8_t *q = x + i * code_size;
471
+
472
+ search_1_query_multihash (*this, q, res, n0, nlist, ndis);
473
+
474
+ }
475
+ indexBinaryHash_stats.nq += n;
476
+ indexBinaryHash_stats.n0 += n0;
477
+ indexBinaryHash_stats.nlist += nlist;
478
+ indexBinaryHash_stats.ndis += ndis;
479
+ }
480
+
481
+ size_t IndexBinaryMultiHash::hashtable_size() const
482
+ {
483
+ size_t tot = 0;
484
+ for (auto map: maps) {
485
+ tot += map.size();
486
+ }
487
+
488
+ return tot;
489
+ }
490
+
491
+
492
+ }