faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
37
37
 
38
38
  void IndexFastScan::init_fastscan(
39
39
  int d,
40
- size_t M,
41
- size_t nbits,
40
+ size_t M_2,
41
+ size_t nbits_2,
42
42
  MetricType metric,
43
43
  int bbs) {
44
- FAISS_THROW_IF_NOT(nbits == 4);
44
+ FAISS_THROW_IF_NOT(nbits_2 == 4);
45
45
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
46
46
  this->d = d;
47
- this->M = M;
48
- this->nbits = nbits;
47
+ this->M = M_2;
48
+ this->nbits = nbits_2;
49
49
  this->metric_type = metric;
50
50
  this->bbs = bbs;
51
- ksub = (1 << nbits);
51
+ ksub = (1 << nbits_2);
52
52
 
53
- code_size = (M * nbits + 7) / 8;
53
+ code_size = (M_2 * nbits_2 + 7) / 8;
54
54
  ntotal = ntotal2 = 0;
55
- M2 = roundup(M, 2);
55
+ M2 = roundup(M_2, 2);
56
56
  is_trained = false;
57
57
  }
58
58
 
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
158
158
 
159
159
  namespace {
160
160
 
161
- template <class C, typename dis_t, class Scaler>
161
+ template <class C, typename dis_t>
162
162
  void estimators_from_tables_generic(
163
163
  const IndexFastScan& index,
164
164
  const uint8_t* codes,
@@ -167,23 +167,27 @@ void estimators_from_tables_generic(
167
167
  size_t k,
168
168
  typename C::T* heap_dis,
169
169
  int64_t* heap_ids,
170
- const Scaler& scaler) {
170
+ const NormTableScaler* scaler) {
171
171
  using accu_t = typename C::T;
172
172
 
173
173
  for (size_t j = 0; j < ncodes; ++j) {
174
174
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
175
175
  accu_t dis = 0;
176
176
  const dis_t* dt = dis_table;
177
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
177
+ int nscale = scaler ? scaler->nscale : 0;
178
+
179
+ for (size_t m = 0; m < index.M - nscale; m++) {
178
180
  uint64_t c = bsr.read(index.nbits);
179
181
  dis += dt[c];
180
182
  dt += index.ksub;
181
183
  }
182
184
 
183
- for (size_t m = 0; m < scaler.nscale; m++) {
184
- uint64_t c = bsr.read(index.nbits);
185
- dis += scaler.scale_one(dt[c]);
186
- dt += index.ksub;
185
+ if (nscale) {
186
+ for (size_t m = 0; m < nscale; m++) {
187
+ uint64_t c = bsr.read(index.nbits);
188
+ dis += scaler->scale_one(dt[c]);
189
+ dt += index.ksub;
190
+ }
187
191
  }
188
192
 
189
193
  if (C::cmp(heap_dis[0], dis)) {
@@ -193,6 +197,28 @@ void estimators_from_tables_generic(
193
197
  }
194
198
  }
195
199
 
200
+ template <class C>
201
+ ResultHandlerCompare<C, false>* make_knn_handler(
202
+ int impl,
203
+ idx_t n,
204
+ idx_t k,
205
+ size_t ntotal,
206
+ float* distances,
207
+ idx_t* labels,
208
+ const IDSelector* sel = nullptr) {
209
+ using HeapHC = HeapHandler<C, false>;
210
+ using ReservoirHC = ReservoirHandler<C, false>;
211
+ using SingleResultHC = SingleResultHandler<C, false>;
212
+
213
+ if (k == 1) {
214
+ return new SingleResultHC(n, ntotal, distances, labels, sel);
215
+ } else if (impl % 2 == 0) {
216
+ return new HeapHC(n, ntotal, k, distances, labels, sel);
217
+ } else /* if (impl % 2 == 1) */ {
218
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
219
+ }
220
+ }
221
+
196
222
  } // anonymous namespace
197
223
 
198
224
  using namespace quantize_lut;
@@ -241,22 +267,21 @@ void IndexFastScan::search(
241
267
  !params, "search params not supported for this index");
242
268
  FAISS_THROW_IF_NOT(k > 0);
243
269
 
244
- DummyScaler scaler;
245
270
  if (metric_type == METRIC_L2) {
246
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
271
+ search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
247
272
  } else {
248
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
273
+ search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
249
274
  }
250
275
  }
251
276
 
252
- template <bool is_max, class Scaler>
277
+ template <bool is_max>
253
278
  void IndexFastScan::search_dispatch_implem(
254
279
  idx_t n,
255
280
  const float* x,
256
281
  idx_t k,
257
282
  float* distances,
258
283
  idx_t* labels,
259
- const Scaler& scaler) const {
284
+ const NormTableScaler* scaler) const {
260
285
  using Cfloat = typename std::conditional<
261
286
  is_max,
262
287
  CMax<float, int64_t>,
@@ -319,14 +344,14 @@ void IndexFastScan::search_dispatch_implem(
319
344
  }
320
345
  }
321
346
 
322
- template <class Cfloat, class Scaler>
347
+ template <class Cfloat>
323
348
  void IndexFastScan::search_implem_234(
324
349
  idx_t n,
325
350
  const float* x,
326
351
  idx_t k,
327
352
  float* distances,
328
353
  idx_t* labels,
329
- const Scaler& scaler) const {
354
+ const NormTableScaler* scaler) const {
330
355
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
331
356
 
332
357
  const size_t dim12 = ksub * M;
@@ -378,7 +403,7 @@ void IndexFastScan::search_implem_234(
378
403
  }
379
404
  }
380
405
 
381
- template <class C, class Scaler>
406
+ template <class C>
382
407
  void IndexFastScan::search_implem_12(
383
408
  idx_t n,
384
409
  const float* x,
@@ -386,7 +411,8 @@ void IndexFastScan::search_implem_12(
386
411
  float* distances,
387
412
  idx_t* labels,
388
413
  int impl,
389
- const Scaler& scaler) const {
414
+ const NormTableScaler* scaler) const {
415
+ using RH = ResultHandlerCompare<C, false>;
390
416
  FAISS_THROW_IF_NOT(bbs == 32);
391
417
 
392
418
  // handle qbs2 blocking by recursive call
@@ -432,63 +458,31 @@ void IndexFastScan::search_implem_12(
432
458
  pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
433
459
  FAISS_THROW_IF_NOT(LUT_nq == n);
434
460
 
435
- if (k == 1) {
436
- SingleResultHandler<C> handler(n, ntotal);
437
- if (skip & 4) {
438
- // pass
439
- } else {
440
- handler.disable = bool(skip & 2);
441
- pq4_accumulate_loop_qbs(
442
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
443
- }
444
-
445
- handler.to_flat_arrays(distances, labels, normalizers.get());
446
-
447
- } else if (impl == 12) {
448
- std::vector<uint16_t> tmp_dis(n * k);
449
- std::vector<int32_t> tmp_ids(n * k);
450
-
451
- if (skip & 4) {
452
- // skip
453
- } else {
454
- HeapHandler<C> handler(
455
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
456
- handler.disable = bool(skip & 2);
457
-
458
- pq4_accumulate_loop_qbs(
459
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
460
-
461
- if (!(skip & 8)) {
462
- handler.to_flat_arrays(distances, labels, normalizers.get());
463
- }
464
- }
465
-
466
- } else { // impl == 13
467
-
468
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
469
- handler.disable = bool(skip & 2);
461
+ std::unique_ptr<RH> handler(
462
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
463
+ handler->disable = bool(skip & 2);
464
+ handler->normalizers = normalizers.get();
470
465
 
471
- if (skip & 4) {
472
- // skip
473
- } else {
474
- pq4_accumulate_loop_qbs(
475
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
476
- }
477
-
478
- if (!(skip & 8)) {
479
- handler.to_flat_arrays(distances, labels, normalizers.get());
480
- }
481
-
482
- FastScan_stats.t0 += handler.times[0];
483
- FastScan_stats.t1 += handler.times[1];
484
- FastScan_stats.t2 += handler.times[2];
485
- FastScan_stats.t3 += handler.times[3];
466
+ if (skip & 4) {
467
+ // pass
468
+ } else {
469
+ pq4_accumulate_loop_qbs(
470
+ qbs,
471
+ ntotal2,
472
+ M2,
473
+ codes.get(),
474
+ LUT.get(),
475
+ *handler.get(),
476
+ scaler);
477
+ }
478
+ if (!(skip & 8)) {
479
+ handler->end();
486
480
  }
487
481
  }
488
482
 
489
483
  FastScanStats FastScan_stats;
490
484
 
491
- template <class C, class Scaler>
485
+ template <class C>
492
486
  void IndexFastScan::search_implem_14(
493
487
  idx_t n,
494
488
  const float* x,
@@ -496,7 +490,8 @@ void IndexFastScan::search_implem_14(
496
490
  float* distances,
497
491
  idx_t* labels,
498
492
  int impl,
499
- const Scaler& scaler) const {
493
+ const NormTableScaler* scaler) const {
494
+ using RH = ResultHandlerCompare<C, false>;
500
495
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
501
496
 
502
497
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -531,90 +526,44 @@ void IndexFastScan::search_implem_14(
531
526
  AlignedTable<uint8_t> LUT(n * dim12);
532
527
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
533
528
 
534
- if (k == 1) {
535
- SingleResultHandler<C> handler(n, ntotal);
536
- if (skip & 4) {
537
- // pass
538
- } else {
539
- handler.disable = bool(skip & 2);
540
- pq4_accumulate_loop(
541
- n,
542
- ntotal2,
543
- bbs,
544
- M2,
545
- codes.get(),
546
- LUT.get(),
547
- handler,
548
- scaler);
549
- }
550
- handler.to_flat_arrays(distances, labels, normalizers.get());
551
-
552
- } else if (impl == 14) {
553
- std::vector<uint16_t> tmp_dis(n * k);
554
- std::vector<int32_t> tmp_ids(n * k);
555
-
556
- if (skip & 4) {
557
- // skip
558
- } else if (k > 1) {
559
- HeapHandler<C> handler(
560
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
561
- handler.disable = bool(skip & 2);
562
-
563
- pq4_accumulate_loop(
564
- n,
565
- ntotal2,
566
- bbs,
567
- M2,
568
- codes.get(),
569
- LUT.get(),
570
- handler,
571
- scaler);
572
-
573
- if (!(skip & 8)) {
574
- handler.to_flat_arrays(distances, labels, normalizers.get());
575
- }
576
- }
577
-
578
- } else { // impl == 15
529
+ std::unique_ptr<RH> handler(
530
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
531
+ handler->disable = bool(skip & 2);
532
+ handler->normalizers = normalizers.get();
579
533
 
580
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
581
- handler.disable = bool(skip & 2);
582
-
583
- if (skip & 4) {
584
- // skip
585
- } else {
586
- pq4_accumulate_loop(
587
- n,
588
- ntotal2,
589
- bbs,
590
- M2,
591
- codes.get(),
592
- LUT.get(),
593
- handler,
594
- scaler);
595
- }
596
-
597
- if (!(skip & 8)) {
598
- handler.to_flat_arrays(distances, labels, normalizers.get());
599
- }
534
+ if (skip & 4) {
535
+ // pass
536
+ } else {
537
+ pq4_accumulate_loop(
538
+ n,
539
+ ntotal2,
540
+ bbs,
541
+ M2,
542
+ codes.get(),
543
+ LUT.get(),
544
+ *handler.get(),
545
+ scaler);
546
+ }
547
+ if (!(skip & 8)) {
548
+ handler->end();
600
549
  }
601
550
  }
602
551
 
603
- template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
552
+ template void IndexFastScan::search_dispatch_implem<true>(
604
553
  idx_t n,
605
554
  const float* x,
606
555
  idx_t k,
607
556
  float* distances,
608
557
  idx_t* labels,
609
- const NormTableScaler& scaler) const;
558
+ const NormTableScaler* scaler) const;
610
559
 
611
- template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
560
+ template void IndexFastScan::search_dispatch_implem<false>(
612
561
  idx_t n,
613
562
  const float* x,
614
563
  idx_t k,
615
564
  float* distances,
616
565
  idx_t* labels,
617
- const NormTableScaler& scaler) const;
566
+ const NormTableScaler* scaler) const;
618
567
 
619
568
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
620
569
  std::vector<uint8_t> code(code_size, 0);
@@ -13,6 +13,7 @@
13
13
  namespace faiss {
14
14
 
15
15
  struct CodePacker;
16
+ struct NormTableScaler;
16
17
 
17
18
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
18
19
  *
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
87
88
  uint8_t* lut,
88
89
  float* normalizers) const;
89
90
 
90
- template <bool is_max, class Scaler>
91
+ template <bool is_max>
91
92
  void search_dispatch_implem(
92
93
  idx_t n,
93
94
  const float* x,
94
95
  idx_t k,
95
96
  float* distances,
96
97
  idx_t* labels,
97
- const Scaler& scaler) const;
98
+ const NormTableScaler* scaler) const;
98
99
 
99
- template <class Cfloat, class Scaler>
100
+ template <class Cfloat>
100
101
  void search_implem_234(
101
102
  idx_t n,
102
103
  const float* x,
103
104
  idx_t k,
104
105
  float* distances,
105
106
  idx_t* labels,
106
- const Scaler& scaler) const;
107
+ const NormTableScaler* scaler) const;
107
108
 
108
- template <class C, class Scaler>
109
+ template <class C>
109
110
  void search_implem_12(
110
111
  idx_t n,
111
112
  const float* x,
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
113
114
  float* distances,
114
115
  idx_t* labels,
115
116
  int impl,
116
- const Scaler& scaler) const;
117
+ const NormTableScaler* scaler) const;
117
118
 
118
- template <class C, class Scaler>
119
+ template <class C>
119
120
  void search_implem_14(
120
121
  idx_t n,
121
122
  const float* x,
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
123
124
  float* distances,
124
125
  idx_t* labels,
125
126
  int impl,
126
- const Scaler& scaler) const;
127
+ const NormTableScaler* scaler) const;
127
128
 
128
129
  void reconstruct(idx_t key, float* recons) const override;
129
130
  size_t remove_ids(const IDSelector& sel) override;
@@ -14,6 +14,7 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
16
  #include <faiss/utils/extra_distances.h>
17
+ #include <faiss/utils/prefetch.h>
17
18
  #include <faiss/utils/sorting.h>
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <cstring>
@@ -40,15 +41,19 @@ void IndexFlat::search(
40
41
  } else if (metric_type == METRIC_L2) {
41
42
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
42
43
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
43
- } else if (is_similarity_metric(metric_type)) {
44
- float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
- knn_extra_metrics(
46
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
44
  } else {
48
- FAISS_THROW_IF_NOT(!sel);
49
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
+ FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
50
46
  knn_extra_metrics(
51
- x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47
+ x,
48
+ get_xb(),
49
+ d,
50
+ n,
51
+ ntotal,
52
+ metric_type,
53
+ metric_arg,
54
+ k,
55
+ distances,
56
+ labels);
52
57
  }
53
58
  }
54
59
 
@@ -122,6 +127,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122
127
  void set_query(const float* x) override {
123
128
  q = x;
124
129
  }
130
+
131
+ // compute four distances
132
+ void distances_batch_4(
133
+ const idx_t idx0,
134
+ const idx_t idx1,
135
+ const idx_t idx2,
136
+ const idx_t idx3,
137
+ float& dis0,
138
+ float& dis1,
139
+ float& dis2,
140
+ float& dis3) final override {
141
+ ndis += 4;
142
+
143
+ // compute first, assign next
144
+ const float* __restrict y0 =
145
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
146
+ const float* __restrict y1 =
147
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
148
+ const float* __restrict y2 =
149
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
150
+ const float* __restrict y3 =
151
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
152
+
153
+ float dp0 = 0;
154
+ float dp1 = 0;
155
+ float dp2 = 0;
156
+ float dp3 = 0;
157
+ fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
158
+ dis0 = dp0;
159
+ dis1 = dp1;
160
+ dis2 = dp2;
161
+ dis3 = dp3;
162
+ }
125
163
  };
126
164
 
127
165
  struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +169,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131
169
  const float* b;
132
170
  size_t ndis;
133
171
 
134
- float symmetric_dis(idx_t i, idx_t j) override {
172
+ float symmetric_dis(idx_t i, idx_t j) final override {
135
173
  return fvec_inner_product(b + j * d, b + i * d, d);
136
174
  }
137
175
 
138
- float distance_to_code(const uint8_t* code) final {
176
+ float distance_to_code(const uint8_t* code) final override {
139
177
  ndis++;
140
- return fvec_inner_product(q, (float*)code, d);
178
+ return fvec_inner_product(q, (const float*)code, d);
141
179
  }
142
180
 
143
181
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -153,6 +191,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153
191
  void set_query(const float* x) override {
154
192
  q = x;
155
193
  }
194
+
195
+ // compute four distances
196
+ void distances_batch_4(
197
+ const idx_t idx0,
198
+ const idx_t idx1,
199
+ const idx_t idx2,
200
+ const idx_t idx3,
201
+ float& dis0,
202
+ float& dis1,
203
+ float& dis2,
204
+ float& dis3) final override {
205
+ ndis += 4;
206
+
207
+ // compute first, assign next
208
+ const float* __restrict y0 =
209
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
210
+ const float* __restrict y1 =
211
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
212
+ const float* __restrict y2 =
213
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
214
+ const float* __restrict y3 =
215
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
216
+
217
+ float dp0 = 0;
218
+ float dp1 = 0;
219
+ float dp2 = 0;
220
+ float dp3 = 0;
221
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
222
+ dis0 = dp0;
223
+ dis1 = dp1;
224
+ dis2 = dp2;
225
+ dis3 = dp3;
226
+ }
156
227
  };
157
228
 
158
229
  } // namespace
@@ -184,6 +255,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184
255
  }
185
256
  }
186
257
 
258
+ /***************************************************
259
+ * IndexFlatL2
260
+ ***************************************************/
261
+
262
+ namespace {
263
+ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
264
+ size_t d;
265
+ idx_t nb;
266
+ const float* q;
267
+ const float* b;
268
+ size_t ndis;
269
+
270
+ const float* l2norms;
271
+ float query_l2norm;
272
+
273
+ float distance_to_code(const uint8_t* code) final override {
274
+ ndis++;
275
+ return fvec_L2sqr(q, (float*)code, d);
276
+ }
277
+
278
+ float operator()(const idx_t i) final override {
279
+ const float* __restrict y =
280
+ reinterpret_cast<const float*>(codes + i * code_size);
281
+
282
+ prefetch_L2(l2norms + i);
283
+ const float dp0 = fvec_inner_product(q, y, d);
284
+ return query_l2norm + l2norms[i] - 2 * dp0;
285
+ }
286
+
287
+ float symmetric_dis(idx_t i, idx_t j) final override {
288
+ const float* __restrict yi =
289
+ reinterpret_cast<const float*>(codes + i * code_size);
290
+ const float* __restrict yj =
291
+ reinterpret_cast<const float*>(codes + j * code_size);
292
+
293
+ prefetch_L2(l2norms + i);
294
+ prefetch_L2(l2norms + j);
295
+ const float dp0 = fvec_inner_product(yi, yj, d);
296
+ return l2norms[i] + l2norms[j] - 2 * dp0;
297
+ }
298
+
299
+ explicit FlatL2WithNormsDis(
300
+ const IndexFlatL2& storage,
301
+ const float* q = nullptr)
302
+ : FlatCodesDistanceComputer(
303
+ storage.codes.data(),
304
+ storage.code_size),
305
+ d(storage.d),
306
+ nb(storage.ntotal),
307
+ q(q),
308
+ b(storage.get_xb()),
309
+ ndis(0),
310
+ l2norms(storage.cached_l2norms.data()),
311
+ query_l2norm(0) {}
312
+
313
+ void set_query(const float* x) override {
314
+ q = x;
315
+ query_l2norm = fvec_norm_L2sqr(q, d);
316
+ }
317
+
318
+ // compute four distances
319
+ void distances_batch_4(
320
+ const idx_t idx0,
321
+ const idx_t idx1,
322
+ const idx_t idx2,
323
+ const idx_t idx3,
324
+ float& dis0,
325
+ float& dis1,
326
+ float& dis2,
327
+ float& dis3) final override {
328
+ ndis += 4;
329
+
330
+ // compute first, assign next
331
+ const float* __restrict y0 =
332
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
333
+ const float* __restrict y1 =
334
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
335
+ const float* __restrict y2 =
336
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
337
+ const float* __restrict y3 =
338
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
339
+
340
+ prefetch_L2(l2norms + idx0);
341
+ prefetch_L2(l2norms + idx1);
342
+ prefetch_L2(l2norms + idx2);
343
+ prefetch_L2(l2norms + idx3);
344
+
345
+ float dp0 = 0;
346
+ float dp1 = 0;
347
+ float dp2 = 0;
348
+ float dp3 = 0;
349
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
350
+ dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
351
+ dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
352
+ dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
353
+ dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
354
+ }
355
+ };
356
+
357
+ } // namespace
358
+
359
+ void IndexFlatL2::sync_l2norms() {
360
+ cached_l2norms.resize(ntotal);
361
+ fvec_norms_L2sqr(
362
+ cached_l2norms.data(),
363
+ reinterpret_cast<const float*>(codes.data()),
364
+ d,
365
+ ntotal);
366
+ }
367
+
368
+ void IndexFlatL2::clear_l2norms() {
369
+ cached_l2norms.clear();
370
+ cached_l2norms.shrink_to_fit();
371
+ }
372
+
373
+ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
374
+ if (metric_type == METRIC_L2) {
375
+ if (!cached_l2norms.empty()) {
376
+ return new FlatL2WithNormsDis(*this);
377
+ }
378
+ }
379
+
380
+ return IndexFlat::get_FlatCodesDistanceComputer();
381
+ }
382
+
187
383
  /***************************************************
188
384
  * IndexFlat1D
189
385
  ***************************************************/