faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
+ #include <algorithm>
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstdio>
@@ -112,6 +113,74 @@ void fvec_L2sqr_ny_ref(
112
113
  }
113
114
  }
114
115
 
116
+ void fvec_L2sqr_ny_y_transposed_ref(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const float* y_sqlen,
121
+ size_t d,
122
+ size_t d_offset,
123
+ size_t ny) {
124
+ float x_sqlen = 0;
125
+ for (size_t j = 0; j < d; j++) {
126
+ x_sqlen += x[j] * x[j];
127
+ }
128
+
129
+ for (size_t i = 0; i < ny; i++) {
130
+ float dp = 0;
131
+ for (size_t j = 0; j < d; j++) {
132
+ dp += x[j] * y[i + j * d_offset];
133
+ }
134
+
135
+ dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
136
+ }
137
+ }
138
+
139
+ size_t fvec_L2sqr_ny_nearest_ref(
140
+ float* distances_tmp_buffer,
141
+ const float* x,
142
+ const float* y,
143
+ size_t d,
144
+ size_t ny) {
145
+ fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
146
+
147
+ size_t nearest_idx = 0;
148
+ float min_dis = HUGE_VALF;
149
+
150
+ for (size_t i = 0; i < ny; i++) {
151
+ if (distances_tmp_buffer[i] < min_dis) {
152
+ min_dis = distances_tmp_buffer[i];
153
+ nearest_idx = i;
154
+ }
155
+ }
156
+
157
+ return nearest_idx;
158
+ }
159
+
160
+ size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
161
+ float* distances_tmp_buffer,
162
+ const float* x,
163
+ const float* y,
164
+ const float* y_sqlen,
165
+ size_t d,
166
+ size_t d_offset,
167
+ size_t ny) {
168
+ fvec_L2sqr_ny_y_transposed_ref(
169
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
170
+
171
+ size_t nearest_idx = 0;
172
+ float min_dis = HUGE_VALF;
173
+
174
+ for (size_t i = 0; i < ny; i++) {
175
+ if (distances_tmp_buffer[i] < min_dis) {
176
+ min_dis = distances_tmp_buffer[i];
177
+ nearest_idx = i;
178
+ }
179
+ }
180
+
181
+ return nearest_idx;
182
+ }
183
+
115
184
  void fvec_inner_products_ny_ref(
116
185
  float* ip,
117
186
  const float* x,
@@ -257,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
257
326
  }
258
327
  }
259
328
 
