faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -23,6 +23,10 @@
23
23
  #include <immintrin.h>
24
24
  #endif
25
25
 
26
+ #ifdef __AVX2__
27
+ #include <faiss/utils/transpose/transpose-avx2-inl.h>
28
+ #endif
29
+
26
30
  #ifdef __aarch64__
27
31
  #include <arm_neon.h>
28
32
  #endif
@@ -56,16 +60,6 @@ namespace faiss {
56
60
  * Reference implementations
57
61
  */
58
62
 
59
- float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
60
- size_t i;
61
- float res = 0;
62
- for (i = 0; i < d; i++) {
63
- const float tmp = x[i] - y[i];
64
- res += tmp * tmp;
65
- }
66
- return res;
67
- }
68
-
69
63
  float fvec_L1_ref(const float* x, const float* y, size_t d) {
70
64
  size_t i;
71
65
  float res = 0;
@@ -85,22 +79,6 @@ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
85
79
  return res;
86
80
  }
87
81
 
88
- float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
89
- size_t i;
90
- float res = 0;
91
- for (i = 0; i < d; i++)
92
- res += x[i] * y[i];
93
- return res;
94
- }
95
-
96
- float fvec_norm_L2sqr_ref(const float* x, size_t d) {
97
- size_t i;
98
- double res = 0;
99
- for (i = 0; i < d; i++)
100
- res += x[i] * x[i];
101
- return res;
102
- }
103
-
104
82
  void fvec_L2sqr_ny_ref(
105
83
  float* dis,
106
84
  const float* x,
@@ -203,6 +181,48 @@ void fvec_inner_products_ny_ref(
203
181
  }
204
182
  }
205
183
 
184
+ /*********************************************************
185
+ * Autovectorized implementations
186
+ */
187
+
188
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
189
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
190
+ float res = 0.F;
191
+ FAISS_PRAGMA_IMPRECISE_LOOP
192
+ for (size_t i = 0; i != d; ++i) {
193
+ res += x[i] * y[i];
194
+ }
195
+ return res;
196
+ }
197
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
198
+
199
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
200
+ float fvec_norm_L2sqr(const float* x, size_t d) {
201
+ // the double in the _ref is suspected to be a typo. Some of the manual
202
+ // implementations this replaces used float.
203
+ float res = 0;
204
+ FAISS_PRAGMA_IMPRECISE_LOOP
205
+ for (size_t i = 0; i != d; ++i) {
206
+ res += x[i] * x[i];
207
+ }
208
+
209
+ return res;
210
+ }
211
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
212
+
213
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
214
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
215
+ size_t i;
216
+ float res = 0;
217
+ FAISS_PRAGMA_IMPRECISE_LOOP
218
+ for (i = 0; i < d; i++) {
219
+ const float tmp = x[i] - y[i];
220
+ res += tmp * tmp;
221
+ }
222
+ return res;
223
+ }
224
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
225
+
206
226
  /*********************************************************
207
227
  * SSE and AVX implementations
208
228
  */
@@ -225,25 +245,6 @@ static inline __m128 masked_read(int d, const float* x) {
225
245
  // cannot use AVX2 _mm_mask_set1_epi32
226
246
  }
227
247
 
