faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -5,24 +5,21 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
-
9
8
  #include <faiss/IndexPQFastScan.h>
10
9
 
10
+ #include <limits.h>
11
11
  #include <cassert>
12
12
  #include <memory>
13
- #include <limits.h>
14
13
 
15
14
  #include <omp.h>
16
15
 
17
-
18
16
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/utils/utils.h>
20
17
  #include <faiss/utils/random.h>
18
+ #include <faiss/utils/utils.h>
21
19
 
20
+ #include <faiss/impl/pq4_fast_scan.h>
22
21
  #include <faiss/impl/simd_result_handlers.h>
23
22
  #include <faiss/utils/quantize_lut.h>
24
- #include <faiss/impl/pq4_fast_scan.h>
25
-
26
23
 
27
24
  namespace faiss {
28
25
 
@@ -33,25 +30,24 @@ inline size_t roundup(size_t a, size_t b) {
33
30
  }
34
31
 
35
32
  IndexPQFastScan::IndexPQFastScan(
36
- int d, size_t M, size_t nbits,
33
+ int d,
34
+ size_t M,
35
+ size_t nbits,
37
36
  MetricType metric,
38
- int bbs):
39
- Index(d, metric), pq(d, M, nbits),
40
- bbs(bbs), ntotal2(0), M2(roundup(M, 2))
41
- {
37
+ int bbs)
38
+ : Index(d, metric),
39
+ pq(d, M, nbits),
40
+ bbs(bbs),
41
+ ntotal2(0),
42
+ M2(roundup(M, 2)) {
42
43
  FAISS_THROW_IF_NOT(nbits == 4);
43
44
  is_trained = false;
44
45
  }
45
46
 
46
- IndexPQFastScan::IndexPQFastScan():
47
- bbs(0), ntotal2(0), M2(0)
48
- {}
47
+ IndexPQFastScan::IndexPQFastScan() : bbs(0), ntotal2(0), M2(0) {}
49
48
 
50
- IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
51
- Index(orig.d, orig.metric_type),
52
- pq(orig.pq),
53
- bbs(bbs)
54
- {
49
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
50
+ : Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
55
51
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
56
52
  ntotal = orig.ntotal;
57
53
  is_trained = orig.is_trained;
@@ -70,16 +66,10 @@ IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
70
66
  codes.resize(ntotal2 * M2 / 2);
71
67
 
72
68
  // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
73
- pq4_pack_codes(
74
- orig.codes.data(),
75
- ntotal, M,
76
- ntotal2, bbs, M2,
77
- codes.get()
78
- );
69
+ pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
79
70
  }
80
71
 
81
- void IndexPQFastScan::train (idx_t n, const float *x)
82
- {
72
+ void IndexPQFastScan::train(idx_t n, const float* x) {
83
73
  if (is_trained) {
84
74
  return;
85
75
  }
@@ -87,11 +77,10 @@ void IndexPQFastScan::train (idx_t n, const float *x)
87
77
  is_trained = true;
88
78
  }
89
79
 
90
-
91
- void IndexPQFastScan::add (idx_t n, const float *x) {
92
- FAISS_THROW_IF_NOT (is_trained);
80
+ void IndexPQFastScan::add(idx_t n, const float* x) {
81
+ FAISS_THROW_IF_NOT(is_trained);
93
82
  AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
94
- pq.compute_codes (x, tmp_codes.get(), n);
83
+ pq.compute_codes(x, tmp_codes.get(), n);
95
84
  ntotal2 = roundup(ntotal + n, bbs);
96
85
  size_t new_size = ntotal2 * M2 / 2;
97
86
  size_t old_size = codes.size();
@@ -100,39 +89,35 @@ void IndexPQFastScan::add (idx_t n, const float *x) {
100
89
  memset(codes.get() + old_size, 0, new_size - old_size);
101
90
  }
102
91
  pq4_pack_codes_range(
103
- tmp_codes.get(), pq.M, ntotal, ntotal + n,
104
- bbs, M2, codes.get()
105
- );
92
+ tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
106
93
  ntotal += n;
107
94
  }
108
95
 
109
- void IndexPQFastScan::reset()
110
- {
96
+ void IndexPQFastScan::reset() {
111
97
  codes.resize(0);
112
98
  ntotal = 0;
113
- }
114
-
115
-
99
+ }
116
100
 
117
101
  namespace {
118
102
 
119
103
  // from impl/ProductQuantizer.cpp
120
104
  template <class C, typename dis_t>
121
105
  void pq_estimators_from_tables_generic(
122
- const ProductQuantizer& pq, size_t nbits,
123
- const uint8_t *codes, size_t ncodes,
124
- const dis_t *dis_table, size_t k,
125
- typename C::T *heap_dis, int64_t *heap_ids)
126
- {
106
+ const ProductQuantizer& pq,
107
+ size_t nbits,
108
+ const uint8_t* codes,
109
+ size_t ncodes,
110
+ const dis_t* dis_table,
111
+ size_t k,
112
+ typename C::T* heap_dis,
113
+ int64_t* heap_ids) {
127
114
  using accu_t = typename C::T;
128
115
  const size_t M = pq.M;
129
116
  const size_t ksub = pq.ksub;
130
117
  for (size_t j = 0; j < ncodes; ++j) {
131
- PQDecoderGeneric decoder(
132
- codes + j * pq.code_size, nbits
133
- );
118
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
134
119
  accu_t dis = 0;
135
- const dis_t * __restrict dt = dis_table;
120
+ const dis_t* __restrict dt = dis_table;
136
121
  for (size_t m = 0; m < M; m++) {
137
122
  uint64_t c = decoder.decode();
138
123
  dis += dt[c];
@@ -146,53 +131,55 @@ void pq_estimators_from_tables_generic(
146
131
  }
147
132
  }
148
133
 
149
-
150
134
  } // anonymous namespace
151
135
 
152
-
153
136
  using namespace quantize_lut;
154
137
 
155
138
  void IndexPQFastScan::compute_quantized_LUT(
156
- idx_t n, const float* x,
157
- uint8_t *lut, float *normalizers) const
158
- {
139
+ idx_t n,
140
+ const float* x,
141
+ uint8_t* lut,
142
+ float* normalizers) const {
159
143
  size_t dim12 = pq.ksub * pq.M;
160
- std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
144
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
161
145
  if (metric_type == METRIC_L2) {
162
- pq.compute_distance_tables (n, x, dis_tables.get());
146
+ pq.compute_distance_tables(n, x, dis_tables.get());
163
147
  } else {
164
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
148
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
165
149
  }
166
150
 
167
- for(uint64_t i = 0; i < n; i++) {
151
+ for (uint64_t i = 0; i < n; i++) {
168
152
  round_uint8_per_column(
169
- dis_tables.get() + i * dim12, pq.M, pq.ksub,
170
- &normalizers[2 * i], &normalizers[2 * i + 1]
171
- );
153
+ dis_tables.get() + i * dim12,
154
+ pq.M,
155
+ pq.ksub,
156
+ &normalizers[2 * i],
157
+ &normalizers[2 * i + 1]);
172
158
  }
173
159
 
174
- for(uint64_t i = 0; i < n; i++) {
175
- const float *t_in = dis_tables.get() + i * dim12;
176
- uint8_t *t_out = lut + i * M2 * pq.ksub;
160
+ for (uint64_t i = 0; i < n; i++) {
161
+ const float* t_in = dis_tables.get() + i * dim12;
162
+ uint8_t* t_out = lut + i * M2 * pq.ksub;
177
163
 
178
- for(int j = 0; j < dim12; j++) {
164
+ for (int j = 0; j < dim12; j++) {
179
165
  t_out[j] = int(t_in[j]);
180
166
  }
181
167
  memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
182
168
  }
183
169
  }
184
170
 
185
-
186
-
187
171
  /******************************************************************************
188
172
  * Search driver routine
189
173
  ******************************************************************************/
190
174
 
191
-
192
175
  void IndexPQFastScan::search(
193
- idx_t n, const float* x, idx_t k,
194
- float* distances, idx_t* labels) const
195
- {
176
+ idx_t n,
177
+ const float* x,
178
+ idx_t k,
179
+ float* distances,
180
+ idx_t* labels) const {
181
+ FAISS_THROW_IF_NOT(k > 0);
182
+
196
183
  if (metric_type == METRIC_L2) {
197
184
  search_dispatch_implem<true>(n, x, k, distances, labels);
198
185
  } else {
@@ -200,20 +187,20 @@ void IndexPQFastScan::search(
200
187
  }
201
188
  }
202
189
 
203
-
204
- template<bool is_max>
190
+ template <bool is_max>
205
191
  void IndexPQFastScan::search_dispatch_implem(
206
- idx_t n,
207
- const float* x,
208
- idx_t k,
209
- float* distances,
210
- idx_t* labels) const
211
- {
212
- using Cfloat = typename std::conditional<is_max,
213
- CMax<float, int64_t>, CMin<float, int64_t> >::type;
214
-
215
- using C = typename std::conditional<is_max,
216
- CMax<uint16_t, int>, CMin<uint16_t, int> >::type;
192
+ idx_t n,
193
+ const float* x,
194
+ idx_t k,
195
+ float* distances,
196
+ idx_t* labels) const {
197
+ using Cfloat = typename std::conditional<
198
+ is_max,
199
+ CMax<float, int64_t>,
200
+ CMin<float, int64_t>>::type;
201
+
202
+ using C = typename std::
203
+ conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
217
204
 
218
205
  if (n == 0) {
219
206
  return;
@@ -229,26 +216,24 @@ void IndexPQFastScan::search_dispatch_implem(
229
216
  impl = 14;
230
217
  }
231
218
  if (k > 20) {
232
- impl ++;
219
+ impl++;
233
220
  }
234
221
  }
235
222
 
236
-
237
223
  if (implem == 1) {
238
224
  FAISS_THROW_IF_NOT(orig_codes);
239
225
  FAISS_THROW_IF_NOT(is_max);
240
- float_maxheap_array_t res = {
241
- size_t(n), size_t(k), labels, distances };
242
- pq.search (x, n, orig_codes, ntotal, &res, true);
226
+ float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
227
+ pq.search(x, n, orig_codes, ntotal, &res, true);
243
228
  } else if (implem == 2 || implem == 3 || implem == 4) {
244
229
  FAISS_THROW_IF_NOT(orig_codes);
245
230
 
246
231
  size_t dim12 = pq.ksub * pq.M;
247
- std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
232
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
248
233
  if (is_max) {
249
- pq.compute_distance_tables (n, x, dis_tables.get());
234
+ pq.compute_distance_tables(n, x, dis_tables.get());
250
235
  } else {
251
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
236
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
252
237
  }
253
238
 
254
239
  std::vector<float> normalizers(n * 2);
@@ -256,34 +241,39 @@ void IndexPQFastScan::search_dispatch_implem(
256
241
  if (implem == 2) {
257
242
  // default float
258
243
  } else if (implem == 3 || implem == 4) {
259
- for(uint64_t i = 0; i < n; i++) {
244
+ for (uint64_t i = 0; i < n; i++) {
260
245
  round_uint8_per_column(
261
- dis_tables.get() + i * dim12, pq.M,
246
+ dis_tables.get() + i * dim12,
247
+ pq.M,
262
248
  pq.ksub,
263
- &normalizers[2 * i], &normalizers[2 * i + 1]
264
- );
249
+ &normalizers[2 * i],
250
+ &normalizers[2 * i + 1]);
265
251
  }
266
252
  }
267
253
 
268
254
  for (int64_t i = 0; i < n; i++) {
269
- int64_t *heap_ids = labels + i * k;
270
- float *heap_dis = distances + i * k;
255
+ int64_t* heap_ids = labels + i * k;
256
+ float* heap_dis = distances + i * k;
271
257
 
272
- heap_heapify<Cfloat> (k, heap_dis, heap_ids);
258
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
273
259
 
274
260
  pq_estimators_from_tables_generic<Cfloat>(
275
- pq, pq.nbits, orig_codes, ntotal,
276
- dis_tables.get() + i * dim12,
277
- k, heap_dis, heap_ids
278
- );
261
+ pq,
262
+ pq.nbits,
263
+ orig_codes,
264
+ ntotal,
265
+ dis_tables.get() + i * dim12,
266
+ k,
267
+ heap_dis,
268
+ heap_ids);
279
269
 
280
- heap_reorder<Cfloat> (k, heap_dis, heap_ids);
270
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
281
271
 
282
272
  if (implem == 4) {
283
273
  float a = normalizers[2 * i];
284
274
  float b = normalizers[2 * i + 1];
285
275
 
286
- for(int j = 0; j < k; j++) {
276
+ for (int j = 0; j < k; j++) {
287
277
  heap_dis[j] = heap_dis[j] / a + b;
288
278
  }
289
279
  }
@@ -303,30 +293,30 @@ void IndexPQFastScan::search_dispatch_implem(
303
293
  for (int slice = 0; slice < nt; slice++) {
304
294
  idx_t i0 = n * slice / nt;
305
295
  idx_t i1 = n * (slice + 1) / nt;
306
- float *dis_i = distances + i0 * k;
307
- idx_t *lab_i = labels + i0 * k;
296
+ float* dis_i = distances + i0 * k;
297
+ idx_t* lab_i = labels + i0 * k;
308
298
  if (impl == 12 || impl == 13) {
309
299
  search_implem_12<C>(
310
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
300
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
311
301
  } else {
312
302
  search_implem_14<C>(
313
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
303
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
314
304
  }
315
305
  }
316
306
  }
317
307
  } else {
318
308
  FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
319
309
  }
320
-
321
310
  }
322
311
 
323
- template<class C>
312
+ template <class C>
324
313
  void IndexPQFastScan::search_implem_12(
325
- idx_t n, const float* x, idx_t k,
326
- float* distances, idx_t* labels,
327
- int impl) const
328
- {
329
-
314
+ idx_t n,
315
+ const float* x,
316
+ idx_t k,
317
+ float* distances,
318
+ idx_t* labels,
319
+ int impl) const {
330
320
  FAISS_THROW_IF_NOT(bbs == 32);
331
321
 
332
322
  // handle qbs2 blocking by recursive call
@@ -335,23 +325,25 @@ void IndexPQFastScan::search_implem_12(
335
325
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
336
326
  int64_t i1 = std::min(i0 + qbs2, n);
337
327
  search_implem_12<C>(
338
- i1 - i0, x + d * i0, k,
339
- distances + i0 * k, labels + i0 * k, impl
340
- );
328
+ i1 - i0,
329
+ x + d * i0,
330
+ k,
331
+ distances + i0 * k,
332
+ labels + i0 * k,
333
+ impl);
341
334
  }
342
335
  return;
343
336
  }
344
337
 
345
338
  size_t dim12 = pq.ksub * M2;
346
339
  AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
347
- std::unique_ptr<float []> normalizers(new float[2 * n]);
340
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
348
341
 
349
342
  if (skip & 1) {
350
343
  quantized_dis_tables.clear();
351
344
  } else {
352
345
  compute_quantized_LUT(
353
- n, x, quantized_dis_tables.get(), normalizers.get()
354
- );
346
+ n, x, quantized_dis_tables.get(), normalizers.get());
355
347
  }
356
348
 
357
349
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -365,9 +357,8 @@ void IndexPQFastScan::search_implem_12(
365
357
  qbs = pq4_preferred_qbs(n);
366
358
  }
367
359
 
368
- int LUT_nq = pq4_pack_LUT_qbs(
369
- qbs, M2, quantized_dis_tables.get(), LUT.get()
370
- );
360
+ int LUT_nq =
361
+ pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
371
362
  FAISS_THROW_IF_NOT(LUT_nq == n);
372
363
 
373
364
  if (k == 1) {
@@ -377,37 +368,30 @@ void IndexPQFastScan::search_implem_12(
377
368
  } else {
378
369
  handler.disable = bool(skip & 2);
379
370
  pq4_accumulate_loop_qbs(
380
- qbs, ntotal2, M2,
381
- codes.get(), LUT.get(),
382
- handler
383
- );
371
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
384
372
  }
385
373
 
386
374
  handler.to_flat_arrays(distances, labels, normalizers.get());
387
375
 
388
376
  } else if (impl == 12) {
389
-
390
377
  std::vector<uint16_t> tmp_dis(n * k);
391
378
  std::vector<int32_t> tmp_ids(n * k);
392
379
 
393
380
  if (skip & 4) {
394
381
  // skip
395
382
  } else {
396
- HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
383
+ HeapHandler<C> handler(
384
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
397
385
  handler.disable = bool(skip & 2);
398
386
 
399
387
  pq4_accumulate_loop_qbs(
400
- qbs, ntotal2, M2,
401
- codes.get(), LUT.get(),
402
- handler
403
- );
388
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
404
389
 
405
390
  if (!(skip & 8)) {
406
391
  handler.to_flat_arrays(distances, labels, normalizers.get());
407
392
  }
408
393
  }
409
394
 
410
-
411
395
  } else { // impl == 13
412
396
 
413
397
  ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
@@ -417,10 +401,7 @@ void IndexPQFastScan::search_implem_12(
417
401
  // skip
418
402
  } else {
419
403
  pq4_accumulate_loop_qbs(
420
- qbs, ntotal2, M2,
421
- codes.get(), LUT.get(),
422
- handler
423
- );
404
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
424
405
  }
425
406
 
426
407
  if (!(skip & 8)) {
@@ -431,18 +412,19 @@ void IndexPQFastScan::search_implem_12(
431
412
  FastScan_stats.t1 += handler.times[1];
432
413
  FastScan_stats.t2 += handler.times[2];
433
414
  FastScan_stats.t3 += handler.times[3];
434
-
435
415
  }
436
416
  }
437
417
 
438
418
  FastScanStats FastScan_stats;
439
419
 
440
- template<class C>
420
+ template <class C>
441
421
  void IndexPQFastScan::search_implem_14(
442
- idx_t n, const float* x, idx_t k,
443
- float* distances, idx_t* labels, int impl) const
444
- {
445
-
422
+ idx_t n,
423
+ const float* x,
424
+ idx_t k,
425
+ float* distances,
426
+ idx_t* labels,
427
+ int impl) const {
446
428
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
447
429
 
448
430
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -452,23 +434,25 @@ void IndexPQFastScan::search_implem_14(
452
434
  for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
453
435
  int64_t i1 = std::min(i0 + qbs2, n);
454
436
  search_implem_14<C>(
455
- i1 - i0, x + d * i0, k,
456
- distances + i0 * k, labels + i0 * k, impl
457
- );
437
+ i1 - i0,
438
+ x + d * i0,
439
+ k,
440
+ distances + i0 * k,
441
+ labels + i0 * k,
442
+ impl);
458
443
  }
459
444
  return;
460
445
  }
461
446
 
462
447
  size_t dim12 = pq.ksub * M2;
463
448
  AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
464
- std::unique_ptr<float []> normalizers(new float[2 * n]);
449
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
465
450
 
466
451
  if (skip & 1) {
467
452
  quantized_dis_tables.clear();
468
453
  } else {
469
454
  compute_quantized_LUT(
470
- n, x, quantized_dis_tables.get(), normalizers.get()
471
- );
455
+ n, x, quantized_dis_tables.get(), normalizers.get());
472
456
  }
473
457
 
474
458
  AlignedTable<uint8_t> LUT(n * dim12);
@@ -480,37 +464,30 @@ void IndexPQFastScan::search_implem_14(
480
464
  // pass
481
465
  } else {
482
466
  handler.disable = bool(skip & 2);
483
- pq4_accumulate_loop (
484
- n, ntotal2, bbs, M2,
485
- codes.get(), LUT.get(),
486
- handler
487
- );
467
+ pq4_accumulate_loop(
468
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
488
469
  }
489
470
  handler.to_flat_arrays(distances, labels, normalizers.get());
490
471
 
491
472
  } else if (impl == 14) {
492
-
493
473
  std::vector<uint16_t> tmp_dis(n * k);
494
474
  std::vector<int32_t> tmp_ids(n * k);
495
475
 
496
476
  if (skip & 4) {
497
477
  // skip
498
478
  } else if (k > 1) {
499
- HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
479
+ HeapHandler<C> handler(
480
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
500
481
  handler.disable = bool(skip & 2);
501
482
 
502
- pq4_accumulate_loop (
503
- n, ntotal2, bbs, M2,
504
- codes.get(), LUT.get(),
505
- handler
506
- );
483
+ pq4_accumulate_loop(
484
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
507
485
 
508
486
  if (!(skip & 8)) {
509
487
  handler.to_flat_arrays(distances, labels, normalizers.get());
510
488
  }
511
489
  }
512
490
 
513
-
514
491
  } else { // impl == 15
515
492
 
516
493
  ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
@@ -519,11 +496,8 @@ void IndexPQFastScan::search_implem_14(
519
496
  if (skip & 4) {
520
497
  // skip
521
498
  } else {
522
- pq4_accumulate_loop (
523
- n, ntotal2, bbs, M2,
524
- codes.get(), LUT.get(),
525
- handler
526
- );
499
+ pq4_accumulate_loop(
500
+ n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
527
501
  }
528
502
 
529
503
  if (!(skip & 8)) {
@@ -532,5 +506,4 @@ void IndexPQFastScan::search_implem_14(
532
506
  }
533
507
  }
534
508
 
535
-
536
509
  } // namespace faiss