329
+ #ifdef __AVX2__
330
+
331
+ // Specialized versions for AVX2 for any CPUs that support gather/scatter.
332
+ // Todo: implement fvec_op_ny_Dxxx in the same way.
333
+
334
+ template <>
335
+ void fvec_op_ny_D4<ElementOpIP>(
336
+ float* dis,
337
+ const float* x,
338
+ const float* y,
339
+ size_t ny) {
340
+ const size_t ny8 = ny / 8;
341
+ size_t i = 0;
342
+
343
+ if (ny8 > 0) {
344
+ // process 8 D4-vectors per loop.
345
+ _mm_prefetch(y, _MM_HINT_NTA);
346
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
347
+
348
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
349
+ const __m256 m0 = _mm256_set1_ps(x[0]);
350
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
351
+ const __m256 m1 = _mm256_set1_ps(x[1]);
352
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
353
+ const __m256 m2 = _mm256_set1_ps(x[2]);
354
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
+ const __m256 m3 = _mm256_set1_ps(x[3]);
356
+
357
+ const __m256i indices0 =
358
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
+
360
+ for (i = 0; i < ny8 * 8; i += 8) {
361
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
362
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
363
+
364
+ // collect dim 0 for 8 D4-vectors.
365
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
+ // collect dim 1 for 8 D4-vectors.
368
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
+ // collect dim 2 for 8 D4-vectors.
371
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
+ // collect dim 3 for 8 D4-vectors.
374
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
376
+
377
+ // compute distances
378
+ __m256 distances = _mm256_mul_ps(m0, v0);
379
+ distances = _mm256_fmadd_ps(m1, v1, distances);
380
+ distances = _mm256_fmadd_ps(m2, v2, distances);
381
+ distances = _mm256_fmadd_ps(m3, v3, distances);
382
+
383
+ // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
+ // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
+ // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
+ // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
+ // ...
388
+ // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
+ // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
+ // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
+ // (x[3] * y[(i * 8 + 7) * 4 + 3])
392
+ _mm256_storeu_ps(dis + i, distances);
393
+
394
+ y += 32;
395
+ }
396
+ }
397
+
398
+ if (i < ny) {
399
+ // process leftovers
400
+ __m128 x0 = _mm_loadu_ps(x);
401
+
402
+ for (; i < ny; i++) {
403
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
404
+ y += 4;
405
+ accu = _mm_hadd_ps(accu, accu);
406
+ accu = _mm_hadd_ps(accu, accu);
407
+ dis[i] = _mm_cvtss_f32(accu);
408
+ }
409
+ }
410
+ }
411
+
412
+ template <>
413
+ void fvec_op_ny_D4<ElementOpL2>(
414
+ float* dis,
415
+ const float* x,
416
+ const float* y,
417
+ size_t ny) {
418
+ const size_t ny8 = ny / 8;
419
+ size_t i = 0;
420
+
421
+ if (ny8 > 0) {
422
+ // process 8 D4-vectors per loop.
423
+ _mm_prefetch(y, _MM_HINT_NTA);
424
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
425
+
426
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
427
+ const __m256 m0 = _mm256_set1_ps(x[0]);
428
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
429
+ const __m256 m1 = _mm256_set1_ps(x[1]);
430
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
431
+ const __m256 m2 = _mm256_set1_ps(x[2]);
432
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
+ const __m256 m3 = _mm256_set1_ps(x[3]);
434
+
435
+ const __m256i indices0 =
436
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
+
438
+ for (i = 0; i < ny8 * 8; i += 8) {
439
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
440
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
441
+
442
+ // collect dim 0 for 8 D4-vectors.
443
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
+ // collect dim 1 for 8 D4-vectors.
446
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
+ // collect dim 2 for 8 D4-vectors.
449
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
+ // collect dim 3 for 8 D4-vectors.
452
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
454
+
455
+ // compute differences
456
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
457
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
458
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
459
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
460
+
461
+ // compute squares of differences
462
+ __m256 distances = _mm256_mul_ps(d0, d0);
463
+ distances = _mm256_fmadd_ps(d1, d1, distances);
464
+ distances = _mm256_fmadd_ps(d2, d2, distances);
465
+ distances = _mm256_fmadd_ps(d3, d3, distances);
466
+
467
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
+ // ...
472
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
476
+ _mm256_storeu_ps(dis + i, distances);
477
+
478
+ y += 32;
479
+ }
480
+ }
481
+
482
+ if (i < ny) {
483
+ // process leftovers
484
+ __m128 x0 = _mm_loadu_ps(x);
485
+
486
+ for (; i < ny; i++) {
487
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
488
+ y += 4;
489
+ accu = _mm_hadd_ps(accu, accu);
490
+ accu = _mm_hadd_ps(accu, accu);
491
+ dis[i] = _mm_cvtss_f32(accu);
492
+ }
493
+ }
494
+ }
495
+
496
+ #endif
497
+
260
498
  template <class ElementOp>
261
499
  void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
262
500
  __m128 x0 = _mm_loadu_ps(x);
@@ -344,6 +582,324 @@ void fvec_inner_products_ny(
344
582
  #undef DISPATCH
345
583
  }
346
584
 
