faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -113,6 +113,74 @@ void fvec_L2sqr_ny_ref(
113
113
  }
114
114
  }
115
115
 
116
+ void fvec_L2sqr_ny_y_transposed_ref(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const float* y_sqlen,
121
+ size_t d,
122
+ size_t d_offset,
123
+ size_t ny) {
124
+ float x_sqlen = 0;
125
+ for (size_t j = 0; j < d; j++) {
126
+ x_sqlen += x[j] * x[j];
127
+ }
128
+
129
+ for (size_t i = 0; i < ny; i++) {
130
+ float dp = 0;
131
+ for (size_t j = 0; j < d; j++) {
132
+ dp += x[j] * y[i + j * d_offset];
133
+ }
134
+
135
+ dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
136
+ }
137
+ }
138
+
139
+ size_t fvec_L2sqr_ny_nearest_ref(
140
+ float* distances_tmp_buffer,
141
+ const float* x,
142
+ const float* y,
143
+ size_t d,
144
+ size_t ny) {
145
+ fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
146
+
147
+ size_t nearest_idx = 0;
148
+ float min_dis = HUGE_VALF;
149
+
150
+ for (size_t i = 0; i < ny; i++) {
151
+ if (distances_tmp_buffer[i] < min_dis) {
152
+ min_dis = distances_tmp_buffer[i];
153
+ nearest_idx = i;
154
+ }
155
+ }
156
+
157
+ return nearest_idx;
158
+ }
159
+
160
+ size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
161
+ float* distances_tmp_buffer,
162
+ const float* x,
163
+ const float* y,
164
+ const float* y_sqlen,
165
+ size_t d,
166
+ size_t d_offset,
167
+ size_t ny) {
168
+ fvec_L2sqr_ny_y_transposed_ref(
169
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
170
+
171
+ size_t nearest_idx = 0;
172
+ float min_dis = HUGE_VALF;
173
+
174
+ for (size_t i = 0; i < ny; i++) {
175
+ if (distances_tmp_buffer[i] < min_dis) {
176
+ min_dis = distances_tmp_buffer[i];
177
+ nearest_idx = i;
178
+ }
179
+ }
180
+
181
+ return nearest_idx;
182
+ }
183
+
116
184
  void fvec_inner_products_ny_ref(
117
185
  float* ip,
118
186
  const float* x,
@@ -258,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
258
326
  }
259
327
  }
260
328
 
