faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -5,13 +5,12 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/utils/distances.h>
11
9
 
12
10
  #include <algorithm>
13
11
  #include <cassert>
14
12
  #include <cmath>
13
+ #include <cstddef>
15
14
  #include <cstdio>
16
15
  #include <cstring>
17
16
 
@@ -64,7 +63,7 @@ void fvec_norms_L2(
64
63
  const float* __restrict x,
65
64
  size_t d,
66
65
  size_t nx) {
67
- #pragma omp parallel for
66
+ #pragma omp parallel for if (nx > 10000)
68
67
  for (int64_t i = 0; i < nx; i++) {
69
68
  nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
70
69
  }
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
75
74
  const float* __restrict x,
76
75
  size_t d,
77
76
  size_t nx) {
78
- #pragma omp parallel for
77
+ #pragma omp parallel for if (nx > 10000)
79
78
  for (int64_t i = 0; i < nx; i++)
80
79
  nr[i] = fvec_norm_L2sqr(x + i * d, d);
81
80
  }
82
81
 
83
- void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
84
- #pragma omp parallel for
82
+ // The following is a workaround to a problem
83
+ // in OpenMP in fbcode. The crash occurs
84
+ // inside OMP when IndexIVFSpectralHash::set_query()
85
+ // calls fvec_renorm_L2. set_query() is always
86
+ // calling this function with nx == 1, so even
87
+ // the omp version should run single threaded,
88
+ // as per the if condition of the omp pragma.
89
+ // Instead, the omp version crashes inside OMP.
90
+ // The workaround below is explicitly branching
91
+ // off to a codepath without omp.
92
+
93
+ #define FVEC_RENORM_L2_IMPL \
94
+ float* __restrict xi = x + i * d; \
95
+ \
96
+ float nr = fvec_norm_L2sqr(xi, d); \
97
+ \
98
+ if (nr > 0) { \
99
+ size_t j; \
100
+ const float inv_nr = 1.0 / sqrtf(nr); \
101
+ for (j = 0; j < d; j++) \
102
+ xi[j] *= inv_nr; \
103
+ }
104
+
105
+ void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
85
106
  for (int64_t i = 0; i < nx; i++) {
86
- float* __restrict xi = x + i * d;
107
+ FVEC_RENORM_L2_IMPL
108
+ }
109
+ }
87
110
 
88
- float nr = fvec_norm_L2sqr(xi, d);
111
+ void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
112
+ #pragma omp parallel for if (nx > 10000)
113
+ for (int64_t i = 0; i < nx; i++) {
114
+ FVEC_RENORM_L2_IMPL
115
+ }
116
+ }
89
117
 
90
- if (nr > 0) {
91
- size_t j;
92
- const float inv_nr = 1.0 / sqrtf(nr);
93
- for (j = 0; j < d; j++)
94
- xi[j] *= inv_nr;
95
- }
118
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
119
+ if (nx <= 10000) {
120
+ fvec_renorm_L2_noomp(d, nx, x);
121
+ } else {
122
+ fvec_renorm_L2_omp(d, nx, x);
96
123
  }
97
124
  }
98
125
 
