faiss 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -1
  3. data/README.md +15 -3
  4. data/ext/faiss/ext.cpp +12 -308
  5. data/ext/faiss/extconf.rb +5 -2
  6. data/ext/faiss/index.cpp +189 -0
  7. data/ext/faiss/index_binary.cpp +75 -0
  8. data/ext/faiss/kmeans.cpp +40 -0
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +33 -0
  11. data/ext/faiss/product_quantizer.cpp +53 -0
  12. data/ext/faiss/utils.cpp +13 -0
  13. data/ext/faiss/utils.h +5 -0
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +31 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -9,80 +9,100 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFlat.h>
11
11
 
12
- #include <cstring>
12
+ #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/Heap.h>
13
15
  #include <faiss/utils/hamming.h>
14
16
  #include <faiss/utils/utils.h>
15
- #include <faiss/utils/Heap.h>
16
- #include <faiss/impl/FaissAssert.h>
17
- #include <faiss/impl/AuxIndexStructures.h>
17
+ #include <cstring>
18
18
 
19
19
  namespace faiss {
20
20
 
21
- IndexBinaryFlat::IndexBinaryFlat(idx_t d)
22
- : IndexBinary(d) {}
21
+ IndexBinaryFlat::IndexBinaryFlat(idx_t d) : IndexBinary(d) {}
23
22
 
24
- void IndexBinaryFlat::add(idx_t n, const uint8_t *x) {
25
- xb.insert(xb.end(), x, x + n * code_size);
26
- ntotal += n;
23
+ void IndexBinaryFlat::add(idx_t n, const uint8_t* x) {
24
+ xb.insert(xb.end(), x, x + n * code_size);
25
+ ntotal += n;
27
26
  }
28
27
 
29
28
  void IndexBinaryFlat::reset() {
30
- xb.clear();
31
- ntotal = 0;
29
+ xb.clear();
30
+ ntotal = 0;
32
31
  }
33
32
 
34
- void IndexBinaryFlat::search(idx_t n, const uint8_t *x, idx_t k,
35
- int32_t *distances, idx_t *labels) const {
36
- const idx_t block_size = query_batch_size;
37
- for (idx_t s = 0; s < n; s += block_size) {
38
- idx_t nn = block_size;
39
- if (s + block_size > n) {
40
- nn = n - s;
41
- }
33
+ void IndexBinaryFlat::search(
34
+ idx_t n,
35
+ const uint8_t* x,
36
+ idx_t k,
37
+ int32_t* distances,
38
+ idx_t* labels) const {
39
+ FAISS_THROW_IF_NOT(k > 0);
42
40
 
43
- if (use_heap) {
44
- // We see the distances and labels as heaps.
45
- int_maxheap_array_t res = {
46
- size_t(nn), size_t(k), labels + s * k, distances + s * k
47
- };
41
+ const idx_t block_size = query_batch_size;
42
+ for (idx_t s = 0; s < n; s += block_size) {
43
+ idx_t nn = block_size;
44
+ if (s + block_size > n) {
45
+ nn = n - s;
46
+ }
48
47
 
49
- hammings_knn_hc(&res, x + s * code_size, xb.data(), ntotal, code_size,
50
- /* ordered = */ true);
51
- } else {
52
- hammings_knn_mc(x + s * code_size, xb.data(), nn, ntotal, k, code_size,
53
- distances + s * k, labels + s * k);
48
+ if (use_heap) {
49
+ // We see the distances and labels as heaps.
50
+ int_maxheap_array_t res = {
51
+ size_t(nn), size_t(k), labels + s * k, distances + s * k};
52
+
53
+ hammings_knn_hc(
54
+ &res,
55
+ x + s * code_size,
56
+ xb.data(),
57
+ ntotal,
58
+ code_size,
59
+ /* ordered = */ true);
60
+ } else {
61
+ hammings_knn_mc(
62
+ x + s * code_size,
63
+ xb.data(),
64
+ nn,
65
+ ntotal,
66
+ k,
67
+ code_size,
68
+ distances + s * k,
69
+ labels + s * k);
70
+ }
54
71
  }
55
- }
56
72
  }
57
73
 
58
74
  size_t IndexBinaryFlat::remove_ids(const IDSelector& sel) {
59
- idx_t j = 0;
60
- for (idx_t i = 0; i < ntotal; i++) {
61
- if (sel.is_member(i)) {
62
- // should be removed
63
- } else {
64
- if (i > j) {
65
- memmove(&xb[code_size * j], &xb[code_size * i], sizeof(xb[0]) * code_size);
66
- }
67
- j++;
75
+ idx_t j = 0;
76
+ for (idx_t i = 0; i < ntotal; i++) {
77
+ if (sel.is_member(i)) {
78
+ // should be removed
79
+ } else {
80
+ if (i > j) {
81
+ memmove(&xb[code_size * j],
82
+ &xb[code_size * i],
83
+ sizeof(xb[0]) * code_size);
84
+ }
85
+ j++;
86
+ }
87
+ }
88
+ long nremove = ntotal - j;
89
+ if (nremove > 0) {
90
+ ntotal = j;
91
+ xb.resize(ntotal * code_size);
68
92
  }
69
- }
70
- long nremove = ntotal - j;
71
- if (nremove > 0) {
72
- ntotal = j;
73
- xb.resize(ntotal * code_size);
74
- }
75
- return nremove;
93
+ return nremove;
76
94
  }
77
95
 
78
- void IndexBinaryFlat::reconstruct(idx_t key, uint8_t *recons) const {
79
- memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
96
+ void IndexBinaryFlat::reconstruct(idx_t key, uint8_t* recons) const {
97
+ memcpy(recons, &(xb[code_size * key]), sizeof(*recons) * code_size);
80
98
  }
81
99
 
82
- void IndexBinaryFlat::range_search(idx_t n, const uint8_t *x, int radius,
83
- RangeSearchResult *result) const
84
- {
85
- hamming_range_search (x, xb.data(), n, ntotal, radius, code_size, result);
100
+ void IndexBinaryFlat::range_search(
101
+ idx_t n,
102
+ const uint8_t* x,
103
+ int radius,
104
+ RangeSearchResult* result) const {
105
+ hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
86
106
  }
87
107
 
88
- } // namespace faiss
108
+ } // namespace faiss
@@ -16,42 +16,47 @@
16
16
 
17
17
  namespace faiss {
18
18
 
19
-
20
19
  /** Index that stores the full vectors and performs exhaustive search. */
21
20
  struct IndexBinaryFlat : IndexBinary {
22
- /// database vectors, size ntotal * d / 8
23
- std::vector<uint8_t> xb;
21
+ /// database vectors, size ntotal * d / 8
22
+ std::vector<uint8_t> xb;
24
23
 
25
- /** Select between using a heap or counting to select the k smallest values
26
- * when scanning inverted lists.
27
- */
28
- bool use_heap = true;
24
+ /** Select between using a heap or counting to select the k smallest values
25
+ * when scanning inverted lists.
26
+ */
27
+ bool use_heap = true;
29
28
 
30
- size_t query_batch_size = 32;
29
+ size_t query_batch_size = 32;
31
30
 
32
- explicit IndexBinaryFlat(idx_t d);
31
+ explicit IndexBinaryFlat(idx_t d);
33
32
 
34
- void add(idx_t n, const uint8_t *x) override;
33
+ void add(idx_t n, const uint8_t* x) override;
35
34
 
36
- void reset() override;
35
+ void reset() override;
37
36
 
38
- void search(idx_t n, const uint8_t *x, idx_t k,
39
- int32_t *distances, idx_t *labels) const override;
37
+ void search(
38
+ idx_t n,
39
+ const uint8_t* x,
40
+ idx_t k,
41
+ int32_t* distances,
42
+ idx_t* labels) const override;
40
43
 
41
- void range_search(idx_t n, const uint8_t *x, int radius,
42
- RangeSearchResult *result) const override;
44
+ void range_search(
45
+ idx_t n,
46
+ const uint8_t* x,
47
+ int radius,
48
+ RangeSearchResult* result) const override;
43
49
 
44
- void reconstruct(idx_t key, uint8_t *recons) const override;
50
+ void reconstruct(idx_t key, uint8_t* recons) const override;
45
51
 
46
- /** Remove some ids. Note that because of the indexing structure,
47
- * the semantics of this operation are different from the usual ones:
48
- * the new ids are shifted. */
49
- size_t remove_ids(const IDSelector& sel) override;
52
+ /** Remove some ids. Note that because of the indexing structure,
53
+ * the semantics of this operation are different from the usual ones:
54
+ * the new ids are shifted. */
55
+ size_t remove_ids(const IDSelector& sel) override;
50
56
 
51
- IndexBinaryFlat() {}
57
+ IndexBinaryFlat() {}
52
58
  };
53
59
 
60
+ } // namespace faiss
54
61
 
55
- } // namespace faiss
56
-
57
- #endif // INDEX_BINARY_FLAT_H
62
+ #endif // INDEX_BINARY_FLAT_H
@@ -9,71 +9,74 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFromFloat.h>
11
11
 
12
+ #include <faiss/utils/utils.h>
12
13
  #include <algorithm>
13
14
  #include <memory>
14
- #include <faiss/utils/utils.h>
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  IndexBinaryFromFloat::IndexBinaryFromFloat() {}
20
19
 
21
- IndexBinaryFromFloat::IndexBinaryFromFloat(Index *index)
22
- : IndexBinary(index->d),
23
- index(index),
24
- own_fields(false) {
25
- is_trained = index->is_trained;
26
- ntotal = index->ntotal;
20
+ IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index)
21
+ : IndexBinary(index->d), index(index), own_fields(false) {
22
+ is_trained = index->is_trained;
23
+ ntotal = index->ntotal;
27
24
  }
28
25
 
29
26
  IndexBinaryFromFloat::~IndexBinaryFromFloat() {
30
- if (own_fields) {
31
- delete index;
32
- }
27
+ if (own_fields) {
28
+ delete index;
29
+ }
33
30
  }
34
31
 
35
- void IndexBinaryFromFloat::add(idx_t n, const uint8_t *x) {
36
- constexpr idx_t bs = 32768;
37
- std::unique_ptr<float[]> xf(new float[bs * d]);
32
+ void IndexBinaryFromFloat::add(idx_t n, const uint8_t* x) {
33
+ constexpr idx_t bs = 32768;
34
+ std::unique_ptr<float[]> xf(new float[bs * d]);
38
35
 
39
- for (idx_t b = 0; b < n; b += bs) {
40
- idx_t bn = std::min(bs, n - b);
41
- binary_to_real(bn * d, x + b * code_size, xf.get());
36
+ for (idx_t b = 0; b < n; b += bs) {
37
+ idx_t bn = std::min(bs, n - b);
38
+ binary_to_real(bn * d, x + b * code_size, xf.get());
42
39
 
43
- index->add(bn, xf.get());
44
- }
45
- ntotal = index->ntotal;
40
+ index->add(bn, xf.get());
41
+ }
42
+ ntotal = index->ntotal;
46
43
  }
47
44
 
48
45
  void IndexBinaryFromFloat::reset() {
49
- index->reset();
50
- ntotal = index->ntotal;
46
+ index->reset();
47
+ ntotal = index->ntotal;
51
48
  }
52
49
 
53
- void IndexBinaryFromFloat::search(idx_t n, const uint8_t *x, idx_t k,
54
- int32_t *distances, idx_t *labels) const {
55
- constexpr idx_t bs = 32768;
56
- std::unique_ptr<float[]> xf(new float[bs * d]);
57
- std::unique_ptr<float[]> df(new float[bs * k]);
58
-
59
- for (idx_t b = 0; b < n; b += bs) {
60
- idx_t bn = std::min(bs, n - b);
61
- binary_to_real(bn * d, x + b * code_size, xf.get());
62
-
63
- index->search(bn, xf.get(), k, df.get(), labels + b * k);
64
- for (int i = 0; i < bn * k; ++i) {
65
- distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
50
+ void IndexBinaryFromFloat::search(
51
+ idx_t n,
52
+ const uint8_t* x,
53
+ idx_t k,
54
+ int32_t* distances,
55
+ idx_t* labels) const {
56
+ FAISS_THROW_IF_NOT(k > 0);
57
+
58
+ constexpr idx_t bs = 32768;
59
+ std::unique_ptr<float[]> xf(new float[bs * d]);
60
+ std::unique_ptr<float[]> df(new float[bs * k]);
61
+
62
+ for (idx_t b = 0; b < n; b += bs) {
63
+ idx_t bn = std::min(bs, n - b);
64
+ binary_to_real(bn * d, x + b * code_size, xf.get());
65
+
66
+ index->search(bn, xf.get(), k, df.get(), labels + b * k);
67
+ for (int i = 0; i < bn * k; ++i) {
68
+ distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
69
+ }
66
70
  }
67
- }
68
71
  }
69
72
 
70
- void IndexBinaryFromFloat::train(idx_t n, const uint8_t *x) {
71
- std::unique_ptr<float[]> xf(new float[n * d]);
72
- binary_to_real(n * d, x, xf.get());
73
+ void IndexBinaryFromFloat::train(idx_t n, const uint8_t* x) {
74
+ std::unique_ptr<float[]> xf(new float[n * d]);
75
+ binary_to_real(n * d, x, xf.get());
73
76
 
74
- index->train(n, xf.get());
75
- is_trained = true;
76
- ntotal = index->ntotal;
77
+ index->train(n, xf.get());
78
+ is_trained = true;
79
+ ntotal = index->ntotal;
77
80
  }
78
81
 
79
- } // namespace faiss
82
+ } // namespace faiss
@@ -12,10 +12,8 @@
12
12
 
13
13
  #include <faiss/IndexBinary.h>
14
14
 
15
-
16
15
  namespace faiss {
17
16
 
18
-
19
17
  struct Index;
20
18
 
21
19
  /** IndexBinary backed by a float Index.
@@ -26,27 +24,30 @@ struct Index;
26
24
  * vectors.
27
25
  */
28
26
  struct IndexBinaryFromFloat : IndexBinary {
29
- Index *index = nullptr;
27
+ Index* index = nullptr;
30
28
 
31
- bool own_fields = false; ///< Whether object owns the index pointer.
29
+ bool own_fields = false; ///< Whether object owns the index pointer.
32
30
 
33
- IndexBinaryFromFloat();
31
+ IndexBinaryFromFloat();
34
32
 
35
- explicit IndexBinaryFromFloat(Index *index);
33
+ explicit IndexBinaryFromFloat(Index* index);
36
34
 
37
- ~IndexBinaryFromFloat();
35
+ ~IndexBinaryFromFloat();
38
36
 
39
- void add(idx_t n, const uint8_t *x) override;
37
+ void add(idx_t n, const uint8_t* x) override;
40
38
 
41
- void reset() override;
39
+ void reset() override;
42
40
 
43
- void search(idx_t n, const uint8_t *x, idx_t k,
44
- int32_t *distances, idx_t *labels) const override;
41
+ void search(
42
+ idx_t n,
43
+ const uint8_t* x,
44
+ idx_t k,
45
+ int32_t* distances,
46
+ idx_t* labels) const override;
45
47
 
46
- void train(idx_t n, const uint8_t *x) override;
48
+ void train(idx_t n, const uint8_t* x) override;
47
49
  };
48
50
 
51
+ } // namespace faiss
49
52
 
50
- } // namespace faiss
51
-
52
- #endif // FAISS_INDEX_BINARY_FROM_FLOAT_H
53
+ #endif // FAISS_INDEX_BINARY_FROM_FLOAT_H
@@ -9,316 +9,299 @@
9
9
 
10
10
  #include <faiss/IndexBinaryHNSW.h>
11
11
 
12
-
13
- #include <memory>
14
- #include <cstdlib>
12
+ #include <omp.h>
15
13
  #include <cassert>
16
- #include <cstring>
17
- #include <cstdio>
18
14
  #include <cmath>
19
- #include <omp.h>
15
+ #include <cstdio>
16
+ #include <cstdlib>
17
+ #include <cstring>
18
+ #include <memory>
20
19
 
21
- #include <unordered_set>
22
20
  #include <queue>
21
+ #include <unordered_set>
23
22
 
24
- #include <sys/types.h>
25
- #include <sys/stat.h>
26
23
  #include <stdint.h>
24
+ #include <sys/stat.h>
25
+ #include <sys/types.h>
27
26
 
28
- #include <faiss/utils/random.h>
29
- #include <faiss/utils/Heap.h>
30
- #include <faiss/impl/FaissAssert.h>
31
27
  #include <faiss/IndexBinaryFlat.h>
32
- #include <faiss/utils/hamming.h>
33
28
  #include <faiss/impl/AuxIndexStructures.h>
29
+ #include <faiss/impl/FaissAssert.h>
30
+ #include <faiss/utils/Heap.h>
31
+ #include <faiss/utils/hamming.h>
32
+ #include <faiss/utils/random.h>
34
33
 
35
34
  namespace faiss {
36
35
 
37
-
38
36
  /**************************************************************
39
37
  * add / search blocks of descriptors
40
38
  **************************************************************/
41
39
 
42
40
  namespace {
43
41
 
42
+ void hnsw_add_vertices(
43
+ IndexBinaryHNSW& index_hnsw,
44
+ size_t n0,
45
+ size_t n,
46
+ const uint8_t* x,
47
+ bool verbose,
48
+ bool preset_levels = false) {
49
+ HNSW& hnsw = index_hnsw.hnsw;
50
+ size_t ntotal = n0 + n;
51
+ double t0 = getmillisecs();
52
+ if (verbose) {
53
+ printf("hnsw_add_vertices: adding %zd elements on top of %zd "
54
+ "(preset_levels=%d)\n",
55
+ n,
56
+ n0,
57
+ int(preset_levels));
58
+ }
59
+
60
+ int max_level = hnsw.prepare_level_tab(n, preset_levels);
44
61
 
45
- void hnsw_add_vertices(IndexBinaryHNSW& index_hnsw,
46
- size_t n0,
47
- size_t n, const uint8_t *x,
48
- bool verbose,
49
- bool preset_levels = false) {
50
- HNSW& hnsw = index_hnsw.hnsw;
51
- size_t ntotal = n0 + n;
52
- double t0 = getmillisecs();
53
- if (verbose) {
54
- printf("hnsw_add_vertices: adding %zd elements on top of %zd "
55
- "(preset_levels=%d)\n",
56
- n, n0, int(preset_levels));
57
- }
58
-
59
- int max_level = hnsw.prepare_level_tab(n, preset_levels);
60
-
61
- if (verbose) {
62
- printf(" max_level = %d\n", max_level);
63
- }
64
-
65
- std::vector<omp_lock_t> locks(ntotal);
66
- for(int i = 0; i < ntotal; i++) {
67
- omp_init_lock(&locks[i]);
68
- }
69
-
70
- // add vectors from highest to lowest level
71
- std::vector<int> hist;
72
- std::vector<int> order(n);
73
-
74
- { // make buckets with vectors of the same level
75
-
76
- // build histogram
77
- for (int i = 0; i < n; i++) {
78
- HNSW::storage_idx_t pt_id = i + n0;
79
- int pt_level = hnsw.levels[pt_id] - 1;
80
- while (pt_level >= hist.size()) {
81
- hist.push_back(0);
82
- }
83
- hist[pt_level] ++;
62
+ if (verbose) {
63
+ printf(" max_level = %d\n", max_level);
84
64
  }
85
65
 
86
- // accumulate
87
- std::vector<int> offsets(hist.size() + 1, 0);
88
- for (int i = 0; i < hist.size() - 1; i++) {
89
- offsets[i + 1] = offsets[i] + hist[i];
66
+ std::vector<omp_lock_t> locks(ntotal);
67
+ for (int i = 0; i < ntotal; i++) {
68
+ omp_init_lock(&locks[i]);
90
69
  }
91
70
 
92
- // bucket sort
93
- for (int i = 0; i < n; i++) {
94
- HNSW::storage_idx_t pt_id = i + n0;
95
- int pt_level = hnsw.levels[pt_id] - 1;
96
- order[offsets[pt_level]++] = pt_id;
71
+ // add vectors from highest to lowest level
72
+ std::vector<int> hist;
73
+ std::vector<int> order(n);
74
+
75
+ { // make buckets with vectors of the same level
76
+
77
+ // build histogram
78
+ for (int i = 0; i < n; i++) {
79
+ HNSW::storage_idx_t pt_id = i + n0;
80
+ int pt_level = hnsw.levels[pt_id] - 1;
81
+ while (pt_level >= hist.size()) {
82
+ hist.push_back(0);
83
+ }
84
+ hist[pt_level]++;
85
+ }
86
+
87
+ // accumulate
88
+ std::vector<int> offsets(hist.size() + 1, 0);
89
+ for (int i = 0; i < hist.size() - 1; i++) {
90
+ offsets[i + 1] = offsets[i] + hist[i];
91
+ }
92
+
93
+ // bucket sort
94
+ for (int i = 0; i < n; i++) {
95
+ HNSW::storage_idx_t pt_id = i + n0;
96
+ int pt_level = hnsw.levels[pt_id] - 1;
97
+ order[offsets[pt_level]++] = pt_id;
98
+ }
97
99
  }
98
- }
99
100
 
100
- { // perform add
101
- RandomGenerator rng2(789);
101
+ { // perform add
102
+ RandomGenerator rng2(789);
102
103
 
103
- int i1 = n;
104
+ int i1 = n;
104
105
 
105
- for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
106
- int i0 = i1 - hist[pt_level];
106
+ for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
107
+ int i0 = i1 - hist[pt_level];
107
108
 
108
- if (verbose) {
109
- printf("Adding %d elements at level %d\n",
110
- i1 - i0, pt_level);
111
- }
109
+ if (verbose) {
110
+ printf("Adding %d elements at level %d\n", i1 - i0, pt_level);
111
+ }
112
112
 
113
- // random permutation to get rid of dataset order bias
114
- for (int j = i0; j < i1; j++) {
115
- std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
116
- }
113
+ // random permutation to get rid of dataset order bias
114
+ for (int j = i0; j < i1; j++) {
115
+ std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
116
+ }
117
117
 
118
118
  #pragma omp parallel
119
- {
120
- VisitedTable vt (ntotal);
121
-
122
- std::unique_ptr<DistanceComputer> dis(
123
- index_hnsw.get_distance_computer()
124
- );
125
- int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
126
-
127
- #pragma omp for schedule(dynamic)
128
- for (int i = i0; i < i1; i++) {
129
- HNSW::storage_idx_t pt_id = order[i];
130
- dis->set_query((float *)(x + (pt_id - n0) * index_hnsw.code_size));
131
-
132
- hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
133
-
134
- if (prev_display >= 0 && i - i0 > prev_display + 10000) {
135
- prev_display = i - i0;
136
- printf(" %d / %d\r", i - i0, i1 - i0);
137
- fflush(stdout);
138
- }
119
+ {
120
+ VisitedTable vt(ntotal);
121
+
122
+ std::unique_ptr<DistanceComputer> dis(
123
+ index_hnsw.get_distance_computer());
124
+ int prev_display =
125
+ verbose && omp_get_thread_num() == 0 ? 0 : -1;
126
+
127
+ #pragma omp for schedule(dynamic)
128
+ for (int i = i0; i < i1; i++) {
129
+ HNSW::storage_idx_t pt_id = order[i];
130
+ dis->set_query(
131
+ (float*)(x + (pt_id - n0) * index_hnsw.code_size));
132
+
133
+ hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
134
+
135
+ if (prev_display >= 0 && i - i0 > prev_display + 10000) {
136
+ prev_display = i - i0;
137
+ printf(" %d / %d\r", i - i0, i1 - i0);
138
+ fflush(stdout);
139
+ }
140
+ }
141
+ }
142
+ i1 = i0;
139
143
  }
140
- }
141
- i1 = i0;
144
+ FAISS_ASSERT(i1 == 0);
145
+ }
146
+ if (verbose) {
147
+ printf("Done in %.3f ms\n", getmillisecs() - t0);
142
148
  }
143
- FAISS_ASSERT(i1 == 0);
144
- }
145
- if (verbose) {
146
- printf("Done in %.3f ms\n", getmillisecs() - t0);
147
- }
148
-
149
- for(int i = 0; i < ntotal; i++)
150
- omp_destroy_lock(&locks[i]);
151
- }
152
149
 
150
+ for (int i = 0; i < ntotal; i++)
151
+ omp_destroy_lock(&locks[i]);
152
+ }
153
153
 
154
154
  } // anonymous namespace
155
155
 
156
-
157
156
  /**************************************************************
158
157
  * IndexBinaryHNSW implementation
159
158
  **************************************************************/
160
159
 
161
- IndexBinaryHNSW::IndexBinaryHNSW()
162
- {
163
- is_trained = true;
160
+ IndexBinaryHNSW::IndexBinaryHNSW() {
161
+ is_trained = true;
164
162
  }
165
163
 
166
164
  IndexBinaryHNSW::IndexBinaryHNSW(int d, int M)
167
- : IndexBinary(d),
168
- hnsw(M),
169
- own_fields(true),
170
- storage(new IndexBinaryFlat(d))
171
- {
172
- is_trained = true;
165
+ : IndexBinary(d),
166
+ hnsw(M),
167
+ own_fields(true),
168
+ storage(new IndexBinaryFlat(d)) {
169
+ is_trained = true;
173
170
  }
174
171
 
175
- IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary *storage, int M)
176
- : IndexBinary(storage->d),
177
- hnsw(M),
178
- own_fields(false),
179
- storage(storage)
180
- {
181
- is_trained = true;
172
+ IndexBinaryHNSW::IndexBinaryHNSW(IndexBinary* storage, int M)
173
+ : IndexBinary(storage->d),
174
+ hnsw(M),
175
+ own_fields(false),
176
+ storage(storage) {
177
+ is_trained = true;
182
178
  }
183
179
 
184
180
  IndexBinaryHNSW::~IndexBinaryHNSW() {
185
- if (own_fields) {
186
- delete storage;
187
- }
181
+ if (own_fields) {
182
+ delete storage;
183
+ }
188
184
  }
189
185
 
190
- void IndexBinaryHNSW::train(idx_t n, const uint8_t *x)
191
- {
192
- // hnsw structure does not require training
193
- storage->train(n, x);
194
- is_trained = true;
186
+ void IndexBinaryHNSW::train(idx_t n, const uint8_t* x) {
187
+ // hnsw structure does not require training
188
+ storage->train(n, x);
189
+ is_trained = true;
195
190
  }
196
191
 
197
- void IndexBinaryHNSW::search(idx_t n, const uint8_t *x, idx_t k,
198
- int32_t *distances, idx_t *labels) const
199
- {
192
+ void IndexBinaryHNSW::search(
193
+ idx_t n,
194
+ const uint8_t* x,
195
+ idx_t k,
196
+ int32_t* distances,
197
+ idx_t* labels) const {
198
+ FAISS_THROW_IF_NOT(k > 0);
199
+
200
200
  #pragma omp parallel
201
- {
202
- VisitedTable vt(ntotal);
203
- std::unique_ptr<DistanceComputer> dis(get_distance_computer());
201
+ {
202
+ VisitedTable vt(ntotal);
203
+ std::unique_ptr<DistanceComputer> dis(get_distance_computer());
204
204
 
205
205
  #pragma omp for
206
- for(idx_t i = 0; i < n; i++) {
207
- idx_t *idxi = labels + i * k;
208
- float *simi = (float *)(distances + i * k);
206
+ for (idx_t i = 0; i < n; i++) {
207
+ idx_t* idxi = labels + i * k;
208
+ float* simi = (float*)(distances + i * k);
209
209
 
210
- dis->set_query((float *)(x + i * code_size));
210
+ dis->set_query((float*)(x + i * code_size));
211
211
 
212
- maxheap_heapify(k, simi, idxi);
213
- hnsw.search(*dis, k, idxi, simi, vt);
214
- maxheap_reorder(k, simi, idxi);
212
+ maxheap_heapify(k, simi, idxi);
213
+ hnsw.search(*dis, k, idxi, simi, vt);
214
+ maxheap_reorder(k, simi, idxi);
215
+ }
215
216
  }
216
- }
217
217
 
218
218
  #pragma omp parallel for
219
- for (int i = 0; i < n * k; ++i) {
220
- distances[i] = std::round(((float *)distances)[i]);
221
- }
219
+ for (int i = 0; i < n * k; ++i) {
220
+ distances[i] = std::round(((float*)distances)[i]);
221
+ }
222
222
  }
223
223
 
224
+ void IndexBinaryHNSW::add(idx_t n, const uint8_t* x) {
225
+ FAISS_THROW_IF_NOT(is_trained);
226
+ int n0 = ntotal;
227
+ storage->add(n, x);
228
+ ntotal = storage->ntotal;
224
229
 
225
- void IndexBinaryHNSW::add(idx_t n, const uint8_t *x)
226
- {
227
- FAISS_THROW_IF_NOT(is_trained);
228
- int n0 = ntotal;
229
- storage->add(n, x);
230
- ntotal = storage->ntotal;
231
-
232
- hnsw_add_vertices(*this, n0, n, x, verbose,
233
- hnsw.levels.size() == ntotal);
230
+ hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal);
234
231
  }
235
232
 
236
- void IndexBinaryHNSW::reset()
237
- {
238
- hnsw.reset();
239
- storage->reset();
240
- ntotal = 0;
233
+ void IndexBinaryHNSW::reset() {
234
+ hnsw.reset();
235
+ storage->reset();
236
+ ntotal = 0;
241
237
  }
242
238
 
243
- void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t *recons) const
244
- {
245
- storage->reconstruct(key, recons);
239
+ void IndexBinaryHNSW::reconstruct(idx_t key, uint8_t* recons) const {
240
+ storage->reconstruct(key, recons);
246
241
  }
247
242
 
248
-
249
243
  namespace {
250
244
 
251
-
252
- template<class HammingComputer>
245
+ template <class HammingComputer>
253
246
  struct FlatHammingDis : DistanceComputer {
254
- const int code_size;
255
- const uint8_t *b;
256
- size_t ndis;
257
- HammingComputer hc;
258
-
259
- float operator () (idx_t i) override {
260
- ndis++;
261
- return hc.hamming(b + i * code_size);
262
- }
263
-
264
- float symmetric_dis(idx_t i, idx_t j) override {
265
- return HammingComputerDefault(b + j * code_size, code_size)
266
- .hamming(b + i * code_size);
267
- }
268
-
269
-
270
- explicit FlatHammingDis(const IndexBinaryFlat& storage)
271
- : code_size(storage.code_size),
272
- b(storage.xb.data()),
273
- ndis(0),
274
- hc() {}
275
-
276
- // NOTE: Pointers are cast from float in order to reuse the floating-point
277
- // DistanceComputer.
278
- void set_query(const float *x) override {
279
- hc.set((uint8_t *)x, code_size);
280
- }
281
-
282
- ~FlatHammingDis() override {
247
+ const int code_size;
248
+ const uint8_t* b;
249
+ size_t ndis;
250
+ HammingComputer hc;
251
+
252
+ float operator()(idx_t i) override {
253
+ ndis++;
254
+ return hc.hamming(b + i * code_size);
255
+ }
256
+
257
+ float symmetric_dis(idx_t i, idx_t j) override {
258
+ return HammingComputerDefault(b + j * code_size, code_size)
259
+ .hamming(b + i * code_size);
260
+ }
261
+
262
+ explicit FlatHammingDis(const IndexBinaryFlat& storage)
263
+ : code_size(storage.code_size),
264
+ b(storage.xb.data()),
265
+ ndis(0),
266
+ hc() {}
267
+
268
+ // NOTE: Pointers are cast from float in order to reuse the floating-point
269
+ // DistanceComputer.
270
+ void set_query(const float* x) override {
271
+ hc.set((uint8_t*)x, code_size);
272
+ }
273
+
274
+ ~FlatHammingDis() override {
283
275
  #pragma omp critical
284
- {
285
- hnsw_stats.ndis += ndis;
276
+ { hnsw_stats.ndis += ndis; }
286
277
  }
287
- }
288
278
  };
289
279
 
280
+ } // namespace
281
+
282
+ DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
283
+ IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
284
+
285
+ FAISS_ASSERT(flat_storage != nullptr);
286
+
287
+ switch (code_size) {
288
+ case 4:
289
+ return new FlatHammingDis<HammingComputer4>(*flat_storage);
290
+ case 8:
291
+ return new FlatHammingDis<HammingComputer8>(*flat_storage);
292
+ case 16:
293
+ return new FlatHammingDis<HammingComputer16>(*flat_storage);
294
+ case 20:
295
+ return new FlatHammingDis<HammingComputer20>(*flat_storage);
296
+ case 32:
297
+ return new FlatHammingDis<HammingComputer32>(*flat_storage);
298
+ case 64:
299
+ return new FlatHammingDis<HammingComputer64>(*flat_storage);
300
+ default:
301
+ break;
302
+ }
290
303
 
291
- } // namespace
292
-
293
-
294
- DistanceComputer *IndexBinaryHNSW::get_distance_computer() const {
295
- IndexBinaryFlat *flat_storage = dynamic_cast<IndexBinaryFlat *>(storage);
296
-
297
- FAISS_ASSERT(flat_storage != nullptr);
298
-
299
- switch(code_size) {
300
- case 4:
301
- return new FlatHammingDis<HammingComputer4>(*flat_storage);
302
- case 8:
303
- return new FlatHammingDis<HammingComputer8>(*flat_storage);
304
- case 16:
305
- return new FlatHammingDis<HammingComputer16>(*flat_storage);
306
- case 20:
307
- return new FlatHammingDis<HammingComputer20>(*flat_storage);
308
- case 32:
309
- return new FlatHammingDis<HammingComputer32>(*flat_storage);
310
- case 64:
311
- return new FlatHammingDis<HammingComputer64>(*flat_storage);
312
- default:
313
- if (code_size % 8 == 0) {
314
- return new FlatHammingDis<HammingComputerM8>(*flat_storage);
315
- } else if (code_size % 4 == 0) {
316
- return new FlatHammingDis<HammingComputerM4>(*flat_storage);
317
- }
318
- }
319
-
320
- return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
304
+ return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
321
305
  }
322
306
 
323
-
324
307
  } // namespace faiss