329
+ #ifdef __AVX2__
330
+
331
+ // Specialized versions for AVX2 for any CPUs that support gather/scatter.
332
+ // Todo: implement fvec_op_ny_Dxxx in the same way.
333
+
334
+ template <>
335
+ void fvec_op_ny_D4<ElementOpIP>(
336
+ float* dis,
337
+ const float* x,
338
+ const float* y,
339
+ size_t ny) {
340
+ const size_t ny8 = ny / 8;
341
+ size_t i = 0;
342
+
343
+ if (ny8 > 0) {
344
+ // process 8 D4-vectors per loop.
345
+ _mm_prefetch(y, _MM_HINT_NTA);
346
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
347
+
348
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
349
+ const __m256 m0 = _mm256_set1_ps(x[0]);
350
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
351
+ const __m256 m1 = _mm256_set1_ps(x[1]);
352
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
353
+ const __m256 m2 = _mm256_set1_ps(x[2]);
354
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
+ const __m256 m3 = _mm256_set1_ps(x[3]);
356
+
357
+ const __m256i indices0 =
358
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
+
360
+ for (i = 0; i < ny8 * 8; i += 8) {
361
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
362
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
363
+
364
+ // collect dim 0 for 8 D4-vectors.
365
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
+ // collect dim 1 for 8 D4-vectors.
368
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
+ // collect dim 2 for 8 D4-vectors.
371
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
+ // collect dim 3 for 8 D4-vectors.
374
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
376
+
377
+ // compute distances
378
+ __m256 distances = _mm256_mul_ps(m0, v0);
379
+ distances = _mm256_fmadd_ps(m1, v1, distances);
380
+ distances = _mm256_fmadd_ps(m2, v2, distances);
381
+ distances = _mm256_fmadd_ps(m3, v3, distances);
382
+
383
+ // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
+ // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
+ // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
+ // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
+ // ...
388
+ // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
+ // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
+ // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
+ // (x[3] * y[(i * 8 + 7) * 4 + 3])
392
+ _mm256_storeu_ps(dis + i, distances);
393
+
394
+ y += 32;
395
+ }
396
+ }
397
+
398
+ if (i < ny) {
399
+ // process leftovers
400
+ __m128 x0 = _mm_loadu_ps(x);
401
+
402
+ for (; i < ny; i++) {
403
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
404
+ y += 4;
405
+ accu = _mm_hadd_ps(accu, accu);
406
+ accu = _mm_hadd_ps(accu, accu);
407
+ dis[i] = _mm_cvtss_f32(accu);
408
+ }
409
+ }
410
+ }
411
+
412
+ template <>
413
+ void fvec_op_ny_D4<ElementOpL2>(
414
+ float* dis,
415
+ const float* x,
416
+ const float* y,
417
+ size_t ny) {
418
+ const size_t ny8 = ny / 8;
419
+ size_t i = 0;
420
+
421
+ if (ny8 > 0) {
422
+ // process 8 D4-vectors per loop.
423
+ _mm_prefetch(y, _MM_HINT_NTA);
424
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
425
+
426
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
427
+ const __m256 m0 = _mm256_set1_ps(x[0]);
428
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
429
+ const __m256 m1 = _mm256_set1_ps(x[1]);
430
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
431
+ const __m256 m2 = _mm256_set1_ps(x[2]);
432
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
+ const __m256 m3 = _mm256_set1_ps(x[3]);
434
+
435
+ const __m256i indices0 =
436
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
+
438
+ for (i = 0; i < ny8 * 8; i += 8) {
439
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
440
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
441
+
442
+ // collect dim 0 for 8 D4-vectors.
443
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
+ // collect dim 1 for 8 D4-vectors.
446
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
+ // collect dim 2 for 8 D4-vectors.
449
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
+ // collect dim 3 for 8 D4-vectors.
452
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
454
+
455
+ // compute differences
456
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
457
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
458
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
459
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
460
+
461
+ // compute squares of differences
462
+ __m256 distances = _mm256_mul_ps(d0, d0);
463
+ distances = _mm256_fmadd_ps(d1, d1, distances);
464
+ distances = _mm256_fmadd_ps(d2, d2, distances);
465
+ distances = _mm256_fmadd_ps(d3, d3, distances);
466
+
467
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
+ // ...
472
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
476
+ _mm256_storeu_ps(dis + i, distances);
477
+
478
+ y += 32;
479
+ }
480
+ }
481
+
482
+ if (i < ny) {
483
+ // process leftovers
484
+ __m128 x0 = _mm_loadu_ps(x);
485
+
486
+ for (; i < ny; i++) {
487
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
488
+ y += 4;
489
+ accu = _mm_hadd_ps(accu, accu);
490
+ accu = _mm_hadd_ps(accu, accu);
491
+ dis[i] = _mm_cvtss_f32(accu);
492
+ }
493
+ }
494
+ }
495
+
496
+ #endif
497
+
261
498
  template <class ElementOp>
262
499
  void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
263
500
  __m128 x0 = _mm_loadu_ps(x);
@@ -345,6 +582,324 @@ void fvec_inner_products_ny(
345
582
  #undef DISPATCH
346
583
  }
347
584
 
