faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -5,13 +5,12 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/utils/distances.h>
11
9
 
12
10
  #include <algorithm>
13
11
  #include <cassert>
14
12
  #include <cmath>
13
+ #include <cstddef>
15
14
  #include <cstdio>
16
15
  #include <cstring>
17
16
 
@@ -64,7 +63,7 @@ void fvec_norms_L2(
64
63
  const float* __restrict x,
65
64
  size_t d,
66
65
  size_t nx) {
67
- #pragma omp parallel for
66
+ #pragma omp parallel for if (nx > 10000)
68
67
  for (int64_t i = 0; i < nx; i++) {
69
68
  nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
70
69
  }
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
75
74
  const float* __restrict x,
76
75
  size_t d,
77
76
  size_t nx) {
78
- #pragma omp parallel for
77
+ #pragma omp parallel for if (nx > 10000)
79
78
  for (int64_t i = 0; i < nx; i++)
80
79
  nr[i] = fvec_norm_L2sqr(x + i * d, d);
81
80
  }
82
81
 
83
- void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
84
- #pragma omp parallel for
82
+ // The following is a workaround to a problem
83
+ // in OpenMP in fbcode. The crash occurs
84
+ // inside OMP when IndexIVFSpectralHash::set_query()
85
+ // calls fvec_renorm_L2. set_query() is always
86
+ // calling this function with nx == 1, so even
87
+ // the omp version should run single threaded,
88
+ // as per the if condition of the omp pragma.
89
+ // Instead, the omp version crashes inside OMP.
90
+ // The workaround below is explicitly branching
91
+ // off to a codepath without omp.
92
+
93
+ #define FVEC_RENORM_L2_IMPL \
94
+ float* __restrict xi = x + i * d; \
95
+ \
96
+ float nr = fvec_norm_L2sqr(xi, d); \
97
+ \
98
+ if (nr > 0) { \
99
+ size_t j; \
100
+ const float inv_nr = 1.0 / sqrtf(nr); \
101
+ for (j = 0; j < d; j++) \
102
+ xi[j] *= inv_nr; \
103
+ }
104
+
105
+ void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
85
106
  for (int64_t i = 0; i < nx; i++) {
86
- float* __restrict xi = x + i * d;
107
+ FVEC_RENORM_L2_IMPL
108
+ }
109
+ }
87
110
 
88
- float nr = fvec_norm_L2sqr(xi, d);
111
+ void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
112
+ #pragma omp parallel for if (nx > 10000)
113
+ for (int64_t i = 0; i < nx; i++) {
114
+ FVEC_RENORM_L2_IMPL
115
+ }
116
+ }
89
117
 
90
- if (nr > 0) {
91
- size_t j;
92
- const float inv_nr = 1.0 / sqrtf(nr);
93
- for (j = 0; j < d; j++)
94
- xi[j] *= inv_nr;
95
- }
118
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
119
+ if (nx <= 10000) {
120
+ fvec_renorm_L2_noomp(d, nx, x);
121
+ } else {
122
+ fvec_renorm_L2_omp(d, nx, x);
96
123
  }
97
124
  }
98
125
 
