faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -17,8 +17,13 @@
17
17
 
18
18
  #include <omp.h>
19
19
 
20
+ #ifdef __AVX2__
21
+ #include <immintrin.h>
22
+ #endif
23
+
20
24
  #include <faiss/impl/AuxIndexStructures.h>
21
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
22
27
  #include <faiss/impl/ResultHandler.h>
23
28
 
24
29
  #ifndef FINTEGER
@@ -96,17 +101,20 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
96
101
  namespace {
97
102
 
98
103
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
99
- template <class ResultHandler>
104
+ template <class ResultHandler, bool use_sel = false>
100
105
  void exhaustive_inner_product_seq(
101
106
  const float* x,
102
107
  const float* y,
103
108
  size_t d,
104
109
  size_t nx,
105
110
  size_t ny,
106
- ResultHandler& res) {
111
+ ResultHandler& res,
112
+ const IDSelector* sel = nullptr) {
107
113
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
114
  int nt = std::min(int(nx), omp_get_max_threads());
109
115
 
116
+ FAISS_ASSERT(use_sel == (sel != nullptr));
117
+
110
118
  #pragma omp parallel num_threads(nt)
111
119
  {
112
120
  SingleResultHandler resi(res);
@@ -117,27 +125,32 @@ void exhaustive_inner_product_seq(
117
125
 
118
126
  resi.begin(i);
119
127
 
120
- for (size_t j = 0; j < ny; j++) {
128
+ for (size_t j = 0; j < ny; j++, y_j += d) {
129
+ if (use_sel && !sel->is_member(j)) {
130
+ continue;
131
+ }
121
132
  float ip = fvec_inner_product(x_i, y_j, d);
122
133
  resi.add_result(ip, j);
123
- y_j += d;
124
134
  }
125
135
  resi.end();
126
136
  }
127
137
  }
128
138
  }
129
139
 
130
- template <class ResultHandler>
140
+ template <class ResultHandler, bool use_sel = false>
131
141
  void exhaustive_L2sqr_seq(
132
142
  const float* x,
133
143
  const float* y,
134
144
  size_t d,
135
145
  size_t nx,
136
146
  size_t ny,
137
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const IDSelector* sel = nullptr) {
138
149
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
150
  int nt = std::min(int(nx), omp_get_max_threads());
140
151
 
152
+ FAISS_ASSERT(use_sel == (sel != nullptr));
153
+
141
154
  #pragma omp parallel num_threads(nt)
142
155
  {
143
156
  SingleResultHandler resi(res);
@@ -146,10 +159,12 @@ void exhaustive_L2sqr_seq(
146
159
  const float* x_i = x + i * d;
147
160
  const float* y_j = y;
148
161
  resi.begin(i);
149
- for (size_t j = 0; j < ny; j++) {
162
+ for (size_t j = 0; j < ny; j++, y_j += d) {
163
+ if (use_sel && !sel->is_member(j)) {
164
+ continue;
165
+ }
150
166
  float disij = fvec_L2sqr(x_i, y_j, d);
151
167
  resi.add_result(disij, j);
152
- y_j += d;
153
168
  }
154
169
  resi.end();
155
170
  }
@@ -296,6 +311,232 @@ void exhaustive_L2sqr_blas(
296
311
  }
297
312
  }
298
313
 
314
+ #ifdef __AVX2__
315
+ // an override for AVX2 if only a single closest point is needed.
316
+ template <>
317
+ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
318
+ const float* x,
319
+ const float* y,
320
+ size_t d,
321
+ size_t nx,
322
+ size_t ny,
323
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
324
+ const float* y_norms) {
325
+ // BLAS does not like empty matrices
326
+ if (nx == 0 || ny == 0)
327
+ return;
328
+
329
+ /* block sizes */
330
+ const size_t bs_x = distance_compute_blas_query_bs;
331
+ const size_t bs_y = distance_compute_blas_database_bs;
332
+ // const size_t bs_x = 16, bs_y = 16;
333
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
334
+ std::unique_ptr<float[]> x_norms(new float[nx]);
335
+ std::unique_ptr<float[]> del2;
336
+
337
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
338
+
339
+ if (!y_norms) {
340
+ float* y_norms2 = new float[ny];
341
+ del2.reset(y_norms2);
342
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
343
+ y_norms = y_norms2;
344
+ }
345
+
346
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
347
+ size_t i1 = i0 + bs_x;
348
+ if (i1 > nx)
349
+ i1 = nx;
350
+
351
+ res.begin_multiple(i0, i1);
352
+
353
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
354
+ size_t j1 = j0 + bs_y;
355
+ if (j1 > ny)
356
+ j1 = ny;
357
+ /* compute the actual dot products */
358
+ {
359
+ float one = 1, zero = 0;
360
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
361
+ sgemm_("Transpose",
362
+ "Not transpose",
363
+ &nyi,
364
+ &nxi,
365
+ &di,
366
+ &one,
367
+ y + j0 * d,
368
+ &di,
369
+ x + i0 * d,
370
+ &di,
371
+ &zero,
372
+ ip_block.get(),
373
+ &nyi);
374
+ }
375
+ #pragma omp parallel for
376
+ for (int64_t i = i0; i < i1; i++) {
377
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
378
+
379
+ _mm_prefetch(ip_line, _MM_HINT_NTA);
380
+ _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
381
+
382
+ // constant
383
+ const __m256 mul_minus2 = _mm256_set1_ps(-2);
384
+
385
+ // Track 8 min distances + 8 min indices.
386
+ // All the distances tracked do not take x_norms[i]
387
+ // into account in order to get rid of extra
388
+ // _mm256_add_ps(x_norms[i], ...) instructions
389
+ // is distance computations.
390
+ __m256 min_distances =
391
+ _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
392
+
393
+ // these indices are local and are relative to j0.
394
+ // so, value 0 means j0.
395
+ __m256i min_indices = _mm256_set1_epi32(0);
396
+
397
+ __m256i current_indices =
398
+ _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
399
+ const __m256i indices_delta = _mm256_set1_epi32(8);
400
+
401
+ // current j index
402
+ size_t idx_j = 0;
403
+ size_t count = j1 - j0;
404
+
405
+ // process 16 elements per loop
406
+ for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
407
+ _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
408
+ _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
409
+
410
+ // load values for norms
411
+ const __m256 y_norm_0 =
412
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
413
+ const __m256 y_norm_1 =
414
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
415
+
416
+ // load values for dot products
417
+ const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
418
+ const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
419
+
420
+ // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
421
+ // x_norm[i] was dropped off because it is a constant for a
422
+ // given i. We'll deal with it later.
423
+ __m256 distances_0 =
424
+ _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
425
+ __m256 distances_1 =
426
+ _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
427
+
428
+ // compare the new distances to the min distances
429
+ // for each of the first group of 8 AVX2 components.
430
+ const __m256 comparison_0 = _mm256_cmp_ps(
431
+ min_distances, distances_0, _CMP_LE_OS);
432
+
433
+ // update min distances and indices with closest vectors if
434
+ // needed.
435
+ min_distances = _mm256_blendv_ps(
436
+ distances_0, min_distances, comparison_0);
437
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
438
+ _mm256_castsi256_ps(current_indices),
439
+ _mm256_castsi256_ps(min_indices),
440
+ comparison_0));
441
+ current_indices =
442
+ _mm256_add_epi32(current_indices, indices_delta);
443
+
444
+ // compare the new distances to the min distances
445
+ // for each of the second group of 8 AVX2 components.
446
+ const __m256 comparison_1 = _mm256_cmp_ps(
447
+ min_distances, distances_1, _CMP_LE_OS);
448
+
449
+ // update min distances and indices with closest vectors if
450
+ // needed.
451
+ min_distances = _mm256_blendv_ps(
452
+ distances_1, min_distances, comparison_1);
453
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
454
+ _mm256_castsi256_ps(current_indices),
455
+ _mm256_castsi256_ps(min_indices),
456
+ comparison_1));
457
+ current_indices =
458
+ _mm256_add_epi32(current_indices, indices_delta);
459
+ }
460
+
461
+ // dump values and find the minimum distance / minimum index
462
+ float min_distances_scalar[8];
463
+ uint32_t min_indices_scalar[8];
464
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
465
+ _mm256_storeu_si256(
466
+ (__m256i*)(min_indices_scalar), min_indices);
467
+
468
+ float current_min_distance = res.dis_tab[i];
469
+ uint32_t current_min_index = res.ids_tab[i];
470
+
471
+ // This unusual comparison is needed to maintain the behavior
472
+ // of the original implementation: if two indices are
473
+ // represented with equal distance values, then
474
+ // the index with the min value is returned.
475
+ for (size_t jv = 0; jv < 8; jv++) {
476
+ // add missing x_norms[i]
477
+ float distance_candidate =
478
+ min_distances_scalar[jv] + x_norms[i];
479
+
480
+ // negative values can occur for identical vectors
481
+ // due to roundoff errors.
482
+ if (distance_candidate < 0)
483
+ distance_candidate = 0;
484
+
485
+ int64_t index_candidate = min_indices_scalar[jv] + j0;
486
+
487
+ if (current_min_distance > distance_candidate) {
488
+ current_min_distance = distance_candidate;
489
+ current_min_index = index_candidate;
490
+ } else if (
491
+ current_min_distance == distance_candidate &&
492
+ current_min_index > index_candidate) {
493
+ current_min_index = index_candidate;
494
+ }
495
+ }
496
+
497
+ // process leftovers
498
+ for (; idx_j < count; idx_j++, ip_line++) {
499
+ float ip = *ip_line;
500
+ float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
501
+ // negative values can occur for identical vectors
502
+ // due to roundoff errors.
503
+ if (dis < 0)
504
+ dis = 0;
505
+
506
+ if (current_min_distance > dis) {
507
+ current_min_distance = dis;
508
+ current_min_index = idx_j + j0;
509
+ }
510
+ }
511
+
512
+ //
513
+ res.add_result(i, current_min_distance, current_min_index);
514
+ }
515
+ }
516
+ InterruptCallback::check();
517
+ }
518
+ }
519
+ #endif
520
+
521
+ template <class ResultHandler>
522
+ void knn_L2sqr_select(
523
+ const float* x,
524
+ const float* y,
525
+ size_t d,
526
+ size_t nx,
527
+ size_t ny,
528
+ ResultHandler& res,
529
+ const float* y_norm2,
530
+ const IDSelector* sel) {
531
+ if (sel) {
532
+ exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
533
+ } else if (nx < distance_compute_blas_threshold) {
534
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
535
+ } else {
536
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
537
+ }
538
+ }
539
+
299
540
  } // anonymous namespace