585
+ #ifdef __AVX2__
586
+ size_t fvec_L2sqr_ny_nearest_D4(
587
+ float* distances_tmp_buffer,
588
+ const float* x,
589
+ const float* y,
590
+ size_t ny) {
591
+ // this implementation does not use distances_tmp_buffer.
592
+
593
+ // current index being processed
594
+ size_t i = 0;
595
+
596
+ // min distance and the index of the closest vector so far
597
+ float current_min_distance = HUGE_VALF;
598
+ size_t current_min_index = 0;
599
+
600
+ // process 8 D4-vectors per loop.
601
+ const size_t ny8 = ny / 8;
602
+
603
+ if (ny8 > 0) {
604
+ // track min distance and the closest vector independently
605
+ // for each of 8 AVX2 components.
606
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
607
+ __m256i min_indices = _mm256_set1_epi32(0);
608
+
609
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
+ const __m256i indices_increment = _mm256_set1_epi32(8);
611
+
612
+ //
613
+ _mm_prefetch(y, _MM_HINT_NTA);
614
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
615
+
616
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
617
+ const __m256 m0 = _mm256_set1_ps(x[0]);
618
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
+ const __m256 m1 = _mm256_set1_ps(x[1]);
620
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
+ const __m256 m2 = _mm256_set1_ps(x[2]);
622
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
+ const __m256 m3 = _mm256_set1_ps(x[3]);
624
+
625
+ const __m256i indices0 =
626
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
+
628
+ for (; i < ny8 * 8; i += 8) {
629
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
630
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
631
+
632
+ // collect dim 0 for 8 D4-vectors.
633
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
+ // collect dim 1 for 8 D4-vectors.
636
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
+ // collect dim 2 for 8 D4-vectors.
639
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
+ // collect dim 3 for 8 D4-vectors.
642
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
644
+
645
+ // compute differences
646
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
647
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
648
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
649
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
650
+
651
+ // compute squares of differences
652
+ __m256 distances = _mm256_mul_ps(d0, d0);
653
+ distances = _mm256_fmadd_ps(d1, d1, distances);
654
+ distances = _mm256_fmadd_ps(d2, d2, distances);
655
+ distances = _mm256_fmadd_ps(d3, d3, distances);
656
+
657
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
+ // ...
662
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
+
667
+ // compare the new distances to the min distances
668
+ // for each of 8 AVX2 components.
669
+ __m256 comparison =
670
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
+
672
+ // update min distances and indices with closest vectors if needed.
673
+ min_distances =
674
+ _mm256_blendv_ps(distances, min_distances, comparison);
675
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
+ _mm256_castsi256_ps(current_indices),
677
+ _mm256_castsi256_ps(min_indices),
678
+ comparison));
679
+
680
+ // update current indices values. Basically, +8 to each of the
681
+ // 8 AVX2 components.
682
+ current_indices =
683
+ _mm256_add_epi32(current_indices, indices_increment);
684
+
685
+ // scroll y forward (8 vectors 4 DIM each).
686
+ y += 32;
687
+ }
688
+
689
+ // dump values and find the minimum distance / minimum index
690
+ float min_distances_scalar[8];
691
+ uint32_t min_indices_scalar[8];
692
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
693
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
694
+
695
+ for (size_t j = 0; j < 8; j++) {
696
+ if (current_min_distance > min_distances_scalar[j]) {
697
+ current_min_distance = min_distances_scalar[j];
698
+ current_min_index = min_indices_scalar[j];
699
+ }
700
+ }
701
+ }
702
+
703
+ if (i < ny) {
704
+ // process leftovers
705
+ __m128 x0 = _mm_loadu_ps(x);
706
+
707
+ for (; i < ny; i++) {
708
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
709
+ y += 4;
710
+ accu = _mm_hadd_ps(accu, accu);
711
+ accu = _mm_hadd_ps(accu, accu);
712
+
713
+ const auto distance = _mm_cvtss_f32(accu);
714
+
715
+ if (current_min_distance > distance) {
716
+ current_min_distance = distance;
717
+ current_min_index = i;
718
+ }
719
+ }
720
+ }
721
+
722
+ return current_min_index;
723
+ }
724
+ #else
725
+ size_t fvec_L2sqr_ny_nearest_D4(
726
+ float* distances_tmp_buffer,
727
+ const float* x,
728
+ const float* y,
729
+ size_t ny) {
730
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
+ }
732
+ #endif
733
+
734
+ size_t fvec_L2sqr_ny_nearest(
735
+ float* distances_tmp_buffer,
736
+ const float* x,
737
+ const float* y,
738
+ size_t d,
739
+ size_t ny) {
740
+ // optimized for a few special cases
741
+ #define DISPATCH(dval) \
742
+ case dval: \
743
+ return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
+
745
+ switch (d) {
746
+ DISPATCH(4)
747
+ default:
748
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
+ }
750
+ #undef DISPATCH
751
+ }
752
+
753
+ #ifdef __AVX2__
754
+ template <size_t DIM>
755
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
756
+ float* distances_tmp_buffer,
757
+ const float* x,
758
+ const float* y,
759
+ const float* y_sqlen,
760
+ const size_t d_offset,
761
+ size_t ny) {
762
+ // this implementation does not use distances_tmp_buffer.
763
+
764
+ // current index being processed
765
+ size_t i = 0;
766
+
767
+ // min distance and the index of the closest vector so far
768
+ float current_min_distance = HUGE_VALF;
769
+ size_t current_min_index = 0;
770
+
771
+ // process 8 vectors per loop.
772
+ const size_t ny8 = ny / 8;
773
+
774
+ if (ny8 > 0) {
775
+ // track min distance and the closest vector independently
776
+ // for each of 8 AVX2 components.
777
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
778
+ __m256i min_indices = _mm256_set1_epi32(0);
779
+
780
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
781
+ const __m256i indices_increment = _mm256_set1_epi32(8);
782
+
783
+ // m[i] = (2 * x[i], ... 2 * x[i])
784
+ __m256 m[DIM];
785
+ for (size_t j = 0; j < DIM; j++) {
786
+ m[j] = _mm256_set1_ps(x[j]);
787
+ m[j] = _mm256_add_ps(m[j], m[j]);
788
+ }
789
+
790
+ for (; i < ny8 * 8; i += 8) {
791
+ // collect dim 0 for 8 D4-vectors.
792
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
793
+ // compute dot products
794
+ __m256 dp = _mm256_mul_ps(m[0], v0);
795
+
796
+ for (size_t j = 1; j < DIM; j++) {
797
+ // collect dim j for 8 D4-vectors.
798
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
799
+ dp = _mm256_fmadd_ps(m[j], vj, dp);
800
+ }
801
+
802
+ // compute y^2 - (2 * x, y), which is sufficient for looking for the
803
+ // lowest distance.
804
+ // x^2 is the constant that can be avoided.
805
+ const __m256 distances =
806
+ _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
807
+
808
+ // compare the new distances to the min distances
809
+ // for each of 8 AVX2 components.
810
+ const __m256 comparison =
811
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
812
+
813
+ // update min distances and indices with closest vectors if needed.
814
+ min_distances =
815
+ _mm256_blendv_ps(distances, min_distances, comparison);
816
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
817
+ _mm256_castsi256_ps(current_indices),
818
+ _mm256_castsi256_ps(min_indices),
819
+ comparison));
820
+
821
+ // update current indices values. Basically, +8 to each of the
822
+ // 8 AVX2 components.
823
+ current_indices =
824
+ _mm256_add_epi32(current_indices, indices_increment);
825
+
826
+ // scroll y and y_sqlen forward.
827
+ y += 8;
828
+ y_sqlen += 8;
829
+ }
830
+
831
+ // dump values and find the minimum distance / minimum index
832
+ float min_distances_scalar[8];
833
+ uint32_t min_indices_scalar[8];
834
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
835
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
836
+
837
+ for (size_t j = 0; j < 8; j++) {
838
+ if (current_min_distance > min_distances_scalar[j]) {
839
+ current_min_distance = min_distances_scalar[j];
840
+ current_min_index = min_indices_scalar[j];
841
+ }
842
+ }
843
+ }
844
+
845
+ if (i < ny) {
846
+ // process leftovers
847
+ for (; i < ny; i++) {
848
+ float dp = 0;
849
+ for (size_t j = 0; j < DIM; j++) {
850
+ dp += x[j] * y[j * d_offset];
851
+ }
852
+
853
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
854
+ // lowest distance.
855
+ const float distance = y_sqlen[0] - 2 * dp;
856
+
857
+ if (current_min_distance > distance) {
858
+ current_min_distance = distance;
859
+ current_min_index = i;
860
+ }
861
+
862
+ y += 1;
863
+ y_sqlen += 1;
864
+ }
865
+ }
866
+
867
+ return current_min_index;
868
+ }
869
+ #endif
870
+
871
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
872
+ float* distances_tmp_buffer,
873
+ const float* x,
874
+ const float* y,
875
+ const float* y_sqlen,
876
+ size_t d,
877
+ size_t d_offset,
878
+ size_t ny) {
879
+ // optimized for a few special cases
880
+ #ifdef __AVX2__
881
+ #define DISPATCH(dval) \
882
+ case dval: \
883
+ return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
884
+ distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
885
+
886
+ switch (d) {
887
+ DISPATCH(1)
888
+ DISPATCH(2)
889
+ DISPATCH(4)
890
+ DISPATCH(8)
891
+ default:
892
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
893
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
894
+ }
895
+ #undef DISPATCH
896
+ #else
897
+ // non-AVX2 case
898
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
899
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
900
+ #endif
901
+ }
902
+
348
903
  #endif