@@ -103,16 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
103
130
  namespace {
104
131
 
105
132
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
106
- template <class ResultHandler, bool use_sel = false>
133
+ template <class BlockResultHandler, bool use_sel = false>
107
134
  void exhaustive_inner_product_seq(
108
135
  const float* x,
109
136
  const float* y,
110
137
  size_t d,
111
138
  size_t nx,
112
139
  size_t ny,
113
- ResultHandler& res,
140
+ BlockResultHandler& res,
114
141
  const IDSelector* sel = nullptr) {
115
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
142
+ using SingleResultHandler =
143
+ typename BlockResultHandler::SingleResultHandler;
116
144
  int nt = std::min(int(nx), omp_get_max_threads());
117
145
 
118
146
  FAISS_ASSERT(use_sel == (sel != nullptr));
@@ -139,16 +167,17 @@ void exhaustive_inner_product_seq(
139
167
  }
140
168
  }
141
169
 
142
- template <class ResultHandler, bool use_sel = false>
170
+ template <class BlockResultHandler, bool use_sel = false>
143
171
  void exhaustive_L2sqr_seq(
144
172
  const float* x,
145
173
  const float* y,
146
174
  size_t d,
147
175
  size_t nx,
148
176
  size_t ny,
149
- ResultHandler& res,
177
+ BlockResultHandler& res,
150
178
  const IDSelector* sel = nullptr) {
151
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
179
+ using SingleResultHandler =
180
+ typename BlockResultHandler::SingleResultHandler;
152
181
  int nt = std::min(int(nx), omp_get_max_threads());
153
182
 
154
183
  FAISS_ASSERT(use_sel == (sel != nullptr));
@@ -174,14 +203,14 @@ void exhaustive_L2sqr_seq(
174
203
  }
175
204
 
176
205
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
177
- template <class ResultHandler>
206
+ template <class BlockResultHandler>
178
207
  void exhaustive_inner_product_blas(
179
208
  const float* x,
180
209
  const float* y,
181
210
  size_t d,
182
211
  size_t nx,
183
212
  size_t ny,
184
- ResultHandler& res) {
213
+ BlockResultHandler& res) {
185
214
  // BLAS does not like empty matrices
186
215
  if (nx == 0 || ny == 0)
187
216
  return;
@@ -230,14 +259,14 @@ void exhaustive_inner_product_blas(
230
259
 
231
260
  // distance correction is an operator that can be applied to transform
232
261
  // the distances
233
- template <class ResultHandler>
262
+ template <class BlockResultHandler>
234
263
  void exhaustive_L2sqr_blas_default_impl(
235
264
  const float* x,
236
265
  const float* y,
237
266
  size_t d,
238
267
  size_t nx,
239
268
  size_t ny,
240
- ResultHandler& res,
269
+ BlockResultHandler& res,
241
270
  const float* y_norms = nullptr) {
242
271
  // BLAS does not like empty matrices
243
272
  if (nx == 0 || ny == 0)
@@ -313,14 +342,14 @@ void exhaustive_L2sqr_blas_default_impl(
313
342
  }
314
343
  }
315
344
 
316
- template <class ResultHandler>
345
+ template <class BlockResultHandler>
317
346
  void exhaustive_L2sqr_blas(
318
347
  const float* x,
319
348
  const float* y,
320
349
  size_t d,
321
350
  size_t nx,
322
351
  size_t ny,
323
- ResultHandler& res,
352
+ BlockResultHandler& res,
324
353
  const float* y_norms = nullptr) {
325
354
  exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326
355
  }
@@ -332,7 +361,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
332
361
  size_t d,
333
362
  size_t nx,
334
363
  size_t ny,
335
- SingleBestResultHandler<CMax<float, int64_t>>& res,
364
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
336
365
  const float* y_norms) {
337
366
  // BLAS does not like empty matrices
338
367
  if (nx == 0 || ny == 0)
@@ -388,8 +417,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
388
417
  for (int64_t i = i0; i < i1; i++) {
389
418
  float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
390
419
 
391
- _mm_prefetch(ip_line, _MM_HINT_NTA);
392
- _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
420
+ _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
421
+ _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
393
422
 
394
423
  // constant
395
424
  const __m256 mul_minus2 = _mm256_set1_ps(-2);
@@ -416,8 +445,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
416
445
 
417
446
  // process 16 elements per loop
418
447
  for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
419
- _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
420
- _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
448
+ _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
449
+ _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
421
450
 
422
451
  // load values for norms
423
452
  const __m256 y_norm_0 =
@@ -535,13 +564,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
535
564
 
536
565
  // an override if only a single closest point is needed
537
566
  template <>
538
- void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
567
+ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
539
568
  const float* x,
540
569
  const float* y,
541
570
  size_t d,
542
571
  size_t nx,
543
572
  size_t ny,
544
- SingleBestResultHandler<CMax<float, int64_t>>& res,
573
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
545
574
  const float* y_norms) {
546
575
  #if defined(__AVX2__)
547
576
  // use a faster fused kernel if available
@@ -562,28 +591,29 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
562
591
 
563
592
  // run the default implementation
564
593
  exhaustive_L2sqr_blas_default_impl<
565
- SingleBestResultHandler<CMax<float, int64_t>>>(
594
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
566
595
  x, y, d, nx, ny, res, y_norms);
567
596
  #else
568
597
  // run the default implementation
569
598
  exhaustive_L2sqr_blas_default_impl<
570
- SingleBestResultHandler<CMax<float, int64_t>>>(
599
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
571
600
  x, y, d, nx, ny, res, y_norms);
572
601
  #endif
573
602
  }
574
603
 
575
- template <class ResultHandler>
604
+ template <class BlockResultHandler>
576
605
  void knn_L2sqr_select(
577
606
  const float* x,
578
607
  const float* y,
579
608
  size_t d,
580
609
  size_t nx,
581
610
  size_t ny,
582
- ResultHandler& res,
611
+ BlockResultHandler& res,
583
612
  const float* y_norm2,
584
613
  const IDSelector* sel) {
585
614
  if (sel) {
586
- exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
615
+ exhaustive_L2sqr_seq<BlockResultHandler, true>(
616
+ x, y, d, nx, ny, res, sel);
587
617
  } else if (nx < distance_compute_blas_threshold) {
588
618
  exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
589
619
  } else {
@@ -591,6 +621,25 @@ void knn_L2sqr_select(
591
621
  }
592
622
  }
593
623
 
624
+ template <class BlockResultHandler>
625
+ void knn_inner_product_select(
626
+ const float* x,
627
+ const float* y,
628
+ size_t d,
629
+ size_t nx,
630
+ size_t ny,
631
+ BlockResultHandler& res,
632
+ const IDSelector* sel) {
633
+ if (sel) {
634
+ exhaustive_inner_product_seq<BlockResultHandler, true>(
635
+ x, y, d, nx, ny, res, sel);
636
+ } else if (nx < distance_compute_blas_threshold) {
637
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
638
+ } else {
639
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
640
+ }
641
+ }
642
+
594
643
  } // anonymous namespace
595
644
 
596
645
  /*******************************************************
@@ -609,7 +658,7 @@ void knn_inner_product(
609
658
  size_t nx,
610
659
  size_t ny,
611
660
  size_t k,
612
- float* val,
661
+ float* vals,
613
662
  int64_t* ids,
614
663
  const IDSelector* sel) {
615
664
  int64_t imin = 0;
@@ -622,30 +671,21 @@ void knn_inner_product(
622
671
  }
623
672
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
624
673
  knn_inner_products_by_idx(
625
- x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
674
+ x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
626
675
  return;
627
676
  }
628
- if (k < distance_compute_min_k_reservoir) {
629
- using RH = HeapResultHandler<CMin<float, int64_t>>;
630
- RH res(nx, val, ids, k);
631
- if (sel) {
632
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
633
- } else if (nx < distance_compute_blas_threshold) {
634
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
635
- } else {
636
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
637
- }
677
+
678
+ if (k == 1) {
679
+ Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
680
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
681
+ } else if (k < distance_compute_min_k_reservoir) {
682
+ HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
683
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
638
684
  } else {
639
- using RH = ReservoirResultHandler<CMin<float, int64_t>>;
640
- RH res(nx, val, ids, k);
641
- if (sel) {
642
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
643
- } else if (nx < distance_compute_blas_threshold) {
644
- exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
645
- } else {
646
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
647
- }
685
+ ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
686
+ knn_inner_product_select(x, y, d, nx, ny, res, sel);
648
687
  }
688
+
649
689
  if (imin != 0) {
650
690
  for (size_t i = 0; i < nx * k; i++) {
651
691
  if (ids[i] >= 0) {
@@ -687,17 +727,17 @@ void knn_L2sqr(
687
727
  sel = nullptr;
688
728
  }
689
729
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
690
- knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
730
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
691
731
  return;
692
732
  }
693
733
  if (k == 1) {
694
- SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
734
+ Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
695
735
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
696
736
  } else if (k < distance_compute_min_k_reservoir) {
697
- HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
737
+ HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
698
738
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
699
739
  } else {
700
- ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
740
+ ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
701
741
  knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
702
742
  }
703
743
  if (imin != 0) {
@@ -735,7 +775,7 @@ void range_search_L2sqr(
735
775
  float radius,
736
776
  RangeSearchResult* res,
737
777
  const IDSelector* sel) {
738
- using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
778
+ using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
739
779
  RH resh(res, radius);
740
780
  if (sel) {
741
781
  exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
@@ -755,7 +795,7 @@ void range_search_inner_product(
755
795
  float radius,
756
796
  RangeSearchResult* res,
757
797
  const IDSelector* sel) {
758
- using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
798
+ using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
759
799
  RH resh(res, radius);
760
800
  if (sel) {
761
801
  exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
@@ -786,9 +826,11 @@ void fvec_inner_products_by_idx(
786
826
  const float* xj = x + j * d;
787
827
  float* __restrict ipj = ip + j * ny;
788
828
  for (size_t i = 0; i < ny; i++) {
789
- if (idsj[i] < 0)
790
- continue;
791
- ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
829
+ if (idsj[i] < 0) {
830
+ ipj[i] = -INFINITY;
831
+ } else {
832
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
833
+ }
792
834
  }
793
835
  }
794
836
  }
@@ -809,9 +851,11 @@ void fvec_L2sqr_by_idx(
809
851
  const float* xj = x + j * d;
810
852
  float* __restrict disj = dis + j * ny;
811
853
  for (size_t i = 0; i < ny; i++) {
812
- if (idsj[i] < 0)
813
- continue;
814
- disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
854
+ if (idsj[i] < 0) {
855
+ disj[i] = INFINITY;
856
+ } else {
857
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
858
+ }
815
859
  }
816
860
  }
817
861
  }
@@ -828,6 +872,8 @@ void pairwise_indexed_L2sqr(
828
872
  for (int64_t j = 0; j < n; j++) {
829
873
  if (ix[j] >= 0 && iy[j] >= 0) {
830
874
  dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
875
+ } else {
876
+ dis[j] = INFINITY;
831
877
  }
832
878
  }
833
879
  }
@@ -844,6 +890,8 @@ void pairwise_indexed_inner_product(
844
890
  for (int64_t j = 0; j < n; j++) {
845
891
  if (ix[j] >= 0 && iy[j] >= 0) {
846
892
  dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
893
+ } else {
894
+ dis[j] = -INFINITY;
847
895
  }
848
896
  }
849
897
  }
@@ -857,6 +905,7 @@ void knn_inner_products_by_idx(
857
905
  size_t d,
858
906
  size_t nx,
859
907
  size_t ny,
908
+ size_t nsubset,
860
909
  size_t k,
861
910
  float* res_vals,
862
911
  int64_t* res_ids,
@@ -874,9 +923,10 @@ void knn_inner_products_by_idx(
874
923
  int64_t* __restrict idxi = res_ids + i * k;
875
924
  minheap_heapify(k, simi, idxi);
876
925
 
877
- for (j = 0; j < ny; j++) {
878
- if (idsi[j] < 0)
926
+ for (j = 0; j < nsubset; j++) {
927
+ if (idsi[j] < 0 || idsi[j] >= ny) {
879
928
  break;
929
+ }
880
930
  float ip = fvec_inner_product(x_, y + d * idsi[j], d);
881
931
 
882
932
  if (ip > simi[0]) {
@@ -894,6 +944,7 @@ void knn_L2sqr_by_idx(
894
944
  size_t d,
895
945
  size_t nx,
896
946
  size_t ny,
947
+ size_t nsubset,
897
948
  size_t k,
898
949
  float* res_vals,
899
950
  int64_t* res_ids,
@@ -908,7 +959,10 @@ void knn_L2sqr_by_idx(
908
959
  float* __restrict simi = res_vals + i * k;
909
960
  int64_t* __restrict idxi = res_ids + i * k;
910
961
  maxheap_heapify(k, simi, idxi);
911
- for (size_t j = 0; j < ny; j++) {
962
+ for (size_t j = 0; j < nsubset; j++) {
963
+ if (idsi[j] < 0 || idsi[j] >= ny) {
964
+ break;
965
+ }
912
966
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
913
967
 
914
968
  if (disij < simi[0]) {
@@ -36,6 +36,34 @@ float fvec_L1(const float* x, const float* y, size_t d);
36
36
  /// infinity distance
37
37
  float fvec_Linf(const float* x, const float* y, size_t d);
38
38
 
39
+ /// Special version of inner product that computes 4 distances
40
+ /// between x and yi, which is performance oriented.
41
+ void fvec_inner_product_batch_4(
42
+ const float* x,
43
+ const float* y0,
44
+ const float* y1,
45
+ const float* y2,
46
+ const float* y3,
47
+ const size_t d,
48
+ float& dis0,
49
+ float& dis1,
50
+ float& dis2,
51
+ float& dis3);
52
+
53
+ /// Special version of L2sqr that computes 4 distances
54
+ /// between x and yi, which is performance oriented.
55
+ void fvec_L2sqr_batch_4(
56
+ const float* x,
57
+ const float* y0,
58
+ const float* y1,
59
+ const float* y2,
60
+ const float* y3,
61
+ const size_t d,
62
+ float& dis0,
63
+ float& dis1,
64
+ float& dis2,
65
+ float& dis3);
66
+
39
67
  /** Compute pairwise distances between sets of vectors
40
68
  *
41
69
  * @param d dimension of the vectors
@@ -170,8 +198,16 @@ void fvec_sub(size_t d, const float* a, const float* b, float* c);
170
198
  * Compute a subset of distances
171
199
  ***************************************************************************/
172
200
 
173
- /* compute the inner product between x and a subset y of ny vectors,
174
- whose indices are given by idy. */
201
+ /** compute the inner product between x and a subset y of ny vectors defined by
202
+ * ids
203
+ *
204
+ * ip(i, j) = inner_product(x(i, :), y(ids(i, j), :))
205
+ *
206
+ * @param ip output array, size nx * ny
207
+ * @param x first-term vector, size nx * d
208
+ * @param y second-term vector, size (max(ids) + 1) * d
209
+ * @param ids ids to sample from y, size nx * ny
210
+ */
175
211
  void fvec_inner_products_by_idx(
176
212
  float* ip,
177
213
  const float* x,
@@ -181,7 +217,16 @@ void fvec_inner_products_by_idx(
181
217
  size_t nx,
182
218
  size_t ny);
183
219
 
184
- /* same but for a subset in y indexed by idsy (ny vectors in total) */
220
+ /** compute the squared L2 distances between x and a subset y of ny vectors
221
+ * defined by ids
222
+ *
223
+ * dis(i, j) = inner_product(x(i, :), y(ids(i, j), :))
224
+ *
225
+ * @param dis output array, size nx * ny
226
+ * @param x first-term vector, size nx * d
227
+ * @param y second-term vector, size (max(ids) + 1) * d
228
+ * @param ids ids to sample from y, size nx * ny
229
+ */
185
230
  void fvec_L2sqr_by_idx(
186
231
  float* dis,
187
232
  const float* x,
@@ -208,7 +253,14 @@ void pairwise_indexed_L2sqr(
208
253
  const int64_t* iy,
209
254
  float* dis);
210
255
 
211
- /* same for inner product */
256
+ /** compute dis[j] = inner_product(x[ix[j]], y[iy[j]]) forall j=0..n-1
257
+ *
258
+ * @param x size (max(ix) + 1, d)
259
+ * @param y size (max(iy) + 1, d)
260
+ * @param ix size n
261
+ * @param iy size n
262
+ * @param dis size n
263
+ */
212
264
  void pairwise_indexed_inner_product(
213
265
  size_t d,
214
266
  size_t n,
@@ -324,6 +376,7 @@ void knn_inner_products_by_idx(
324
376
  const int64_t* subset,
325
377
  size_t d,
326
378
  size_t nx,
379
+ size_t ny,
327
380
  size_t nsubset,
328
381
  size_t k,
329
382
  float* vals,
@@ -346,6 +399,7 @@ void knn_L2sqr_by_idx(
346
399
  const int64_t* subset,
347
400
  size_t d,
348
401
  size_t nx,
402
+ size_t ny,
349
403
  size_t nsubset,
350
404
  size_t k,
351
405
  float* vals,
@@ -406,4 +460,27 @@ void compute_PQ_dis_tables_dsub2(
406
460
  * Templatized versions of distance functions
407
461
  ***************************************************************************/
408
462
 
463
+ /***************************************************************************
464
+ * Misc matrix and vector manipulation functions
465
+ ***************************************************************************/
466
+
467
+ /** compute c := a + bf * b for a, b and c tables
468
+ *
469
+ * @param n size of the tables
470
+ * @param a size n
471
+ * @param b size n
472
+ * @param c restult table, size n
473
+ */
474
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
475
+
476
+ /** same as fvec_madd, also return index of the min of the result table
477
+ * @return index of the min of table c
478
+ */
479
+ int fvec_madd_and_argmin(
480
+ size_t n,
481
+ const float* a,
482
+ float bf,
483
+ const float* b,
484
+ float* c);
485
+
409
486
  } // namespace faiss
@@ -9,7 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances_fused/avx512.h>
11
11
 
12
- #ifdef __AVX512__
12
+ #ifdef __AVX512F__
13
13
 
14
14
  #include <immintrin.h>
15
15
 
@@ -68,7 +68,7 @@ void kernel(
68
68
  const float* const __restrict y,
69
69
  const float* const __restrict y_transposed,
70
70
  size_t ny,
71
- SingleBestResultHandler<CMax<float, int64_t>>& res,
71
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
72
72
  const float* __restrict y_norms,
73
73
  size_t i) {
74
74
  const size_t ny_p =
@@ -231,7 +231,7 @@ void exhaustive_L2sqr_fused_cmax(
231
231
  const float* const __restrict y,
232
232
  size_t nx,
233
233
  size_t ny,
234
- SingleBestResultHandler<CMax<float, int64_t>>& res,
234
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
235
235
  const float* __restrict y_norms) {
236
236
  // BLAS does not like empty matrices
237
237
  if (nx == 0 || ny == 0) {
@@ -275,7 +275,7 @@ void exhaustive_L2sqr_fused_cmax(
275
275
  x, y, y_transposed.data(), ny, res, y_norms, i);
276
276
  }
277
277
 
278
- // Does nothing for SingleBestResultHandler, but
278
+ // Does nothing for Top1BlockResultHandler, but
279
279
  // keeping the call for the consistency.
280
280
  res.end_multiple();
281
281
  InterruptCallback::check();
@@ -289,7 +289,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
289
289
  size_t d,
290
290
  size_t nx,
291
291
  size_t ny,
292
- SingleBestResultHandler<CMax<float, int64_t>>& res,
292
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
293
293
  const float* y_norms) {
294
294
  // process only cases with certain dimensionalities
295
295
 
@@ -16,7 +16,7 @@
16
16
 
17
17
  #include <faiss/utils/Heap.h>
18
18
 
19
- #ifdef __AVX512__
19
+ #ifdef __AVX512F__
20
20
 
21
21
  namespace faiss {
22
22
 
@@ -28,7 +28,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
28
28
  size_t d,
29
29
  size_t nx,
30
30
  size_t ny,
31
- SingleBestResultHandler<CMax<float, int64_t>>& res,
31
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
32
32
  const float* y_norms);
33
33
 
34
34
  } // namespace faiss
@@ -20,14 +20,14 @@ bool exhaustive_L2sqr_fused_cmax(
20
20
  size_t d,
21
21
  size_t nx,
22
22
  size_t ny,
23
- SingleBestResultHandler<CMax<float, int64_t>>& res,
23
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
24
24
  const float* y_norms) {
25
25
  if (nx == 0 || ny == 0) {
26
26
  // nothing to do
27
27
  return true;
28
28
  }
29
29
 
30
- #ifdef __AVX512__
30
+ #ifdef __AVX512F__
31
31
  // avx512 kernel
32
32
  return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
33
33
  #elif defined(__AVX2__) || defined(__aarch64__)
@@ -34,7 +34,7 @@ bool exhaustive_L2sqr_fused_cmax(
34
34
  size_t d,
35
35
  size_t nx,
36
36
  size_t ny,
37
- SingleBestResultHandler<CMax<float, int64_t>>& res,
37
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
38
38
  const float* y_norms);
39
39
 
40
40
  } // namespace faiss