@@ -103,16 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
103
130
  namespace {
104
131
 
105
132
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
106
- template <class ResultHandler, bool use_sel = false>
133
+ template <class BlockResultHandler, bool use_sel = false>
107
134
  void exhaustive_inner_product_seq(
108
135
  const float* x,
109
136
  const float* y,
110
137
  size_t d,
111
138
  size_t nx,
112
139
  size_t ny,
113
- ResultHandler& res,
140
+ BlockResultHandler& res,
114
141
  const IDSelector* sel = nullptr) {
115
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
142
+ using SingleResultHandler =
143
+ typename BlockResultHandler::SingleResultHandler;
116
144
  int nt = std::min(int(nx), omp_get_max_threads());
117
145
 
118
146
  FAISS_ASSERT(use_sel == (sel != nullptr));
@@ -139,16 +167,17 @@ void exhaustive_inner_product_seq(
139
167
  }
140
168
  }
141
169
 
142
- template <class ResultHandler, bool use_sel = false>
170
+ template <class BlockResultHandler, bool use_sel = false>
143
171
  void exhaustive_L2sqr_seq(
144
172
  const float* x,
145
173
  const float* y,
146
174
  size_t d,
147
175
  size_t nx,
148
176
  size_t ny,
149
- ResultHandler& res,
177
+ BlockResultHandler& res,
150
178
  const IDSelector* sel = nullptr) {
151
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
179
+ using SingleResultHandler =
180
+ typename BlockResultHandler::SingleResultHandler;
152
181
  int nt = std::min(int(nx), omp_get_max_threads());
153
182
 
154
183
  FAISS_ASSERT(use_sel == (sel != nullptr));
@@ -174,14 +203,14 @@ void exhaustive_L2sqr_seq(
174
203
  }
175
204
 
176
205
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
177
- template <class ResultHandler>
206
+ template <class BlockResultHandler>
178
207
  void exhaustive_inner_product_blas(
179
208
  const float* x,
180
209
  const float* y,
181
210
  size_t d,
182
211
  size_t nx,
183
212
  size_t ny,
184
- ResultHandler& res) {
213
+ BlockResultHandler& res) {
185
214
  // BLAS does not like empty matrices
186
215
  if (nx == 0 || ny == 0)
187
216
  return;
@@ -230,14 +259,14 @@ void exhaustive_inner_product_blas(
230
259
 
231
260
  // distance correction is an operator that can be applied to transform
232
261
  // the distances
233
- template <class ResultHandler>
262
+ template <class BlockResultHandler>
234
263
  void exhaustive_L2sqr_blas_default_impl(
235
264
  const float* x,
236
265
  const float* y,
237
266
  size_t d,
238
267
  size_t nx,
239
268
  size_t ny,
240
- ResultHandler& res,
269
+ BlockResultHandler& res,
241
270
  const float* y_norms = nullptr) {
242
271
  // BLAS does not like empty matrices
243
272
  if (nx == 0 || ny == 0)
@@ -313,14 +342,14 @@ void exhaustive_L2sqr_blas_default_impl(
313
342
  }
314
343
  }
315
344
 
316
- template <class ResultHandler>
345
+ template <class BlockResultHandler>
317
346
  void exhaustive_L2sqr_blas(
318
347
  const float* x,
319
348
  const float* y,
320
349
  size_t d,
321
350
  size_t nx,
322
351
  size_t ny,
323
- ResultHandler& res,
352
+ BlockResultHandler& res,
324
353
  const float* y_norms = nullptr) {
325
354
  exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326
355
  }
@@ -332,7 +361,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
332
361
  size_t d,
333
362
  size_t nx,
334
363
  size_t ny,
335
- SingleBestResultHandler<CMax<float, int64_t>>& res,
364
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
336
365
  const float* y_norms) {
337
366
  // BLAS does not like empty matrices
338
367
  if (nx == 0 || ny == 0)
@@ -388,8 +417,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
388
417
  for (int64_t i = i0; i < i1; i++) {
389
418
  float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
390
419
 
391
- _mm_prefetch(ip_line, _MM_HINT_NTA);
392
- _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
420
+ _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
421
+ _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
393
422
 
394
423
  // constant
395
424
  const __m256 mul_minus2 = _mm256_set1_ps(-2);
@@ -416,8 +445,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
416
445
 
417
446
  // process 16 elements per loop
418
447
  for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
419
- _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
420
- _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
448
+ _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
449
+ _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
421
450
 
422
451
  // load values for norms
423
452
  const __m256 y_norm_0 =
@@ -535,13 +564,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
535
564
 
536
565
  // an override if only a single closest point is needed
537
566
  template <>
538
- void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
567
+ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
539
568
  const float* x,
540
569
  const float* y,
541
570
  size_t d,
542
571
  size_t nx,
543
572
  size_t ny,
544
- SingleBestResultHandler<CMax<float, int64_t>>& res,
573
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
545
574
  const float* y_norms) {
546
575
  #if defined(__AVX2__)
547
576
  // use a faster fused kernel if available
@@ -562,28 +591,29 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
562
591
 
563
592
  // run the default implementation
564
593
  exhaustive_L2sqr_blas_default_impl<
565
- SingleBestResultHandler<CMax<float, int64_t>>>(
594
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
566
595
  x, y, d, nx, ny, res, y_norms);
567
596
  #else
568
597
  // run the default implementation
569
598
  exhaustive_L2sqr_blas_default_impl<
570
- SingleBestResultHandler<CMax<float, int64_t>>>(
599
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
571
600
  x, y, d, nx, ny, res, y_norms);
572
601
  #endif
573
602
  }
574
603
 
575
- template <class ResultHandler>
604
+ template <class BlockResultHandler>
576
605
  void knn_L2sqr_select(
577
606
  const float* x,
578
607
  const float* y,
579
608
  size_t d,
580
609
  size_t nx,
581
610
  size_t ny,
582
- ResultHandler& res,
611
+ BlockResultHandler& res,
583
612
  const float* y_norm2,
584
613
  const IDSelector* sel) {
585
614
  if (sel) {
586
- exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
615
+ exhaustive_L2sqr_seq<BlockResultHandler, true>(
616
+ x, y, d, nx, ny, res, sel);
587
617
  } else if (nx < distance_compute_blas_threshold) {
588
618
  exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
589
619
  } else {
@@ -591,6 +621,25 @@ void knn_L2sqr_select(
591
621
  }
592
622
  }
593
623
 
624
+ template <class BlockResultHandler>
625
+ void knn_inner_product_select(
626
+ const float* x,
627
+ const float* y,
628
+ size_t d,
629
+ size_t nx,
630
+ size_t ny,
631
+ BlockResultHandler& res,
632
+ const IDSelector* sel) {
633
+ if (sel) {
634
+ exhaustive_inner_product_seq<BlockResultHandler, true>(
635
+ x, y, d, nx, ny, res, sel);
636
+ } else if (nx < distance_compute_blas_threshold) {
637
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
638
+ } else {
639
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
640
+ }
641
+ }
642
+
594
643
  } // anonymous namespace
595
644
 
596
645
  /*******************************************************
@@ -609,7 +658,7 @@ void knn_inner_product(
609
658
  size_t nx,
610
659
  size_t ny,
611
660
  size_t k,
612
- float* val,
661
+ float* vals,
613
662
  int64_t* ids,
614
663
  const IDSelector* sel) {
615
664
  int64_t imin = 0;
@@ -622,30 +671,21 @@ void knn_inner_product(
622
671
  }
623
672
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
624
673
  knn_inner_products_by_idx(
625
- x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
674
+ x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
626
675
  return;
627
676
  }
628
- if (k < distance_compute_min_k_reservoir) {
629
- using RH = HeapResultHandler<CMin<float, int64_t>>;
630
- RH res(nx, val, ids, k);
631
- if (sel) {
632
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
633
- } else if (nx < distance_compute_blas_threshold) {
634
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
635
- } else {
636
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
637
- }
677
+
678
+ if (k == 1) {
679
+ Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
680
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
681
+ } else if (k < distance_compute_min_k_reservoir) {
682
+ HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
683
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
638
684
  } else {
639
- using RH = ReservoirResultHandler<CMin<float, int64_t>>;
640
- RH res(nx, val, ids, k);
641
- if (sel) {
642
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
643
- } else if (nx < distance_compute_blas_threshold) {
644
- exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
645
- } else {
646
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
647
- }
685
+ ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
686
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
648
687
  }
688
+
649
689
  if (imin != 0) {
650
690
  for (size_t i = 0; i < nx * k; i++) {
651
691
  if (ids[i] >= 0) {
@@ -687,17 +727,17 @@ void knn_L2sqr(
687
727
  sel = nullptr;
688
728
  }
689
729
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
690
- knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
730
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
691
731
  return;
692
732
  }
693
733
  if (k == 1) {
694
- SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
734
+ Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
695
735
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
696
736
  } else if (k < distance_compute_min_k_reservoir) {
697
- HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
737
+ HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
698
738
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
699
739
  } else {
700
- ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
740
+ ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
701
741
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
702
742
  }
703
743
  if (imin != 0) {
@@ -735,7 +775,7 @@ void range_search_L2sqr(
735
775
  float radius,
736
776
  RangeSearchResult* res,
737
777
  const IDSelector* sel) {
738
- using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
778
+ using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
739
779
  RH resh(res, radius);
740
780
  if (sel) {
741
781
  exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
@@ -755,7 +795,7 @@ void range_search_inner_product(
755
795
  float radius,
756
796
  RangeSearchResult* res,
757
797
  const IDSelector* sel) {
758
- using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
798
+ using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
759
799
  RH resh(res, radius);
760
800
  if (sel) {
761
801
  exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
@@ -786,9 +826,11 @@ void fvec_inner_products_by_idx(
786
826
  const float* xj = x + j * d;
787
827
  float* __restrict ipj = ip + j * ny;
788
828
  for (size_t i = 0; i < ny; i++) {
789
- if (idsj[i] < 0)
790
- continue;
791
- ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
829
+ if (idsj[i] < 0) {
830
+ ipj[i] = -INFINITY;
831
+ } else {
832
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
833
+ }
792
834
  }
793
835
  }
794
836
  }
@@ -809,9 +851,11 @@ void fvec_L2sqr_by_idx(
809
851
  const float* xj = x + j * d;
810
852
  float* __restrict disj = dis + j * ny;
811
853
  for (size_t i = 0; i < ny; i++) {
812
- if (idsj[i] < 0)
813
- continue;
814
- disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
854
+ if (idsj[i] < 0) {
855
+ disj[i] = INFINITY;
856
+ } else {
857
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
858
+ }
815
859
  }
816
860
  }
817
861
  }
@@ -828,6 +872,8 @@ void pairwise_indexed_L2sqr(
828
872
  for (int64_t j = 0; j < n; j++) {
829
873
  if (ix[j] >= 0 && iy[j] >= 0) {
830
874
  dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
875
+ } else {
876
+ dis[j] = INFINITY;
831
877
  }
832
878
  }
833
879
  }
@@ -844,6 +890,8 @@ void pairwise_indexed_inner_product(
844
890
  for (int64_t j = 0; j < n; j++) {
845
891
  if (ix[j] >= 0 && iy[j] >= 0) {
846
892
  dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
893
+ } else {
894
+ dis[j] = -INFINITY;
847
895
  }
848
896
  }
849
897
  }
@@ -857,6 +905,7 @@ void knn_inner_products_by_idx(
857
905
  size_t d,
858
906
  size_t nx,
859
907
  size_t ny,
908
+ size_t nsubset,
860
909
  size_t k,
861
910
  float* res_vals,
862
911
  int64_t* res_ids,
@@ -874,9 +923,10 @@ void knn_inner_products_by_idx(
874
923
  int64_t* __restrict idxi = res_ids + i * k;
875
924
  minheap_heapify(k, simi, idxi);
876
925
 
877
- for (j = 0; j < ny; j++) {
878
- if (idsi[j] < 0)
926
+ for (j = 0; j < nsubset; j++) {
927
+ if (idsi[j] < 0 || idsi[j] >= ny) {
879
928
  break;
929
+ }
880
930
  float ip = fvec_inner_product(x_, y + d * idsi[j], d);
881
931
 
882
932
  if (ip > simi[0]) {
@@ -894,6 +944,7 @@ void knn_L2sqr_by_idx(
894
944
  size_t d,
895
945
  size_t nx,
896
946
  size_t ny,
947
+ size_t nsubset,
897
948
  size_t k,
898
949
  float* res_vals,
899
950
  int64_t* res_ids,
@@ -908,7 +959,10 @@ void knn_L2sqr_by_idx(
908
959
  float* __restrict simi = res_vals + i * k;
909
960
  int64_t* __restrict idxi = res_ids + i * k;
910
961
  maxheap_heapify(k, simi, idxi);
911
- for (size_t j = 0; j < ny; j++) {
962
+ for (size_t j = 0; j < nsubset; j++) {
963
+ if (idsi[j] < 0 || idsi[j] >= ny) {
964
+ break;
965
+ }
912
966
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
913
967
 
914
968
  if (disij < simi[0]) {
@@ -36,6 +36,34 @@ float fvec_L1(const float* x, const float* y, size_t d);
36
36
  /// infinity distance
37
37
  float fvec_Linf(const float* x, const float* y, size_t d);
38
38
 
39
+ /// Special version of inner product that computes 4 distances
40
+ /// between x and yi, which is performance oriented.
41
+ void fvec_inner_product_batch_4(
42
+ const float* x,
43
+ const float* y0,
44
+ const float* y1,
45
+ const float* y2,
46
+ const float* y3,
47
+ const size_t d,
48
+ float& dis0,
49
+ float& dis1,
50
+ float& dis2,
51
+ float& dis3);
52
+
53
+ /// Special version of L2sqr that computes 4 distances
54
+ /// between x and yi, which is performance oriented.
55
+ void fvec_L2sqr_batch_4(
56
+ const float* x,
57
+ const float* y0,
58
+ const float* y1,
59
+ const float* y2,
60
+ const float* y3,
61
+ const size_t d,
62
+ float& dis0,
63
+ float& dis1,
64
+ float& dis2,
65
+ float& dis3);
66
+
39
67
  /** Compute pairwise distances between sets of vectors
40
68
  *
41
69
  * @param d dimension of the vectors
@@ -170,8 +198,16 @@ void fvec_sub(size_t d, const float* a, const float* b, float* c);
170
198
  * Compute a subset of distances
171
199
  ***************************************************************************/
172
200
 
173
- /* compute the inner product between x and a subset y of ny vectors,
174
- whose indices are given by idy. */
201
+ /** compute the inner product between x and a subset y of ny vectors defined by
202
+ * ids
203
+ *
204
+ * ip(i, j) = inner_product(x(i, :), y(ids(i, j), :))
205
+ *
206
+ * @param ip output array, size nx * ny
207
+ * @param x first-term vector, size nx * d
208
+ * @param y second-term vector, size (max(ids) + 1) * d
209
+ * @param ids ids to sample from y, size nx * ny
210
+ */
175
211
  void fvec_inner_products_by_idx(
176
212
  float* ip,
177
213
  const float* x,
@@ -181,7 +217,16 @@ void fvec_inner_products_by_idx(
181
217
  size_t nx,
182
218
  size_t ny);
183
219
 
184
- /* same but for a subset in y indexed by idsy (ny vectors in total) */
220
+ /** compute the squared L2 distances between x and a subset y of ny vectors
221
+ * defined by ids
222
+ *
223
+ * dis(i, j) = inner_product(x(i, :), y(ids(i, j), :))
224
+ *
225
+ * @param dis output array, size nx * ny
226
+ * @param x first-term vector, size nx * d
227
+ * @param y second-term vector, size (max(ids) + 1) * d
228
+ * @param ids ids to sample from y, size nx * ny
229
+ */
185
230
  void fvec_L2sqr_by_idx(
186
231
  float* dis,
187
232
  const float* x,
@@ -208,7 +253,14 @@ void pairwise_indexed_L2sqr(
208
253
  const int64_t* iy,
209
254
  float* dis);
210
255
 
211
- /* same for inner product */
256
+ /** compute dis[j] = inner_product(x[ix[j]], y[iy[j]]) forall j=0..n-1
257
+ *
258
+ * @param x size (max(ix) + 1, d)
259
+ * @param y size (max(iy) + 1, d)
260
+ * @param ix size n
261
+ * @param iy size n
262
+ * @param dis size n
263
+ */
212
264
  void pairwise_indexed_inner_product(
213
265
  size_t d,
214
266
  size_t n,
@@ -324,6 +376,7 @@ void knn_inner_products_by_idx(
324
376
  const int64_t* subset,
325
377
  size_t d,
326
378
  size_t nx,
379
+ size_t ny,
327
380
  size_t nsubset,
328
381
  size_t k,
329
382
  float* vals,
@@ -346,6 +399,7 @@ void knn_L2sqr_by_idx(
346
399
  const int64_t* subset,
347
400
  size_t d,
348
401
  size_t nx,
402
+ size_t ny,
349
403
  size_t nsubset,
350
404
  size_t k,
351
405
  float* vals,
@@ -406,4 +460,27 @@ void compute_PQ_dis_tables_dsub2(
406
460
  * Templatized versions of distance functions
407
461
  ***************************************************************************/
408
462
 
463
+ /***************************************************************************
464
+ * Misc matrix and vector manipulation functions
465
+ ***************************************************************************/
466
+
467
+ /** compute c := a + bf * b for a, b and c tables
468
+ *
469
+ * @param n size of the tables
470
+ * @param a size n
471
+ * @param b size n
472
+ * @param c restult table, size n
473
+ */
474
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
475
+
476
+ /** same as fvec_madd, also return index of the min of the result table
477
+ * @return index of the min of table c
478
+ */
479
+ int fvec_madd_and_argmin(
480
+ size_t n,
481
+ const float* a,
482
+ float bf,
483
+ const float* b,
484
+ float* c);
485
+
409
486
  } // namespace faiss
@@ -9,7 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances_fused/avx512.h>
11
11
 
12
- #ifdef __AVX512__
12
+ #ifdef __AVX512F__
13
13
 
14
14
  #include <immintrin.h>
15
15
 
@@ -68,7 +68,7 @@ void kernel(
68
68
  const float* const __restrict y,
69
69
  const float* const __restrict y_transposed,
70
70
  size_t ny,
71
- SingleBestResultHandler<CMax<float, int64_t>>& res,
71
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
72
72
  const float* __restrict y_norms,
73
73
  size_t i) {
74
74
  const size_t ny_p =
@@ -231,7 +231,7 @@ void exhaustive_L2sqr_fused_cmax(
231
231
  const float* const __restrict y,
232
232
  size_t nx,
233
233
  size_t ny,
234
- SingleBestResultHandler<CMax<float, int64_t>>& res,
234
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
235
235
  const float* __restrict y_norms) {
236
236
  // BLAS does not like empty matrices
237
237
  if (nx == 0 || ny == 0) {
@@ -275,7 +275,7 @@ void exhaustive_L2sqr_fused_cmax(
275
275
  x, y, y_transposed.data(), ny, res, y_norms, i);
276
276
  }
277
277
 
278
- // Does nothing for SingleBestResultHandler, but
278
+ // Does nothing for Top1BlockResultHandler, but
279
279
  // keeping the call for the consistency.
280
280
  res.end_multiple();
281
281
  InterruptCallback::check();
@@ -289,7 +289,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
289
289
  size_t d,
290
290
  size_t nx,
291
291
  size_t ny,
292
- SingleBestResultHandler<CMax<float, int64_t>>& res,
292
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
293
293
  const float* y_norms) {
294
294
  // process only cases with certain dimensionalities
295
295
 
@@ -16,7 +16,7 @@
16
16
 
17
17
  #include <faiss/utils/Heap.h>
18
18
 
19
- #ifdef __AVX512__
19
+ #ifdef __AVX512F__
20
20
 
21
21
  namespace faiss {
22
22
 
@@ -28,7 +28,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
28
28
  size_t d,
29
29
  size_t nx,
30
30
  size_t ny,
31
- SingleBestResultHandler<CMax<float, int64_t>>& res,
31
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
32
32
  const float* y_norms);
33
33
 
34
34
  } // namespace faiss
@@ -20,14 +20,14 @@ bool exhaustive_L2sqr_fused_cmax(
20
20
  size_t d,
21
21
  size_t nx,
22
22
  size_t ny,
23
- SingleBestResultHandler<CMax<float, int64_t>>& res,
23
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
24
24
  const float* y_norms) {
25
25
  if (nx == 0 || ny == 0) {
26
26
  // nothing to do
27
27
  return true;
28
28
  }
29
29
 
30
- #ifdef __AVX512__
30
+ #ifdef __AVX512F__
31
31
  // avx512 kernel
32
32
  return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
33
33
  #elif defined(__AVX2__) || defined(__aarch64__)
@@ -34,7 +34,7 @@ bool exhaustive_L2sqr_fused_cmax(
34
34
  size_t d,
35
35
  size_t nx,
36
36
  size_t ny,
37
- SingleBestResultHandler<CMax<float, int64_t>>& res,
37
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
38
38
  const float* y_norms);
39
39
 
40
40
  } // namespace faiss