faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -17,8 +17,13 @@
17
17
 
18
18
  #include <omp.h>
19
19
 
20
+ #ifdef __AVX2__
21
+ #include <immintrin.h>
22
+ #endif
23
+
20
24
  #include <faiss/impl/AuxIndexStructures.h>
21
25
  #include <faiss/impl/FaissAssert.h>
26
+ #include <faiss/impl/IDSelector.h>
22
27
  #include <faiss/impl/ResultHandler.h>
23
28
 
24
29
  #ifndef FINTEGER
@@ -96,17 +101,20 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
96
101
  namespace {
97
102
 
98
103
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
99
- template <class ResultHandler>
104
+ template <class ResultHandler, bool use_sel = false>
100
105
  void exhaustive_inner_product_seq(
101
106
  const float* x,
102
107
  const float* y,
103
108
  size_t d,
104
109
  size_t nx,
105
110
  size_t ny,
106
- ResultHandler& res) {
111
+ ResultHandler& res,
112
+ const IDSelector* sel = nullptr) {
107
113
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
108
114
  int nt = std::min(int(nx), omp_get_max_threads());
109
115
 
116
+ FAISS_ASSERT(use_sel == (sel != nullptr));
117
+
110
118
  #pragma omp parallel num_threads(nt)
111
119
  {
112
120
  SingleResultHandler resi(res);
@@ -117,27 +125,32 @@ void exhaustive_inner_product_seq(
117
125
 
118
126
  resi.begin(i);
119
127
 
120
- for (size_t j = 0; j < ny; j++) {
128
+ for (size_t j = 0; j < ny; j++, y_j += d) {
129
+ if (use_sel && !sel->is_member(j)) {
130
+ continue;
131
+ }
121
132
  float ip = fvec_inner_product(x_i, y_j, d);
122
133
  resi.add_result(ip, j);
123
- y_j += d;
124
134
  }
125
135
  resi.end();
126
136
  }
127
137
  }
128
138
  }
129
139
 
130
- template <class ResultHandler>
140
+ template <class ResultHandler, bool use_sel = false>
131
141
  void exhaustive_L2sqr_seq(
132
142
  const float* x,
133
143
  const float* y,
134
144
  size_t d,
135
145
  size_t nx,
136
146
  size_t ny,
137
- ResultHandler& res) {
147
+ ResultHandler& res,
148
+ const IDSelector* sel = nullptr) {
138
149
  using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
150
  int nt = std::min(int(nx), omp_get_max_threads());
140
151
 
152
+ FAISS_ASSERT(use_sel == (sel != nullptr));
153
+
141
154
  #pragma omp parallel num_threads(nt)
142
155
  {
143
156
  SingleResultHandler resi(res);
@@ -146,10 +159,12 @@ void exhaustive_L2sqr_seq(
146
159
  const float* x_i = x + i * d;
147
160
  const float* y_j = y;
148
161
  resi.begin(i);
149
- for (size_t j = 0; j < ny; j++) {
162
+ for (size_t j = 0; j < ny; j++, y_j += d) {
163
+ if (use_sel && !sel->is_member(j)) {
164
+ continue;
165
+ }
150
166
  float disij = fvec_L2sqr(x_i, y_j, d);
151
167
  resi.add_result(disij, j);
152
- y_j += d;
153
168
  }
154
169
  resi.end();
155
170
  }
@@ -296,6 +311,232 @@ void exhaustive_L2sqr_blas(
296
311
  }
297
312
  }
298
313
 
314
+ #ifdef __AVX2__
315
+ // an override for AVX2 if only a single closest point is needed.
316
+ template <>
317
+ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
318
+ const float* x,
319
+ const float* y,
320
+ size_t d,
321
+ size_t nx,
322
+ size_t ny,
323
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
324
+ const float* y_norms) {
325
+ // BLAS does not like empty matrices
326
+ if (nx == 0 || ny == 0)
327
+ return;
328
+
329
+ /* block sizes */
330
+ const size_t bs_x = distance_compute_blas_query_bs;
331
+ const size_t bs_y = distance_compute_blas_database_bs;
332
+ // const size_t bs_x = 16, bs_y = 16;
333
+ std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
334
+ std::unique_ptr<float[]> x_norms(new float[nx]);
335
+ std::unique_ptr<float[]> del2;
336
+
337
+ fvec_norms_L2sqr(x_norms.get(), x, d, nx);
338
+
339
+ if (!y_norms) {
340
+ float* y_norms2 = new float[ny];
341
+ del2.reset(y_norms2);
342
+ fvec_norms_L2sqr(y_norms2, y, d, ny);
343
+ y_norms = y_norms2;
344
+ }
345
+
346
+ for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
347
+ size_t i1 = i0 + bs_x;
348
+ if (i1 > nx)
349
+ i1 = nx;
350
+
351
+ res.begin_multiple(i0, i1);
352
+
353
+ for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
354
+ size_t j1 = j0 + bs_y;
355
+ if (j1 > ny)
356
+ j1 = ny;
357
+ /* compute the actual dot products */
358
+ {
359
+ float one = 1, zero = 0;
360
+ FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
361
+ sgemm_("Transpose",
362
+ "Not transpose",
363
+ &nyi,
364
+ &nxi,
365
+ &di,
366
+ &one,
367
+ y + j0 * d,
368
+ &di,
369
+ x + i0 * d,
370
+ &di,
371
+ &zero,
372
+ ip_block.get(),
373
+ &nyi);
374
+ }
375
+ #pragma omp parallel for
376
+ for (int64_t i = i0; i < i1; i++) {
377
+ float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
378
+
379
+ _mm_prefetch(ip_line, _MM_HINT_NTA);
380
+ _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
381
+
382
+ // constant
383
+ const __m256 mul_minus2 = _mm256_set1_ps(-2);
384
+
385
+ // Track 8 min distances + 8 min indices.
386
+ // All the distances tracked do not take x_norms[i]
387
+ // into account in order to get rid of extra
388
+ // _mm256_add_ps(x_norms[i], ...) instructions
389
+ // is distance computations.
390
+ __m256 min_distances =
391
+ _mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
392
+
393
+ // these indices are local and are relative to j0.
394
+ // so, value 0 means j0.
395
+ __m256i min_indices = _mm256_set1_epi32(0);
396
+
397
+ __m256i current_indices =
398
+ _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
399
+ const __m256i indices_delta = _mm256_set1_epi32(8);
400
+
401
+ // current j index
402
+ size_t idx_j = 0;
403
+ size_t count = j1 - j0;
404
+
405
+ // process 16 elements per loop
406
+ for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
407
+ _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
408
+ _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
409
+
410
+ // load values for norms
411
+ const __m256 y_norm_0 =
412
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 0);
413
+ const __m256 y_norm_1 =
414
+ _mm256_loadu_ps(y_norms + idx_j + j0 + 8);
415
+
416
+ // load values for dot products
417
+ const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
418
+ const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
419
+
420
+ // compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
421
+ // x_norm[i] was dropped off because it is a constant for a
422
+ // given i. We'll deal with it later.
423
+ __m256 distances_0 =
424
+ _mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
425
+ __m256 distances_1 =
426
+ _mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
427
+
428
+ // compare the new distances to the min distances
429
+ // for each of the first group of 8 AVX2 components.
430
+ const __m256 comparison_0 = _mm256_cmp_ps(
431
+ min_distances, distances_0, _CMP_LE_OS);
432
+
433
+ // update min distances and indices with closest vectors if
434
+ // needed.
435
+ min_distances = _mm256_blendv_ps(
436
+ distances_0, min_distances, comparison_0);
437
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
438
+ _mm256_castsi256_ps(current_indices),
439
+ _mm256_castsi256_ps(min_indices),
440
+ comparison_0));
441
+ current_indices =
442
+ _mm256_add_epi32(current_indices, indices_delta);
443
+
444
+ // compare the new distances to the min distances
445
+ // for each of the second group of 8 AVX2 components.
446
+ const __m256 comparison_1 = _mm256_cmp_ps(
447
+ min_distances, distances_1, _CMP_LE_OS);
448
+
449
+ // update min distances and indices with closest vectors if
450
+ // needed.
451
+ min_distances = _mm256_blendv_ps(
452
+ distances_1, min_distances, comparison_1);
453
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
454
+ _mm256_castsi256_ps(current_indices),
455
+ _mm256_castsi256_ps(min_indices),
456
+ comparison_1));
457
+ current_indices =
458
+ _mm256_add_epi32(current_indices, indices_delta);
459
+ }
460
+
461
+ // dump values and find the minimum distance / minimum index
462
+ float min_distances_scalar[8];
463
+ uint32_t min_indices_scalar[8];
464
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
465
+ _mm256_storeu_si256(
466
+ (__m256i*)(min_indices_scalar), min_indices);
467
+
468
+ float current_min_distance = res.dis_tab[i];
469
+ uint32_t current_min_index = res.ids_tab[i];
470
+
471
+ // This unusual comparison is needed to maintain the behavior
472
+ // of the original implementation: if two indices are
473
+ // represented with equal distance values, then
474
+ // the index with the min value is returned.
475
+ for (size_t jv = 0; jv < 8; jv++) {
476
+ // add missing x_norms[i]
477
+ float distance_candidate =
478
+ min_distances_scalar[jv] + x_norms[i];
479
+
480
+ // negative values can occur for identical vectors
481
+ // due to roundoff errors.
482
+ if (distance_candidate < 0)
483
+ distance_candidate = 0;
484
+
485
+ int64_t index_candidate = min_indices_scalar[jv] + j0;
486
+
487
+ if (current_min_distance > distance_candidate) {
488
+ current_min_distance = distance_candidate;
489
+ current_min_index = index_candidate;
490
+ } else if (
491
+ current_min_distance == distance_candidate &&
492
+ current_min_index > index_candidate) {
493
+ current_min_index = index_candidate;
494
+ }
495
+ }
496
+
497
+ // process leftovers
498
+ for (; idx_j < count; idx_j++, ip_line++) {
499
+ float ip = *ip_line;
500
+ float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
501
+ // negative values can occur for identical vectors
502
+ // due to roundoff errors.
503
+ if (dis < 0)
504
+ dis = 0;
505
+
506
+ if (current_min_distance > dis) {
507
+ current_min_distance = dis;
508
+ current_min_index = idx_j + j0;
509
+ }
510
+ }
511
+
512
+ //
513
+ res.add_result(i, current_min_distance, current_min_index);
514
+ }
515
+ }
516
+ InterruptCallback::check();
517
+ }
518
+ }
519
+ #endif
520
+
521
+ template <class ResultHandler>
522
+ void knn_L2sqr_select(
523
+ const float* x,
524
+ const float* y,
525
+ size_t d,
526
+ size_t nx,
527
+ size_t ny,
528
+ ResultHandler& res,
529
+ const float* y_norm2,
530
+ const IDSelector* sel) {
531
+ if (sel) {
532
+ exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
533
+ } else if (nx < distance_compute_blas_threshold) {
534
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
535
+ } else {
536
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
537
+ }
538
+ }
539
+
299
540
  } // anonymous namespace