349
904
 
350
905
  #ifdef USE_AVX
@@ -590,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
590
1145
  float32x4_t sq = vsubq_f32(xi, yi);
591
1146
  accux4 = vfmaq_f32(accux4, sq, sq);
592
1147
  }
593
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
594
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1148
+ float32_t accux1 = vaddvq_f32(accux4);
595
1149
  for (; i < d; ++i) {
596
1150
  float32_t xi = x[i];
597
1151
  float32_t yi = y[i];
@@ -610,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
610
1164
  float32x4_t yi = vld1q_f32(y + i);
611
1165
  accux4 = vfmaq_f32(accux4, xi, yi);
612
1166
  }
613
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
614
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1167
+ float32_t accux1 = vaddvq_f32(accux4);
615
1168
  for (; i < d; ++i) {
616
1169
  float32_t xi = x[i];
617
1170
  float32_t yi = y[i];
@@ -628,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
628
1181
  float32x4_t xi = vld1q_f32(x + i);
629
1182
  accux4 = vfmaq_f32(accux4, xi, xi);
630
1183
  }
631
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
632
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1184
+ float32_t accux1 = vaddvq_f32(accux4);
633
1185
  for (; i < d; ++i) {
634
1186
  float32_t xi = x[i];
635
1187
  accux1 += xi * xi;
@@ -647,6 +1199,27 @@ void fvec_L2sqr_ny(
647
1199
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
648
1200
  }
649
1201
 
1202
+ size_t fvec_L2sqr_ny_nearest(
1203
+ float* distances_tmp_buffer,
1204
+ const float* x,
1205
+ const float* y,
1206
+ size_t d,
1207
+ size_t ny) {
1208
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1209
+ }
1210
+
1211
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1212
+ float* distances_tmp_buffer,
1213
+ const float* x,
1214
+ const float* y,
1215
+ const float* y_sqlen,
1216
+ size_t d,
1217
+ size_t d_offset,
1218
+ size_t ny) {
1219
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1220
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1221
+ }
1222
+
650
1223
  float fvec_L1(const float* x, const float* y, size_t d) {
651
1224
  return fvec_L1_ref(x, y, d);
652
1225
  }