585
+ #ifdef __AVX2__
586
+ size_t fvec_L2sqr_ny_nearest_D4(
587
+ float* distances_tmp_buffer,
588
+ const float* x,
589
+ const float* y,
590
+ size_t ny) {
591
+ // this implementation does not use distances_tmp_buffer.
592
+
593
+ // current index being processed
594
+ size_t i = 0;
595
+
596
+ // min distance and the index of the closest vector so far
597
+ float current_min_distance = HUGE_VALF;
598
+ size_t current_min_index = 0;
599
+
600
+ // process 8 D4-vectors per loop.
601
+ const size_t ny8 = ny / 8;
602
+
603
+ if (ny8 > 0) {
604
+ // track min distance and the closest vector independently
605
+ // for each of 8 AVX2 components.
606
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
607
+ __m256i min_indices = _mm256_set1_epi32(0);
608
+
609
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
+ const __m256i indices_increment = _mm256_set1_epi32(8);
611
+
612
+ //
613
+ _mm_prefetch(y, _MM_HINT_NTA);
614
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
615
+
616
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
617
+ const __m256 m0 = _mm256_set1_ps(x[0]);
618
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
+ const __m256 m1 = _mm256_set1_ps(x[1]);
620
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
+ const __m256 m2 = _mm256_set1_ps(x[2]);
622
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
+ const __m256 m3 = _mm256_set1_ps(x[3]);
624
+
625
+ const __m256i indices0 =
626
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
+
628
+ for (; i < ny8 * 8; i += 8) {
629
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
630
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
631
+
632
+ // collect dim 0 for 8 D4-vectors.
633
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
+ // collect dim 1 for 8 D4-vectors.
636
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
+ // collect dim 2 for 8 D4-vectors.
639
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
+ // collect dim 3 for 8 D4-vectors.
642
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
644
+
645
+ // compute differences
646
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
647
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
648
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
649
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
650
+
651
+ // compute squares of differences
652
+ __m256 distances = _mm256_mul_ps(d0, d0);
653
+ distances = _mm256_fmadd_ps(d1, d1, distances);
654
+ distances = _mm256_fmadd_ps(d2, d2, distances);
655
+ distances = _mm256_fmadd_ps(d3, d3, distances);
656
+
657
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
+ // ...
662
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
+
667
+ // compare the new distances to the min distances
668
+ // for each of 8 AVX2 components.
669
+ __m256 comparison =
670
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
+
672
+ // update min distances and indices with closest vectors if needed.
673
+ min_distances =
674
+ _mm256_blendv_ps(distances, min_distances, comparison);
675
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
+ _mm256_castsi256_ps(current_indices),
677
+ _mm256_castsi256_ps(min_indices),
678
+ comparison));
679
+
680
+ // update current indices values. Basically, +8 to each of the
681
+ // 8 AVX2 components.
682
+ current_indices =
683
+ _mm256_add_epi32(current_indices, indices_increment);
684
+
685
+ // scroll y forward (8 vectors 4 DIM each).
686
+ y += 32;
687
+ }
688
+
689
+ // dump values and find the minimum distance / minimum index
690
+ float min_distances_scalar[8];
691
+ uint32_t min_indices_scalar[8];
692
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
693
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
694
+
695
+ for (size_t j = 0; j < 8; j++) {
696
+ if (current_min_distance > min_distances_scalar[j]) {
697
+ current_min_distance = min_distances_scalar[j];
698
+ current_min_index = min_indices_scalar[j];
699
+ }
700
+ }
701
+ }
702
+
703
+ if (i < ny) {
704
+ // process leftovers
705
+ __m128 x0 = _mm_loadu_ps(x);
706
+
707
+ for (; i < ny; i++) {
708
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
709
+ y += 4;
710
+ accu = _mm_hadd_ps(accu, accu);
711
+ accu = _mm_hadd_ps(accu, accu);
712
+
713
+ const auto distance = _mm_cvtss_f32(accu);
714
+
715
+ if (current_min_distance > distance) {
716
+ current_min_distance = distance;
717
+ current_min_index = i;
718
+ }
719
+ }
720
+ }
721
+
722
+ return current_min_index;
723
+ }
724
+ #else
725
+ size_t fvec_L2sqr_ny_nearest_D4(
726
+ float* distances_tmp_buffer,
727
+ const float* x,
728
+ const float* y,
729
+ size_t ny) {
730
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
+ }
732
+ #endif
733
+
734
+ size_t fvec_L2sqr_ny_nearest(
735
+ float* distances_tmp_buffer,
736
+ const float* x,
737
+ const float* y,
738
+ size_t d,
739
+ size_t ny) {
740
+ // optimized for a few special cases
741
+ #define DISPATCH(dval) \
742
+ case dval: \
743
+ return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
+
745
+ switch (d) {
746
+ DISPATCH(4)
747
+ default:
748
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
+ }
750
+ #undef DISPATCH
751
+ }
752
+
753
+ #ifdef __AVX2__
754
+ template <size_t DIM>
755
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
756
+ float* distances_tmp_buffer,
757
+ const float* x,
758
+ const float* y,
759
+ const float* y_sqlen,
760
+ const size_t d_offset,
761
+ size_t ny) {
762
+ // this implementation does not use distances_tmp_buffer.
763
+
764
+ // current index being processed
765
+ size_t i = 0;
766
+
767
+ // min distance and the index of the closest vector so far
768
+ float current_min_distance = HUGE_VALF;
769
+ size_t current_min_index = 0;
770
+
771
+ // process 8 vectors per loop.
772
+ const size_t ny8 = ny / 8;
773
+
774
+ if (ny8 > 0) {
775
+ // track min distance and the closest vector independently
776
+ // for each of 8 AVX2 components.
777
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
778
+ __m256i min_indices = _mm256_set1_epi32(0);
779
+
780
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
781
+ const __m256i indices_increment = _mm256_set1_epi32(8);
782
+
783
+ // m[i] = (2 * x[i], ... 2 * x[i])
784
+ __m256 m[DIM];
785
+ for (size_t j = 0; j < DIM; j++) {
786
+ m[j] = _mm256_set1_ps(x[j]);
787
+ m[j] = _mm256_add_ps(m[j], m[j]);
788
+ }
789
+
790
+ for (; i < ny8 * 8; i += 8) {
791
+ // collect dim 0 for 8 D4-vectors.
792
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
793
+ // compute dot products
794
+ __m256 dp = _mm256_mul_ps(m[0], v0);
795
+
796
+ for (size_t j = 1; j < DIM; j++) {
797
+ // collect dim j for 8 D4-vectors.
798
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
799
+ dp = _mm256_fmadd_ps(m[j], vj, dp);
800
+ }
801
+
802
+ // compute y^2 - (2 * x, y), which is sufficient for looking for the
803
+ // lowest distance.
804
+ // x^2 is the constant that can be avoided.
805
+ const __m256 distances =
806
+ _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
807
+
808
+ // compare the new distances to the min distances
809
+ // for each of 8 AVX2 components.
810
+ const __m256 comparison =
811
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
812
+
813
+ // update min distances and indices with closest vectors if needed.
814
+ min_distances =
815
+ _mm256_blendv_ps(distances, min_distances, comparison);
816
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
817
+ _mm256_castsi256_ps(current_indices),
818
+ _mm256_castsi256_ps(min_indices),
819
+ comparison));
820
+
821
+ // update current indices values. Basically, +8 to each of the
822
+ // 8 AVX2 components.
823
+ current_indices =
824
+ _mm256_add_epi32(current_indices, indices_increment);
825
+
826
+ // scroll y and y_sqlen forward.
827
+ y += 8;
828
+ y_sqlen += 8;
829
+ }
830
+
831
+ // dump values and find the minimum distance / minimum index
832
+ float min_distances_scalar[8];
833
+ uint32_t min_indices_scalar[8];
834
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
835
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
836
+
837
+ for (size_t j = 0; j < 8; j++) {
838
+ if (current_min_distance > min_distances_scalar[j]) {
839
+ current_min_distance = min_distances_scalar[j];
840
+ current_min_index = min_indices_scalar[j];
841
+ }
842
+ }
843
+ }
844
+
845
+ if (i < ny) {
846
+ // process leftovers
847
+ for (; i < ny; i++) {
848
+ float dp = 0;
849
+ for (size_t j = 0; j < DIM; j++) {
850
+ dp += x[j] * y[j * d_offset];
851
+ }
852
+
853
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
854
+ // lowest distance.
855
+ const float distance = y_sqlen[0] - 2 * dp;
856
+
857
+ if (current_min_distance > distance) {
858
+ current_min_distance = distance;
859
+ current_min_index = i;
860
+ }
861
+
862
+ y += 1;
863
+ y_sqlen += 1;
864
+ }
865
+ }
866
+
867
+ return current_min_index;
868
+ }
869
+ #endif
870
+
871
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
872
+ float* distances_tmp_buffer,
873
+ const float* x,
874
+ const float* y,
875
+ const float* y_sqlen,
876
+ size_t d,
877
+ size_t d_offset,
878
+ size_t ny) {
879
+ // optimized for a few special cases
880
+ #ifdef __AVX2__
881
+ #define DISPATCH(dval) \
882
+ case dval: \
883
+ return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
884
+ distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
885
+
886
+ switch (d) {
887
+ DISPATCH(1)
888
+ DISPATCH(2)
889
+ DISPATCH(4)
890
+ DISPATCH(8)
891
+ default:
892
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
893
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
894
+ }
895
+ #undef DISPATCH
896
+ #else
897
+ // non-AVX2 case
898
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
899
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
900
+ #endif
901
+ }
902
+
347
903
  #endif