300
541
 
301
542
  /*******************************************************
@@ -313,24 +554,63 @@ void knn_inner_product(
313
554
  size_t d,
314
555
  size_t nx,
315
556
  size_t ny,
316
- float_minheap_array_t* ha) {
317
- if (ha->k < distance_compute_min_k_reservoir) {
318
- HeapResultHandler<CMin<float, int64_t>> res(
319
- ha->nh, ha->val, ha->ids, ha->k);
320
- if (nx < distance_compute_blas_threshold) {
557
+ size_t k,
558
+ float* val,
559
+ int64_t* ids,
560
+ const IDSelector* sel) {
561
+ int64_t imin = 0;
562
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
563
+ imin = std::max(selr->imin, int64_t(0));
564
+ int64_t imax = std::min(selr->imax, int64_t(ny));
565
+ ny = imax - imin;
566
+ y += d * imin;
567
+ sel = nullptr;
568
+ }
569
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
570
+ knn_inner_products_by_idx(
571
+ x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
572
+ return;
573
+ }
574
+ if (k < distance_compute_min_k_reservoir) {
575
+ using RH = HeapResultHandler<CMin<float, int64_t>>;
576
+ RH res(nx, val, ids, k);
577
+ if (sel) {
578
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
579
+ } else if (nx < distance_compute_blas_threshold) {
321
580
  exhaustive_inner_product_seq(x, y, d, nx, ny, res);
322
581
  } else {
323
582
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
324
583
  }
325
584
  } else {
326
- ReservoirResultHandler<CMin<float, int64_t>> res(
327
- ha->nh, ha->val, ha->ids, ha->k);
328
- if (nx < distance_compute_blas_threshold) {
329
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
585
+ using RH = ReservoirResultHandler<CMin<float, int64_t>>;
586
+ RH res(nx, val, ids, k);
587
+ if (sel) {
588
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
589
+ } else if (nx < distance_compute_blas_threshold) {
590
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
330
591
  } else {
331
592
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
332
593
  }
333
594
  }
595
+ if (imin != 0) {
596
+ for (size_t i = 0; i < nx * k; i++) {
597
+ if (ids[i] >= 0) {
598
+ ids[i] += imin;
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ void knn_inner_product(
605
+ const float* x,
606
+ const float* y,
607
+ size_t d,
608
+ size_t nx,
609
+ size_t ny,
610
+ float_minheap_array_t* res,
611
+ const IDSelector* sel) {
612
+ FAISS_THROW_IF_NOT(nx == res->nh);
613
+ knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
334
614
  }
335
615
 
336
616
  void knn_L2sqr(
@@ -339,28 +619,55 @@ void knn_L2sqr(
339
619
  size_t d,
340
620
  size_t nx,
341
621
  size_t ny,
342
- float_maxheap_array_t* ha,
343
- const float* y_norm2) {
344
- if (ha->k < distance_compute_min_k_reservoir) {
345
- HeapResultHandler<CMax<float, int64_t>> res(
346
- ha->nh, ha->val, ha->ids, ha->k);
347
-
348
- if (nx < distance_compute_blas_threshold) {
349
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
350
- } else {
351
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
352
- }
622
+ size_t k,
623
+ float* vals,
624
+ int64_t* ids,
625
+ const float* y_norm2,
626
+ const IDSelector* sel) {
627
+ int64_t imin = 0;
628
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
629
+ imin = std::max(selr->imin, int64_t(0));
630
+ int64_t imax = std::min(selr->imax, int64_t(ny));
631
+ ny = imax - imin;
632
+ y += d * imin;
633
+ sel = nullptr;
634
+ }
635
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
636
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
637
+ return;
638
+ }
639
+ if (k == 1) {
640
+ SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
641
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
642
+ } else if (k < distance_compute_min_k_reservoir) {
643
+ HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
644
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
353
645
  } else {
354
- ReservoirResultHandler<CMax<float, int64_t>> res(
355
- ha->nh, ha->val, ha->ids, ha->k);
356
- if (nx < distance_compute_blas_threshold) {
357
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
358
- } else {
359
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
646
+ ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
647
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
648
+ }
649
+ if (imin != 0) {
650
+ for (size_t i = 0; i < nx * k; i++) {
651
+ if (ids[i] >= 0) {
652
+ ids[i] += imin;
653
+ }
360
654
  }
361
655
  }
362
656
  }
363
657
 
658
+ void knn_L2sqr(
659
+ const float* x,
660
+ const float* y,
661
+ size_t d,
662
+ size_t nx,
663
+ size_t ny,
664
+ float_maxheap_array_t* res,
665
+ const float* y_norm2,
666
+ const IDSelector* sel) {
667
+ FAISS_THROW_IF_NOT(res->nh == nx);
668
+ knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
669
+ }
670
+
364
671
  /***************************************************************************
365
672
  * Range search
366
673
  ***************************************************************************/
