faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -17,8 +17,13 @@
17
17
 
18
18
  #include <omp.h>
19
19
 
20
+ #ifdef __AVX2__
21
+ #include <immintrin.h>
22
+ #endif
23
+
20
24
  #include <faiss/impl/AuxIndexStructures.h>
21
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
22
27
  #include <faiss/impl/ResultHandler.h>
23
28
 
24
29
  #ifndef FINTEGER
@@ -96,17 +101,21 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
96
101
  namespace {
97
102
 
98
103
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
99
- template <class ResultHandler>
104
+ template <class ResultHandler, bool use_sel = false>
100
105
  void exhaustive_inner_product_seq(
101
106
  const float* x,
102
107
  const float* y,
103
108
  size_t d,
104
109
  size_t nx,
105
110
  size_t ny,
106
- ResultHandler& res) {
111
+ ResultHandler& res,
112
+ const IDSelector* sel = nullptr) {
107
113
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
114
+ int nt = std::min(int(nx), omp_get_max_threads());
108
115
 
109
- #pragma omp parallel
116
+ FAISS_ASSERT(use_sel == (sel != nullptr));
117
+
118
+ #pragma omp parallel num_threads(nt)
110
119
  {
111
120
  SingleResultHandler resi(res);
112
121
  #pragma omp for
@@ -116,27 +125,33 @@ void exhaustive_inner_product_seq(
116
125
 
117
126
  resi.begin(i);
118
127
 
119
- for (size_t j = 0; j < ny; j++) {
128
+ for (size_t j = 0; j < ny; j++, y_j += d) {
129
+ if (use_sel && !sel->is_member(j)) {
130
+ continue;
131
+ }
120
132
  float ip = fvec_inner_product(x_i, y_j, d);
121
133
  resi.add_result(ip, j);
122
- y_j += d;
123
134
  }
124
135
  resi.end();
125
136
  }
126
137
  }
127
138
  }
128
139
 
129
- template <class ResultHandler>
140
+ template <class ResultHandler, bool use_sel = false>
130
141
  void exhaustive_L2sqr_seq(
131
142
  const float* x,
132
143
  const float* y,
133
144
  size_t d,
134
145
  size_t nx,
135
146
  size_t ny,
136
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const IDSelector* sel = nullptr) {
137
149
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
150
+ int nt = std::min(int(nx), omp_get_max_threads());
151
+
152
+ FAISS_ASSERT(use_sel == (sel != nullptr));
138
153
 
139
- #pragma omp parallel
154
+ #pragma omp parallel num_threads(nt)
140
155
  {
141
156
  SingleResultHandler resi(res);
142
157
  #pragma omp for
@@ -144,10 +159,12 @@ void exhaustive_L2sqr_seq(
144
159
  const float* x_i = x + i * d;
145
160
  const float* y_j = y;
146
161
  resi.begin(i);
147
- for (size_t j = 0; j < ny; j++) {
162
+ for (size_t j = 0; j < ny; j++, y_j += d) {
163
+ if (use_sel && !sel->is_member(j)) {
164
+ continue;
165
+ }
148
166
  float disij = fvec_L2sqr(x_i, y_j, d);
149
167
  resi.add_result(disij, j);
150
- y_j += d;
151
168
  }
152
169
  resi.end();
153
170
  }
@@ -294,6 +311,232 @@ void exhaustive_L2sqr_blas(
294
311
  }
295
312
  }
296
313
 
314
+ #ifdef __AVX2__
315
+ // an override for AVX2 if only a single closest point is needed.
316
+ template <>
317
+ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
318
+ const float* x,
319
+ const float* y,
320
+ size_t d,
321
+ size_t nx,
322
+ size_t ny,
323
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
324
+ const float* y_norms) {
325
+ // BLAS does not like empty matrices
326
+ if (nx == 0 || ny == 0)
327
+ return;
328
+
329
+ /* block sizes */
330
+ const size_t bs_x = distance_compute_blas_query_bs;
331
+ const size_t bs_y = distance_compute_blas_database_bs;
332
+ // const size_t bs_x = 16, bs_y = 16;
333
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
334
+ std::unique_ptr<float[]> x_norms(new float[nx]);
335
+ std::unique_ptr<float[]> del2;
336
+
337
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
338
+
339
+ if (!y_norms) {
340
+ float* y_norms2 = new float[ny];
341
+ del2.reset(y_norms2);
342
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
343
+ y_norms = y_norms2;
344
+ }
345
+
346
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
347
+ size_t i1 = i0 + bs_x;
348
+ if (i1 > nx)
349
+ i1 = nx;
350
+
351
+ res.begin_multiple(i0, i1);
352
+
353
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
354
+ size_t j1 = j0 + bs_y;
355
+ if (j1 > ny)
356
+ j1 = ny;
357
+ /* compute the actual dot products */
358
+ {
359
+ float one = 1, zero = 0;
360
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
361
+ sgemm_("Transpose",
362
+ "Not transpose",
363
+ &nyi,
364
+ &nxi,
365
+ &di,
366
+ &one,
367
+ y + j0 * d,
368
+ &di,
369
+ x + i0 * d,
370
+ &di,
371
+ &zero,
372
+ ip_block.get(),
373
+ &nyi);
374
+ }
375
+ #pragma omp parallel for
376
+ for (int64_t i = i0; i < i1; i++) {
377
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
378
+
379
+ _mm_prefetch(ip_line, _MM_HINT_NTA);
380
+ _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
381
+
382
+ // constant
383
+ const __m256 mul_minus2 = _mm256_set1_ps(-2);
384
+
385
+ // Track 8 min distances + 8 min indices.
386
+ // All the distances tracked do not take x_norms[i]
387
+ // into account in order to get rid of extra
388
+ // _mm256_add_ps(x_norms[i], ...) instructions
389
+ // is distance computations.
390
+ __m256 min_distances =
391
+ _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
392
+
393
+ // these indices are local and are relative to j0.
394
+ // so, value 0 means j0.
395
+ __m256i min_indices = _mm256_set1_epi32(0);
396
+
397
+ __m256i current_indices =
398
+ _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
399
+ const __m256i indices_delta = _mm256_set1_epi32(8);
400
+
401
+ // current j index
402
+ size_t idx_j = 0;
403
+ size_t count = j1 - j0;
404
+
405
+ // process 16 elements per loop
406
+ for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
407
+ _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
408
+ _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
409
+
410
+ // load values for norms
411
+ const __m256 y_norm_0 =
412
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
413
+ const __m256 y_norm_1 =
414
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
415
+
416
+ // load values for dot products
417
+ const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
418
+ const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
419
+
420
+ // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
421
+ // x_norm[i] was dropped off because it is a constant for a
422
+ // given i. We'll deal with it later.
423
+ __m256 distances_0 =
424
+ _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
425
+ __m256 distances_1 =
426
+ _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
427
+
428
+ // compare the new distances to the min distances
429
+ // for each of the first group of 8 AVX2 components.
430
+ const __m256 comparison_0 = _mm256_cmp_ps(
431
+ min_distances, distances_0, _CMP_LE_OS);
432
+
433
+ // update min distances and indices with closest vectors if
434
+ // needed.
435
+ min_distances = _mm256_blendv_ps(
436
+ distances_0, min_distances, comparison_0);
437
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
438
+ _mm256_castsi256_ps(current_indices),
439
+ _mm256_castsi256_ps(min_indices),
440
+ comparison_0));
441
+ current_indices =
442
+ _mm256_add_epi32(current_indices, indices_delta);
443
+
444
+ // compare the new distances to the min distances
445
+ // for each of the second group of 8 AVX2 components.
446
+ const __m256 comparison_1 = _mm256_cmp_ps(
447
+ min_distances, distances_1, _CMP_LE_OS);
448
+
449
+ // update min distances and indices with closest vectors if
450
+ // needed.
451
+ min_distances = _mm256_blendv_ps(
452
+ distances_1, min_distances, comparison_1);
453
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
454
+ _mm256_castsi256_ps(current_indices),
455
+ _mm256_castsi256_ps(min_indices),
456
+ comparison_1));
457
+ current_indices =
458
+ _mm256_add_epi32(current_indices, indices_delta);
459
+ }
460
+
461
+ // dump values and find the minimum distance / minimum index
462
+ float min_distances_scalar[8];
463
+ uint32_t min_indices_scalar[8];
464
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
465
+ _mm256_storeu_si256(
466
+ (__m256i*)(min_indices_scalar), min_indices);
467
+
468
+ float current_min_distance = res.dis_tab[i];
469
+ uint32_t current_min_index = res.ids_tab[i];
470
+
471
+ // This unusual comparison is needed to maintain the behavior
472
+ // of the original implementation: if two indices are
473
+ // represented with equal distance values, then
474
+ // the index with the min value is returned.
475
+ for (size_t jv = 0; jv < 8; jv++) {
476
+ // add missing x_norms[i]
477
+ float distance_candidate =
478
+ min_distances_scalar[jv] + x_norms[i];
479
+
480
+ // negative values can occur for identical vectors
481
+ // due to roundoff errors.
482
+ if (distance_candidate < 0)
483
+ distance_candidate = 0;
484
+
485
+ int64_t index_candidate = min_indices_scalar[jv] + j0;
486
+
487
+ if (current_min_distance > distance_candidate) {
488
+ current_min_distance = distance_candidate;
489
+ current_min_index = index_candidate;
490
+ } else if (
491
+ current_min_distance == distance_candidate &&
492
+ current_min_index > index_candidate) {
493
+ current_min_index = index_candidate;
494
+ }
495
+ }
496
+
497
+ // process leftovers
498
+ for (; idx_j < count; idx_j++, ip_line++) {
499
+ float ip = *ip_line;
500
+ float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
501
+ // negative values can occur for identical vectors
502
+ // due to roundoff errors.
503
+ if (dis < 0)
504
+ dis = 0;
505
+
506
+ if (current_min_distance > dis) {
507
+ current_min_distance = dis;
508
+ current_min_index = idx_j + j0;
509
+ }
510
+ }
511
+
512
+ //
513
+ res.add_result(i, current_min_distance, current_min_index);
514
+ }
515
+ }
516
+ InterruptCallback::check();
517
+ }
518
+ }
519
+ #endif
520
+
521
+ template <class ResultHandler>
522
+ void knn_L2sqr_select(
523
+ const float* x,
524
+ const float* y,
525
+ size_t d,
526
+ size_t nx,
527
+ size_t ny,
528
+ ResultHandler& res,
529
+ const float* y_norm2,
530
+ const IDSelector* sel) {
531
+ if (sel) {
532
+ exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
533
+ } else if (nx < distance_compute_blas_threshold) {
534
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
535
+ } else {
536
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
537
+ }
538
+ }
539
+
297
540
  } // anonymous namespace