348
904
 
349
905
  #ifdef USE_AVX
@@ -589,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
589
1145
  float32x4_t sq = vsubq_f32(xi, yi);
590
1146
  accux4 = vfmaq_f32(accux4, sq, sq);
591
1147
  }
592
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
593
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1148
+ float32_t accux1 = vaddvq_f32(accux4);
594
1149
  for (; i < d; ++i) {
595
1150
  float32_t xi = x[i];
596
1151
  float32_t yi = y[i];
@@ -609,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
609
1164
  float32x4_t yi = vld1q_f32(y + i);
610
1165
  accux4 = vfmaq_f32(accux4, xi, yi);
611
1166
  }
612
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
613
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1167
+ float32_t accux1 = vaddvq_f32(accux4);
614
1168
  for (; i < d; ++i) {
615
1169
  float32_t xi = x[i];
616
1170
  float32_t yi = y[i];
@@ -627,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
627
1181
  float32x4_t xi = vld1q_f32(x + i);
628
1182
  accux4 = vfmaq_f32(accux4, xi, xi);
629
1183
  }
630
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
631
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1184
+ float32_t accux1 = vaddvq_f32(accux4);
632
1185
  for (; i < d; ++i) {
633
1186
  float32_t xi = x[i];
634
1187
  accux1 += xi * xi;
@@ -646,6 +1199,27 @@ void fvec_L2sqr_ny(
646
1199
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
647
1200
  }
648
1201
 
1202
+ size_t fvec_L2sqr_ny_nearest(
1203
+ float* distances_tmp_buffer,
1204
+ const float* x,
1205
+ const float* y,
1206
+ size_t d,
1207
+ size_t ny) {
1208
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1209
+ }
1210
+
1211
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1212
+ float* distances_tmp_buffer,
1213
+ const float* x,
1214
+ const float* y,
1215
+ const float* y_sqlen,
1216
+ size_t d,
1217
+ size_t d_offset,
1218
+ size_t ny) {
1219
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1220
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1221
+ }
1222
+
649
1223
  float fvec_L1(const float* x, const float* y, size_t d) {
650
1224
  return fvec_L1_ref(x, y, d);
651
1225
  }