300
541
 
301
542
  /*******************************************************
@@ -313,24 +554,63 @@ void knn_inner_product(
313
554
  size_t d,
314
555
  size_t nx,
315
556
  size_t ny,
316
- float_minheap_array_t* ha) {
317
- if (ha->k < distance_compute_min_k_reservoir) {
318
- HeapResultHandler<CMin<float, int64_t>> res(
319
- ha->nh, ha->val, ha->ids, ha->k);
320
- if (nx < distance_compute_blas_threshold) {
557
+ size_t k,
558
+ float* val,
559
+ int64_t* ids,
560
+ const IDSelector* sel) {
561
+ int64_t imin = 0;
562
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
563
+ imin = std::max(selr->imin, int64_t(0));
564
+ int64_t imax = std::min(selr->imax, int64_t(ny));
565
+ ny = imax - imin;
566
+ y += d * imin;
567
+ sel = nullptr;
568
+ }
569
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
570
+ knn_inner_products_by_idx(
571
+ x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
572
+ return;
573
+ }
574
+ if (k < distance_compute_min_k_reservoir) {
575
+ using RH = HeapResultHandler<CMin<float, int64_t>>;
576
+ RH res(nx, val, ids, k);
577
+ if (sel) {
578
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
579
+ } else if (nx < distance_compute_blas_threshold) {
321
580
  exhaustive_inner_product_seq(x, y, d, nx, ny, res);
322
581
  } else {
323
582
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
324
583
  }
325
584
  } else {
326
- ReservoirResultHandler<CMin<float, int64_t>> res(
327
- ha->nh, ha->val, ha->ids, ha->k);
328
- if (nx < distance_compute_blas_threshold) {
329
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
585
+ using RH = ReservoirResultHandler<CMin<float, int64_t>>;
586
+ RH res(nx, val, ids, k);
587
+ if (sel) {
588
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
589
+ } else if (nx < distance_compute_blas_threshold) {
590
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
330
591
  } else {
331
592
  exhaustive_inner_product_blas(x, y, d, nx, ny, res);
332
593
  }
333
594
  }
595
+ if (imin != 0) {
596
+ for (size_t i = 0; i < nx * k; i++) {
597
+ if (ids[i] >= 0) {
598
+ ids[i] += imin;
599
+ }
600
+ }
601
+ }
602
+ }
603
+
604
+ void knn_inner_product(
605
+ const float* x,
606
+ const float* y,
607
+ size_t d,
608
+ size_t nx,
609
+ size_t ny,
610
+ float_minheap_array_t* res,
611
+ const IDSelector* sel) {
612
+ FAISS_THROW_IF_NOT(nx == res->nh);
613
+ knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
334
614
  }
335
615
 
336
616
  void knn_L2sqr(
@@ -339,28 +619,55 @@ void knn_L2sqr(
339
619
  size_t d,
340
620
  size_t nx,
341
621
  size_t ny,
342
- float_maxheap_array_t* ha,
343
- const float* y_norm2) {
344
- if (ha->k < distance_compute_min_k_reservoir) {
345
- HeapResultHandler<CMax<float, int64_t>> res(
346
- ha->nh, ha->val, ha->ids, ha->k);
347
-
348
- if (nx < distance_compute_blas_threshold) {
349
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
350
- } else {
351
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
352
- }
622
+ size_t k,
623
+ float* vals,
624
+ int64_t* ids,
625
+ const float* y_norm2,
626
+ const IDSelector* sel) {
627
+ int64_t imin = 0;
628
+ if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
629
+ imin = std::max(selr->imin, int64_t(0));
630
+ int64_t imax = std::min(selr->imax, int64_t(ny));
631
+ ny = imax - imin;
632
+ y += d * imin;
633
+ sel = nullptr;
634
+ }
635
+ if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
636
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
637
+ return;
638
+ }
639
+ if (k == 1) {
640
+ SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
641
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
642
+ } else if (k < distance_compute_min_k_reservoir) {
643
+ HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
644
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
353
645
  } else {
354
- ReservoirResultHandler<CMax<float, int64_t>> res(
355
- ha->nh, ha->val, ha->ids, ha->k);
356
- if (nx < distance_compute_blas_threshold) {
357
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
358
- } else {
359
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
646
+ ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
647
+ knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
648
+ }
649
+ if (imin != 0) {
650
+ for (size_t i = 0; i < nx * k; i++) {
651
+ if (ids[i] >= 0) {
652
+ ids[i] += imin;
653
+ }
360
654
  }
361
655
  }
362
656
  }
363
657
 
658
+ void knn_L2sqr(
659
+ const float* x,
660
+ const float* y,
661
+ size_t d,
662
+ size_t nx,
663
+ size_t ny,
664
+ float_maxheap_array_t* res,
665
+ const float* y_norm2,
666
+ const IDSelector* sel) {
667
+ FAISS_THROW_IF_NOT(res->nh == nx);
668
+ knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
669
+ }
670
+
364
671
  /***************************************************************************
365
672
  * Range search
366
673
  ***************************************************************************/