@@ -696,6 +1269,27 @@ void fvec_L2sqr_ny(
696
1269
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
697
1270
  }
698
1271
 
1272
+ size_t fvec_L2sqr_ny_nearest(
1273
+ float* distances_tmp_buffer,
1274
+ const float* x,
1275
+ const float* y,
1276
+ size_t d,
1277
+ size_t ny) {
1278
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1279
+ }
1280
+
1281
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1282
+ float* distances_tmp_buffer,
1283
+ const float* x,
1284
+ const float* y,
1285
+ const float* y_sqlen,
1286
+ size_t d,
1287
+ size_t d_offset,
1288
+ size_t ny) {
1289
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1290
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1291
+ }
1292
+
699
1293
  void fvec_inner_products_ny(
700
1294
  float* dis,
701
1295
  const float* x,
@@ -721,6 +1315,61 @@ static inline void fvec_madd_ref(
721
1315
  c[i] = a[i] + bf * b[i];
722
1316
  }
723
1317
 
1318
+ #ifdef __AVX2__
1319
+ static inline void fvec_madd_avx2(
1320
+ const size_t n,
1321
+ const float* __restrict a,
1322
+ const float bf,
1323
+ const float* __restrict b,
1324
+ float* __restrict c) {
1325
+ //
1326
+ const size_t n8 = n / 8;
1327
+ const size_t n_for_masking = n % 8;
1328
+
1329
+ const __m256 bfmm = _mm256_set1_ps(bf);
1330
+
1331
+ size_t idx = 0;
1332
+ for (idx = 0; idx < n8 * 8; idx += 8) {
1333
+ const __m256 ax = _mm256_loadu_ps(a + idx);
1334
+ const __m256 bx = _mm256_loadu_ps(b + idx);
1335
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1336
+ _mm256_storeu_ps(c + idx, abmul);
1337
+ }
1338
+
1339
+ if (n_for_masking > 0) {
1340
+ __m256i mask;
1341
+ switch (n_for_masking) {
1342
+ case 1:
1343
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
1344
+ break;
1345
+ case 2:
1346
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
1347
+ break;
1348
+ case 3:
1349
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
1350
+ break;
1351
+ case 4:
1352
+ mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
1353
+ break;
1354
+ case 5:
1355
+ mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
1356
+ break;
1357
+ case 6:
1358
+ mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
1359
+ break;
1360
+ case 7:
1361
+ mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
1362
+ break;
1363
+ }
1364
+
1365
+ const __m256 ax = _mm256_maskload_ps(a + idx, mask);
1366
+ const __m256 bx = _mm256_maskload_ps(b + idx, mask);
1367
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1368
+ _mm256_maskstore_ps(c + idx, mask, abmul);
1369
+ }
1370
+ }
1371
+ #endif
1372
+
724
1373
  #ifdef __SSE3__
725
1374
 
726
1375
  static inline void fvec_madd_sse(
@@ -744,10 +1393,30 @@ static inline void fvec_madd_sse(
744
1393
  }
745
1394
 
746
1395
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1396
+ #ifdef __AVX2__
1397
+ fvec_madd_avx2(n, a, bf, b, c);
1398
+ #else
747
1399
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
748
1400
  fvec_madd_sse(n, a, bf, b, c);
749
1401
  else
750
1402
  fvec_madd_ref(n, a, bf, b, c);
1403
+ #endif
1404
+ }
1405
+
1406
+ #elif defined(__aarch64__)
1407
+
1408
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1409
+ const size_t n_simd = n - (n & 3);
1410
+ const float32x4_t bfv = vdupq_n_f32(bf);
1411
+ size_t i;
1412
+ for (i = 0; i < n_simd; i += 4) {
1413
+ const float32x4_t ai = vld1q_f32(a + i);
1414
+ const float32x4_t bi = vld1q_f32(b + i);
1415
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1416
+ vst1q_f32(c + i, ci);
1417
+ }
1418
+ for (; i < n; ++i)
1419
+ c[i] = a[i] + bf * b[i];
751
1420
  }
