faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -9,80 +9,100 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFlat.h>
11
11
 
12
- #include <cstring>
12
+ #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/Heap.h>
13
15
  #include <faiss/utils/hamming.h>
14
16
  #include <faiss/utils/utils.h>
15
- #include <faiss/utils/Heap.h>
16
- #include <faiss/impl/FaissAssert.h>
17
- #include <faiss/impl/AuxIndexStructures.h>
17
+ #include <cstring>
18
18
 
19
19
  namespace faiss {
20
20
 
21
- IndexBinaryFlat::IndexBinaryFlat(idx_t d)
22
- : IndexBinary(d) {}
21
+ IndexBinaryFlat::IndexBinaryFlat(idx_t d) : IndexBinary(d) {}
23
22
 
24
- void IndexBinaryFlat::add(idx_t n, const uint8_t *x) {
25
- xb.insert(xb.end(), x, x + n * code_size);
26
- ntotal += n;
23
+ void IndexBinaryFlat::add(idx_t n, const uint8_t* x) {
24
+ xb.insert(xb.end(), x, x + n * code_size);
25
+ ntotal += n;
27
26
  }
28
27
 
29
28
  void IndexBinaryFlat::reset() {
30
- xb.clear();
31
- ntotal = 0;
29
+ xb.clear();
30
+ ntotal = 0;
32
31
  }
33
32
 
34
- void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
35
- int32_t *distances, idx_t *labels) const {
36
- const idx_t block_size = query_batch_size;
37
- for (idx_t s = 0; s < n; s += block_size) {
38
- idx_t nn = block_size;
39
- if (s + block_size > n) {
40
- nn = n - s;
41
- }
33
+ void IndexBinaryFlat::search(
34
+ idx_t n,
35
+ const uint8_t* x,
36
+ idx_t k,
37
+ int32_t* distances,
38
+ idx_t* labels) const {
39
+ FAISS_THROW_IF_NOT(k > 0);
42
40
 
43
- if (use_heap) {
44
- // We see the distances and labels as heaps.
45
- int_maxheap_array_t res = {
46
- size_t(nn), size_t(k), labels + s * k, distances + s * k
47
- };
41
+ const idx_t block_size = query_batch_size;
42
+ for (idx_t s = 0; s < n; s += block_size) {
43
+ idx_t nn = block_size;
44
+ if (s + block_size > n) {
45
+ nn = n - s;
46
+ }
48
47
 
49
- hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
50
- /* ordered = */ true);
51
- } else {
52
- hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
53
- distances + s * k, labels + s * k);
48
+ if (use_heap) {
49
+ // We see the distances and labels as heaps.
50
+ int_maxheap_array_t res = {
51
+ size_t(nn), size_t(k), labels + s * k, distances + s * k};
52
+
53
+ hammings_knn_hc(
54
+ &res,
55
+ x + s * code_size,
56
+ xb.data(),
57
+ ntotal,
58
+ code_size,
59
+ /* ordered = */ true);
60
+ } else {
61
+ hammings_knn_mc(
62
+ x + s * code_size,
63
+ xb.data(),
64
+ nn,
65
+ ntotal,
66
+ k,
67
+ code_size,
68
+ distances + s * k,
69
+ labels + s * k);
70
+ }
54
71
  }
55
- }
56
72
  }
57
73
 
58
74
  size_t IndexBinaryFlat::remove_ids(const IDSelector& sel) {
59
- idx_t j = 0;
60
- for (idx_t i = 0; i < ntotal; i++) {
61
- if (sel.is_member(i)) {
62
- // should be removed
63
- } else {
64
- if (i > j) {
65
- memmove(&xb[code_size * j], &xb[code_size * i], sizeof(xb[0]) * code_size);
66
- }
67
- j++;
75
+ idx_t j = 0;
76
+ for (idx_t i = 0; i < ntotal; i++) {
77
+ if (sel.is_member(i)) {
78
+ // should be removed
79
+ } else {
80
+ if (i > j) {
81
+ memmove(&xb[code_size * j],
82
+ &xb[code_size * i],
83
+ sizeof(xb[0]) * code_size);
84
+ }
85
+ j++;
86
+ }
87
+ }
88
+ long nremove = ntotal - j;
89
+ if (nremove > 0) {
90
+ ntotal = j;
91
+ xb.resize(ntotal * code_size);
68
92
  }
69
- }
70
- long nremove = ntotal - j;
71
- if (nremove > 0) {
72
- ntotal = j;
73
- xb.resize(ntotal * code_size);
74
- }
75
- return nremove;
93
+ return nremove;
76
94
  }
77
95
 
78
- void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const {
79
- memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
96
+ void IndexBinaryFlat::reconstruct(idx_t key, uint8_t* recons) const {
97
+ memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
80
98
  }
81
99
 
82
- void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
83
- RangeSearchResult *result) const
84
- {
85
- hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
100
+ void IndexBinaryFlat::range_search(
101
+ idx_t n,
102
+ const uint8_t* x,
103
+ int radius,
104
+ RangeSearchResult* result) const {
105
+ hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
86
106
  }
87
107
 
88
- } // namespace faiss
108
+ } // namespace faiss
@@ -16,42 +16,47 @@
16
16
 
17
17
  namespace faiss {
18
18
 
19
-
20
19
  /** Index that stores the full vectors and performs exhaustive search. */
21
20
  struct IndexBinaryFlat : IndexBinary {
22
- /// database vectors, size ntotal * d / 8
23
- std::vector<uint8_t> xb;
21
+ /// database vectors, size ntotal * d / 8
22
+ std::vector<uint8_t> xb;
24
23
 
25
- /** Select between using a heap or counting to select the k smallest values
26
- * when scanning inverted lists.
27
- */
28
- bool use_heap = true;
24
+ /** Select between using a heap or counting to select the k smallest values
25
+ * when scanning inverted lists.
26
+ */
27
+ bool use_heap = true;
29
28
 
30
- size_t query_batch_size = 32;
29
+ size_t query_batch_size = 32;
31
30
 
32
- explicit IndexBinaryFlat(idx_t d);
31
+ explicit IndexBinaryFlat(idx_t d);
33
32
 
34
- void add(idx_t n, const uint8_t *x) override;
33
+ void add(idx_t n, const uint8_t* x) override;
35
34
 
36
- void reset() override;
35
+ void reset() override;
37
36
 
38
- void search(idx_t n, const uint8_t *x, idx_t k,
39
- int32_t *distances, idx_t *labels) const override;
37
+ void search(
38
+ idx_t n,
39
+ const uint8_t* x,
40
+ idx_t k,
41
+ int32_t* distances,
42
+ idx_t* labels) const override;
40
43
 
41
- void range_search(idx_t n, const uint8_t *x, int radius,
42
- RangeSearchResult *result) const override;
44
+ void range_search(
45
+ idx_t n,
46
+ const uint8_t* x,
47
+ int radius,
48
+ RangeSearchResult* result) const override;
43
49
 
44
- void reconstruct(idx_t key, uint8_t *recons) const override;
50
+ void reconstruct(idx_t key, uint8_t* recons) const override;
45
51
 
46
- /** Remove some ids. Note that because of the indexing structure,
47
- * the semantics of this operation are different from the usual ones:
48
- * the new ids are shifted. */
49
- size_t remove_ids(const IDSelector& sel) override;
52
+ /** Remove some ids. Note that because of the indexing structure,
53
+ * the semantics of this operation are different from the usual ones:
54
+ * the new ids are shifted. */
55
+ size_t remove_ids(const IDSelector& sel) override;
50
56
 
51
- IndexBinaryFlat() {}
57
+ IndexBinaryFlat() {}
52
58
  };
53
59
 
60
+ } // namespace faiss
54
61
 
55
- } // namespace faiss
56
-
57
- #endif // INDEX_BINARY_FLAT_H
62
+ #endif // INDEX_BINARY_FLAT_H
@@ -9,71 +9,74 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFromFloat.h>
11
11
 
