faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -23,6 +23,10 @@
23
23
  #include <immintrin.h>
24
24
  #endif
25
25
 
26
+ #ifdef __AVX2__
27
+ #include <faiss/utils/transpose/transpose-avx2-inl.h>
28
+ #endif
29
+
26
30
  #ifdef __aarch64__
27
31
  #include <arm_neon.h>
28
32
  #endif
@@ -56,16 +60,6 @@ namespace faiss {
56
60
  * Reference implementations
57
61
  */
58
62
 
59
- float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
60
- size_t i;
61
- float res = 0;
62
- for (i = 0; i < d; i++) {
63
- const float tmp = x[i] - y[i];
64
- res += tmp * tmp;
65
- }
66
- return res;
67
- }
68
-
69
63
  float fvec_L1_ref(const float* x, const float* y, size_t d) {
70
64
  size_t i;
71
65
  float res = 0;
@@ -85,22 +79,6 @@ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
85
79
  return res;
86
80
  }
87
81
 
88
- float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
89
- size_t i;
90
- float res = 0;
91
- for (i = 0; i < d; i++)
92
- res += x[i] * y[i];
93
- return res;
94
- }
95
-
96
- float fvec_norm_L2sqr_ref(const float* x, size_t d) {
97
- size_t i;
98
- double res = 0;
99
- for (i = 0; i < d; i++)
100
- res += x[i] * x[i];
101
- return res;
102
- }
103
-
104
82
  void fvec_L2sqr_ny_ref(
105
83
  float* dis,
106
84
  const float* x,
@@ -203,6 +181,48 @@ void fvec_inner_products_ny_ref(
203
181
  }
204
182
  }
205
183
 
184
+ /*********************************************************
185
+ * Autovectorized implementations
186
+ */
187
+
188
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
189
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
190
+ float res = 0.F;
191
+ FAISS_PRAGMA_IMPRECISE_LOOP
192
+ for (size_t i = 0; i != d; ++i) {
193
+ res += x[i] * y[i];
194
+ }
195
+ return res;
196
+ }
197
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
198
+
199
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
200
+ float fvec_norm_L2sqr(const float* x, size_t d) {
201
+ // the double in the _ref is suspected to be a typo. Some of the manual
202
+ // implementations this replaces used float.
203
+ float res = 0;
204
+ FAISS_PRAGMA_IMPRECISE_LOOP
205
+ for (size_t i = 0; i != d; ++i) {
206
+ res += x[i] * x[i];
207
+ }
208
+
209
+ return res;
210
+ }
211
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
212
+
213
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
214
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
215
+ size_t i;
216
+ float res = 0;
217
+ FAISS_PRAGMA_IMPRECISE_LOOP
218
+ for (i = 0; i < d; i++) {
219
+ const float tmp = x[i] - y[i];
220
+ res += tmp * tmp;
221
+ }
222
+ return res;
223
+ }
224
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
225
+
206
226
  /*********************************************************
207
227
  * SSE and AVX implementations
208
228
  */
@@ -225,25 +245,6 @@ static inline __m128 masked_read(int d, const float* x) {
225
245
  // cannot use AVX2 _mm_mask_set1_epi32
226
246
  }
227
247
 