752
1421
 
753
1422
  #else
@@ -843,6 +1512,57 @@ int fvec_madd_and_argmin(
843
1512
  return fvec_madd_and_argmin_ref(n, a, bf, b, c);
844
1513
  }
845
1514
 
1515
+ #elif defined(__aarch64__)
1516
+
1517
+ int fvec_madd_and_argmin(
1518
+ size_t n,
1519
+ const float* a,
1520
+ float bf,
1521
+ const float* b,
1522
+ float* c) {
1523
+ float32x4_t vminv = vdupq_n_f32(1e20);
1524
+ uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
1525
+ size_t i;
1526
+ {
1527
+ const size_t n_simd = n - (n & 3);
1528
+ const uint32_t iota[] = {0, 1, 2, 3};
1529
+ uint32x4_t iv = vld1q_u32(iota);
1530
+ const uint32x4_t incv = vdupq_n_u32(4);
1531
+ const float32x4_t bfv = vdupq_n_f32(bf);
1532
+ for (i = 0; i < n_simd; i += 4) {
1533
+ const float32x4_t ai = vld1q_f32(a + i);
1534
+ const float32x4_t bi = vld1q_f32(b + i);
1535
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1536
+ vst1q_f32(c + i, ci);
1537
+ const uint32x4_t less_than = vcltq_f32(ci, vminv);
1538
+ vminv = vminq_f32(ci, vminv);
1539
+ iminv = vorrq_u32(
1540
+ vandq_u32(less_than, iv),
1541
+ vandq_u32(vmvnq_u32(less_than), iminv));
1542
+ iv = vaddq_u32(iv, incv);
1543
+ }
1544
+ }
1545
+ float vmin = vminvq_f32(vminv);
1546
+ uint32_t imin;
1547
+ {
1548
+ const float32x4_t vminy = vdupq_n_f32(vmin);
1549
+ const uint32x4_t equals = vceqq_f32(vminv, vminy);
1550
+ imin = vminvq_u32(vorrq_u32(
1551
+ vandq_u32(equals, iminv),
1552
+ vandq_u32(
1553
+ vmvnq_u32(equals),
1554
+ vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
1555
+ }
1556
+ for (; i < n; ++i) {
1557
+ c[i] = a[i] + bf * b[i];
1558
+ if (c[i] < vmin) {
1559
+ vmin = c[i];
1560
+ imin = static_cast<uint32_t>(i);
1561
+ }
1562
+ }
1563
+ return static_cast<int>(imin);
1564
+ }
1565
+
846
1566
  #else
847
1567
 
848
1568
  int fvec_madd_and_argmin(