298
541
 
299
542
  /*******************************************************
@@ -311,24 +554,63 @@ void knn_inner_product(
311
554
  size_t d,
312
555
  size_t nx,
313
556
  size_t ny,
314
- float_minheap_array_t* ha) {
315
- if (ha->k < distance_compute_min_k_reservoir) {
316
- HeapResultHandler<CMin<float, int64_t>> res(
317
- ha->nh, ha->val, ha->ids, ha->k);
318
- if (nx < distance_compute_blas_threshold) {
557
+ size_t k,
558
+ float* val,
559
+ int64_t* ids,
560
+ const IDSelector* sel) {
561
+ int64_t imin = 0;
562
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
563
+ imin = std::max(selr->imin, int64_t(0));
564
+ int64_t imax = std::min(selr->imax, int64_t(ny));
565
+ ny = imax - imin;
566
+ y += d * imin;
567
+ sel = nullptr;
568
+ }
569
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
570
+ knn_inner_products_by_idx(
571
+ x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
572
+ return;
573
+ }
574
+ if (k < distance_compute_min_k_reservoir) {
575
+ using RH = HeapResultHandler<CMin<float, int64_t>>;
576
+ RH res(nx, val, ids, k);
577
+ if (sel) {
578
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
579
+ } else if (nx < distance_compute_blas_threshold) {
319
580
  exhaustive_inner_product_seq(x, y, d, nx, ny, res);
320
581
  } else {
321
582
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
322
583
  }
323
584
  } else {
324
- ReservoirResultHandler<CMin<float, int64_t>> res(
325
- ha->nh, ha->val, ha->ids, ha->k);
326
- if (nx < distance_compute_blas_threshold) {
327
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
585
+ using RH = ReservoirResultHandler<CMin<float, int64_t>>;
586
+ RH res(nx, val, ids, k);
587
+ if (sel) {
588
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
589
+ } else if (nx < distance_compute_blas_threshold) {
590
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
328
591
  } else {
329
592
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
330
593
  }
331
594
  }
595
+ if (imin != 0) {
596
+ for (size_t i = 0; i < nx * k; i++) {
597
+ if (ids[i] >= 0) {
598
+ ids[i] += imin;
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ void knn_inner_product(
605
+ const float* x,
606
+ const float* y,
607
+ size_t d,
608
+ size_t nx,
609
+ size_t ny,
610
+ float_minheap_array_t* res,
611
+ const IDSelector* sel) {
612
+ FAISS_THROW_IF_NOT(nx == res->nh);
613
+ knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
332
614
  }
333
615
 
334
616
  void knn_L2sqr(
@@ -337,28 +619,55 @@ void knn_L2sqr(
337
619
  size_t d,
338
620
  size_t nx,
339
621
  size_t ny,
340
- float_maxheap_array_t* ha,
341
- const float* y_norm2) {
342
- if (ha->k < distance_compute_min_k_reservoir) {
343
- HeapResultHandler<CMax<float, int64_t>> res(
344
- ha->nh, ha->val, ha->ids, ha->k);
345
-
346
- if (nx < distance_compute_blas_threshold) {
347
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
348
- } else {
349
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
350
- }
622
+ size_t k,
623
+ float* vals,
624
+ int64_t* ids,
625
+ const float* y_norm2,
626
+ const IDSelector* sel) {
627
+ int64_t imin = 0;
628
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
629
+ imin = std::max(selr->imin, int64_t(0));
630
+ int64_t imax = std::min(selr->imax, int64_t(ny));
631
+ ny = imax - imin;
632
+ y += d * imin;
633
+ sel = nullptr;
634
+ }
635
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
636
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
637
+ return;
638
+ }
639
+ if (k == 1) {
640
+ SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
641
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
642
+ } else if (k < distance_compute_min_k_reservoir) {
643
+ HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
644
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
351
645
  } else {
352
- ReservoirResultHandler<CMax<float, int64_t>> res(
353
- ha->nh, ha->val, ha->ids, ha->k);
354
- if (nx < distance_compute_blas_threshold) {
355
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
356
- } else {
357
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
646
+ ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
647
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
648
+ }
649
+ if (imin != 0) {
650
+ for (size_t i = 0; i < nx * k; i++) {
651
+ if (ids[i] >= 0) {
652
+ ids[i] += imin;
653
+ }
358
654
  }
359
655
  }
360
656
  }
361
657
 
658
+ void knn_L2sqr(
659
+ const float* x,
660
+ const float* y,
661
+ size_t d,
662
+ size_t nx,
663
+ size_t ny,
664
+ float_maxheap_array_t* res,
665
+ const float* y_norm2,
666
+ const IDSelector* sel) {
667
+ FAISS_THROW_IF_NOT(res->nh == nx);
668
+ knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
669
+ }
670
+
362
671
  /***************************************************************************
363
672
  * Range search
364
673
  ***************************************************************************/