12
+ #include <faiss/utils/utils.h>
12
13
  #include <algorithm>
13
14
  #include <memory>
14
- #include <faiss/utils/utils.h>
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  IndexBinaryFromFloat::IndexBinaryFromFloat() {}
20
19
 
21
- IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index)
22
- : IndexBinary(index->d),
23
- index(index),
24
- own_fields(false) {
25
- is_trained = index->is_trained;
26
- ntotal = index->ntotal;
20
+ IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index)
21
+ : IndexBinary(index->d), index(index), own_fields(false) {
22
+ is_trained = index->is_trained;
23
+ ntotal = index->ntotal;
27
24
  }
28
25
 
29
26
  IndexBinaryFromFloat::~IndexBinaryFromFloat() {
30
- if (own_fields) {
31
- delete index;
32
- }
27
+ if (own_fields) {
28
+ delete index;
29
+ }
33
30
  }
34
31
 
35
- void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) {
36
- constexpr idx_t bs = 32768;
37
- std::unique_ptr<float[]> xf(new float[bs * d]);
32
+ void IndexBinaryFromFloat::add(idx_t n, const uint8_t* x) {
33
+ constexpr idx_t bs = 32768;
34
+ std::unique_ptr<float[]> xf(new float[bs * d]);
38
35
 
39
- for (idx_t b = 0; b < n; b += bs) {
40
- idx_t bn = std::min(bs, n - b);
41
- binary_to_real(bn * d, x + b * code_size, xf.get());
36
+ for (idx_t b = 0; b < n; b += bs) {
37
+ idx_t bn = std::min(bs, n - b);
38
+ binary_to_real(bn * d, x + b * code_size, xf.get());
42
39
 
43
- index->add(bn, xf.get());
44
- }
45
- ntotal = index->ntotal;
40
+ index->add(bn, xf.get());
41
+ }
42
+ ntotal = index->ntotal;
46
43
  }
47
44
 
48
45
  void IndexBinaryFromFloat::reset() {
49
- index->reset();
50
- ntotal = index->ntotal;
46
+ index->reset();
47
+ ntotal = index->ntotal;
51
48
  }
52
49
 
53
- void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
54
- int32_t *distances, idx_t *labels) const {
55
- constexpr idx_t bs = 32768;
56
- std::unique_ptr<float[]> xf(new float[bs * d]);
57
- std::unique_ptr<float[]> df(new float[bs * k]);
58
-
59
- for (idx_t b = 0; b < n; b += bs) {
60
- idx_t bn = std::min(bs, n - b);
61
- binary_to_real(bn * d, x + b * code_size, xf.get());
62
-
63
- index->search(bn, xf.get(), k, df.get(), labels + b * k);
64
- for (int i = 0; i < bn * k; ++i) {
65
- distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
50
+ void IndexBinaryFromFloat::search(
51
+ idx_t n,
52
+ const uint8_t* x,
53
+ idx_t k,
54
+ int32_t* distances,
55
+ idx_t* labels) const {
56
+ FAISS_THROW_IF_NOT(k > 0);
57
+
58
+ constexpr idx_t bs = 32768;
59
+ std::unique_ptr<float[]> xf(new float[bs * d]);
60
+ std::unique_ptr<float[]> df(new float[bs * k]);
61
+
62
+ for (idx_t b = 0; b < n; b += bs) {
63
+ idx_t bn = std::min(bs, n - b);
64
+ binary_to_real(bn * d, x + b * code_size, xf.get());
65
+
66
+ index->search(bn, xf.get(), k, df.get(), labels + b * k);
67
+ for (int i = 0; i < bn * k; ++i) {
68
+ distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
69
+ }
66
70
  }
67
- }
68
71
  }
69
72
 
70
- void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) {
71
- std::unique_ptr<float[]> xf(new float[n * d]);
72
- binary_to_real(n * d, x, xf.get());
73
+ void IndexBinaryFromFloat::train(idx_t n, const uint8_t* x) {
74
+ std::unique_ptr<float[]> xf(new float[n * d]);
75
+ binary_to_real(n * d, x, xf.get());
73
76
 
74
- index->train(n, xf.get());
75
- is_trained = true;
76
- ntotal = index->ntotal;
77
+ index->train(n, xf.get());
78
+ is_trained = true;
79
+ ntotal = index->ntotal;
77
80
  }
78
81
 
79
- } // namespace faiss
82
+ } // namespace faiss
@@ -12,10 +12,8 @@
12
12
 
