faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -5,13 +5,12 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/utils/distances.h>
11
9
 
12
10
  #include <algorithm>
13
11
  #include <cassert>
14
12
  #include <cmath>
13
+ #include <cstddef>
15
14
  #include <cstdio>
16
15
  #include <cstring>
17
16
 
@@ -64,7 +63,7 @@ void fvec_norms_L2(
64
63
  const float* __restrict x,
65
64
  size_t d,
66
65
  size_t nx) {
67
- #pragma omp parallel for
66
+ #pragma omp parallel for if (nx > 10000)
68
67
  for (int64_t i = 0; i < nx; i++) {
69
68
  nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
70
69
  }
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
75
74
  const float* __restrict x,
76
75
  size_t d,
77
76
  size_t nx) {
78
- #pragma omp parallel for
77
+ #pragma omp parallel for if (nx > 10000)
79
78
  for (int64_t i = 0; i < nx; i++)
80
79
  nr[i] = fvec_norm_L2sqr(x + i * d, d);
81
80
  }
82
81
 
83
- void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
84
- #pragma omp parallel for
82
+ // The following is a workaround to a problem
83
+ // in OpenMP in fbcode. The crash occurs
84
+ // inside OMP when IndexIVFSpectralHash::set_query()
85
+ // calls fvec_renorm_L2. set_query() is always
86
+ // calling this function with nx == 1, so even
87
+ // the omp version should run single threaded,
88
+ // as per the if condition of the omp pragma.
89
+ // Instead, the omp version crashes inside OMP.
90
+ // The workaround below is explicitly branching
91
+ // off to a codepath without omp.
92
+
93
+ #define FVEC_RENORM_L2_IMPL \
94
+ float* __restrict xi = x + i * d; \
95
+ \
96
+ float nr = fvec_norm_L2sqr(xi, d); \
97
+ \
98
+ if (nr > 0) { \
99
+ size_t j; \
100
+ const float inv_nr = 1.0 / sqrtf(nr); \
101
+ for (j = 0; j < d; j++) \
102
+ xi[j] *= inv_nr; \
103
+ }
104
+
105
+ void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
85
106
  for (int64_t i = 0; i < nx; i++) {
86
- float* __restrict xi = x + i * d;
107
+ FVEC_RENORM_L2_IMPL
108
+ }
109
+ }
87
110
 
88
- float nr = fvec_norm_L2sqr(xi, d);
111
+ void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
112
+ #pragma omp parallel for if (nx > 10000)
113
+ for (int64_t i = 0; i < nx; i++) {
114
+ FVEC_RENORM_L2_IMPL
115
+ }
116
+ }
89
117
 
90
- if (nr > 0) {
91
- size_t j;
92
- const float inv_nr = 1.0 / sqrtf(nr);
93
- for (j = 0; j < d; j++)
94
- xi[j] *= inv_nr;
95
- }
118
+ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
119
+ if (nx <= 10000) {
120
+ fvec_renorm_L2_noomp(d, nx, x);
121
+ } else {
122
+ fvec_renorm_L2_omp(d, nx, x);
96
123
  }
97
124
  }
98
125
 