@@ -370,10 +679,14 @@ void range_search_L2sqr(
370
679
  size_t nx,
371
680
  size_t ny,
372
681
  float radius,
373
- RangeSearchResult* res) {
374
- RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
375
- if (nx < distance_compute_blas_threshold) {
376
- exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
682
+ RangeSearchResult* res,
683
+ const IDSelector* sel) {
684
+ using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
685
+ RH resh(res, radius);
686
+ if (sel) {
687
+ exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
688
+ } else if (nx < distance_compute_blas_threshold) {
689
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
377
690
  } else {
378
691
  exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
379
692
  }
@@ -386,9 +699,13 @@ void range_search_inner_product(
386
699
  size_t nx,
387
700
  size_t ny,
388
701
  float radius,
389
- RangeSearchResult* res) {
390
- RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
391
- if (nx < distance_compute_blas_threshold) {
702
+ RangeSearchResult* res,
703
+ const IDSelector* sel) {
704
+ using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
705
+ RH resh(res, radius);
706
+ if (sel) {
707
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
708
+ } else if (nx < distance_compute_blas_threshold) {
392
709
  exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
393
710
  } else {
394
711
  exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
@@ -486,16 +803,21 @@ void knn_inner_products_by_idx(
486
803
  size_t d,
487
804
  size_t nx,
488
805
  size_t ny,
489
- float_minheap_array_t* res) {
490
- size_t k = res->k;
806
+ size_t k,
807
+ float* res_vals,
808
+ int64_t* res_ids,
809
+ int64_t ld_ids) {
810
+ if (ld_ids < 0) {
811
+ ld_ids = ny;
812
+ }
491
813
 
492
- #pragma omp parallel for
814
+ #pragma omp parallel for if (nx > 100)
493
815
  for (int64_t i = 0; i < nx; i++) {
494
816
  const float* x_ = x + i * d;
495
- const int64_t* idsi = ids + i * ny;
817
+ const int64_t* idsi = ids + i * ld_ids;
496
818
  size_t j;
497
- float* __restrict simi = res->get_val(i);
498
- int64_t* __restrict idxi = res->get_ids(i);
819
+ float* __restrict simi = res_vals + i * k;
820
+ int64_t* __restrict idxi = res_ids + i * k;
499
821
  minheap_heapify(k, simi, idxi);
500
822
 
501
823
  for (j = 0; j < ny; j++) {
@@ -518,16 +840,20 @@ void knn_L2sqr_by_idx(
518
840
  size_t d,
519
841
  size_t nx,
520
842
  size_t ny,
521
- float_maxheap_array_t* res) {
522
- size_t k = res->k;
523
-
524
- #pragma omp parallel for
843
+ size_t k,
844
+ float* res_vals,
845
+ int64_t* res_ids,
846
+ int64_t ld_ids) {
847
+ if (ld_ids < 0) {
848
+ ld_ids = ny;
849
+ }
850
+ #pragma omp parallel for if (nx > 100)
525
851
  for (int64_t i = 0; i < nx; i++) {
526
852
  const float* x_ = x + i * d;
527
- const int64_t* __restrict idsi = ids + i * ny;
528
- float* __restrict simi = res->get_val(i);
529
- int64_t* __restrict idxi = res->get_ids(i);
530
- maxheap_heapify(res->k, simi, idxi);
853
+ const int64_t* __restrict idsi = ids + i * ld_ids;
854
+ float* __restrict simi = res_vals + i * k;
855
+ int64_t* __restrict idxi = res_ids + i * k;
856
+ maxheap_heapify(k, simi, idxi);
531
857
  for (size_t j = 0; j < ny; j++) {
532
858
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
533
859
 
@@ -535,7 +861,7 @@ void knn_L2sqr_by_idx(
535
861
  maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
536
862
  }
537
863
  }
538
- maxheap_reorder(res->k, simi, idxi);
864
+ maxheap_reorder(k, simi, idxi);
539
865
  }
540
866
  }
541
867