13
13
  #include <faiss/IndexBinary.h>
14
14
 
15
-
16
15
  namespace faiss {
17
16
 
18
-
19
17
  struct Index;
20
18
 
21
19
  /** IndexBinary backed by a float Index.
@@ -26,27 +24,30 @@ struct Index;
26
24
  * vectors.
27
25
  */
28
26
  struct IndexBinaryFromFloat : IndexBinary {
29
- Index *index = nullptr;
27
+ Index* index = nullptr;
30
28
 
31
- bool own_fields = false; ///< Whether object owns the index pointer.
29
+ bool own_fields = false; ///< Whether object owns the index pointer.
32
30
 
33
- IndexBinaryFromFloat();
31
+ IndexBinaryFromFloat();
34
32
 
35
- explicit IndexBinaryFromFloat(Index *index);
33
+ explicit IndexBinaryFromFloat(Index* index);
36
34
 
37
- ~IndexBinaryFromFloat();
35
+ ~IndexBinaryFromFloat();
38
36
 
39
- void add(idx_t n, const uint8_t *x) override;
37
+ void add(idx_t n, const uint8_t* x) override;
40
38
 
41
- void reset() override;
39
+ void reset() override;
42
40
 
43
- void search(idx_t n, const uint8_t *x, idx_t k,
44
- int32_t *distances, idx_t *labels) const override;
41
+ void search(
42
+ idx_t n,
43
+ const uint8_t* x,
44
+ idx_t k,
45
+ int32_t* distances,
46
+ idx_t* labels) const override;
45
47
 
46
- void train(idx_t n, const uint8_t *x) override;
48
+ void train(idx_t n, const uint8_t* x) override;
47
49
  };
48
50
 
51
+ } // namespace faiss
49
52
 