@@ -103,19 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
103
130
  namespace {
104
131
 
105
132
  /* Find the nearest neighbors for nx queries in a set of ny vectors */
106
- template <class ResultHandler, bool use_sel = false>
133
+ template <class BlockResultHandler>
107
134
  void exhaustive_inner_product_seq(
108
135
  const float* x,
109
136
  const float* y,
110
137
  size_t d,
111
138
  size_t nx,
112
139
  size_t ny,
113
- ResultHandler& res,
114
- const IDSelector* sel = nullptr) {
115
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
116
- int nt = std::min(int(nx), omp_get_max_threads());
117
-
118
- FAISS_ASSERT(use_sel == (sel != nullptr));
140
+ BlockResultHandler& res) {
141
+ using SingleResultHandler =
142
+ typename BlockResultHandler::SingleResultHandler;
143
+ [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
119
144
 
120
145
  #pragma omp parallel num_threads(nt)
121
146
  {
@@ -128,7 +153,7 @@ void exhaustive_inner_product_seq(
128
153
  resi.begin(i);
129
154
 
130
155
  for (size_t j = 0; j < ny; j++, y_j += d) {
131
- if (use_sel && !sel->is_member(j)) {
156
+ if (!res.is_in_selection(j)) {
132
157
  continue;
133
158
  }
134
159
  float ip = fvec_inner_product(x_i, y_j, d);
@@ -139,19 +164,17 @@ void exhaustive_inner_product_seq(
139
164
  }
140
165
  }
141
166
 
142
- template <class ResultHandler, bool use_sel = false>
167
+ template <class BlockResultHandler>
143
168
  void exhaustive_L2sqr_seq(
144
169
  const float* x,
145
170
  const float* y,
146
171
  size_t d,
147
172
  size_t nx,
148
173
  size_t ny,
149
- ResultHandler& res,
150
- const IDSelector* sel = nullptr) {
151
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
152
- int nt = std::min(int(nx), omp_get_max_threads());
153
-
154
- FAISS_ASSERT(use_sel == (sel != nullptr));
174
+ BlockResultHandler& res) {
175
+ using SingleResultHandler =
176
+ typename BlockResultHandler::SingleResultHandler;
177
+ [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
155
178
 
156
179
  #pragma omp parallel num_threads(nt)
157
180
  {
@@ -162,7 +185,7 @@ void exhaustive_L2sqr_seq(
162
185
  const float* y_j = y;
163
186
  resi.begin(i);
164
187
  for (size_t j = 0; j < ny; j++, y_j += d) {
165
- if (use_sel && !sel->is_member(j)) {
188
+ if (!res.is_in_selection(j)) {
166
189
  continue;
167
190
  }
168
191
  float disij = fvec_L2sqr(x_i, y_j, d);
@@ -174,14 +197,14 @@ void exhaustive_L2sqr_seq(
174
197
  }
175
198
 
176
199
  /** Find the nearest neighbors for nx queries in a set of ny vectors */
177
- template <class ResultHandler>
200
+ template <class BlockResultHandler>
178
201
  void exhaustive_inner_product_blas(
179
202
  const float* x,
180
203
  const float* y,
181
204
  size_t d,
182
205
  size_t nx,
183
206
  size_t ny,
184
- ResultHandler& res) {
207
+ BlockResultHandler& res) {
185
208
  // BLAS does not like empty matrices
186
209
  if (nx == 0 || ny == 0)
187
210
  return;
@@ -230,14 +253,14 @@ void exhaustive_inner_product_blas(
230
253
 
231
254
  // distance correction is an operator that can be applied to transform
232
255
  // the distances
233
- template <class ResultHandler>
256
+ template <class BlockResultHandler>
234
257
  void exhaustive_L2sqr_blas_default_impl(
235
258
  const float* x,
236
259
  const float* y,
237
260
  size_t d,
238
261
  size_t nx,
239
262
  size_t ny,
240
- ResultHandler& res,
263
+ BlockResultHandler& res,
241
264
  const float* y_norms = nullptr) {
242
265
  // BLAS does not like empty matrices
243
266
  if (nx == 0 || ny == 0)
@@ -297,6 +320,9 @@ void exhaustive_L2sqr_blas_default_impl(
297
320
  float ip = *ip_line;
298
321
  float dis = x_norms[i] + y_norms[j] - 2 * ip;
299
322
 
323
+ if (!res.is_in_selection(j)) {
324
+ dis = HUGE_VALF;
325
+ }
300
326
  // negative values can occur for identical vectors
301
327
  // due to roundoff errors
302
328
  if (dis < 0)
@@ -313,14 +339,14 @@ void exhaustive_L2sqr_blas_default_impl(
313
339
  }
314
340
  }
315
341
 
316
- template <class ResultHandler>
342
+ template <class BlockResultHandler>
317
343
  void exhaustive_L2sqr_blas(
318
344
  const float* x,
319
345
  const float* y,
320
346
  size_t d,
321
347
  size_t nx,
322
348
  size_t ny,
323
- ResultHandler& res,
349
+ BlockResultHandler& res,
324
350
  const float* y_norms = nullptr) {
325
351
  exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326
352
  }
@@ -332,7 +358,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
332
358
  size_t d,
333
359
  size_t nx,
334
360
  size_t ny,
335
- SingleBestResultHandler<CMax<float, int64_t>>& res,
361
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
336
362
  const float* y_norms) {
337
363
  // BLAS does not like empty matrices
338
364
  if (nx == 0 || ny == 0)
@@ -388,8 +414,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
388
414
  for (int64_t i = i0; i < i1; i++) {
389
415
  float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
390
416
 
391
- _mm_prefetch(ip_line, _MM_HINT_NTA);
392
- _mm_prefetch(ip_line + 16, _MM_HINT_NTA);
417
+ _mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
418
+ _mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
393
419
 
394
420
  // constant
395
421
  const __m256 mul_minus2 = _mm256_set1_ps(-2);
@@ -416,8 +442,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
416
442
 
417
443
  // process 16 elements per loop
418
444
  for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
419
- _mm_prefetch(ip_line + 32, _MM_HINT_NTA);
420
- _mm_prefetch(ip_line + 48, _MM_HINT_NTA);
445
+ _mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
446
+ _mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
421
447
 
422
448
  // load values for norms
423
449
  const __m256 y_norm_0 =
@@ -535,13 +561,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
535
561
 
536
562
  // an override if only a single closest point is needed
537
563
  template <>
538
- void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
564
+ void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
539
565
  const float* x,
540
566
  const float* y,
541
567
  size_t d,
542
568
  size_t nx,
543
569
  size_t ny,
544
- SingleBestResultHandler<CMax<float, int64_t>>& res,
570
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
545
571
  const float* y_norms) {
546
572
  #if defined(__AVX2__)
547
573
  // use a faster fused kernel if available
@@ -562,34 +588,50 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
562
588
 
563
589
  // run the default implementation
564
590
  exhaustive_L2sqr_blas_default_impl<
565
- SingleBestResultHandler<CMax<float, int64_t>>>(
591
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
566
592
  x, y, d, nx, ny, res, y_norms);
567
593
  #else
568
594
  // run the default implementation
569
595
  exhaustive_L2sqr_blas_default_impl<
570
- SingleBestResultHandler<CMax<float, int64_t>>>(
596
+ Top1BlockResultHandler<CMax<float, int64_t>>>(
571
597
  x, y, d, nx, ny, res, y_norms);
572
598
  #endif
573
599
  }
574
600
 
575
- template <class ResultHandler>
576
- void knn_L2sqr_select(
577
- const float* x,
578
- const float* y,
579
- size_t d,
580
- size_t nx,
581
- size_t ny,
582
- ResultHandler& res,
583
- const float* y_norm2,
584
- const IDSelector* sel) {
585
- if (sel) {
586
- exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
587
- } else if (nx < distance_compute_blas_threshold) {
588
- exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
589
- } else {
590
- exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
601
+ struct Run_search_inner_product {
602
+ using T = void;
603
+ template <class BlockResultHandler>
604
+ void f(BlockResultHandler& res,
605
+ const float* x,
606
+ const float* y,
607
+ size_t d,
608
+ size_t nx,
609
+ size_t ny) {
610
+ if (res.sel || nx < distance_compute_blas_threshold) {
611
+ exhaustive_inner_product_seq(x, y, d, nx, ny, res);
612
+ } else {
613
+ exhaustive_inner_product_blas(x, y, d, nx, ny, res);
614
+ }
591
615
  }
592
- }
616
+ };
617
+
618
+ struct Run_search_L2sqr {
619
+ using T = void;
620
+ template <class BlockResultHandler>
621
+ void f(BlockResultHandler& res,
622
+ const float* x,
623
+ const float* y,
624
+ size_t d,
625
+ size_t nx,
626
+ size_t ny,
627
+ const float* y_norm2) {
628
+ if (res.sel || nx < distance_compute_blas_threshold) {
629
+ exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
630
+ } else {
631
+ exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
632
+ }
633
+ }
634
+ };
593
635
 
594
636
  } // anonymous namespace
595
637
 
@@ -609,7 +651,7 @@ void knn_inner_product(
609
651
  size_t nx,
610
652
  size_t ny,
611
653
  size_t k,
612
- float* val,
654
+ float* vals,
613
655
  int64_t* ids,
614
656
  const IDSelector* sel) {
615
657
  int64_t imin = 0;
@@ -622,30 +664,14 @@ void knn_inner_product(
622
664
  }
623
665
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
624
666
  knn_inner_products_by_idx(
625
- x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
667
+ x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
626
668
  return;
627
669
  }
628
- if (k < distance_compute_min_k_reservoir) {
629
- using RH = HeapResultHandler<CMin<float, int64_t>>;
630
- RH res(nx, val, ids, k);
631
- if (sel) {
632
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
633
- } else if (nx < distance_compute_blas_threshold) {
634
- exhaustive_inner_product_seq(x, y, d, nx, ny, res);
635
- } else {
636
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
637
- }
638
- } else {
639
- using RH = ReservoirResultHandler<CMin<float, int64_t>>;
640
- RH res(nx, val, ids, k);
641
- if (sel) {
642
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
643
- } else if (nx < distance_compute_blas_threshold) {
644
- exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
645
- } else {
646
- exhaustive_inner_product_blas(x, y, d, nx, ny, res);
647
- }
648
- }
670
+
671
+ Run_search_inner_product r;
672
+ dispatch_knn_ResultHandler(
673
+ nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
674
+
649
675
  if (imin != 0) {
650
676
  for (size_t i = 0; i < nx * k; i++) {
651
677
  if (ids[i] >= 0) {
@@ -687,19 +713,14 @@ void knn_L2sqr(
687
713
  sel = nullptr;
688
714
  }
689
715
  if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
690
- knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
716
+ knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
691
717
  return;
692
718
  }
693
- if (k == 1) {
694
- SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
695
- knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
696
- } else if (k < distance_compute_min_k_reservoir) {
697
- HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
698
- knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
699
- } else {
700
- ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
701
- knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
702
- }
719
+
720
+ Run_search_L2sqr r;
721
+ dispatch_knn_ResultHandler(
722
+ nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
723
+
703
724
  if (imin != 0) {
704
725
  for (size_t i = 0; i < nx * k; i++) {
705
726
  if (ids[i] >= 0) {
@@ -726,6 +747,7 @@ void knn_L2sqr(
726
747
  * Range search
727
748
  ***************************************************************************/
728
749
 
750
+ // TODO accept a y_norm2 as well
729
751
  void range_search_L2sqr(
730
752
  const float* x,
731
753
  const float* y,
@@ -735,15 +757,9 @@ void range_search_L2sqr(
735
757
  float radius,
736
758
  RangeSearchResult* res,
737
759
  const IDSelector* sel) {
738
- using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
739
- RH resh(res, radius);
740
- if (sel) {
741
- exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
742
- } else if (nx < distance_compute_blas_threshold) {
743
- exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
744
- } else {
745
- exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
746
- }
760
+ Run_search_L2sqr r;
761
+ dispatch_range_ResultHandler(
762
+ res, radius, METRIC_L2, sel, r, x, y, d, nx, ny, nullptr);
747
763
  }
748
764
 
749
765
  void range_search_inner_product(
@@ -755,15 +771,9 @@ void range_search_inner_product(
755
771
  float radius,
756
772
  RangeSearchResult* res,
757
773
  const IDSelector* sel) {
758
- using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
759
- RH resh(res, radius);
760
- if (sel) {
761
- exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
762
- } else if (nx < distance_compute_blas_threshold) {
763
- exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
764
- } else {
765
- exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
766
- }
774
+ Run_search_inner_product r;
775
+ dispatch_range_ResultHandler(
776
+ res, radius, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
767
777
  }
768
778
 
769
779
  /***************************************************************************
@@ -786,9 +796,11 @@ void fvec_inner_products_by_idx(
786
796
  const float* xj = x + j * d;
787
797
  float* __restrict ipj = ip + j * ny;
788
798
  for (size_t i = 0; i < ny; i++) {
789
- if (idsj[i] < 0)
790
- continue;
791
- ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
799
+ if (idsj[i] < 0) {
800
+ ipj[i] = -INFINITY;
801
+ } else {
802
+ ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
803
+ }
792
804
  }
793
805
  }
794
806
  }
@@ -809,9 +821,11 @@ void fvec_L2sqr_by_idx(
809
821
  const float* xj = x + j * d;
810
822
  float* __restrict disj = dis + j * ny;
811
823
  for (size_t i = 0; i < ny; i++) {
812
- if (idsj[i] < 0)
813
- continue;
814
- disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
824
+ if (idsj[i] < 0) {
825
+ disj[i] = INFINITY;
826
+ } else {
827
+ disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
828
+ }
815
829
  }
816
830
  }
817
831
  }
@@ -828,6 +842,8 @@ void pairwise_indexed_L2sqr(
828
842
  for (int64_t j = 0; j < n; j++) {
829
843
  if (ix[j] >= 0 && iy[j] >= 0) {
830
844
  dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
845
+ } else {
846
+ dis[j] = INFINITY;
831
847
  }
832
848
  }
833
849
  }
@@ -844,6 +860,8 @@ void pairwise_indexed_inner_product(
844
860
  for (int64_t j = 0; j < n; j++) {
845
861
  if (ix[j] >= 0 && iy[j] >= 0) {
846
862
  dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
863
+ } else {
864
+ dis[j] = -INFINITY;
847
865
  }
848
866
  }
849
867
  }
@@ -857,6 +875,7 @@ void knn_inner_products_by_idx(
857
875
  size_t d,
858
876
  size_t nx,
859
877
  size_t ny,
878
+ size_t nsubset,
860
879
  size_t k,
861
880
  float* res_vals,
862
881
  int64_t* res_ids,
@@ -874,9 +893,10 @@ void knn_inner_products_by_idx(
874
893
  int64_t* __restrict idxi = res_ids + i * k;
875
894
  minheap_heapify(k, simi, idxi);
876
895
 
877
- for (j = 0; j < ny; j++) {
878
- if (idsi[j] < 0)
896
+ for (j = 0; j < nsubset; j++) {
897
+ if (idsi[j] < 0 || idsi[j] >= ny) {
879
898
  break;
899
+ }
880
900
  float ip = fvec_inner_product(x_, y + d * idsi[j], d);
881
901
 
882
902
  if (ip > simi[0]) {
@@ -894,6 +914,7 @@ void knn_L2sqr_by_idx(
894
914
  size_t d,
895
915
  size_t nx,
896
916
  size_t ny,
917
+ size_t nsubset,
897
918
  size_t k,
898
919
  float* res_vals,
899
920
  int64_t* res_ids,
@@ -908,7 +929,10 @@ void knn_L2sqr_by_idx(
908
929
  float* __restrict simi = res_vals + i * k;
909
930
  int64_t* __restrict idxi = res_ids + i * k;
910
931
  maxheap_heapify(k, simi, idxi);
911
- for (size_t j = 0; j < ny; j++) {
932
+ for (size_t j = 0; j < nsubset; j++) {
933
+ if (idsi[j] < 0 || idsi[j] >= ny) {
934
+ break;
935
+ }
912
936
  float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
913
937
 
914
938
  if (disij < simi[0]) {