@@ -695,6 +1269,27 @@ void fvec_L2sqr_ny(
695
1269
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
696
1270
  }
697
1271
 
1272
+ size_t fvec_L2sqr_ny_nearest(
1273
+ float* distances_tmp_buffer,
1274
+ const float* x,
1275
+ const float* y,
1276
+ size_t d,
1277
+ size_t ny) {
1278
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1279
+ }
1280
+
1281
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1282
+ float* distances_tmp_buffer,
1283
+ const float* x,
1284
+ const float* y,
1285
+ const float* y_sqlen,
1286
+ size_t d,
1287
+ size_t d_offset,
1288
+ size_t ny) {
1289
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1290
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1291
+ }
1292
+
698
1293
  void fvec_inner_products_ny(
699
1294
  float* dis,
700
1295
  const float* x,
@@ -720,6 +1315,61 @@ static inline void fvec_madd_ref(
720
1315
  c[i] = a[i] + bf * b[i];
721
1316
  }
722
1317
 
1318
+ #ifdef __AVX2__
1319
+ static inline void fvec_madd_avx2(
1320
+ const size_t n,
1321
+ const float* __restrict a,
1322
+ const float bf,
1323
+ const float* __restrict b,
1324
+ float* __restrict c) {
1325
+ //
1326
+ const size_t n8 = n / 8;
1327
+ const size_t n_for_masking = n % 8;
1328
+
1329
+ const __m256 bfmm = _mm256_set1_ps(bf);
1330
+
1331
+ size_t idx = 0;
1332
+ for (idx = 0; idx < n8 * 8; idx += 8) {
1333
+ const __m256 ax = _mm256_loadu_ps(a + idx);
1334
+ const __m256 bx = _mm256_loadu_ps(b + idx);
1335
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1336
+ _mm256_storeu_ps(c + idx, abmul);
1337
+ }
1338
+
1339
+ if (n_for_masking > 0) {
1340
+ __m256i mask;
1341
+ switch (n_for_masking) {
1342
+ case 1:
1343
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
1344
+ break;
1345
+ case 2:
1346
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
1347
+ break;
1348
+ case 3:
1349
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
1350
+ break;
1351
+ case 4:
1352
+ mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
1353
+ break;
1354
+ case 5:
1355
+ mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
1356
+ break;
1357
+ case 6:
1358
+ mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
1359
+ break;
1360
+ case 7:
1361
+ mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
1362
+ break;
1363
+ }
1364
+
1365
+ const __m256 ax = _mm256_maskload_ps(a + idx, mask);
1366
+ const __m256 bx = _mm256_maskload_ps(b + idx, mask);
1367
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1368
+ _mm256_maskstore_ps(c + idx, mask, abmul);
1369
+ }
1370
+ }
1371
+ #endif
1372
+
723
1373
  #ifdef __SSE3__
724
1374
 
725
1375
  static inline void fvec_madd_sse(
@@ -743,10 +1393,30 @@ static inline void fvec_madd_sse(
743
1393
  }
744
1394
 
745
1395
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1396
+ #ifdef __AVX2__
1397
+ fvec_madd_avx2(n, a, bf, b, c);
1398
+ #else
746
1399
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
747
1400
  fvec_madd_sse(n, a, bf, b, c);
748
1401
  else
749
1402
  fvec_madd_ref(n, a, bf, b, c);
1403
+ #endif
1404
+ }
1405
+
1406
+ #elif defined(__aarch64__)
1407
+
1408
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1409
+ const size_t n_simd = n - (n & 3);
1410
+ const float32x4_t bfv = vdupq_n_f32(bf);
1411
+ size_t i;
1412
+ for (i = 0; i < n_simd; i += 4) {
1413
+ const float32x4_t ai = vld1q_f32(a + i);
1414
+ const float32x4_t bi = vld1q_f32(b + i);
1415
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1416
+ vst1q_f32(c + i, ci);
1417
+ }
1418
+ for (; i < n; ++i)
1419
+ c[i] = a[i] + bf * b[i];
750
1420
  }
