faiss 0.3.0 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
37
37
 
38
38
  void IndexFastScan::init_fastscan(
39
39
  int d,
40
- size_t M,
41
- size_t nbits,
40
+ size_t M_2,
41
+ size_t nbits_2,
42
42
  MetricType metric,
43
43
  int bbs) {
44
- FAISS_THROW_IF_NOT(nbits == 4);
44
+ FAISS_THROW_IF_NOT(nbits_2 == 4);
45
45
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
46
46
  this->d = d;
47
- this->M = M;
48
- this->nbits = nbits;
47
+ this->M = M_2;
48
+ this->nbits = nbits_2;
49
49
  this->metric_type = metric;
50
50
  this->bbs = bbs;
51
- ksub = (1 << nbits);
51
+ ksub = (1 << nbits_2);
52
52
 
53
- code_size = (M * nbits + 7) / 8;
53
+ code_size = (M_2 * nbits_2 + 7) / 8;
54
54
  ntotal = ntotal2 = 0;
55
- M2 = roundup(M, 2);
55
+ M2 = roundup(M_2, 2);
56
56
  is_trained = false;
57
57
  }
58
58
 
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
158
158
 
159
159
  namespace {
160
160
 
161
- template <class C, typename dis_t, class Scaler>
161
+ template <class C, typename dis_t>
162
162
  void estimators_from_tables_generic(
163
163
  const IndexFastScan& index,
164
164
  const uint8_t* codes,
@@ -167,23 +167,27 @@ void estimators_from_tables_generic(
167
167
  size_t k,
168
168
  typename C::T* heap_dis,
169
169
  int64_t* heap_ids,
170
- const Scaler& scaler) {
170
+ const NormTableScaler* scaler) {
171
171
  using accu_t = typename C::T;
172
172
 
173
173
  for (size_t j = 0; j < ncodes; ++j) {
174
174
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
175
175
  accu_t dis = 0;
176
176
  const dis_t* dt = dis_table;
177
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
177
+ int nscale = scaler ? scaler->nscale : 0;
178
+
179
+ for (size_t m = 0; m < index.M - nscale; m++) {
178
180
  uint64_t c = bsr.read(index.nbits);
179
181
  dis += dt[c];
180
182
  dt += index.ksub;
181
183
  }
182
184
 
183
- for (size_t m = 0; m < scaler.nscale; m++) {
184
- uint64_t c = bsr.read(index.nbits);
185
- dis += scaler.scale_one(dt[c]);
186
- dt += index.ksub;
185
+ if (nscale) {
186
+ for (size_t m = 0; m < nscale; m++) {
187
+ uint64_t c = bsr.read(index.nbits);
188
+ dis += scaler->scale_one(dt[c]);
189
+ dt += index.ksub;
190
+ }
187
191
  }
188
192
 
189
193
  if (C::cmp(heap_dis[0], dis)) {
@@ -193,6 +197,28 @@ void estimators_from_tables_generic(
193
197
  }
194
198
  }
195
199
 
200
+ template <class C>
201
+ ResultHandlerCompare<C, false>* make_knn_handler(
202
+ int impl,
203
+ idx_t n,
204
+ idx_t k,
205
+ size_t ntotal,
206
+ float* distances,
207
+ idx_t* labels,
208
+ const IDSelector* sel = nullptr) {
209
+ using HeapHC = HeapHandler<C, false>;
210
+ using ReservoirHC = ReservoirHandler<C, false>;
211
+ using SingleResultHC = SingleResultHandler<C, false>;
212
+
213
+ if (k == 1) {
214
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
215
+ } else if (impl % 2 == 0) {
216
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
217
+ } else /* if (impl % 2 == 1) */ {
218
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
219
+ }
220
+ }
221
+
196
222
  } // anonymous namespace
197
223
 
198
224
  using namespace quantize_lut;
@@ -241,22 +267,21 @@ void IndexFastScan::search(
241
267
  !params, "search params not supported for this index");
242
268
  FAISS_THROW_IF_NOT(k > 0);
243
269
 
244
- DummyScaler scaler;
245
270
  if (metric_type == METRIC_L2) {
246
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
271
+ search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
247
272
  } else {
248
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
273
+ search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
249
274
  }
250
275
  }
251
276
 
252
- template <bool is_max, class Scaler>
277
+ template <bool is_max>
253
278
  void IndexFastScan::search_dispatch_implem(
254
279
  idx_t n,
255
280
  const float* x,
256
281
  idx_t k,
257
282
  float* distances,
258
283
  idx_t* labels,
259
- const Scaler& scaler) const {
284
+ const NormTableScaler* scaler) const {
260
285
  using Cfloat = typename std::conditional<
261
286
  is_max,
262
287
  CMax<float, int64_t>,
@@ -319,14 +344,14 @@ void IndexFastScan::search_dispatch_implem(
319
344
  }
320
345
  }
321
346
 
322
- template <class Cfloat, class Scaler>
347
+ template <class Cfloat>
323
348
  void IndexFastScan::search_implem_234(
324
349
  idx_t n,
325
350
  const float* x,
326
351
  idx_t k,
327
352
  float* distances,
328
353
  idx_t* labels,
329
- const Scaler& scaler) const {
354
+ const NormTableScaler* scaler) const {
330
355
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
331
356
 
332
357
  const size_t dim12 = ksub * M;
@@ -378,7 +403,7 @@ void IndexFastScan::search_implem_234(
378
403
  }
379
404
  }
380
405
 
381
- template <class C, class Scaler>
406
+ template <class C>
382
407
  void IndexFastScan::search_implem_12(
383
408
  idx_t n,
384
409
  const float* x,
@@ -386,7 +411,8 @@ void IndexFastScan::search_implem_12(
386
411
  float* distances,
387
412
  idx_t* labels,
388
413
  int impl,
389
- const Scaler& scaler) const {
414
+ const NormTableScaler* scaler) const {
415
+ using RH = ResultHandlerCompare<C, false>;
390
416
  FAISS_THROW_IF_NOT(bbs == 32);
391
417
 
392
418
  // handle qbs2 blocking by recursive call
@@ -432,63 +458,31 @@ void IndexFastScan::search_implem_12(
432
458
  pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
433
459
  FAISS_THROW_IF_NOT(LUT_nq == n);
434
460
 
435
- if (k == 1) {
436
- SingleResultHandler<C> handler(n, ntotal);
437
- if (skip & 4) {
438
- // pass
439
- } else {
440
- handler.disable = bool(skip & 2);
441
- pq4_accumulate_loop_qbs(
442
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
443
- }
444
-
445
- handler.to_flat_arrays(distances, labels, normalizers.get());
446
-
447
- } else if (impl == 12) {
448
- std::vector<uint16_t> tmp_dis(n * k);
449
- std::vector<int32_t> tmp_ids(n * k);
450
-
451
- if (skip & 4) {
452
- // skip
453
- } else {
454
- HeapHandler<C> handler(
455
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
456
- handler.disable = bool(skip & 2);
457
-
458
- pq4_accumulate_loop_qbs(
459
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
460
-
461
- if (!(skip & 8)) {
462
- handler.to_flat_arrays(distances, labels, normalizers.get());
463
- }
464
- }
465
-
466
- } else { // impl == 13
467
-
468
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
469
- handler.disable = bool(skip & 2);
461
+ std::unique_ptr<RH> handler(
462
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
463
+ handler->disable = bool(skip & 2);
464
+ handler->normalizers = normalizers.get();
470
465
 
471
- if (skip & 4) {
472
- // skip
473
- } else {
474
- pq4_accumulate_loop_qbs(
475
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
476
- }
477
-
478
- if (!(skip & 8)) {
479
- handler.to_flat_arrays(distances, labels, normalizers.get());
480
- }
481
-
482
- FastScan_stats.t0 += handler.times[0];
483
- FastScan_stats.t1 += handler.times[1];
484
- FastScan_stats.t2 += handler.times[2];
485
- FastScan_stats.t3 += handler.times[3];
466
+ if (skip & 4) {
467
+ // pass
468
+ } else {
469
+ pq4_accumulate_loop_qbs(
470
+ qbs,
471
+ ntotal2,
472
+ M2,
473
+ codes.get(),
474
+ LUT.get(),
475
+ *handler.get(),
476
+ scaler);
477
+ }
478
+ if (!(skip & 8)) {
479
+ handler->end();
486
480
  }
487
481
  }
488
482
 
489
483
  FastScanStats FastScan_stats;
490
484
 
491
- template <class C, class Scaler>
485
+ template <class C>
492
486
  void IndexFastScan::search_implem_14(
493
487
  idx_t n,
494
488
  const float* x,
@@ -496,7 +490,8 @@ void IndexFastScan::search_implem_14(
496
490
  float* distances,
497
491
  idx_t* labels,
498
492
  int impl,
499
- const Scaler& scaler) const {
493
+ const NormTableScaler* scaler) const {
494
+ using RH = ResultHandlerCompare<C, false>;
500
495
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
501
496
 
502
497
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -531,90 +526,44 @@ void IndexFastScan::search_implem_14(
531
526
  AlignedTable<uint8_t> LUT(n * dim12);
532
527
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
533
528
 
534
- if (k == 1) {
535
- SingleResultHandler<C> handler(n, ntotal);
536
- if (skip & 4) {
537
- // pass
538
- } else {
539
- handler.disable = bool(skip & 2);
540
- pq4_accumulate_loop(
541
- n,
542
- ntotal2,
543
- bbs,
544
- M2,
545
- codes.get(),
546
- LUT.get(),
547
- handler,
548
- scaler);
549
- }
550
- handler.to_flat_arrays(distances, labels, normalizers.get());
551
-
552
- } else if (impl == 14) {
553
- std::vector<uint16_t> tmp_dis(n * k);
554
- std::vector<int32_t> tmp_ids(n * k);
555
-
556
- if (skip & 4) {
557
- // skip
558
- } else if (k > 1) {
559
- HeapHandler<C> handler(
560
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
561
- handler.disable = bool(skip & 2);
562
-
563
- pq4_accumulate_loop(
564
- n,
565
- ntotal2,
566
- bbs,
567
- M2,
568
- codes.get(),
569
- LUT.get(),
570
- handler,
571
- scaler);
572
-
573
- if (!(skip & 8)) {
574
- handler.to_flat_arrays(distances, labels, normalizers.get());
575
- }
576
- }
577
-
578
- } else { // impl == 15
529
+ std::unique_ptr<RH> handler(
530
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
531
+ handler->disable = bool(skip & 2);
532
+ handler->normalizers = normalizers.get();
579
533
 
580
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
581
- handler.disable = bool(skip & 2);
582
-
583
- if (skip & 4) {
584
- // skip
585
- } else {
586
- pq4_accumulate_loop(
587
- n,
588
- ntotal2,
589
- bbs,
590
- M2,
591
- codes.get(),
592
- LUT.get(),
593
- handler,
594
- scaler);
595
- }
596
-
597
- if (!(skip & 8)) {
598
- handler.to_flat_arrays(distances, labels, normalizers.get());
599
- }
534
+ if (skip & 4) {
535
+ // pass
536
+ } else {
537
+ pq4_accumulate_loop(
538
+ n,
539
+ ntotal2,
540
+ bbs,
541
+ M2,
542
+ codes.get(),
543
+ LUT.get(),
544
+ *handler.get(),
545
+ scaler);
546
+ }
547
+ if (!(skip & 8)) {
548
+ handler->end();
600
549
  }
601
550
  }
602
551
 
603
- template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
552
+ template void IndexFastScan::search_dispatch_implem<true>(
604
553
  idx_t n,
605
554
  const float* x,
606
555
  idx_t k,
607
556
  float* distances,
608
557
  idx_t* labels,
609
- const NormTableScaler& scaler) const;
558
+ const NormTableScaler* scaler) const;
610
559
 
611
- template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
560
+ template void IndexFastScan::search_dispatch_implem<false>(
612
561
  idx_t n,
613
562
  const float* x,
614
563
  idx_t k,
615
564
  float* distances,
616
565
  idx_t* labels,
617
- const NormTableScaler& scaler) const;
566
+ const NormTableScaler* scaler) const;
618
567
 
619
568
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
620
569
  std::vector<uint8_t> code(code_size, 0);
@@ -13,6 +13,7 @@
13
13
  namespace faiss {
14
14
 
15
15
  struct CodePacker;
16
+ struct NormTableScaler;
16
17
 
17
18
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
18
19
  *
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
87
88
  uint8_t* lut,
88
89
  float* normalizers) const;
89
90
 
90
- template <bool is_max, class Scaler>
91
+ template <bool is_max>
91
92
  void search_dispatch_implem(
92
93
  idx_t n,
93
94
  const float* x,
94
95
  idx_t k,
95
96
  float* distances,
96
97
  idx_t* labels,
97
- const Scaler& scaler) const;
98
+ const NormTableScaler* scaler) const;
98
99
 
99
- template <class Cfloat, class Scaler>
100
+ template <class Cfloat>
100
101
  void search_implem_234(
101
102
  idx_t n,
102
103
  const float* x,
103
104
  idx_t k,
104
105
  float* distances,
105
106
  idx_t* labels,
106
- const Scaler& scaler) const;
107
+ const NormTableScaler* scaler) const;
107
108
 
108
- template <class C, class Scaler>
109
+ template <class C>
109
110
  void search_implem_12(
110
111
  idx_t n,
111
112
  const float* x,
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
113
114
  float* distances,
114
115
  idx_t* labels,
115
116
  int impl,
116
- const Scaler& scaler) const;
117
+ const NormTableScaler* scaler) const;
117
118
 
118
- template <class C, class Scaler>
119
+ template <class C>
119
120
  void search_implem_14(
120
121
  idx_t n,
121
122
  const float* x,
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
123
124
  float* distances,
124
125
  idx_t* labels,
125
126
  int impl,
126
- const Scaler& scaler) const;
127
+ const NormTableScaler* scaler) const;
127
128
 
128
129
  void reconstruct(idx_t key, float* recons) const override;
129
130
  size_t remove_ids(const IDSelector& sel) override;
@@ -14,6 +14,7 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
16
  #include <faiss/utils/extra_distances.h>
17
+ #include <faiss/utils/prefetch.h>
17
18
  #include <faiss/utils/sorting.h>
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <cstring>
@@ -40,15 +41,19 @@ void IndexFlat::search(
40
41
  } else if (metric_type == METRIC_L2) {
41
42
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
42
43
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
43
- } else if (is_similarity_metric(metric_type)) {
44
- float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
- knn_extra_metrics(
46
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
44
  } else {
48
- FAISS_THROW_IF_NOT(!sel);
49
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
+ FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
50
46
  knn_extra_metrics(
51
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
+ x,
48
+ get_xb(),
49
+ d,
50
+ n,
51
+ ntotal,
52
+ metric_type,
53
+ metric_arg,
54
+ k,
55
+ distances,
56
+ labels);
52
57
  }
53
58
  }
54
59
 
@@ -122,6 +127,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122
127
  void set_query(const float* x) override {
123
128
  q = x;
124
129
  }
130
+
131
+ // compute four distances
132
+ void distances_batch_4(
133
+ const idx_t idx0,
134
+ const idx_t idx1,
135
+ const idx_t idx2,
136
+ const idx_t idx3,
137
+ float& dis0,
138
+ float& dis1,
139
+ float& dis2,
140
+ float& dis3) final override {
141
+ ndis += 4;
142
+
143
+ // compute first, assign next
144
+ const float* __restrict y0 =
145
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
146
+ const float* __restrict y1 =
147
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
148
+ const float* __restrict y2 =
149
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
150
+ const float* __restrict y3 =
151
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
152
+
153
+ float dp0 = 0;
154
+ float dp1 = 0;
155
+ float dp2 = 0;
156
+ float dp3 = 0;
157
+ fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
158
+ dis0 = dp0;
159
+ dis1 = dp1;
160
+ dis2 = dp2;
161
+ dis3 = dp3;
162
+ }
125
163
  };
126
164
 
127
165
  struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +169,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131
169
  const float* b;
132
170
  size_t ndis;
133
171
 
134
- float symmetric_dis(idx_t i, idx_t j) override {
172
+ float symmetric_dis(idx_t i, idx_t j) final override {
135
173
  return fvec_inner_product(b + j * d, b + i * d, d);
136
174
  }
137
175
 
138
- float distance_to_code(const uint8_t* code) final {
176
+ float distance_to_code(const uint8_t* code) final override {
139
177
  ndis++;
140
- return fvec_inner_product(q, (float*)code, d);
178
+ return fvec_inner_product(q, (const float*)code, d);
141
179
  }
142
180
 
143
181
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -153,6 +191,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153
191
  void set_query(const float* x) override {
154
192
  q = x;
155
193
  }
194
+
195
+ // compute four distances
196
+ void distances_batch_4(
197
+ const idx_t idx0,
198
+ const idx_t idx1,
199
+ const idx_t idx2,
200
+ const idx_t idx3,
201
+ float& dis0,
202
+ float& dis1,
203
+ float& dis2,
204
+ float& dis3) final override {
205
+ ndis += 4;
206
+
207
+ // compute first, assign next
208
+ const float* __restrict y0 =
209
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
210
+ const float* __restrict y1 =
211
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
212
+ const float* __restrict y2 =
213
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
214
+ const float* __restrict y3 =
215
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
216
+
217
+ float dp0 = 0;
218
+ float dp1 = 0;
219
+ float dp2 = 0;
220
+ float dp3 = 0;
221
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
222
+ dis0 = dp0;
223
+ dis1 = dp1;
224
+ dis2 = dp2;
225
+ dis3 = dp3;
226
+ }
156
227
  };
157
228
 
158
229
  } // namespace
@@ -184,6 +255,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184
255
  }
185
256
  }
186
257
 
258
+ /***************************************************
259
+ * IndexFlatL2
260
+ ***************************************************/
261
+
262
+ namespace {
263
+ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
264
+ size_t d;
265
+ idx_t nb;
266
+ const float* q;
267
+ const float* b;
268
+ size_t ndis;
269
+
270
+ const float* l2norms;
271
+ float query_l2norm;
272
+
273
+ float distance_to_code(const uint8_t* code) final override {
274
+ ndis++;
275
+ return fvec_L2sqr(q, (float*)code, d);
276
+ }
277
+
278
+ float operator()(const idx_t i) final override {
279
+ const float* __restrict y =
280
+ reinterpret_cast<const float*>(codes + i * code_size);
281
+
282
+ prefetch_L2(l2norms + i);
283
+ const float dp0 = fvec_inner_product(q, y, d);
284
+ return query_l2norm + l2norms[i] - 2 * dp0;
285
+ }
286
+
287
+ float symmetric_dis(idx_t i, idx_t j) final override {
288
+ const float* __restrict yi =
289
+ reinterpret_cast<const float*>(codes + i * code_size);
290
+ const float* __restrict yj =
291
+ reinterpret_cast<const float*>(codes + j * code_size);
292
+
293
+ prefetch_L2(l2norms + i);
294
+ prefetch_L2(l2norms + j);
295
+ const float dp0 = fvec_inner_product(yi, yj, d);
296
+ return l2norms[i] + l2norms[j] - 2 * dp0;
297
+ }
298
+
299
+ explicit FlatL2WithNormsDis(
300
+ const IndexFlatL2& storage,
301
+ const float* q = nullptr)
302
+ : FlatCodesDistanceComputer(
303
+ storage.codes.data(),
304
+ storage.code_size),
305
+ d(storage.d),
306
+ nb(storage.ntotal),
307
+ q(q),
308
+ b(storage.get_xb()),
309
+ ndis(0),
310
+ l2norms(storage.cached_l2norms.data()),
311
+ query_l2norm(0) {}
312
+
313
+ void set_query(const float* x) override {
314
+ q = x;
315
+ query_l2norm = fvec_norm_L2sqr(q, d);
316
+ }
317
+
318
+ // compute four distances
319
+ void distances_batch_4(
320
+ const idx_t idx0,
321
+ const idx_t idx1,
322
+ const idx_t idx2,
323
+ const idx_t idx3,
324
+ float& dis0,
325
+ float& dis1,
326
+ float& dis2,
327
+ float& dis3) final override {
328
+ ndis += 4;
329
+
330
+ // compute first, assign next
331
+ const float* __restrict y0 =
332
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
333
+ const float* __restrict y1 =
334
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
335
+ const float* __restrict y2 =
336
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
337
+ const float* __restrict y3 =
338
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
339
+
340
+ prefetch_L2(l2norms + idx0);
341
+ prefetch_L2(l2norms + idx1);
342
+ prefetch_L2(l2norms + idx2);
343
+ prefetch_L2(l2norms + idx3);
344
+
345
+ float dp0 = 0;
346
+ float dp1 = 0;
347
+ float dp2 = 0;
348
+ float dp3 = 0;
349
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
350
+ dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
351
+ dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
352
+ dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
353
+ dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
354
+ }
355
+ };
356
+
357
+ } // namespace
358
+
359
+ void IndexFlatL2::sync_l2norms() {
360
+ cached_l2norms.resize(ntotal);
361
+ fvec_norms_L2sqr(
362
+ cached_l2norms.data(),
363
+ reinterpret_cast<const float*>(codes.data()),
364
+ d,
365
+ ntotal);
366
+ }
367
+
368
+ void IndexFlatL2::clear_l2norms() {
369
+ cached_l2norms.clear();
370
+ cached_l2norms.shrink_to_fit();
371
+ }
372
+
373
+ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
374
+ if (metric_type == METRIC_L2) {
375
+ if (!cached_l2norms.empty()) {
376
+ return new FlatL2WithNormsDis(*this);
377
+ }
378
+ }
379
+
380
+ return IndexFlat::get_FlatCodesDistanceComputer();
381
+ }
382
+
187
383
  /***************************************************
188
384
  * IndexFlat1D
189
385
  ***************************************************/