faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -5,24 +5,21 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexPQFastScan.h>
10
9
 
10
+ #include <limits.h>
11
11
  #include <cassert>
12
12
  #include <memory>
13
- #include <limits.h>
14
13
 
15
14
  #include <omp.h>
16
15
 
17
-
18
16
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/utils/utils.h>
20
17
  #include <faiss/utils/random.h>
18
+ #include <faiss/utils/utils.h>
21
19
 
20
+ #include <faiss/impl/pq4_fast_scan.h>
22
21
  #include <faiss/impl/simd_result_handlers.h>
23
22
  #include <faiss/utils/quantize_lut.h>
24
- #include <faiss/impl/pq4_fast_scan.h>
25
-
26
23
 
27
24
  namespace faiss {
28
25
 
@@ -33,25 +30,24 @@ inline size_t roundup(size_t a, size_t b) {
33
30
  }
34
31
 
35
32
  IndexPQFastScan::IndexPQFastScan(
36
- int d, size_t M, size_t nbits,
33
+ int d,
34
+ size_t M,
35
+ size_t nbits,
37
36
  MetricType metric,
38
- int bbs):
39
- Index(d, metric), pq(d, M, nbits),
40
- bbs(bbs), ntotal2(0), M2(roundup(M, 2))
41
- {
37
+ int bbs)
38
+ : Index(d, metric),
39
+ pq(d, M, nbits),
40
+ bbs(bbs),
41
+ ntotal2(0),
42
+ M2(roundup(M, 2)) {
42
43
  FAISS_THROW_IF_NOT(nbits == 4);
43
44
  is_trained = false;
44
45
  }
45
46
 
46
- IndexPQFastScan::IndexPQFastScan():
47
- bbs(0), ntotal2(0), M2(0)
48
- {}
47
+ IndexPQFastScan::IndexPQFastScan() : bbs(0), ntotal2(0), M2(0) {}
49
48
 
50
- IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
51
- Index(orig.d, orig.metric_type),
52
- pq(orig.pq),
53
- bbs(bbs)
54
- {
49
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
50
+ : Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
55
51
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
56
52
  ntotal = orig.ntotal;
57
53
  is_trained = orig.is_trained;
@@ -70,16 +66,10 @@ IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
70
66
  codes.resize(ntotal2 * M2 / 2);
71
67
 
72
68
  // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
73
- pq4_pack_codes(
74
- orig.codes.data(),
75
- ntotal, M,
76
- ntotal2, bbs, M2,
77
- codes.get()
78
- );
69
+ pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
79
70
  }
80
71
 
81
- void IndexPQFastScan::train (idx_t n, const float *x)
82
- {
72
+ void IndexPQFastScan::train(idx_t n, const float* x) {
83
73
  if (is_trained) {
84
74
  return;
85
75
  }
@@ -87,11 +77,10 @@ void IndexPQFastScan::train (idx_t n, const float *x)
87
77
  is_trained = true;
88
78
  }
89
79
 
90
-
91
- void IndexPQFastScan::add (idx_t n, const float *x) {
92
- FAISS_THROW_IF_NOT (is_trained);
80
+ void IndexPQFastScan::add(idx_t n, const float* x) {
81
+ FAISS_THROW_IF_NOT(is_trained);
93
82
  AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
94
- pq.compute_codes (x, tmp_codes.get(), n);
83
+ pq.compute_codes(x, tmp_codes.get(), n);
95
84
  ntotal2 = roundup(ntotal + n, bbs);
96
85
  size_t new_size = ntotal2 * M2 / 2;
97
86
  size_t old_size = codes.size();
@@ -100,39 +89,35 @@ void IndexPQFastScan::add (idx_t n, const float *x) {
100
89
  memset(codes.get() + old_size, 0, new_size - old_size);
101
90
  }
102
91
  pq4_pack_codes_range(
103
- tmp_codes.get(), pq.M, ntotal, ntotal + n,
104
- bbs, M2, codes.get()
105
- );
92
+ tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
106
93
  ntotal += n;
107
94
  }
108
95
 
109
- void IndexPQFastScan::reset()
110
- {
96
+ void IndexPQFastScan::reset() {
111
97
  codes.resize(0);
112
98
  ntotal = 0;
113
- }
114
-
115
-
99
+ }
116
100
 
117
101
  namespace {
118
102
 
119
103
  // from impl/ProductQuantizer.cpp
120
104
  template <class C, typename dis_t>
121
105
  void pq_estimators_from_tables_generic(
122
- const ProductQuantizer& pq, size_t nbits,
123
- const uint8_t *codes, size_t ncodes,
124
- const dis_t *dis_table, size_t k,
125
- typename C::T *heap_dis, int64_t *heap_ids)
126
- {
106
+ const ProductQuantizer& pq,
107
+ size_t nbits,
108
+ const uint8_t* codes,
109
+ size_t ncodes,
110
+ const dis_t* dis_table,
111
+ size_t k,
112
+ typename C::T* heap_dis,
113
+ int64_t* heap_ids) {
127
114
  using accu_t = typename C::T;
128
115
  const size_t M = pq.M;
129
116
  const size_t ksub = pq.ksub;
130
117
  for (size_t j = 0; j < ncodes; ++j) {
131
- PQDecoderGeneric decoder(
132
- codes + j * pq.code_size, nbits
133
- );
118
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
134
119
  accu_t dis = 0;
135
- const dis_t * __restrict dt = dis_table;
120
+ const dis_t* __restrict dt = dis_table;
136
121
  for (size_t m = 0; m < M; m++) {
137
122
  uint64_t c = decoder.decode();
138
123
  dis += dt[c];
@@ -146,53 +131,55 @@ void pq_estimators_from_tables_generic(
146
131
  }
147
132
  }
148
133
 
149
-
150
134
  } // anonymous namespace
151
135
 
152
-
153
136
  using namespace quantize_lut;
154
137
 
155
138
  void IndexPQFastScan::compute_quantized_LUT(
156
- idx_t n, const float* x,
157
- uint8_t *lut, float *normalizers) const
158
- {
139
+ idx_t n,
140
+ const float* x,
141
+ uint8_t* lut,
142
+ float* normalizers) const {
159
143
  size_t dim12 = pq.ksub * pq.M;
160
- std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
144
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
161
145
  if (metric_type == METRIC_L2) {
162
- pq.compute_distance_tables (n, x, dis_tables.get());
146
+ pq.compute_distance_tables(n, x, dis_tables.get());
163
147
  } else {
164
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
148
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
165
149
  }
166
150
 
167
- for(uint64_t i = 0; i < n; i++) {
151
+ for (uint64_t i = 0; i < n; i++) {
168
152
  round_uint8_per_column(
169
- dis_tables.get() + i * dim12, pq.M, pq.ksub,
170
- &normalizers[2 * i], &normalizers[2 * i + 1]
171
- );
153
+ dis_tables.get() + i * dim12,
154
+ pq.M,
155
+ pq.ksub,
156
+ &normalizers[2 * i],
157
+ &normalizers[2 * i + 1]);
172
158
  }
173
159
 
174
- for(uint64_t i = 0; i < n; i++) {
175
- const float *t_in = dis_tables.get() + i * dim12;
176
- uint8_t *t_out = lut + i * M2 * pq.ksub;
160
+ for (uint64_t i = 0; i < n; i++) {
161
+ const float* t_in = dis_tables.get() + i * dim12;
162
+ uint8_t* t_out = lut + i * M2 * pq.ksub;
177
163
 
178
- for(int j = 0; j < dim12; j++) {
164
+ for (int j = 0; j < dim12; j++) {
179
165
  t_out[j] = int(t_in[j]);
180
166
  }
181
167
  memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
182
168
  }
183
169
  }
184
170
 
185
-
186
-
187
171
  /******************************************************************************
188
172
  * Search driver routine
189
173
  ******************************************************************************/
190
174
 
191
-
192
175
  void IndexPQFastScan::search(
193
- idx_t n, const float* x, idx_t k,
194
- float* distances, idx_t* labels) const
195
- {
176
+ idx_t n,
177
+ const float* x,
178
+ idx_t k,
179
+ float* distances,
180
+ idx_t* labels) const {
181
+ FAISS_THROW_IF_NOT(k > 0);
182
+
196
183
  if (metric_type == METRIC_L2) {
197
184
  search_dispatch_implem<true>(n, x, k, distances, labels);
198
185
  } else {
@@ -200,20 +187,20 @@ void IndexPQFastScan::search(
200
187
  }
201
188
  }
202
189
 
203
-
204
- template<bool is_max>
190
+ template <bool is_max>
205
191
  void IndexPQFastScan::search_dispatch_implem(
206
- idx_t n,
207
- const float* x,
208
- idx_t k,
209
- float* distances,
210
- idx_t* labels) const
211
- {
212
- using Cfloat = typename std::conditional<is_max,
213
- CMax<float, int64_t>, CMin<float, int64_t> >::type;
214
-
215
- using C = typename std::conditional<is_max,
216
- CMax<uint16_t, int>, CMin<uint16_t, int> >::type;
192
+ idx_t n,
193
+ const float* x,
194
+ idx_t k,
195
+ float* distances,
196
+ idx_t* labels) const {
197
+ using Cfloat = typename std::conditional<
198
+ is_max,
199
+ CMax<float, int64_t>,
200
+ CMin<float, int64_t>>::type;
201
+
202
+ using C = typename std::
203
+ conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
217
204
 
218
205
  if (n == 0) {
219
206
  return;
@@ -229,26 +216,24 @@ void IndexPQFastScan::search_dispatch_implem(
229
216
  impl = 14;
230
217
  }
231
218
  if (k > 20) {
232
- impl ++;
219
+ impl++;
233
220
  }
234
221
  }
235
222
 
236
-
237
223
  if (implem == 1) {
238
224
  FAISS_THROW_IF_NOT(orig_codes);
239
225
  FAISS_THROW_IF_NOT(is_max);
240
- float_maxheap_array_t res = {
241
- size_t(n), size_t(k), labels, distances };
242
- pq.search (x, n, orig_codes, ntotal, &res, true);
226
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
227
+ pq.search(x, n, orig_codes, ntotal, &res, true);
243
228
  } else if (implem == 2 || implem == 3 || implem == 4) {
244
229
  FAISS_THROW_IF_NOT(orig_codes);
245
230
 
246
231
  size_t dim12 = pq.ksub * pq.M;
247
- std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
232
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
248
233
  if (is_max) {
249
- pq.compute_distance_tables (n, x, dis_tables.get());
234
+ pq.compute_distance_tables(n, x, dis_tables.get());
250
235
  } else {
251
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
236
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
252
237
  }
253
238
 
254
239
  std::vector<float> normalizers(n * 2);
@@ -256,34 +241,39 @@ void IndexPQFastScan::search_dispatch_implem(
256
241
  if (implem == 2) {
257
242
  // default float
258
243
  } else if (implem == 3 || implem == 4) {
259
- for(uint64_t i = 0; i < n; i++) {
244
+ for (uint64_t i = 0; i < n; i++) {
260
245
  round_uint8_per_column(
261
- dis_tables.get() + i * dim12, pq.M,
246
+ dis_tables.get() + i * dim12,
247
+ pq.M,
262
248
  pq.ksub,
263
- &normalizers[2 * i], &normalizers[2 * i + 1]
264
- );
249
+ &normalizers[2 * i],
250
+ &normalizers[2 * i + 1]);
265
251
  }
266
252
  }
267
253
 
268
254
  for (int64_t i = 0; i < n; i++) {
269
- int64_t *heap_ids = labels + i * k;
270
- float *heap_dis = distances + i * k;
255
+ int64_t* heap_ids = labels + i * k;
256
+ float* heap_dis = distances + i * k;
271
257
 
272
- heap_heapify<Cfloat> (k, heap_dis, heap_ids);
258
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
273
259
 
274
260
  pq_estimators_from_tables_generic<Cfloat>(
275
- pq, pq.nbits, orig_codes, ntotal,
276
- dis_tables.get() + i * dim12,
277
- k, heap_dis, heap_ids
278
- );
261
+ pq,
262
+ pq.nbits,
263
+ orig_codes,
264
+ ntotal,
265
+ dis_tables.get() + i * dim12,
266
+ k,
267
+ heap_dis,
268
+ heap_ids);
279
269
 
280
- heap_reorder<Cfloat> (k, heap_dis, heap_ids);
270
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
281
271
 
282
272
  if (implem == 4) {
283
273
  float a = normalizers[2 * i];
284
274
  float b = normalizers[2 * i + 1];
285
275
 
286
- for(int j = 0; j < k; j++) {
276
+ for (int j = 0; j < k; j++) {
287
277
  heap_dis[j] = heap_dis[j] / a + b;
288
278
  }
289
279
  }
@@ -303,30 +293,30 @@ void IndexPQFastScan::search_dispatch_implem(
303
293
  for (int slice = 0; slice < nt; slice++) {
304
294
  idx_t i0 = n * slice / nt;
305
295
  idx_t i1 = n * (slice + 1) / nt;
306
- float *dis_i = distances + i0 * k;
307
- idx_t *lab_i = labels + i0 * k;
296
+ float* dis_i = distances + i0 * k;
297
+ idx_t* lab_i = labels + i0 * k;
308
298
  if (impl == 12 || impl == 13) {
309
299
  search_implem_12<C>(
310
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
300
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
311
301
  } else {
312
302
  search_implem_14<C>(
313
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
303
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
314
304
  }
315
305
  }
316
306
  }
317
307
  } else {
318
308
  FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
319
309
  }
320
-
321
310
  }
322
311
 
323
- template<class C>
312
+ template <class C>
324
313
  void IndexPQFastScan::search_implem_12(
325
- idx_t n, const float* x, idx_t k,
326
- float* distances, idx_t* labels,
327
- int impl) const
328
- {
329
-
314
+ idx_t n,
315
+ const float* x,
316
+ idx_t k,
317
+ float* distances,
318
+ idx_t* labels,
319
+ int impl) const {
330
320
  FAISS_THROW_IF_NOT(bbs == 32);
331
321
 
332
322
  // handle qbs2 blocking by recursive call
@@ -335,23 +325,25 @@ void IndexPQFastScan::search_implem_12(
335
325
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
336
326
  int64_t i1 = std::min(i0 + qbs2, n);
337
327
  search_implem_12<C>(
338
- i1 - i0, x + d * i0, k,
339
- distances + i0 * k, labels + i0 * k, impl
340
- );
328
+ i1 - i0,
329
+ x + d * i0,
330
+ k,
331
+ distances + i0 * k,
332
+ labels + i0 * k,
333
+ impl);
341
334
  }
342
335
  return;
343
336
  }
344
337
 
345
338
  size_t dim12 = pq.ksub * M2;
346
339
  AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
347
- std::unique_ptr<float []> normalizers(new float[2 * n]);
340
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
348
341
 
349
342
  if (skip & 1) {
350
343
  quantized_dis_tables.clear();
351
344
  } else {
352
345
  compute_quantized_LUT(
353
- n, x, quantized_dis_tables.get(), normalizers.get()
354
- );
346
+ n, x, quantized_dis_tables.get(), normalizers.get());
355
347
  }
356
348
 
357
349
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -365,9 +357,8 @@ void IndexPQFastScan::search_implem_12(
365
357
  qbs = pq4_preferred_qbs(n);
366
358
  }
367
359
 
368
- int LUT_nq = pq4_pack_LUT_qbs(
369
- qbs, M2, quantized_dis_tables.get(), LUT.get()
370
- );
360
+ int LUT_nq =
361
+ pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
371
362
  FAISS_THROW_IF_NOT(LUT_nq == n);
372
363
 
373
364
  if (k == 1) {
@@ -377,37 +368,30 @@ void IndexPQFastScan::search_implem_12(
377
368
  } else {
378
369
  handler.disable = bool(skip & 2);
379
370
  pq4_accumulate_loop_qbs(
380
- qbs, ntotal2, M2,
381
- codes.get(), LUT.get(),
382
- handler
383
- );
371
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
384
372
  }
385
373
 
386
374
  handler.to_flat_arrays(distances, labels, normalizers.get());
387
375
 
388
376
  } else if (impl == 12) {
389
-
390
377
  std::vector<uint16_t> tmp_dis(n * k);
391
378
  std::vector<int32_t> tmp_ids(n * k);
392
379
 
393
380
  if (skip & 4) {
394
381
  // skip
395
382
  } else {
396
- HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
383
+ HeapHandler<C> handler(
384
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
397
385
  handler.disable = bool(skip & 2);
398
386
 
399
387
  pq4_accumulate_loop_qbs(
400
- qbs, ntotal2, M2,
401
- codes.get(), LUT.get(),
402
- handler
403
- );
388
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
404
389
 
405
390
  if (!(skip & 8)) {
406
391
  handler.to_flat_arrays(distances, labels, normalizers.get());
407
392
  }
408
393
  }
409
394
 
410
-
411
395
  } else { // impl == 13
412
396
 
413
397
  ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
@@ -417,10 +401,7 @@ void IndexPQFastScan::search_implem_12(
417
401
  // skip
418
402
  } else {
419
403
  pq4_accumulate_loop_qbs(
420
- qbs, ntotal2, M2,
421
- codes.get(), LUT.get(),
422
- handler
423
- );
404
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
424
405
  }
425
406
 
426
407
  if (!(skip & 8)) {
@@ -431,18 +412,19 @@ void IndexPQFastScan::search_implem_12(
431
412
  FastScan_stats.t1 += handler.times[1];
432
413
  FastScan_stats.t2 += handler.times[2];
433
414
  FastScan_stats.t3 += handler.times[3];
434
-
435
415
  }
436
416
  }
437
417
 
438
418
  FastScanStats FastScan_stats;
439
419
 
440
- template<class C>
420
+ template <class C>
441
421
  void IndexPQFastScan::search_implem_14(
442
- idx_t n, const float* x, idx_t k,
443
- float* distances, idx_t* labels, int impl) const
444
- {
445
-
422
+ idx_t n,
423
+ const float* x,
424
+ idx_t k,
425
+ float* distances,
426
+ idx_t* labels,
427
+ int impl) const {
446
428
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
447
429
 
448
430
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -452,23 +434,25 @@ void IndexPQFastScan::search_implem_14(
452
434
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
453
435
  int64_t i1 = std::min(i0 + qbs2, n);
454
436
  search_implem_14<C>(
455
- i1 - i0, x + d * i0, k,
456
- distances + i0 * k, labels + i0 * k, impl
457
- );
437
+ i1 - i0,
438
+ x + d * i0,
439
+ k,
440
+ distances + i0 * k,
441
+ labels + i0 * k,
442
+ impl);
458
443
  }
459
444
  return;
460
445
  }
461
446
 
462
447
  size_t dim12 = pq.ksub * M2;
463
448
  AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
464
- std::unique_ptr<float []> normalizers(new float[2 * n]);
449
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
465
450
 
466
451
  if (skip & 1) {
467
452
  quantized_dis_tables.clear();
468
453
  } else {
469
454
  compute_quantized_LUT(
470
- n, x, quantized_dis_tables.get(), normalizers.get()
471
- );
455
+ n, x, quantized_dis_tables.get(), normalizers.get());
472
456
  }
473
457
 
474
458
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -480,37 +464,30 @@ void IndexPQFastScan::search_implem_14(
480
464
  // pass
481
465
  } else {
482
466
  handler.disable = bool(skip & 2);
483
- pq4_accumulate_loop (
484
- n, ntotal2, bbs, M2,
485
- codes.get(), LUT.get(),
486
- handler
487
- );
467
+ pq4_accumulate_loop(
468
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
488
469
  }
489
470
  handler.to_flat_arrays(distances, labels, normalizers.get());
490
471
 
491
472
  } else if (impl == 14) {
492
-
493
473
  std::vector<uint16_t> tmp_dis(n * k);
494
474
  std::vector<int32_t> tmp_ids(n * k);
495
475
 
496
476
  if (skip & 4) {
497
477
  // skip
498
478
  } else if (k > 1) {
499
- HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
479
+ HeapHandler<C> handler(
480
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
500
481
  handler.disable = bool(skip & 2);
501
482
 
502
- pq4_accumulate_loop (
503
- n, ntotal2, bbs, M2,
504
- codes.get(), LUT.get(),
505
- handler
506
- );
483
+ pq4_accumulate_loop(
484
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
507
485
 
508
486
  if (!(skip & 8)) {
509
487
  handler.to_flat_arrays(distances, labels, normalizers.get());
510
488
  }
511
489
  }
512
490
 
513
-
514
491
  } else { // impl == 15
515
492
 
516
493
  ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
@@ -519,11 +496,8 @@ void IndexPQFastScan::search_implem_14(
519
496
  if (skip & 4) {
520
497
  // skip
521
498
  } else {
522
- pq4_accumulate_loop (
523
- n, ntotal2, bbs, M2,
524
- codes.get(), LUT.get(),
525
- handler
526
- );
499
+ pq4_accumulate_loop(
500
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
527
501
  }
528
502
 
529
503
  if (!(skip & 8)) {
@@ -532,5 +506,4 @@ void IndexPQFastScan::search_implem_14(
532
506
  }
533
507
  }
534
508
 
535
-
536
509
  } // namespace faiss