751
1421
 
752
1422
  #else
@@ -842,6 +1512,57 @@ int fvec_madd_and_argmin(
842
1512
  return fvec_madd_and_argmin_ref(n, a, bf, b, c);
843
1513
  }
844
1514
 
1515
+ #elif defined(__aarch64__)
1516
+
1517
+ int fvec_madd_and_argmin(
1518
+ size_t n,
1519
+ const float* a,
1520
+ float bf,
1521
+ const float* b,
1522
+ float* c) {
1523
+ float32x4_t vminv = vdupq_n_f32(1e20);
1524
+ uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
1525
+ size_t i;
1526
+ {
1527
+ const size_t n_simd = n - (n & 3);
1528
+ const uint32_t iota[] = {0, 1, 2, 3};
1529
+ uint32x4_t iv = vld1q_u32(iota);
1530
+ const uint32x4_t incv = vdupq_n_u32(4);
1531
+ const float32x4_t bfv = vdupq_n_f32(bf);
1532
+ for (i = 0; i < n_simd; i += 4) {
1533
+ const float32x4_t ai = vld1q_f32(a + i);
1534
+ const float32x4_t bi = vld1q_f32(b + i);
1535
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1536
+ vst1q_f32(c + i, ci);
1537
+ const uint32x4_t less_than = vcltq_f32(ci, vminv);
1538
+ vminv = vminq_f32(ci, vminv);
1539
+ iminv = vorrq_u32(
1540
+ vandq_u32(less_than, iv),
1541
+ vandq_u32(vmvnq_u32(less_than), iminv));
1542
+ iv = vaddq_u32(iv, incv);
1543
+ }
1544
+ }
1545
+ float vmin = vminvq_f32(vminv);
1546
+ uint32_t imin;
1547
+ {
1548
+ const float32x4_t vminy = vdupq_n_f32(vmin);
1549
+ const uint32x4_t equals = vceqq_f32(vminv, vminy);
1550
+ imin = vminvq_u32(vorrq_u32(
1551
+ vandq_u32(equals, iminv),
1552
+ vandq_u32(
1553
+ vmvnq_u32(equals),
1554
+ vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
1555
+ }
1556
+ for (; i < n; ++i) {
1557
+ c[i] = a[i] + bf * b[i];
1558
+ if (c[i] < vmin) {
1559
+ vmin = c[i];
1560
+ imin = static_cast<uint32_t>(i);
1561
+ }
1562
+ }
1563
+ return static_cast<int>(imin);
1564
+ }
1565
+
845
1566
  #else
846
1567
 
847
1568
  int fvec_madd_and_argmin(
@@ -973,4 +1694,53 @@ void compute_PQ_dis_tables_dsub2(
973
1694
  }
974
1695
  }
975
1696
 
1697
+ /*********************************************************
1698
+ * Vector to vector functions
1699
+ *********************************************************/
1700
+
1701
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
1702
+ size_t i;
1703
+ for (i = 0; i + 7 < d; i += 8) {
1704
+ simd8float32 ci, ai, bi;
1705
+ ai.loadu(a + i);
1706
+ bi.loadu(b + i);
1707
+ ci = ai - bi;
1708
+ ci.storeu(c + i);
1709
+ }
1710
+ // finish non-multiple of 8 remainder
1711
+ for (; i < d; i++) {
1712
+ c[i] = a[i] - b[i];
1713
+ }
1714
+ }
1715
+
1716
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
1717
+ size_t i;
1718
+ for (i = 0; i + 7 < d; i += 8) {
1719
+ simd8float32 ci, ai, bi;
1720
+ ai.loadu(a + i);
1721
+ bi.loadu(b + i);
1722
+ ci = ai + bi;
1723
+ ci.storeu(c + i);
1724
+ }
1725
+ // finish non-multiple of 8 remainder
1726
+ for (; i < d; i++) {
1727
+ c[i] = a[i] + b[i];
1728
+ }
1729
+ }
1730
+
1731
+ void fvec_add(size_t d, const float* a, float b, float* c) {
1732
+ size_t i;
1733
+ simd8float32 bv(b);
1734
+ for (i = 0; i + 7 < d; i += 8) {
1735
+ simd8float32 ci, ai, bi;
1736
+ ai.loadu(a + i);
1737
+ ci = ai + bv;
1738
+ ci.storeu(c + i);
1739
+ }
1740
+ // finish non-multiple of 8 remainder
1741
+ for (; i < d; i++) {
1742
+ c[i] = a[i] + b;
1743
+ }
1744
+ }
1745
+
976
1746
  } // namespace faiss