50
- } // namespace faiss
51
-
52
- #endif // FAISS_INDEX_BINARY_FROM_FLOAT_H
53
+ #endif // FAISS_INDEX_BINARY_FROM_FLOAT_H
@@ -9,316 +9,299 @@
9
9
 
10
10
  #include <faiss/IndexBinaryHNSW.h>
11
11
 
12
-
13
- #include <memory>
14
- #include <cstdlib>
12
+ #include <omp.h>
15
13
  #include <cassert>
16
- #include <cstring>
17
- #include <cstdio>
18
14
  #include <cmath>
19
- #include <omp.h>
15
+ #include <cstdio>
16
+ #include <cstdlib>
17
+ #include <cstring>
18
+ #include <memory>
20
19
 
21
- #include <unordered_set>
22
20
  #include <queue>
21
+ #include <unordered_set>
23
22
 
24
- #include <sys/types.h>
25
- #include <sys/stat.h>
26
23
  #include <stdint.h>
24
+ #include <sys/stat.h>
25
+ #include <sys/types.h>
27
26
 
28
- #include <faiss/utils/random.h>
29
- #include <faiss/utils/Heap.h>
30
- #include <faiss/impl/FaissAssert.h>
31
27
  #include <faiss/IndexBinaryFlat.h>
32
- #include <faiss/utils/hamming.h>
33
28
  #include <faiss/impl/AuxIndexStructures.h>
29
+ #include <faiss/impl/FaissAssert.h>
30
+ #include <faiss/utils/Heap.h>
31
+ #include <faiss/utils/hamming.h>
32
+ #include <faiss/utils/random.h>
34
33
 