228
- float fvec_norm_L2sqr(const float* x, size_t d) {
229
- __m128 mx;
230
- __m128 msum1 = _mm_setzero_ps();
231
-
232
- while (d >= 4) {
233
- mx = _mm_loadu_ps(x);
234
- x += 4;
235
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
236
- d -= 4;
237
- }
238
-
239
- mx = masked_read(d, x);
240
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
241
-
242
- msum1 = _mm_hadd_ps(msum1, msum1);
243
- msum1 = _mm_hadd_ps(msum1, msum1);
244
- return _mm_cvtss_f32(msum1);
245
- }
246
-
247
248
  namespace {
248
249
 
249
250
  /// Function that does a component-wise operation between x and y
@@ -354,25 +355,25 @@ void fvec_op_ny_D4<ElementOpIP>(
354
355
  // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
356
  const __m256 m3 = _mm256_set1_ps(x[3]);
356
357
 
357
- const __m256i indices0 =
358
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
-
360
358
  for (i = 0; i < ny8 * 8; i += 8) {
361
- _mm_prefetch(y + 32, _MM_HINT_NTA);
362
- _mm_prefetch(y + 48, _MM_HINT_NTA);
363
-
364
- // collect dim 0 for 8 D4-vectors.
365
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
- // collect dim 1 for 8 D4-vectors.
368
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
- // collect dim 2 for 8 D4-vectors.
371
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
- // collect dim 3 for 8 D4-vectors.
374
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
359
+ // load 8x4 matrix and transpose it in registers.
360
+ // the typical bottleneck is memory access, so
361
+ // let's trade instructions for the bandwidth.
362
+
363
+ __m256 v0;
364
+ __m256 v1;
365
+ __m256 v2;
366
+ __m256 v3;
367
+
368
+ transpose_8x4(
369
+ _mm256_loadu_ps(y + 0 * 8),
370
+ _mm256_loadu_ps(y + 1 * 8),
371
+ _mm256_loadu_ps(y + 2 * 8),
372
+ _mm256_loadu_ps(y + 3 * 8),
373
+ v0,
374
+ v1,
375
+ v2,
376
+ v3);
376
377
 
377
378
  // compute distances
378
379
  __m256 distances = _mm256_mul_ps(m0, v0);
@@ -380,15 +381,7 @@ void fvec_op_ny_D4<ElementOpIP>(
380
381
  distances = _mm256_fmadd_ps(m2, v2, distances);
381
382
  distances = _mm256_fmadd_ps(m3, v3, distances);
382
383
 
383
- // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
- // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
- // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
- // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
- // ...
388
- // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
- // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
- // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
- // (x[3] * y[(i * 8 + 7) * 4 + 3])
384
+ // store
392
385
  _mm256_storeu_ps(dis + i, distances);
393
386
 
394
387
  y += 32;
@@ -432,25 +425,25 @@ void fvec_op_ny_D4<ElementOpL2>(
432
425
  // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
426
  const __m256 m3 = _mm256_set1_ps(x[3]);
434
427
 
435
- const __m256i indices0 =
436
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
-
438
428
  for (i = 0; i < ny8 * 8; i += 8) {
439
- _mm_prefetch(y + 32, _MM_HINT_NTA);
440
- _mm_prefetch(y + 48, _MM_HINT_NTA);
441
-
442
- // collect dim 0 for 8 D4-vectors.
443
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
- // collect dim 1 for 8 D4-vectors.
446
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
- // collect dim 2 for 8 D4-vectors.
449
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
- // collect dim 3 for 8 D4-vectors.
452
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
429
+ // load 8x4 matrix and transpose it in registers.
430
+ // the typical bottleneck is memory access, so
431
+ // let's trade instructions for the bandwidth.
432
+
433
+ __m256 v0;
434
+ __m256 v1;
435
+ __m256 v2;
436
+ __m256 v3;
437
+
438
+ transpose_8x4(
439
+ _mm256_loadu_ps(y + 0 * 8),
440
+ _mm256_loadu_ps(y + 1 * 8),
441
+ _mm256_loadu_ps(y + 2 * 8),
442
+ _mm256_loadu_ps(y + 3 * 8),
443
+ v0,
444
+ v1,
445
+ v2,
446
+ v3);
454
447
 
455
448
  // compute differences
456
449
  const __m256 d0 = _mm256_sub_ps(m0, v0);
@@ -464,15 +457,7 @@ void fvec_op_ny_D4<ElementOpL2>(
464
457
  distances = _mm256_fmadd_ps(d2, d2, distances);
465
458
  distances = _mm256_fmadd_ps(d3, d3, distances);
466
459
 
467
- // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
- // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
- // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
- // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
- // ...
472
- // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
- // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
- // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
- // (x[3] - y[(i * 8 + 7) * 4 + 3])
460
+ // store
476
461
  _mm256_storeu_ps(dis + i, distances);
477
462
 
478
463
  y += 32;
@@ -583,6 +568,228 @@ void fvec_inner_products_ny(
583
568
  }
584
569
 
585
570
  #ifdef __AVX2__
571
+ template <size_t DIM>
572
+ void fvec_L2sqr_ny_y_transposed_D(
573
+ float* distances,
574
+ const float* x,
575
+ const float* y,
576
+ const float* y_sqlen,
577
+ const size_t d_offset,
578
+ size_t ny) {
579
+ // current index being processed
580
+ size_t i = 0;
581
+
582
+ // squared length of x
583
+ float x_sqlen = 0;
584
+ ;
585
+ for (size_t j = 0; j < DIM; j++) {
586
+ x_sqlen += x[j] * x[j];
587
+ }
588
+
589
+ // process 8 vectors per loop.
590
+ const size_t ny8 = ny / 8;
591
+
592
+ if (ny8 > 0) {
593
+ // m[i] = (2 * x[i], ... 2 * x[i])
594
+ __m256 m[DIM];
595
+ for (size_t j = 0; j < DIM; j++) {
596
+ m[j] = _mm256_set1_ps(x[j]);
597
+ m[j] = _mm256_add_ps(m[j], m[j]);
598
+ }
599
+
600
+ __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
601
+
602
+ for (; i < ny8 * 8; i += 8) {
603
+ // collect dim 0 for 8 D4-vectors.
604
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
605
+
606
+ // compute dot products
607
+ // this is x^2 - 2x[0]*y[0]
608
+ __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
609
+
610
+ for (size_t j = 1; j < DIM; j++) {
611
+ // collect dim j for 8 D4-vectors.
612
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
613
+ dp = _mm256_fnmadd_ps(m[j], vj, dp);
614
+ }
615
+
616
+ // we've got x^2 - (2x, y) at this point
617
+
618
+ // y^2 - (2x, y) + x^2
619
+ __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
620
+
621
+ _mm256_storeu_ps(distances + i, distances_v);
622
+
623
+ // scroll y and y_sqlen forward.
624
+ y += 8;
625
+ y_sqlen += 8;
626
+ }
627
+ }
628
+
629
+ if (i < ny) {
630
+ // process leftovers
631
+ for (; i < ny; i++) {
632
+ float dp = 0;
633
+ for (size_t j = 0; j < DIM; j++) {
634
+ dp += x[j] * y[j * d_offset];
635
+ }
636
+
637
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
638
+ // lowest distance.
639
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
640
+ distances[i] = distance;
641
+
642
+ y += 1;
643
+ y_sqlen += 1;
644
+ }
645
+ }
646
+ }
647
+ #endif
648
+
649
+ void fvec_L2sqr_ny_transposed(
650
+ float* dis,
651
+ const float* x,
652
+ const float* y,
653
+ const float* y_sqlen,
654
+ size_t d,
655
+ size_t d_offset,
656
+ size_t ny) {
657
+ // optimized for a few special cases
658
+
659
+ #ifdef __AVX2__
660
+ #define DISPATCH(dval) \
661
+ case dval: \
662
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
663
+ dis, x, y, y_sqlen, d_offset, ny);
664
+
665
+ switch (d) {
666
+ DISPATCH(1)
667
+ DISPATCH(2)
668
+ DISPATCH(4)
669
+ DISPATCH(8)
670
+ default:
671
+ return fvec_L2sqr_ny_y_transposed_ref(
672
+ dis, x, y, y_sqlen, d, d_offset, ny);
673
+ }
674
+ #undef DISPATCH
675
+ #else
676
+ // non-AVX2 case
677
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
678
+ #endif
679
+ }
680
+
681
+ #ifdef __AVX2__
682
+
683
+ size_t fvec_L2sqr_ny_nearest_D2(
684
+ float* distances_tmp_buffer,
685
+ const float* x,
686
+ const float* y,
687
+ size_t ny) {
688
+ // this implementation does not use distances_tmp_buffer.
689
+
690
+ // current index being processed
691
+ size_t i = 0;
692
+
693
+ // min distance and the index of the closest vector so far
694
+ float current_min_distance = HUGE_VALF;
695
+ size_t current_min_index = 0;
696
+
697
+ // process 8 D2-vectors per loop.
698
+ const size_t ny8 = ny / 8;
699
+ if (ny8 > 0) {
700
+ _mm_prefetch(y, _MM_HINT_T0);
701
+ _mm_prefetch(y + 16, _MM_HINT_T0);
702
+
703
+ // track min distance and the closest vector independently
704
+ // for each of 8 AVX2 components.
705
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
706
+ __m256i min_indices = _mm256_set1_epi32(0);
707
+
708
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
709
+ const __m256i indices_increment = _mm256_set1_epi32(8);
710
+
711
+ // 1 value per register
712
+ const __m256 m0 = _mm256_set1_ps(x[0]);
713
+ const __m256 m1 = _mm256_set1_ps(x[1]);
714
+
715
+ for (; i < ny8 * 8; i += 8) {
716
+ _mm_prefetch(y + 32, _MM_HINT_T0);
717
+
718
+ __m256 v0;
719
+ __m256 v1;
720
+
721
+ transpose_8x2(
722
+ _mm256_loadu_ps(y + 0 * 8),
723
+ _mm256_loadu_ps(y + 1 * 8),
724
+ v0,
725
+ v1);
726
+
727
+ // compute differences
728
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
729
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
730
+
731
+ // compute squares of differences
732
+ __m256 distances = _mm256_mul_ps(d0, d0);
733
+ distances = _mm256_fmadd_ps(d1, d1, distances);
734
+
735
+ // compare the new distances to the min distances
736
+ // for each of 8 AVX2 components.
737
+ __m256 comparison =
738
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
739
+
740
+ // update min distances and indices with closest vectors if needed.
741
+ min_distances = _mm256_min_ps(distances, min_distances);
742
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
743
+ _mm256_castsi256_ps(current_indices),
744
+ _mm256_castsi256_ps(min_indices),
745
+ comparison));
746
+
747
+ // update current indices values. Basically, +8 to each of the
748
+ // 8 AVX2 components.
749
+ current_indices =
750
+ _mm256_add_epi32(current_indices, indices_increment);
751
+
752
+ // scroll y forward (8 vectors 2 DIM each).
753
+ y += 16;
754
+ }
755
+
756
+ // dump values and find the minimum distance / minimum index
757
+ float min_distances_scalar[8];
758
+ uint32_t min_indices_scalar[8];
759
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
760
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
761
+
762
+ for (size_t j = 0; j < 8; j++) {
763
+ if (current_min_distance > min_distances_scalar[j]) {
764
+ current_min_distance = min_distances_scalar[j];
765
+ current_min_index = min_indices_scalar[j];
766
+ }
767
+ }
768
+ }
769
+
770
+ if (i < ny) {
771
+ // process leftovers.
772
+ // the following code is not optimal, but it is rarely invoked.
773
+ float x0 = x[0];
774
+ float x1 = x[1];
775
+
776
+ for (; i < ny; i++) {
777
+ float sub0 = x0 - y[0];
778
+ float sub1 = x1 - y[1];
779
+ float distance = sub0 * sub0 + sub1 * sub1;
780
+
781
+ y += 2;
782
+
783
+ if (current_min_distance > distance) {
784
+ current_min_distance = distance;
785
+ current_min_index = i;
786
+ }
787
+ }
788
+ }
789
+
790
+ return current_min_index;
791
+ }
792
+
586
793
  size_t fvec_L2sqr_ny_nearest_D4(
587
794
  float* distances_tmp_buffer,
588
795
  const float* x,
@@ -609,38 +816,27 @@ size_t fvec_L2sqr_ny_nearest_D4(
609
816
  __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
817
  const __m256i indices_increment = _mm256_set1_epi32(8);
611
818
 
612
- //
613
- _mm_prefetch(y, _MM_HINT_NTA);
614
- _mm_prefetch(y + 16, _MM_HINT_NTA);
615
-
616
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
819
+ // 1 value per register
617
820
  const __m256 m0 = _mm256_set1_ps(x[0]);
618
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
821
  const __m256 m1 = _mm256_set1_ps(x[1]);
620
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
822
  const __m256 m2 = _mm256_set1_ps(x[2]);
622
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
823
  const __m256 m3 = _mm256_set1_ps(x[3]);
624
824
 
625
- const __m256i indices0 =
626
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
-
628
825
  for (; i < ny8 * 8; i += 8) {
629
- _mm_prefetch(y + 32, _MM_HINT_NTA);
630
- _mm_prefetch(y + 48, _MM_HINT_NTA);
631
-
632
- // collect dim 0 for 8 D4-vectors.
633
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
- // collect dim 1 for 8 D4-vectors.
636
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
- // collect dim 2 for 8 D4-vectors.
639
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
- // collect dim 3 for 8 D4-vectors.
642
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
826
+ __m256 v0;
827
+ __m256 v1;
828
+ __m256 v2;
829
+ __m256 v3;
830
+
831
+ transpose_8x4(
832
+ _mm256_loadu_ps(y + 0 * 8),
833
+ _mm256_loadu_ps(y + 1 * 8),
834
+ _mm256_loadu_ps(y + 2 * 8),
835
+ _mm256_loadu_ps(y + 3 * 8),
836
+ v0,
837
+ v1,
838
+ v2,
839
+ v3);
644
840
 
645
841
  // compute differences
646
842
  const __m256 d0 = _mm256_sub_ps(m0, v0);
@@ -654,24 +850,13 @@ size_t fvec_L2sqr_ny_nearest_D4(
654
850
  distances = _mm256_fmadd_ps(d2, d2, distances);
655
851
  distances = _mm256_fmadd_ps(d3, d3, distances);
656
852
 
657
- // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
- // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
- // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
- // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
- // ...
662
- // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
- // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
- // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
- // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
-
667
853
  // compare the new distances to the min distances
668
854
  // for each of 8 AVX2 components.
669
855
  __m256 comparison =
670
856
  _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
857
 
672
858
  // update min distances and indices with closest vectors if needed.
673
- min_distances =
674
- _mm256_blendv_ps(distances, min_distances, comparison);
859
+ min_distances = _mm256_min_ps(distances, min_distances);
675
860
  min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
861
  _mm256_castsi256_ps(current_indices),
677
862
  _mm256_castsi256_ps(min_indices),
@@ -721,7 +906,168 @@ size_t fvec_L2sqr_ny_nearest_D4(
721
906
 
722
907
  return current_min_index;
723
908
  }
909
+
910
+ size_t fvec_L2sqr_ny_nearest_D8(
911
+ float* distances_tmp_buffer,
912
+ const float* x,
913
+ const float* y,
914
+ size_t ny) {
915
+ // this implementation does not use distances_tmp_buffer.
916
+
917
+ // current index being processed
918
+ size_t i = 0;
919
+
920
+ // min distance and the index of the closest vector so far
921
+ float current_min_distance = HUGE_VALF;
922
+ size_t current_min_index = 0;
923
+
924
+ // process 8 D8-vectors per loop.
925
+ const size_t ny8 = ny / 8;
926
+ if (ny8 > 0) {
927
+ // track min distance and the closest vector independently
928
+ // for each of 8 AVX2 components.
929
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
930
+ __m256i min_indices = _mm256_set1_epi32(0);
931
+
932
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
933
+ const __m256i indices_increment = _mm256_set1_epi32(8);
934
+
935
+ // 1 value per register
936
+ const __m256 m0 = _mm256_set1_ps(x[0]);
937
+ const __m256 m1 = _mm256_set1_ps(x[1]);
938
+ const __m256 m2 = _mm256_set1_ps(x[2]);
939
+ const __m256 m3 = _mm256_set1_ps(x[3]);
940
+
941
+ const __m256 m4 = _mm256_set1_ps(x[4]);
942
+ const __m256 m5 = _mm256_set1_ps(x[5]);
943
+ const __m256 m6 = _mm256_set1_ps(x[6]);
944
+ const __m256 m7 = _mm256_set1_ps(x[7]);
945
+
946
+ for (; i < ny8 * 8; i += 8) {
947
+ __m256 v0;
948
+ __m256 v1;
949
+ __m256 v2;
950
+ __m256 v3;
951
+ __m256 v4;
952
+ __m256 v5;
953
+ __m256 v6;
954
+ __m256 v7;
955
+
956
+ transpose_8x8(
957
+ _mm256_loadu_ps(y + 0 * 8),
958
+ _mm256_loadu_ps(y + 1 * 8),
959
+ _mm256_loadu_ps(y + 2 * 8),
960
+ _mm256_loadu_ps(y + 3 * 8),
961
+ _mm256_loadu_ps(y + 4 * 8),
962
+ _mm256_loadu_ps(y + 5 * 8),
963
+ _mm256_loadu_ps(y + 6 * 8),
964
+ _mm256_loadu_ps(y + 7 * 8),
965
+ v0,
966
+ v1,
967
+ v2,
968
+ v3,
969
+ v4,
970
+ v5,
971
+ v6,
972
+ v7);
973
+
974
+ // compute differences
975
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
976
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
977
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
978
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
979
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
980
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
981
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
982
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
983
+
984
+ // compute squares of differences
985
+ __m256 distances = _mm256_mul_ps(d0, d0);
986
+ distances = _mm256_fmadd_ps(d1, d1, distances);
987
+ distances = _mm256_fmadd_ps(d2, d2, distances);
988
+ distances = _mm256_fmadd_ps(d3, d3, distances);
989
+ distances = _mm256_fmadd_ps(d4, d4, distances);
990
+ distances = _mm256_fmadd_ps(d5, d5, distances);
991
+ distances = _mm256_fmadd_ps(d6, d6, distances);
992
+ distances = _mm256_fmadd_ps(d7, d7, distances);
993
+
994
+ // compare the new distances to the min distances
995
+ // for each of 8 AVX2 components.
996
+ __m256 comparison =
997
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
998
+
999
+ // update min distances and indices with closest vectors if needed.
1000
+ min_distances = _mm256_min_ps(distances, min_distances);
1001
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1002
+ _mm256_castsi256_ps(current_indices),
1003
+ _mm256_castsi256_ps(min_indices),
1004
+ comparison));
1005
+
1006
+ // update current indices values. Basically, +8 to each of the
1007
+ // 8 AVX2 components.
1008
+ current_indices =
1009
+ _mm256_add_epi32(current_indices, indices_increment);
1010
+
1011
+ // scroll y forward (8 vectors 8 DIM each).
1012
+ y += 64;
1013
+ }
1014
+
1015
+ // dump values and find the minimum distance / minimum index
1016
+ float min_distances_scalar[8];
1017
+ uint32_t min_indices_scalar[8];
1018
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
1019
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
1020
+
1021
+ for (size_t j = 0; j < 8; j++) {
1022
+ if (current_min_distance > min_distances_scalar[j]) {
1023
+ current_min_distance = min_distances_scalar[j];
1024
+ current_min_index = min_indices_scalar[j];
1025
+ }
1026
+ }
1027
+ }
1028
+
1029
+ if (i < ny) {
1030
+ // process leftovers
1031
+ __m256 x0 = _mm256_loadu_ps(x);
1032
+
1033
+ for (; i < ny; i++) {
1034
+ __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
1035
+ __m256 accu = _mm256_mul_ps(sub, sub);
1036
+ y += 8;
1037
+
1038
+ // horitontal sum
1039
+ const __m256 h0 = _mm256_hadd_ps(accu, accu);
1040
+ const __m256 h1 = _mm256_hadd_ps(h0, h0);
1041
+
1042
+ // extract high and low __m128 regs from __m256
1043
+ const __m128 h2 = _mm256_extractf128_ps(h1, 1);
1044
+ const __m128 h3 = _mm256_castps256_ps128(h1);
1045
+
1046
+ // get a final hsum into all 4 regs
1047
+ const __m128 h4 = _mm_add_ss(h2, h3);
1048
+
1049
+ // extract f[0] from __m128
1050
+ const float distance = _mm_cvtss_f32(h4);
1051
+
1052
+ if (current_min_distance > distance) {
1053
+ current_min_distance = distance;
1054
+ current_min_index = i;
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ return current_min_index;
1060
+ }
1061
+
724
1062
  #else
1063
+ size_t fvec_L2sqr_ny_nearest_D2(
1064
+ float* distances_tmp_buffer,
1065
+ const float* x,
1066
+ const float* y,
1067
+ size_t ny) {
1068
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
1069
+ }
1070
+
725
1071
  size_t fvec_L2sqr_ny_nearest_D4(
726
1072
  float* distances_tmp_buffer,
727
1073
  const float* x,
@@ -729,6 +1075,14 @@ size_t fvec_L2sqr_ny_nearest_D4(
729
1075
  size_t ny) {
730
1076
  return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
1077
  }
1078
+
1079
+ size_t fvec_L2sqr_ny_nearest_D8(
1080
+ float* distances_tmp_buffer,
1081
+ const float* x,
1082
+ const float* y,
1083
+ size_t ny) {
1084
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
1085
+ }
732
1086
  #endif
733
1087
 
734
1088
  size_t fvec_L2sqr_ny_nearest(
@@ -743,7 +1097,9 @@ size_t fvec_L2sqr_ny_nearest(
743
1097
  return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
1098
 
745
1099
  switch (d) {
1100
+ DISPATCH(2)
746
1101
  DISPATCH(4)
1102
+ DISPATCH(8)
747
1103
  default:
748
1104
  return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
1105
  }
@@ -919,79 +1275,6 @@ static inline __m256 masked_read_8(int d, const float* x) {
919
1275
  }
920
1276
  }
921
1277
 
922
- float fvec_inner_product(const float* x, const float* y, size_t d) {
923
- __m256 msum1 = _mm256_setzero_ps();
924
-
925
- while (d >= 8) {
926
- __m256 mx = _mm256_loadu_ps(x);
927
- x += 8;
928
- __m256 my = _mm256_loadu_ps(y);
929
- y += 8;
930
- msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
931
- d -= 8;
932
- }
933
-
934
- __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
935
- msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
936
-
937
- if (d >= 4) {
938
- __m128 mx = _mm_loadu_ps(x);
939
- x += 4;
940
- __m128 my = _mm_loadu_ps(y);
941
- y += 4;
942
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
943
- d -= 4;
944
- }
945
-
946
- if (d > 0) {
947
- __m128 mx = masked_read(d, x);
948
- __m128 my = masked_read(d, y);
949
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
950
- }
951
-
952
- msum2 = _mm_hadd_ps(msum2, msum2);
953
- msum2 = _mm_hadd_ps(msum2, msum2);
954
- return _mm_cvtss_f32(msum2);
955
- }
956
-
957
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
958
- __m256 msum1 = _mm256_setzero_ps();
959
-
960
- while (d >= 8) {
961
- __m256 mx = _mm256_loadu_ps(x);
962
- x += 8;
963
- __m256 my = _mm256_loadu_ps(y);
964
- y += 8;
965
- const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
966
- msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
967
- d -= 8;
968
- }
969
-
970
- __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
971
- msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
972
-
973
- if (d >= 4) {
974
- __m128 mx = _mm_loadu_ps(x);
975
- x += 4;
976
- __m128 my = _mm_loadu_ps(y);
977
- y += 4;
978
- const __m128 a_m_b1 = _mm_sub_ps(mx, my);
979
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
980
- d -= 4;
981
- }
982
-
983
- if (d > 0) {
984
- __m128 mx = masked_read(d, x);
985
- __m128 my = masked_read(d, y);
986
- __m128 a_m_b1 = _mm_sub_ps(mx, my);
987
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
988
- }
989
-
990
- msum2 = _mm_hadd_ps(msum2, msum2);
991
- msum2 = _mm_hadd_ps(msum2, msum2);
992
- return _mm_cvtss_f32(msum2);
993
- }
994
-
995
1278
  float fvec_L1(const float* x, const float* y, size_t d) {
996
1279
  __m256 msum1 = _mm256_setzero_ps();
997
1280
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
@@ -1082,113 +1365,8 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1082
1365
  return fvec_Linf_ref(x, y, d);
1083
1366
  }
1084
1367
 
1085
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1086
- __m128 msum1 = _mm_setzero_ps();
1087
-
1088
- while (d >= 4) {
1089
- __m128 mx = _mm_loadu_ps(x);
1090
- x += 4;
1091
- __m128 my = _mm_loadu_ps(y);
1092
- y += 4;
1093
- const __m128 a_m_b1 = _mm_sub_ps(mx, my);
1094
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
1095
- d -= 4;
1096
- }
1097
-
1098
- if (d > 0) {
1099
- // add the last 1, 2 or 3 values
1100
- __m128 mx = masked_read(d, x);
1101
- __m128 my = masked_read(d, y);
1102
- __m128 a_m_b1 = _mm_sub_ps(mx, my);
1103
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
1104
- }
1105
-
1106
- msum1 = _mm_hadd_ps(msum1, msum1);
1107
- msum1 = _mm_hadd_ps(msum1, msum1);
1108
- return _mm_cvtss_f32(msum1);
1109
- }
1110
-
1111
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1112
- __m128 mx, my;
1113
- __m128 msum1 = _mm_setzero_ps();
1114
-
1115
- while (d >= 4) {
1116
- mx = _mm_loadu_ps(x);
1117
- x += 4;
1118
- my = _mm_loadu_ps(y);
1119
- y += 4;
1120
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
1121
- d -= 4;
1122
- }
1123
-
1124
- // add the last 1, 2, or 3 values
1125
- mx = masked_read(d, x);
1126
- my = masked_read(d, y);
1127
- __m128 prod = _mm_mul_ps(mx, my);
1128
-
1129
- msum1 = _mm_add_ps(msum1, prod);
1130
-
1131
- msum1 = _mm_hadd_ps(msum1, msum1);
1132
- msum1 = _mm_hadd_ps(msum1, msum1);
1133
- return _mm_cvtss_f32(msum1);
1134
- }
1135
-
1136
1368
  #elif defined(__aarch64__)
1137
1369
 
1138
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1139
- float32x4_t accux4 = vdupq_n_f32(0);
1140
- const size_t d_simd = d - (d & 3);
1141
- size_t i;
1142
- for (i = 0; i < d_simd; i += 4) {
1143
- float32x4_t xi = vld1q_f32(x + i);
1144
- float32x4_t yi = vld1q_f32(y + i);
1145
- float32x4_t sq = vsubq_f32(xi, yi);
1146
- accux4 = vfmaq_f32(accux4, sq, sq);
1147
- }
1148
- float32_t accux1 = vaddvq_f32(accux4);
1149
- for (; i < d; ++i) {
1150
- float32_t xi = x[i];
1151
- float32_t yi = y[i];
1152
- float32_t sq = xi - yi;
1153
- accux1 += sq * sq;
1154
- }
1155
- return accux1;
1156
- }
1157
-
1158
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1159
- float32x4_t accux4 = vdupq_n_f32(0);
1160
- const size_t d_simd = d - (d & 3);
1161
- size_t i;
1162
- for (i = 0; i < d_simd; i += 4) {
1163
- float32x4_t xi = vld1q_f32(x + i);
1164
- float32x4_t yi = vld1q_f32(y + i);
1165
- accux4 = vfmaq_f32(accux4, xi, yi);
1166
- }
1167
- float32_t accux1 = vaddvq_f32(accux4);
1168
- for (; i < d; ++i) {
1169
- float32_t xi = x[i];
1170
- float32_t yi = y[i];
1171
- accux1 += xi * yi;
1172
- }
1173
- return accux1;
1174
- }
1175
-
1176
- float fvec_norm_L2sqr(const float* x, size_t d) {
1177
- float32x4_t accux4 = vdupq_n_f32(0);
1178
- const size_t d_simd = d - (d & 3);
1179
- size_t i;
1180
- for (i = 0; i < d_simd; i += 4) {
1181
- float32x4_t xi = vld1q_f32(x + i);
1182
- accux4 = vfmaq_f32(accux4, xi, xi);
1183
- }
1184
- float32_t accux1 = vaddvq_f32(accux4);
1185
- for (; i < d; ++i) {
1186
- float32_t xi = x[i];
1187
- accux1 += xi * xi;
1188
- }
1189
- return accux1;
1190
- }
1191
-
1192
1370
  // not optimized for ARM
1193
1371
  void fvec_L2sqr_ny(
1194
1372
  float* dis,
@@ -1199,6 +1377,17 @@ void fvec_L2sqr_ny(
1199
1377
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1200
1378
  }
1201
1379
 
1380
+ void fvec_L2sqr_ny_transposed(
1381
+ float* dis,
1382
+ const float* x,
1383
+ const float* y,
1384
+ const float* y_sqlen,
1385
+ size_t d,
1386
+ size_t d_offset,
1387
+ size_t ny) {
1388
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1389
+ }
1390
+
1202
1391
  size_t fvec_L2sqr_ny_nearest(
1203
1392
  float* distances_tmp_buffer,
1204
1393
  const float* x,
@@ -1240,10 +1429,6 @@ void fvec_inner_products_ny(
1240
1429
  #else
1241
1430
  // scalar implementation
1242
1431
 
1243
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1244
- return fvec_L2sqr_ref(x, y, d);
1245
- }
1246
-
1247
1432
  float fvec_L1(const float* x, const float* y, size_t d) {
1248
1433
  return fvec_L1_ref(x, y, d);
1249
1434
  }
@@ -1252,14 +1437,6 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1252
1437
  return fvec_Linf_ref(x, y, d);
1253
1438
  }
1254
1439
 
1255
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1256
- return fvec_inner_product_ref(x, y, d);
1257
- }
1258
-
1259
- float fvec_norm_L2sqr(const float* x, size_t d) {
1260
- return fvec_norm_L2sqr_ref(x, d);
1261
- }
1262
-
1263
1440
  void fvec_L2sqr_ny(
1264
1441
  float* dis,
1265
1442
  const float* x,
@@ -1269,6 +1446,17 @@ void fvec_L2sqr_ny(
1269
1446
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1270
1447
  }
1271
1448
 
1449
+ void fvec_L2sqr_ny_transposed(
1450
+ float* dis,
1451
+ const float* x,
1452
+ const float* y,
1453
+ const float* y_sqlen,
1454
+ size_t d,
1455
+ size_t d_offset,
1456
+ size_t ny) {
1457
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1458
+ }
1459
+
1272
1460
  size_t fvec_L2sqr_ny_nearest(
1273
1461
  float* distances_tmp_buffer,
1274
1462
  const float* x,