@@ -372,10 +679,14 @@ void range_search_L2sqr(
372
679
  size_t nx,
373
680
  size_t ny,
374
681
  float radius,
375
- RangeSearchResult* res) {
376
- RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
377
- if (nx < distance_compute_blas_threshold) {
378
- exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
682
+ RangeSearchResult* res,
683
+ const IDSelector* sel) {
684
+ using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
685
+ RH resh(res, radius);
686
+ if (sel) {
687
+ exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
688
+ } else if (nx < distance_compute_blas_threshold) {
689
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
379
690
  } else {
380
691
  exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
381
692
  }
@@ -388,9 +699,13 @@ void range_search_inner_product(
388
699
  size_t nx,
389
700
  size_t ny,
390
701
  float radius,
391
- RangeSearchResult* res) {
392
- RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
393
- if (nx < distance_compute_blas_threshold) {
702
+ RangeSearchResult* res,
703
+ const IDSelector* sel) {
704
+ using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
705
+ RH resh(res, radius);
706
+ if (sel) {
707
+ exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
708
+ } else if (nx < distance_compute_blas_threshold) {
394
709
  exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
395
710
  } else {
396
711
  exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
@@ -488,16 +803,21 @@ void knn_inner_products_by_idx(
488
803
  size_t d,
489
804
  size_t nx,
490
805
  size_t ny,
491
- float_minheap_array_t* res) {
492
- size_t k = res->k;
806
+ size_t k,
807
+ float* res_vals,
808
+ int64_t* res_ids,
809
+ int64_t ld_ids) {
810
+ if (ld_ids < 0) {
811
+ ld_ids = ny;
812
+ }
493
813
 
494
- #pragma omp parallel for
814
+ #pragma omp parallel for if (nx > 100)
495
815
  for (int64_t i = 0; i < nx; i++) {
496
816
  const float* x_ = x + i * d;
497
- const int64_t* idsi = ids + i * ny;
817
+ const int64_t* idsi = ids + i * ld_ids;
498
818
  size_t j;
499
- float* __restrict simi = res->get_val(i);
500
- int64_t* __restrict idxi = res->get_ids(i);
819
+ float* __restrict simi = res_vals + i * k;
820
+ int64_t* __restrict idxi = res_ids + i * k;
501
821
  minheap_heapify(k, simi, idxi);
502
822
 
503
823
  for (j = 0; j < ny; j++) {
@@ -520,16 +840,20 @@ void knn_L2sqr_by_idx(
520
840
  size_t d,
521
841
  size_t nx,
522
842
  size_t ny,
523
- float_maxheap_array_t* res) {
524
- size_t k = res->k;
525
-
526
- #pragma omp parallel for
843
+ size_t k,
844
+ float* res_vals,
845
+ int64_t* res_ids,
846
+ int64_t ld_ids) {
847
+ if (ld_ids < 0) {
848
+ ld_ids = ny;
849
+ }
850
+ #pragma omp parallel for if (nx > 100)
527
851
  for (int64_t i = 0; i < nx; i++) {
528
852
  const float* x_ = x + i * d;
529
- const int64_t* __restrict idsi = ids + i * ny;
530
- float* __restrict simi = res->get_val(i);
531
- int64_t* __restrict idxi = res->get_ids(i);
532
- maxheap_heapify(res->k, simi, idxi);
853
+ const int64_t* __restrict idsi = ids + i * ld_ids;
854
+ float* __restrict simi = res_vals + i * k;
855
+ int64_t* __restrict idxi = res_ids + i * k;
856
+ maxheap_heapify(k, simi, idxi);
533
857
  for (size_t j = 0; j < ny; j++) {
534
858
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
535
859
 
@@ -537,7 +861,7 @@ void knn_L2sqr_by_idx(
537
861
  maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
538
862
  }
539
863
  }
540
- maxheap_reorder(res->k, simi, idxi);
864
+ maxheap_reorder(k, simi, idxi);
541
865
  }
542
866
  }
543
867