35
34
  namespace faiss {
36
35
 
37
-
38
36
  /**************************************************************
39
37
  * add / search blocks of descriptors
40
38
  **************************************************************/
41
39
 
42
40
  namespace {
43
41
 
42
+ void hnsw_add_vertices(
43
+ IndexBinaryHNSW& index_hnsw,
44
+ size_t n0,
45
+ size_t n,
46
+ const uint8_t* x,
47
+ bool verbose,
48
+ bool preset_levels = false) {
49
+ HNSW& hnsw = index_hnsw.hnsw;
50
+ size_t ntotal = n0 + n;
51
+ double t0 = getmillisecs();
52
+ if (verbose) {
53
+ printf("hnsw_add_vertices: adding %zd elements on top of %zd "
54
+ "(preset_levels=%d)\n",
55
+ n,
56
+ n0,
57
+ int(preset_levels));
58
+ }
59
+
60
+ int max_level = hnsw.prepare_level_tab(n, preset_levels);
44
61
 
45
- void hnsw_add_vertices(IndexBinaryHNSW& index_hnsw,
46
- size_t n0,
47
- size_t n, const uint8_t *x,
48
- bool verbose,
49
- bool preset_levels = false) {
50
- HNSW& hnsw = index_hnsw.hnsw;
51
- size_t ntotal = n0 + n;
52
- double t0 = getmillisecs();
53
- if (verbose) {
54
- printf("hnsw_add_vertices: adding %zd elements on top of %zd "
55
- "(preset_levels=%d)\n",
56
- n, n0, int(preset_levels));
57
- }
58
-
59
- int max_level = hnsw.prepare_level_tab(n, preset_levels);
60
-
61
- if (verbose) {
62
- printf(" max_level = %d\n", max_level);
63
- }
64
-
65
- std::vector<omp_lock_t> locks(ntotal);
66
- for(int i = 0; i < ntotal; i++) {
67
- omp_init_lock(&locks[i]);
68
- }
69
-
70
- // add vectors from highest to lowest level
71
- std::vector<int> hist;
72
- std::vector<int> order(n);
73
-
74
- { // make buckets with vectors of the same level
75
-
76
- // build histogram
77
- for (int i = 0; i < n; i++) {
78
- HNSW::storage_idx_t pt_id = i + n0;
79
- int pt_level = hnsw.levels[pt_id] - 1;
80
- while (pt_level >= hist.size()) {
81
- hist.push_back(0);
82
- }
83
- hist[pt_level] ++;
62
+ if (verbose) {
63
+ printf(" max_level = %d\n", max_level);
84
64
  }
85
65
 
86
- // accumulate
87
- std::vector<int> offsets(hist.size() + 1, 0);
88
- for (int i = 0; i < hist.size() - 1; i++) {
89
- offsets[i + 1] = offsets[i] + hist[i];
66
+ std::vector<omp_lock_t> locks(ntotal);
67
+ for (int i = 0; i < ntotal; i++) {
68
+ omp_init_lock(&locks[i]);
90
69
  }
91
70
 
92
- // bucket sort
93
- for (int i = 0; i < n; i++) {
94
- HNSW::storage_idx_t pt_id = i + n0;
95
- int pt_level = hnsw.levels[pt_id] - 1;
96
- order[offsets[pt_level]++] = pt_id;
71
+ // add vectors from highest to lowest level
72
+ std::vector<int> hist;
73
+ std::vector<int> order(n);
74
+
75
+ { // make buckets with vectors of the same level
76
+
77
+ // build histogram
78
+ for (int i = 0; i < n; i++) {
79
+ HNSW::storage_idx_t pt_id = i + n0;
80
+ int pt_level = hnsw.levels[pt_id] - 1;
81
+ while (pt_level >= hist.size()) {
82
+ hist.push_back(0);
83
+ }
84
+ hist[pt_level]++;
85
+ }
86
+
87
+ // accumulate
88
+ std::vector<int> offsets(hist.size() + 1, 0);
89
+ for (int i = 0; i < hist.size() - 1; i++) {
90
+ offsets[i + 1] = offsets[i] + hist[i];
91
+ }
92
+
93
+ // bucket sort
94
+ for (int i = 0; i < n; i++) {
95
+ HNSW::storage_idx_t pt_id = i + n0;
96
+ int pt_level = hnsw.levels[pt_id] - 1;
97
+ order[offsets[pt_level]++] = pt_id;
98
+ }
97
99
  }
98
- }
99
100
 
100
- { // perform add
101
- RandomGenerator rng2(789);
101
+ { // perform add
102
+ RandomGenerator rng2(789);
102
103
 
103
- int i1 = n;
104
+ int i1 = n;
104
105
 
105
- for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
106
- int i0 = i1 - hist[pt_level];
106
+ for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
107
+ int i0 = i1 - hist[pt_level];
107
108
 
108
- if (verbose) {
109
- printf("Adding %d elements at level %d\n",
110
- i1 - i0, pt_level);
111
- }
109
+ if (verbose) {
110
+ printf("Adding %d elements at level %d\n", i1 - i0, pt_level);
111
+ }
112
112
 
113
- // random permutation to get rid of dataset order bias
114
- for (int j = i0; j < i1; j++) {
115
- std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
116
- }
113
+ // random permutation to get rid of dataset order bias
114
+ for (int j = i0; j < i1; j++) {
115
+ std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
116
+ }
117
117
 
118
118
  #pragma omp parallel
119
- {
120
- VisitedTable vt (ntotal);
121
-
122
- std::unique_ptr<DistanceComputer> dis(
123
- index_hnsw.get_distance_computer()
124
- );
125
- int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
126
-
127
- #pragma omp for schedule(dynamic)
128
- for (int i = i0; i < i1; i++) {
129
- HNSW::storage_idx_t pt_id = order[i];
130
- dis->set_query((float *)(x + (pt_id - n0) * index_hnsw.code_size));
131
-
132
- hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
133
-
134
- if (prev_display >= 0 && i - i0 > prev_display + 10000) {
135
- prev_display = i - i0;
136
- printf(" %d / %d\r", i - i0, i1 - i0);
137
- fflush(stdout);
138
- }
119
+ {
120
+ VisitedTable vt(ntotal);
121
+
122
+ std::unique_ptr<DistanceComputer> dis(
123
+ index_hnsw.get_distance_computer());
124
+ int prev_display =
125
+ verbose && omp_get_thread_num() == 0 ? 0 : -1;
126
+
127
+ #pragma omp for schedule(dynamic)
128
+ for (int i = i0; i < i1; i++) {
129
+ HNSW::storage_idx_t pt_id = order[i];
130
+ dis->set_query(
131
+ (float*)(x + (pt_id - n0) * index_hnsw.code_size));
132
+
133
+ hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
134
+
135
+ if (prev_display >= 0 && i - i0 > prev_display + 10000) {
136
+ prev_display = i - i0;
137
+ printf(" %d / %d\r", i - i0, i1 - i0);
138
+ fflush(stdout);
139
+ }
140
+ }
141
+ }
142
+ i1 = i0;
139
143
  }
140
- }
141
- i1 = i0;
144
+ FAISS_ASSERT(i1 == 0);
145
+ }
146
+ if (verbose) {
147
+ printf("Done in %.3f ms\n", getmillisecs() - t0);
142
148
  }
143
- FAISS_ASSERT(i1 == 0);
144
- }
145
- if (verbose) {
146
- printf("Done in %.3f ms\n", getmillisecs() - t0);
147
- }
148
-
149
- for(int i = 0; i < ntotal; i++)
150
- omp_destroy_lock(&locks[i]);
151
- }
152
149
 
150
+ for (int i = 0; i < ntotal; i++)
151
+ omp_destroy_lock(&locks[i]);
152
+ }
153
153
 
154
154
  } // anonymous namespace
155
155
 
156
-
157
156
  /**************************************************************
158
157
  * IndexBinaryHNSW implementation
159
158
  **************************************************************/
160
159
 
161
- IndexBinaryHNSW::IndexBinaryHNSW()
162
- {
163
- is_trained = true;
160
+ IndexBinaryHNSW::IndexBinaryHNSW() {
161
+ is_trained = true;
164
162
  }
165
163
 
166
164
  IndexBinaryHNSW::IndexBinaryHNSW(int d, int M)
167
- : IndexBinary(d),
168
- hnsw(M),
169
- own_fields(true),
170
- storage(new IndexBinaryFlat(d))
171
- {
172
- is_trained = true;
165
+ : IndexBinary(d),
166
+ hnsw(M),
167
+ own_fields(true),
168
+ storage(new IndexBinaryFlat(d)) {
169
+ is_trained = true;
173
170
  }
174
171
 
175
- IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary *storage, int M)
176
- : IndexBinary(storage->d),
177
- hnsw(M),
178
- own_fields(false),
179
- storage(storage)
180
- {
181
- is_trained = true;
172
+ IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary* storage, int M)
173
+ : IndexBinary(storage->d),
174
+ hnsw(M),
175
+ own_fields(false),
176
+ storage(storage) {
177
+ is_trained = true;
182
178
  }
183
179
 
184
180
  IndexBinaryHNSW::~IndexBinaryHNSW() {
185
- if (own_fields) {
186
- delete storage;
187
- }
181
+ if (own_fields) {
182
+ delete storage;
183
+ }
188
184
  }
189
185
 
190
- void IndexBinaryHNSW::train(idx_t n, const uint8_t *x)
191
- {
192
- // hnsw structure does not require training
193
- storage->train(n, x);
194
- is_trained = true;
186
+ void IndexBinaryHNSW::train(idx_t n, const uint8_t* x) {
187
+ // hnsw structure does not require training
188
+ storage->train(n, x);
189
+ is_trained = true;
195
190
  }
196
191
 
197
- void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k,
198
- int32_t *distances, idx_t *labels) const
199
- {
192
+ void IndexBinaryHNSW::search(
193
+ idx_t n,
194
+ const uint8_t* x,
195
+ idx_t k,
196
+ int32_t* distances,
197
+ idx_t* labels) const {
198
+ FAISS_THROW_IF_NOT(k > 0);
199
+
200
200
  #pragma omp parallel
201
- {
202
- VisitedTable vt(ntotal);
203
- std::unique_ptr<DistanceComputer> dis(get_distance_computer());
201
+ {
202
+ VisitedTable vt(ntotal);
203
+ std::unique_ptr<DistanceComputer> dis(get_distance_computer());
204
204
 
205
205
  #pragma omp for
206
- for(idx_t i = 0; i < n; i++) {
207
- idx_t *idxi = labels + i * k;
208
- float *simi = (float *)(distances + i * k);
206
+ for (idx_t i = 0; i < n; i++) {
207
+ idx_t* idxi = labels + i * k;
208
+ float* simi = (float*)(distances + i * k);
209
209
 
210
- dis->set_query((float *)(x + i * code_size));
210
+ dis->set_query((float*)(x + i * code_size));
211
211
 
212
- maxheap_heapify(k, simi, idxi);
213
- hnsw.search(*dis, k, idxi, simi, vt);
214
- maxheap_reorder(k, simi, idxi);
212
+ maxheap_heapify(k, simi, idxi);
213
+ hnsw.search(*dis, k, idxi, simi, vt);
214
+ maxheap_reorder(k, simi, idxi);
215
+ }
215
216
  }
216
- }
217
217
 
218
218
  #pragma omp parallel for
219
- for (int i = 0; i < n * k; ++i) {
220
- distances[i] = std::round(((float *)distances)[i]);
221
- }
219
+ for (int i = 0; i < n * k; ++i) {
220
+ distances[i] = std::round(((float*)distances)[i]);
221
+ }
222
222
  }
223
223
 
224
+ void IndexBinaryHNSW::add(idx_t n, const uint8_t* x) {
225
+ FAISS_THROW_IF_NOT(is_trained);
226
+ int n0 = ntotal;
227
+ storage->add(n, x);
228
+ ntotal = storage->ntotal;
224
229
 
225
- void IndexBinaryHNSW::add(idx_t n, const uint8_t *x)
226
- {
227
- FAISS_THROW_IF_NOT(is_trained);
228
- int n0 = ntotal;
229
- storage->add(n, x);
230
- ntotal = storage->ntotal;
231
-
232
- hnsw_add_vertices(*this, n0, n, x, verbose,
233
- hnsw.levels.size() == ntotal);
230
+ hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal);
234
231
  }
235
232
 
236
- void IndexBinaryHNSW::reset()
237
- {
238
- hnsw.reset();
239
- storage->reset();
240
- ntotal = 0;
233
+ void IndexBinaryHNSW::reset() {
234
+ hnsw.reset();
235
+ storage->reset();
236
+ ntotal = 0;
241
237
  }
242
238
 
243
- void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t *recons) const
244
- {
245
- storage->reconstruct(key, recons);
239
+ void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t* recons) const {
240
+ storage->reconstruct(key, recons);
246
241
  }
247
242
 
248
-
249
243
  namespace {
250
244
 
251
-
252
- template<class HammingComputer>
245
+ template <class HammingComputer>
253
246
  struct FlatHammingDis : DistanceComputer {
254
- const int code_size;
255
- const uint8_t *b;
256
- size_t ndis;
257
- HammingComputer hc;
258
-
259
- float operator () (idx_t i) override {
260
- ndis++;
261
- return hc.hamming(b + i * code_size);
262
- }
263
-
264
- float symmetric_dis(idx_t i, idx_t j) override {
265
- return HammingComputerDefault(b + j * code_size, code_size)
266
- .hamming(b + i * code_size);
267
- }
268
-
269
-
270
- explicit FlatHammingDis(const IndexBinaryFlat& storage)
271
- : code_size(storage.code_size),
272
- b(storage.xb.data()),
273
- ndis(0),
274
- hc() {}
275
-
276
- // NOTE: Pointers are cast from float in order to reuse the floating-point
277
- // DistanceComputer.
278
- void set_query(const float *x) override {
279
- hc.set((uint8_t *)x, code_size);
280
- }
281
-
282
- ~FlatHammingDis() override {
247
+ const int code_size;
248
+ const uint8_t* b;
249
+ size_t ndis;
250
+ HammingComputer hc;
251
+
252
+ float operator()(idx_t i) override {
253
+ ndis++;
254
+ return hc.hamming(b + i * code_size);
255
+ }
256
+
257
+ float symmetric_dis(idx_t i, idx_t j) override {
258
+ return HammingComputerDefault(b + j * code_size, code_size)
259
+ .hamming(b + i * code_size);
260
+ }
261
+
262
+ explicit FlatHammingDis(const IndexBinaryFlat& storage)
263
+ : code_size(storage.code_size),
264
+ b(storage.xb.data()),
265
+ ndis(0),
266
+ hc() {}
267
+
268
+ // NOTE: Pointers are cast from float in order to reuse the floating-point
269
+ // DistanceComputer.
270
+ void set_query(const float* x) override {
271
+ hc.set((uint8_t*)x, code_size);
272
+ }
273
+
274
+ ~FlatHammingDis() override {
283
275
  #pragma omp critical
284
- {
285
- hnsw_stats.ndis += ndis;
276
+ { hnsw_stats.ndis += ndis; }
286
277
  }
287
- }
288
278
  };
289
279
 
280
+ } // namespace
281
+
282
+ DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
283
+ IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
284
+
285
+ FAISS_ASSERT(flat_storage != nullptr);
286
+
287
+ switch (code_size) {
288
+ case 4:
289
+ return new FlatHammingDis<HammingComputer4>(*flat_storage);
290
+ case 8:
291
+ return new FlatHammingDis<HammingComputer8>(*flat_storage);
292
+ case 16:
293
+ return new FlatHammingDis<HammingComputer16>(*flat_storage);
294
+ case 20:
295
+ return new FlatHammingDis<HammingComputer20>(*flat_storage);
296
+ case 32:
297
+ return new FlatHammingDis<HammingComputer32>(*flat_storage);
298
+ case 64:
299
+ return new FlatHammingDis<HammingComputer64>(*flat_storage);
300
+ default:
301
+ break;
302
+ }
290
303
 
291
- } // namespace
292
-
293
-
294
- DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
295
- IndexBinaryFlat *flat_storage = dynamic_cast<IndexBinaryFlat *>(storage);
296
-
297
- FAISS_ASSERT(flat_storage != nullptr);
298
-
299
- switch(code_size) {
300
- case 4:
301
- return new FlatHammingDis<HammingComputer4>(*flat_storage);
302
- case 8:
303
- return new FlatHammingDis<HammingComputer8>(*flat_storage);
304
- case 16:
305
- return new FlatHammingDis<HammingComputer16>(*flat_storage);
306
- case 20:
307
- return new FlatHammingDis<HammingComputer20>(*flat_storage);
308
- case 32:
309
- return new FlatHammingDis<HammingComputer32>(*flat_storage);
310
- case 64:
311
- return new FlatHammingDis<HammingComputer64>(*flat_storage);
312
- default:
313
- if (code_size % 8 == 0) {
314
- return new FlatHammingDis<HammingComputerM8>(*flat_storage);
315
- } else if (code_size % 4 == 0) {
316
- return new FlatHammingDis<HammingComputerM4>(*flat_storage);
317
- }
318
- }
319
-
320
- return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
304
+ return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
321
305
  }
322
306
 
323
-
324
307
  } // namespace faiss