@@ -372,10 +679,14 @@ void range_search_L2sqr(
372
679
  size_t nx,
373
680
  size_t ny,
374
681
  float radius,
375
- RangeSearchResult* res) {
376
- RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
377
- if (nx < distance_compute_blas_threshold) {
378
- exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
682
+ RangeSearchResult* res,
683
+ const IDSelector* sel) {
684
+ using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
685
+ RH resh(res, radius);
686
+ if (sel) {
687
+ exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
688
+ } else if (nx < distance_compute_blas_threshold) {
689
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
379
690
  } else {
380
691
  exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
381
692
  }
@@ -388,9 +699,13 @@ void range_search_inner_product(
388
699
  size_t nx,
389
700
  size_t ny,
390
701
  float radius,
391
- RangeSearchResult* res) {
392
- RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
393
- if (nx < distance_compute_blas_threshold) {
702
+ RangeSearchResult* res,
703
+ const IDSelector* sel) {
704
+ using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
705
+ RH resh(res, radius);
706
+ if (sel) {
707
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
708
+ } else if (nx < distance_compute_blas_threshold) {
394
709
  exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
395
710
  } else {
396
711
  exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
@@ -488,16 +803,21 @@ void knn_inner_products_by_idx(
488
803
  size_t d,
489
804
  size_t nx,
490
805
  size_t ny,
491
- float_minheap_array_t* res) {
492
- size_t k = res->k;
806
+ size_t k,
807
+ float* res_vals,
808
+ int64_t* res_ids,
809
+ int64_t ld_ids) {
810
+ if (ld_ids < 0) {
811
+ ld_ids = ny;
812
+ }
493
813
 
494
- #pragma omp parallel for
814
+ #pragma omp parallel for if (nx > 100)
495
815
  for (int64_t i = 0; i < nx; i++) {
496
816
  const float* x_ = x + i * d;
497
- const int64_t* idsi = ids + i * ny;
817
+ const int64_t* idsi = ids + i * ld_ids;
498
818
  size_t j;
499
- float* __restrict simi = res->get_val(i);
500
- int64_t* __restrict idxi = res->get_ids(i);
819
+ float* __restrict simi = res_vals + i * k;
820
+ int64_t* __restrict idxi = res_ids + i * k;
501
821
  minheap_heapify(k, simi, idxi);
502
822
 
503
823
  for (j = 0; j < ny; j++) {
@@ -520,16 +840,20 @@ void knn_L2sqr_by_idx(
520
840
  size_t d,
521
841
  size_t nx,
522
842
  size_t ny,
523
- float_maxheap_array_t* res) {
524
- size_t k = res->k;
525
-
526
- #pragma omp parallel for
843
+ size_t k,
844
+ float* res_vals,
845
+ int64_t* res_ids,
846
+ int64_t ld_ids) {
847
+ if (ld_ids < 0) {
848
+ ld_ids = ny;
849
+ }
850
+ #pragma omp parallel for if (nx > 100)
527
851
  for (int64_t i = 0; i < nx; i++) {
528
852
  const float* x_ = x + i * d;
529
- const int64_t* __restrict idsi = ids + i * ny;
530
- float* __restrict simi = res->get_val(i);
531
- int64_t* __restrict idxi = res->get_ids(i);
532
- maxheap_heapify(res->k, simi, idxi);
853
+ const int64_t* __restrict idsi = ids + i * ld_ids;
854
+ float* __restrict simi = res_vals + i * k;
855
+ int64_t* __restrict idxi = res_ids + i * k;
856
+ maxheap_heapify(k, simi, idxi);
533
857
  for (size_t j = 0; j < ny; j++) {
534
858
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
535
859
 
@@ -537,7 +861,7 @@ void knn_L2sqr_by_idx(
537
861
  maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
538
862
  }
539
863
  }
540
- maxheap_reorder(res->k, simi, idxi);
864
+ maxheap_reorder(k, simi, idxi);
541
865
  }
542
866
  }
543
867