228
- float fvec_norm_L2sqr(const float* x, size_t d) {
229
- __m128 mx;
230
- __m128 msum1 = _mm_setzero_ps();
231
-
232
- while (d >= 4) {
233
- mx = _mm_loadu_ps(x);
234
- x += 4;
235
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
236
- d -= 4;
237
- }
238
-
239
- mx = masked_read(d, x);
240
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
241
-
242
- msum1 = _mm_hadd_ps(msum1, msum1);
243
- msum1 = _mm_hadd_ps(msum1, msum1);
244
- return _mm_cvtss_f32(msum1);
245
- }
246
-
247
248
  namespace {
248
249
 
249
250
  /// Function that does a component-wise operation between x and y
@@ -354,25 +355,25 @@ void fvec_op_ny_D4<ElementOpIP>(
354
355
  // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
356
  const __m256 m3 = _mm256_set1_ps(x[3]);
356
357
 
357
- const __m256i indices0 =
358
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
-
360
358
  for (i = 0; i < ny8 * 8; i += 8) {
361
- _mm_prefetch(y + 32, _MM_HINT_NTA);
362
- _mm_prefetch(y + 48, _MM_HINT_NTA);
363
-
364
- // collect dim 0 for 8 D4-vectors.
365
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
- // collect dim 1 for 8 D4-vectors.
368
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
- // collect dim 2 for 8 D4-vectors.
371
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
- // collect dim 3 for 8 D4-vectors.
374
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
359
+ // load 8x4 matrix and transpose it in registers.
360
+ // the typical bottleneck is memory access, so
361
+ // let's trade instructions for the bandwidth.
362
+
363
+ __m256 v0;
364
+ __m256 v1;
365
+ __m256 v2;
366
+ __m256 v3;
367
+
368
+ transpose_8x4(
369
+ _mm256_loadu_ps(y + 0 * 8),
370
+ _mm256_loadu_ps(y + 1 * 8),
371
+ _mm256_loadu_ps(y + 2 * 8),
372
+ _mm256_loadu_ps(y + 3 * 8),
373
+ v0,
374
+ v1,
375
+ v2,
376
+ v3);
376
377
 
377
378
  // compute distances
378
379
  __m256 distances = _mm256_mul_ps(m0, v0);
@@ -380,15 +381,7 @@ void fvec_op_ny_D4<ElementOpIP>(
380
381
  distances = _mm256_fmadd_ps(m2, v2, distances);
381
382
  distances = _mm256_fmadd_ps(m3, v3, distances);
382
383
 
383
- // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
- // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
- // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
- // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
- // ...
388
- // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
- // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
- // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
- // (x[3] * y[(i * 8 + 7) * 4 + 3])
384
+ // store
392
385
  _mm256_storeu_ps(dis + i, distances);
393
386
 
394
387
  y += 32;
@@ -432,25 +425,25 @@ void fvec_op_ny_D4<ElementOpL2>(
432
425
  // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
426
  const __m256 m3 = _mm256_set1_ps(x[3]);
434
427
 
435
- const __m256i indices0 =
436
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
-
438
428
  for (i = 0; i < ny8 * 8; i += 8) {
439
- _mm_prefetch(y + 32, _MM_HINT_NTA);
440
- _mm_prefetch(y + 48, _MM_HINT_NTA);
441
-
442
- // collect dim 0 for 8 D4-vectors.
443
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
- // collect dim 1 for 8 D4-vectors.
446
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
- // collect dim 2 for 8 D4-vectors.
449
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
- // collect dim 3 for 8 D4-vectors.
452
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
429
+ // load 8x4 matrix and transpose it in registers.
430
+ // the typical bottleneck is memory access, so
431
+ // let's trade instructions for the bandwidth.
432
+
433
+ __m256 v0;
434
+ __m256 v1;
435
+ __m256 v2;
436
+ __m256 v3;
437
+
438
+ transpose_8x4(
439
+ _mm256_loadu_ps(y + 0 * 8),
440
+ _mm256_loadu_ps(y + 1 * 8),
441
+ _mm256_loadu_ps(y + 2 * 8),
442
+ _mm256_loadu_ps(y + 3 * 8),
443
+ v0,
444
+ v1,
445
+ v2,
446
+ v3);
454
447
 
455
448
  // compute differences
456
449
  const __m256 d0 = _mm256_sub_ps(m0, v0);
@@ -464,15 +457,7 @@ void fvec_op_ny_D4<ElementOpL2>(
464
457
  distances = _mm256_fmadd_ps(d2, d2, distances);
465
458
  distances = _mm256_fmadd_ps(d3, d3, distances);
466
459
 
467
- // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
- // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
- // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
- // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
- // ...
472
- // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
- // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
- // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
- // (x[3] - y[(i * 8 + 7) * 4 + 3])
460
+ // store
476
461
  _mm256_storeu_ps(dis + i, distances);
477
462
 
478
463
  y += 32;
@@ -583,6 +568,228 @@ void fvec_inner_products_ny(
583
568
  }
584
569
 
585
570
  #ifdef __AVX2__
571
+ template <size_t DIM>
572
+ void fvec_L2sqr_ny_y_transposed_D(
573
+ float* distances,
574
+ const float* x,
575
+ const float* y,
576
+ const float* y_sqlen,
577
+ const size_t d_offset,
578
+ size_t ny) {
579
+ // current index being processed
580
+ size_t i = 0;
581
+
582
+ // squared length of x
583
+ float x_sqlen = 0;
584
+ ;
585
+ for (size_t j = 0; j < DIM; j++) {
586
+ x_sqlen += x[j] * x[j];
587
+ }
588
+
589
+ // process 8 vectors per loop.
590
+ const size_t ny8 = ny / 8;
591
+
592
+ if (ny8 > 0) {
593
+ // m[i] = (2 * x[i], ... 2 * x[i])
594
+ __m256 m[DIM];
595
+ for (size_t j = 0; j < DIM; j++) {
596
+ m[j] = _mm256_set1_ps(x[j]);
597
+ m[j] = _mm256_add_ps(m[j], m[j]);
598
+ }
599
+
600
+ __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
601
+
602
+ for (; i < ny8 * 8; i += 8) {
603
+ // collect dim 0 for 8 D4-vectors.
604
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
605
+
606
+ // compute dot products
607
+ // this is x^2 - 2x[0]*y[0]
608
+ __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
609
+
610
+ for (size_t j = 1; j < DIM; j++) {
611
+ // collect dim j for 8 D4-vectors.
612
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
613
+ dp = _mm256_fnmadd_ps(m[j], vj, dp);
614
+ }
615
+
616
+ // we've got x^2 - (2x, y) at this point
617
+
618
+ // y^2 - (2x, y) + x^2
619
+ __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
620
+
621
+ _mm256_storeu_ps(distances + i, distances_v);
622
+
623
+ // scroll y and y_sqlen forward.
624
+ y += 8;
625
+ y_sqlen += 8;
626
+ }
627
+ }
628
+
629
+ if (i < ny) {
630
+ // process leftovers
631
+ for (; i < ny; i++) {
632
+ float dp = 0;
633
+ for (size_t j = 0; j < DIM; j++) {
634
+ dp += x[j] * y[j * d_offset];
635
+ }
636
+
637
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
638
+ // lowest distance.
639
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
640
+ distances[i] = distance;
641
+
642
+ y += 1;
643
+ y_sqlen += 1;
644
+ }
645
+ }
646
+ }
647
+ #endif
648
+
649
+ void fvec_L2sqr_ny_transposed(
650
+ float* dis,
651
+ const float* x,
652
+ const float* y,
653
+ const float* y_sqlen,
654
+ size_t d,
655
+ size_t d_offset,
656
+ size_t ny) {
657
+ // optimized for a few special cases
658
+
659
+ #ifdef __AVX2__
660
+ #define DISPATCH(dval) \
661
+ case dval: \
662
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
663
+ dis, x, y, y_sqlen, d_offset, ny);
664
+
665
+ switch (d) {
666
+ DISPATCH(1)
667
+ DISPATCH(2)
668
+ DISPATCH(4)
669
+ DISPATCH(8)
670
+ default:
671
+ return fvec_L2sqr_ny_y_transposed_ref(
672
+ dis, x, y, y_sqlen, d, d_offset, ny);
673
+ }
674
+ #undef DISPATCH
675
+ #else
676
+ // non-AVX2 case
677
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
678
+ #endif
679
+ }
680
+
681
+ #ifdef __AVX2__
682
+
683
+ size_t fvec_L2sqr_ny_nearest_D2(
684
+ float* distances_tmp_buffer,
685
+ const float* x,
686
+ const float* y,
687
+ size_t ny) {
688
+ // this implementation does not use distances_tmp_buffer.
689
+
690
+ // current index being processed
691
+ size_t i = 0;
692
+
693
+ // min distance and the index of the closest vector so far
694
+ float current_min_distance = HUGE_VALF;
695
+ size_t current_min_index = 0;
696
+
697
+ // process 8 D2-vectors per loop.
698
+ const size_t ny8 = ny / 8;
699
+ if (ny8 > 0) {
700
+ _mm_prefetch(y, _MM_HINT_T0);
701
+ _mm_prefetch(y + 16, _MM_HINT_T0);
702
+
703
+ // track min distance and the closest vector independently
704
+ // for each of 8 AVX2 components.
705
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
706
+ __m256i min_indices = _mm256_set1_epi32(0);
707
+
708
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
709
+ const __m256i indices_increment = _mm256_set1_epi32(8);
710
+
711
+ // 1 value per register
712
+ const __m256 m0 = _mm256_set1_ps(x[0]);
713
+ const __m256 m1 = _mm256_set1_ps(x[1]);
714
+
715
+ for (; i < ny8 * 8; i += 8) {
716
+ _mm_prefetch(y + 32, _MM_HINT_T0);
717
+
718
+ __m256 v0;
719
+ __m256 v1;
720
+
721
+ transpose_8x2(
722
+ _mm256_loadu_ps(y + 0 * 8),
723
+ _mm256_loadu_ps(y + 1 * 8),
724
+ v0,
725
+ v1);
726
+
727
+ // compute differences
728
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
729
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
730
+
731
+ // compute squares of differences
732
+ __m256 distances = _mm256_mul_ps(d0, d0);
733
+ distances = _mm256_fmadd_ps(d1, d1, distances);
734
+
735
+ // compare the new distances to the min distances
736
+ // for each of 8 AVX2 components.
737
+ __m256 comparison =
738
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
739
+
740
+ // update min distances and indices with closest vectors if needed.
741
+ min_distances = _mm256_min_ps(distances, min_distances);
742
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
743
+ _mm256_castsi256_ps(current_indices),
744
+ _mm256_castsi256_ps(min_indices),
745
+ comparison));
746
+
747
+ // update current indices values. Basically, +8 to each of the
748
+ // 8 AVX2 components.
749
+ current_indices =
750
+ _mm256_add_epi32(current_indices, indices_increment);
751
+
752
+ // scroll y forward (8 vectors 2 DIM each).
753
+ y += 16;
754
+ }
755
+
756
+ // dump values and find the minimum distance / minimum index
757
+ float min_distances_scalar[8];
758
+ uint32_t min_indices_scalar[8];
759
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
760
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
761
+
762
+ for (size_t j = 0; j < 8; j++) {
763
+ if (current_min_distance > min_distances_scalar[j]) {
764
+ current_min_distance = min_distances_scalar[j];
765
+ current_min_index = min_indices_scalar[j];
766
+ }
767
+ }
768
+ }
769
+
770
+ if (i < ny) {
771
+ // process leftovers.
772
+ // the following code is not optimal, but it is rarely invoked.
773
+ float x0 = x[0];
774
+ float x1 = x[1];
775
+
776
+ for (; i < ny; i++) {
777
+ float sub0 = x0 - y[0];
778
+ float sub1 = x1 - y[1];
779
+ float distance = sub0 * sub0 + sub1 * sub1;
780
+
781
+ y += 2;
782
+
783
+ if (current_min_distance > distance) {
784
+ current_min_distance = distance;
785
+ current_min_index = i;
786
+ }
787
+ }
788
+ }
789
+
790
+ return current_min_index;
791
+ }
792
+
586
793
  size_t fvec_L2sqr_ny_nearest_D4(
587
794
  float* distances_tmp_buffer,
588
795
  const float* x,
@@ -609,38 +816,27 @@ size_t fvec_L2sqr_ny_nearest_D4(
609
816
  __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
817
  const __m256i indices_increment = _mm256_set1_epi32(8);
611
818
 
612
- //
613
- _mm_prefetch(y, _MM_HINT_NTA);
614
- _mm_prefetch(y + 16, _MM_HINT_NTA);
615
-
616
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
819
+ // 1 value per register
617
820
  const __m256 m0 = _mm256_set1_ps(x[0]);
618
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
821
  const __m256 m1 = _mm256_set1_ps(x[1]);
620
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
822
  const __m256 m2 = _mm256_set1_ps(x[2]);
622
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
823
  const __m256 m3 = _mm256_set1_ps(x[3]);
624
824
 
625
- const __m256i indices0 =
626
- _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
-
628
825
  for (; i < ny8 * 8; i += 8) {
629
- _mm_prefetch(y + 32, _MM_HINT_NTA);
630
- _mm_prefetch(y + 48, _MM_HINT_NTA);
631
-
632
- // collect dim 0 for 8 D4-vectors.
633
- // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
- const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
- // collect dim 1 for 8 D4-vectors.
636
- // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
- const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
- // collect dim 2 for 8 D4-vectors.
639
- // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
- const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
- // collect dim 3 for 8 D4-vectors.
642
- // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
- const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
826
+ __m256 v0;
827
+ __m256 v1;
828
+ __m256 v2;
829
+ __m256 v3;
830
+
831
+ transpose_8x4(
832
+ _mm256_loadu_ps(y + 0 * 8),
833
+ _mm256_loadu_ps(y + 1 * 8),
834
+ _mm256_loadu_ps(y + 2 * 8),
835
+ _mm256_loadu_ps(y + 3 * 8),
836
+ v0,
837
+ v1,
838
+ v2,
839
+ v3);
644
840
 
645
841
  // compute differences
646
842
  const __m256 d0 = _mm256_sub_ps(m0, v0);
@@ -654,24 +850,13 @@ size_t fvec_L2sqr_ny_nearest_D4(
654
850
  distances = _mm256_fmadd_ps(d2, d2, distances);
655
851
  distances = _mm256_fmadd_ps(d3, d3, distances);
656
852
 
657
- // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
- // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
- // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
- // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
- // ...
662
- // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
- // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
- // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
- // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
-
667
853
  // compare the new distances to the min distances
668
854
  // for each of 8 AVX2 components.
669
855
  __m256 comparison =
670
856
  _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
857
 
672
858
  // update min distances and indices with closest vectors if needed.
673
- min_distances =
674
- _mm256_blendv_ps(distances, min_distances, comparison);
859
+ min_distances = _mm256_min_ps(distances, min_distances);
675
860
  min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
861
  _mm256_castsi256_ps(current_indices),
677
862
  _mm256_castsi256_ps(min_indices),
@@ -721,7 +906,168 @@ size_t fvec_L2sqr_ny_nearest_D4(
721
906
 
722
907
  return current_min_index;
723
908
  }
909
+
910
+ size_t fvec_L2sqr_ny_nearest_D8(
911
+ float* distances_tmp_buffer,
912
+ const float* x,
913
+ const float* y,
914
+ size_t ny) {
915
+ // this implementation does not use distances_tmp_buffer.
916
+
917
+ // current index being processed
918
+ size_t i = 0;
919
+
920
+ // min distance and the index of the closest vector so far
921
+ float current_min_distance = HUGE_VALF;
922
+ size_t current_min_index = 0;
923
+
924
+ // process 8 D8-vectors per loop.
925
+ const size_t ny8 = ny / 8;
926
+ if (ny8 > 0) {
927
+ // track min distance and the closest vector independently
928
+ // for each of 8 AVX2 components.
929
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
930
+ __m256i min_indices = _mm256_set1_epi32(0);
931
+
932
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
933
+ const __m256i indices_increment = _mm256_set1_epi32(8);
934
+
935
+ // 1 value per register
936
+ const __m256 m0 = _mm256_set1_ps(x[0]);
937
+ const __m256 m1 = _mm256_set1_ps(x[1]);
938
+ const __m256 m2 = _mm256_set1_ps(x[2]);
939
+ const __m256 m3 = _mm256_set1_ps(x[3]);
940
+
941
+ const __m256 m4 = _mm256_set1_ps(x[4]);
942
+ const __m256 m5 = _mm256_set1_ps(x[5]);
943
+ const __m256 m6 = _mm256_set1_ps(x[6]);
944
+ const __m256 m7 = _mm256_set1_ps(x[7]);
945
+
946
+ for (; i < ny8 * 8; i += 8) {
947
+ __m256 v0;
948
+ __m256 v1;
949
+ __m256 v2;
950
+ __m256 v3;
951
+ __m256 v4;
952
+ __m256 v5;
953
+ __m256 v6;
954
+ __m256 v7;
955
+
956
+ transpose_8x8(
957
+ _mm256_loadu_ps(y + 0 * 8),
958
+ _mm256_loadu_ps(y + 1 * 8),
959
+ _mm256_loadu_ps(y + 2 * 8),
960
+ _mm256_loadu_ps(y + 3 * 8),
961
+ _mm256_loadu_ps(y + 4 * 8),
962
+ _mm256_loadu_ps(y + 5 * 8),
963
+ _mm256_loadu_ps(y + 6 * 8),
964
+ _mm256_loadu_ps(y + 7 * 8),
965
+ v0,
966
+ v1,
967
+ v2,
968
+ v3,
969
+ v4,
970
+ v5,
971
+ v6,
972
+ v7);
973
+
974
+ // compute differences
975
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
976
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
977
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
978
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
979
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
980
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
981
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
982
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
983
+
984
+ // compute squares of differences
985
+ __m256 distances = _mm256_mul_ps(d0, d0);
986
+ distances = _mm256_fmadd_ps(d1, d1, distances);
987
+ distances = _mm256_fmadd_ps(d2, d2, distances);
988
+ distances = _mm256_fmadd_ps(d3, d3, distances);
989
+ distances = _mm256_fmadd_ps(d4, d4, distances);
990
+ distances = _mm256_fmadd_ps(d5, d5, distances);
991
+ distances = _mm256_fmadd_ps(d6, d6, distances);
992
+ distances = _mm256_fmadd_ps(d7, d7, distances);
993
+
994
+ // compare the new distances to the min distances
995
+ // for each of 8 AVX2 components.
996
+ __m256 comparison =
997
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
998
+
999
+ // update min distances and indices with closest vectors if needed.
1000
+ min_distances = _mm256_min_ps(distances, min_distances);
1001
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
1002
+ _mm256_castsi256_ps(current_indices),
1003
+ _mm256_castsi256_ps(min_indices),
1004
+ comparison));
1005
+
1006
+ // update current indices values. Basically, +8 to each of the
1007
+ // 8 AVX2 components.
1008
+ current_indices =
1009
+ _mm256_add_epi32(current_indices, indices_increment);
1010
+
1011
+ // scroll y forward (8 vectors 8 DIM each).
1012
+ y += 64;
1013
+ }
1014
+
1015
+ // dump values and find the minimum distance / minimum index
1016
+ float min_distances_scalar[8];
1017
+ uint32_t min_indices_scalar[8];
1018
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
1019
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
1020
+
1021
+ for (size_t j = 0; j < 8; j++) {
1022
+ if (current_min_distance > min_distances_scalar[j]) {
1023
+ current_min_distance = min_distances_scalar[j];
1024
+ current_min_index = min_indices_scalar[j];
1025
+ }
1026
+ }
1027
+ }
1028
+
1029
+ if (i < ny) {
1030
+ // process leftovers
1031
+ __m256 x0 = _mm256_loadu_ps(x);
1032
+
1033
+ for (; i < ny; i++) {
1034
+ __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
1035
+ __m256 accu = _mm256_mul_ps(sub, sub);
1036
+ y += 8;
1037
+
1038
+ // horitontal sum
1039
+ const __m256 h0 = _mm256_hadd_ps(accu, accu);
1040
+ const __m256 h1 = _mm256_hadd_ps(h0, h0);
1041
+
1042
+ // extract high and low __m128 regs from __m256
1043
+ const __m128 h2 = _mm256_extractf128_ps(h1, 1);
1044
+ const __m128 h3 = _mm256_castps256_ps128(h1);
1045
+
1046
+ // get a final hsum into all 4 regs
1047
+ const __m128 h4 = _mm_add_ss(h2, h3);
1048
+
1049
+ // extract f[0] from __m128
1050
+ const float distance = _mm_cvtss_f32(h4);
1051
+
1052
+ if (current_min_distance > distance) {
1053
+ current_min_distance = distance;
1054
+ current_min_index = i;
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ return current_min_index;
1060
+ }
1061
+
724
1062
  #else
1063
+ size_t fvec_L2sqr_ny_nearest_D2(
1064
+ float* distances_tmp_buffer,
1065
+ const float* x,
1066
+ const float* y,
1067
+ size_t ny) {
1068
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
1069
+ }
1070
+
725
1071
  size_t fvec_L2sqr_ny_nearest_D4(
726
1072
  float* distances_tmp_buffer,
727
1073
  const float* x,
@@ -729,6 +1075,14 @@ size_t fvec_L2sqr_ny_nearest_D4(
729
1075
  size_t ny) {
730
1076
  return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
1077
  }
1078
+
1079
+ size_t fvec_L2sqr_ny_nearest_D8(
1080
+ float* distances_tmp_buffer,
1081
+ const float* x,
1082
+ const float* y,
1083
+ size_t ny) {
1084
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
1085
+ }
732
1086
  #endif
733
1087
 
734
1088
  size_t fvec_L2sqr_ny_nearest(
@@ -743,7 +1097,9 @@ size_t fvec_L2sqr_ny_nearest(
743
1097
  return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
1098
 
745
1099
  switch (d) {
1100
+ DISPATCH(2)
746
1101
  DISPATCH(4)
1102
+ DISPATCH(8)
747
1103
  default:
748
1104
  return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
1105
  }
@@ -919,79 +1275,6 @@ static inline __m256 masked_read_8(int d, const float* x) {
919
1275
  }
920
1276
  }
921
1277
 
922
- float fvec_inner_product(const float* x, const float* y, size_t d) {
923
- __m256 msum1 = _mm256_setzero_ps();
924
-
925
- while (d >= 8) {
926
- __m256 mx = _mm256_loadu_ps(x);
927
- x += 8;
928
- __m256 my = _mm256_loadu_ps(y);
929
- y += 8;
930
- msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
931
- d -= 8;
932
- }
933
-
934
- __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
935
- msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
936
-
937
- if (d >= 4) {
938
- __m128 mx = _mm_loadu_ps(x);
939
- x += 4;
940
- __m128 my = _mm_loadu_ps(y);
941
- y += 4;
942
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
943
- d -= 4;
944
- }
945
-
946
- if (d > 0) {
947
- __m128 mx = masked_read(d, x);
948
- __m128 my = masked_read(d, y);
949
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
950
- }
951
-
952
- msum2 = _mm_hadd_ps(msum2, msum2);
953
- msum2 = _mm_hadd_ps(msum2, msum2);
954
- return _mm_cvtss_f32(msum2);
955
- }
956
-
957
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
958
- __m256 msum1 = _mm256_setzero_ps();
959
-
960
- while (d >= 8) {
961
- __m256 mx = _mm256_loadu_ps(x);
962
- x += 8;
963
- __m256 my = _mm256_loadu_ps(y);
964
- y += 8;
965
- const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
966
- msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
967
- d -= 8;
968
- }
969
-
970
- __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
971
- msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
972
-
973
- if (d >= 4) {
974
- __m128 mx = _mm_loadu_ps(x);
975
- x += 4;
976
- __m128 my = _mm_loadu_ps(y);
977
- y += 4;
978
- const __m128 a_m_b1 = _mm_sub_ps(mx, my);
979
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
980
- d -= 4;
981
- }
982
-
983
- if (d > 0) {
984
- __m128 mx = masked_read(d, x);
985
- __m128 my = masked_read(d, y);
986
- __m128 a_m_b1 = _mm_sub_ps(mx, my);
987
- msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
988
- }
989
-
990
- msum2 = _mm_hadd_ps(msum2, msum2);
991
- msum2 = _mm_hadd_ps(msum2, msum2);
992
- return _mm_cvtss_f32(msum2);
993
- }
994
-
995
1278
  float fvec_L1(const float* x, const float* y, size_t d) {
996
1279
  __m256 msum1 = _mm256_setzero_ps();
997
1280
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
@@ -1082,113 +1365,8 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1082
1365
  return fvec_Linf_ref(x, y, d);
1083
1366
  }
1084
1367
 
1085
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1086
- __m128 msum1 = _mm_setzero_ps();
1087
-
1088
- while (d >= 4) {
1089
- __m128 mx = _mm_loadu_ps(x);
1090
- x += 4;
1091
- __m128 my = _mm_loadu_ps(y);
1092
- y += 4;
1093
- const __m128 a_m_b1 = _mm_sub_ps(mx, my);
1094
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
1095
- d -= 4;
1096
- }
1097
-
1098
- if (d > 0) {
1099
- // add the last 1, 2 or 3 values
1100
- __m128 mx = masked_read(d, x);
1101
- __m128 my = masked_read(d, y);
1102
- __m128 a_m_b1 = _mm_sub_ps(mx, my);
1103
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
1104
- }
1105
-
1106
- msum1 = _mm_hadd_ps(msum1, msum1);
1107
- msum1 = _mm_hadd_ps(msum1, msum1);
1108
- return _mm_cvtss_f32(msum1);
1109
- }
1110
-
1111
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1112
- __m128 mx, my;
1113
- __m128 msum1 = _mm_setzero_ps();
1114
-
1115
- while (d >= 4) {
1116
- mx = _mm_loadu_ps(x);
1117
- x += 4;
1118
- my = _mm_loadu_ps(y);
1119
- y += 4;
1120
- msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
1121
- d -= 4;
1122
- }
1123
-
1124
- // add the last 1, 2, or 3 values
1125
- mx = masked_read(d, x);
1126
- my = masked_read(d, y);
1127
- __m128 prod = _mm_mul_ps(mx, my);
1128
-
1129
- msum1 = _mm_add_ps(msum1, prod);
1130
-
1131
- msum1 = _mm_hadd_ps(msum1, msum1);
1132
- msum1 = _mm_hadd_ps(msum1, msum1);
1133
- return _mm_cvtss_f32(msum1);
1134
- }
1135
-
1136
1368
  #elif defined(__aarch64__)
1137
1369
 
1138
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1139
- float32x4_t accux4 = vdupq_n_f32(0);
1140
- const size_t d_simd = d - (d & 3);
1141
- size_t i;
1142
- for (i = 0; i < d_simd; i += 4) {
1143
- float32x4_t xi = vld1q_f32(x + i);
1144
- float32x4_t yi = vld1q_f32(y + i);
1145
- float32x4_t sq = vsubq_f32(xi, yi);
1146
- accux4 = vfmaq_f32(accux4, sq, sq);
1147
- }
1148
- float32_t accux1 = vaddvq_f32(accux4);
1149
- for (; i < d; ++i) {
1150
- float32_t xi = x[i];
1151
- float32_t yi = y[i];
1152
- float32_t sq = xi - yi;
1153
- accux1 += sq * sq;
1154
- }
1155
- return accux1;
1156
- }
1157
-
1158
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1159
- float32x4_t accux4 = vdupq_n_f32(0);
1160
- const size_t d_simd = d - (d & 3);
1161
- size_t i;
1162
- for (i = 0; i < d_simd; i += 4) {
1163
- float32x4_t xi = vld1q_f32(x + i);
1164
- float32x4_t yi = vld1q_f32(y + i);
1165
- accux4 = vfmaq_f32(accux4, xi, yi);
1166
- }
1167
- float32_t accux1 = vaddvq_f32(accux4);
1168
- for (; i < d; ++i) {
1169
- float32_t xi = x[i];
1170
- float32_t yi = y[i];
1171
- accux1 += xi * yi;
1172
- }
1173
- return accux1;
1174
- }
1175
-
1176
- float fvec_norm_L2sqr(const float* x, size_t d) {
1177
- float32x4_t accux4 = vdupq_n_f32(0);
1178
- const size_t d_simd = d - (d & 3);
1179
- size_t i;
1180
- for (i = 0; i < d_simd; i += 4) {
1181
- float32x4_t xi = vld1q_f32(x + i);
1182
- accux4 = vfmaq_f32(accux4, xi, xi);
1183
- }
1184
- float32_t accux1 = vaddvq_f32(accux4);
1185
- for (; i < d; ++i) {
1186
- float32_t xi = x[i];
1187
- accux1 += xi * xi;
1188
- }
1189
- return accux1;
1190
- }
1191
-
1192
1370
  // not optimized for ARM
1193
1371
  void fvec_L2sqr_ny(
1194
1372
  float* dis,
@@ -1199,6 +1377,17 @@ void fvec_L2sqr_ny(
1199
1377
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1200
1378
  }
1201
1379
 
1380
+ void fvec_L2sqr_ny_transposed(
1381
+ float* dis,
1382
+ const float* x,
1383
+ const float* y,
1384
+ const float* y_sqlen,
1385
+ size_t d,
1386
+ size_t d_offset,
1387
+ size_t ny) {
1388
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1389
+ }
1390
+
1202
1391
  size_t fvec_L2sqr_ny_nearest(
1203
1392
  float* distances_tmp_buffer,
1204
1393
  const float* x,
@@ -1240,10 +1429,6 @@ void fvec_inner_products_ny(
1240
1429
  #else
1241
1430
  // scalar implementation
1242
1431
 
1243
- float fvec_L2sqr(const float* x, const float* y, size_t d) {
1244
- return fvec_L2sqr_ref(x, y, d);
1245
- }
1246
-
1247
1432
  float fvec_L1(const float* x, const float* y, size_t d) {
1248
1433
  return fvec_L1_ref(x, y, d);
1249
1434
  }
@@ -1252,14 +1437,6 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
1252
1437
  return fvec_Linf_ref(x, y, d);
1253
1438
  }
1254
1439
 
1255
- float fvec_inner_product(const float* x, const float* y, size_t d) {
1256
- return fvec_inner_product_ref(x, y, d);
1257
- }
1258
-
1259
- float fvec_norm_L2sqr(const float* x, size_t d) {
1260
- return fvec_norm_L2sqr_ref(x, d);
1261
- }
1262
-
1263
1440
  void fvec_L2sqr_ny(
1264
1441
  float* dis,
1265
1442
  const float* x,
@@ -1269,6 +1446,17 @@ void fvec_L2sqr_ny(
1269
1446
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1270
1447
  }
1271
1448
 
1449
+ void fvec_L2sqr_ny_transposed(
1450
+ float* dis,
1451
+ const float* x,
1452
+ const float* y,
1453
+ const float* y_sqlen,
1454
+ size_t d,
1455
+ size_t d_offset,
1456
+ size_t ny) {
1457
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1458
+ }
1459
+
1272
1460
  size_t fvec_L2sqr_ny_nearest(
1273
1461
  float* distances_tmp_buffer,
1274
1462
  const float* x,