faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -113,6 +113,74 @@ void fvec_L2sqr_ny_ref(
113
113
  }
114
114
  }
115
115
 
116
+ void fvec_L2sqr_ny_y_transposed_ref(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const float* y_sqlen,
121
+ size_t d,
122
+ size_t d_offset,
123
+ size_t ny) {
124
+ float x_sqlen = 0;
125
+ for (size_t j = 0; j < d; j++) {
126
+ x_sqlen += x[j] * x[j];
127
+ }
128
+
129
+ for (size_t i = 0; i < ny; i++) {
130
+ float dp = 0;
131
+ for (size_t j = 0; j < d; j++) {
132
+ dp += x[j] * y[i + j * d_offset];
133
+ }
134
+
135
+ dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
136
+ }
137
+ }
138
+
139
+ size_t fvec_L2sqr_ny_nearest_ref(
140
+ float* distances_tmp_buffer,
141
+ const float* x,
142
+ const float* y,
143
+ size_t d,
144
+ size_t ny) {
145
+ fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
146
+
147
+ size_t nearest_idx = 0;
148
+ float min_dis = HUGE_VALF;
149
+
150
+ for (size_t i = 0; i < ny; i++) {
151
+ if (distances_tmp_buffer[i] < min_dis) {
152
+ min_dis = distances_tmp_buffer[i];
153
+ nearest_idx = i;
154
+ }
155
+ }
156
+
157
+ return nearest_idx;
158
+ }
159
+
160
+ size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
161
+ float* distances_tmp_buffer,
162
+ const float* x,
163
+ const float* y,
164
+ const float* y_sqlen,
165
+ size_t d,
166
+ size_t d_offset,
167
+ size_t ny) {
168
+ fvec_L2sqr_ny_y_transposed_ref(
169
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
170
+
171
+ size_t nearest_idx = 0;
172
+ float min_dis = HUGE_VALF;
173
+
174
+ for (size_t i = 0; i < ny; i++) {
175
+ if (distances_tmp_buffer[i] < min_dis) {
176
+ min_dis = distances_tmp_buffer[i];
177
+ nearest_idx = i;
178
+ }
179
+ }
180
+
181
+ return nearest_idx;
182
+ }
183
+
116
184
  void fvec_inner_products_ny_ref(
117
185
  float* ip,
118
186
  const float* x,
@@ -258,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
258
326
  }
259
327
  }
260
328
 
329
+ #ifdef __AVX2__
330
+
331
+ // Specialized versions for AVX2 for any CPUs that support gather/scatter.
332
+ // Todo: implement fvec_op_ny_Dxxx in the same way.
333
+
334
+ template <>
335
+ void fvec_op_ny_D4<ElementOpIP>(
336
+ float* dis,
337
+ const float* x,
338
+ const float* y,
339
+ size_t ny) {
340
+ const size_t ny8 = ny / 8;
341
+ size_t i = 0;
342
+
343
+ if (ny8 > 0) {
344
+ // process 8 D4-vectors per loop.
345
+ _mm_prefetch(y, _MM_HINT_NTA);
346
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
347
+
348
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
349
+ const __m256 m0 = _mm256_set1_ps(x[0]);
350
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
351
+ const __m256 m1 = _mm256_set1_ps(x[1]);
352
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
353
+ const __m256 m2 = _mm256_set1_ps(x[2]);
354
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
+ const __m256 m3 = _mm256_set1_ps(x[3]);
356
+
357
+ const __m256i indices0 =
358
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
+
360
+ for (i = 0; i < ny8 * 8; i += 8) {
361
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
362
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
363
+
364
+ // collect dim 0 for 8 D4-vectors.
365
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
+ // collect dim 1 for 8 D4-vectors.
368
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
+ // collect dim 2 for 8 D4-vectors.
371
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
+ // collect dim 3 for 8 D4-vectors.
374
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
376
+
377
+ // compute distances
378
+ __m256 distances = _mm256_mul_ps(m0, v0);
379
+ distances = _mm256_fmadd_ps(m1, v1, distances);
380
+ distances = _mm256_fmadd_ps(m2, v2, distances);
381
+ distances = _mm256_fmadd_ps(m3, v3, distances);
382
+
383
+ // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
+ // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
+ // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
+ // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
+ // ...
388
+ // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
+ // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
+ // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
+ // (x[3] * y[(i * 8 + 7) * 4 + 3])
392
+ _mm256_storeu_ps(dis + i, distances);
393
+
394
+ y += 32;
395
+ }
396
+ }
397
+
398
+ if (i < ny) {
399
+ // process leftovers
400
+ __m128 x0 = _mm_loadu_ps(x);
401
+
402
+ for (; i < ny; i++) {
403
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
404
+ y += 4;
405
+ accu = _mm_hadd_ps(accu, accu);
406
+ accu = _mm_hadd_ps(accu, accu);
407
+ dis[i] = _mm_cvtss_f32(accu);
408
+ }
409
+ }
410
+ }
411
+
412
+ template <>
413
+ void fvec_op_ny_D4<ElementOpL2>(
414
+ float* dis,
415
+ const float* x,
416
+ const float* y,
417
+ size_t ny) {
418
+ const size_t ny8 = ny / 8;
419
+ size_t i = 0;
420
+
421
+ if (ny8 > 0) {
422
+ // process 8 D4-vectors per loop.
423
+ _mm_prefetch(y, _MM_HINT_NTA);
424
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
425
+
426
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
427
+ const __m256 m0 = _mm256_set1_ps(x[0]);
428
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
429
+ const __m256 m1 = _mm256_set1_ps(x[1]);
430
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
431
+ const __m256 m2 = _mm256_set1_ps(x[2]);
432
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
+ const __m256 m3 = _mm256_set1_ps(x[3]);
434
+
435
+ const __m256i indices0 =
436
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
+
438
+ for (i = 0; i < ny8 * 8; i += 8) {
439
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
440
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
441
+
442
+ // collect dim 0 for 8 D4-vectors.
443
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
+ // collect dim 1 for 8 D4-vectors.
446
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
+ // collect dim 2 for 8 D4-vectors.
449
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
+ // collect dim 3 for 8 D4-vectors.
452
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
454
+
455
+ // compute differences
456
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
457
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
458
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
459
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
460
+
461
+ // compute squares of differences
462
+ __m256 distances = _mm256_mul_ps(d0, d0);
463
+ distances = _mm256_fmadd_ps(d1, d1, distances);
464
+ distances = _mm256_fmadd_ps(d2, d2, distances);
465
+ distances = _mm256_fmadd_ps(d3, d3, distances);
466
+
467
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
+ // ...
472
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
476
+ _mm256_storeu_ps(dis + i, distances);
477
+
478
+ y += 32;
479
+ }
480
+ }
481
+
482
+ if (i < ny) {
483
+ // process leftovers
484
+ __m128 x0 = _mm_loadu_ps(x);
485
+
486
+ for (; i < ny; i++) {
487
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
488
+ y += 4;
489
+ accu = _mm_hadd_ps(accu, accu);
490
+ accu = _mm_hadd_ps(accu, accu);
491
+ dis[i] = _mm_cvtss_f32(accu);
492
+ }
493
+ }
494
+ }
495
+
496
+ #endif
497
+
261
498
  template <class ElementOp>
262
499
  void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
263
500
  __m128 x0 = _mm_loadu_ps(x);
@@ -345,6 +582,324 @@ void fvec_inner_products_ny(
345
582
  #undef DISPATCH
346
583
  }
347
584
 
585
+ #ifdef __AVX2__
586
+ size_t fvec_L2sqr_ny_nearest_D4(
587
+ float* distances_tmp_buffer,
588
+ const float* x,
589
+ const float* y,
590
+ size_t ny) {
591
+ // this implementation does not use distances_tmp_buffer.
592
+
593
+ // current index being processed
594
+ size_t i = 0;
595
+
596
+ // min distance and the index of the closest vector so far
597
+ float current_min_distance = HUGE_VALF;
598
+ size_t current_min_index = 0;
599
+
600
+ // process 8 D4-vectors per loop.
601
+ const size_t ny8 = ny / 8;
602
+
603
+ if (ny8 > 0) {
604
+ // track min distance and the closest vector independently
605
+ // for each of 8 AVX2 components.
606
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
607
+ __m256i min_indices = _mm256_set1_epi32(0);
608
+
609
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
+ const __m256i indices_increment = _mm256_set1_epi32(8);
611
+
612
+ //
613
+ _mm_prefetch(y, _MM_HINT_NTA);
614
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
615
+
616
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
617
+ const __m256 m0 = _mm256_set1_ps(x[0]);
618
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
+ const __m256 m1 = _mm256_set1_ps(x[1]);
620
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
+ const __m256 m2 = _mm256_set1_ps(x[2]);
622
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
+ const __m256 m3 = _mm256_set1_ps(x[3]);
624
+
625
+ const __m256i indices0 =
626
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
+
628
+ for (; i < ny8 * 8; i += 8) {
629
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
630
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
631
+
632
+ // collect dim 0 for 8 D4-vectors.
633
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
+ // collect dim 1 for 8 D4-vectors.
636
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
+ // collect dim 2 for 8 D4-vectors.
639
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
+ // collect dim 3 for 8 D4-vectors.
642
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
644
+
645
+ // compute differences
646
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
647
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
648
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
649
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
650
+
651
+ // compute squares of differences
652
+ __m256 distances = _mm256_mul_ps(d0, d0);
653
+ distances = _mm256_fmadd_ps(d1, d1, distances);
654
+ distances = _mm256_fmadd_ps(d2, d2, distances);
655
+ distances = _mm256_fmadd_ps(d3, d3, distances);
656
+
657
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
+ // ...
662
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
+
667
+ // compare the new distances to the min distances
668
+ // for each of 8 AVX2 components.
669
+ __m256 comparison =
670
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
+
672
+ // update min distances and indices with closest vectors if needed.
673
+ min_distances =
674
+ _mm256_blendv_ps(distances, min_distances, comparison);
675
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
+ _mm256_castsi256_ps(current_indices),
677
+ _mm256_castsi256_ps(min_indices),
678
+ comparison));
679
+
680
+ // update current indices values. Basically, +8 to each of the
681
+ // 8 AVX2 components.
682
+ current_indices =
683
+ _mm256_add_epi32(current_indices, indices_increment);
684
+
685
+ // scroll y forward (8 vectors 4 DIM each).
686
+ y += 32;
687
+ }
688
+
689
+ // dump values and find the minimum distance / minimum index
690
+ float min_distances_scalar[8];
691
+ uint32_t min_indices_scalar[8];
692
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
693
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
694
+
695
+ for (size_t j = 0; j < 8; j++) {
696
+ if (current_min_distance > min_distances_scalar[j]) {
697
+ current_min_distance = min_distances_scalar[j];
698
+ current_min_index = min_indices_scalar[j];
699
+ }
700
+ }
701
+ }
702
+
703
+ if (i < ny) {
704
+ // process leftovers
705
+ __m128 x0 = _mm_loadu_ps(x);
706
+
707
+ for (; i < ny; i++) {
708
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
709
+ y += 4;
710
+ accu = _mm_hadd_ps(accu, accu);
711
+ accu = _mm_hadd_ps(accu, accu);
712
+
713
+ const auto distance = _mm_cvtss_f32(accu);
714
+
715
+ if (current_min_distance > distance) {
716
+ current_min_distance = distance;
717
+ current_min_index = i;
718
+ }
719
+ }
720
+ }
721
+
722
+ return current_min_index;
723
+ }
724
+ #else
725
+ size_t fvec_L2sqr_ny_nearest_D4(
726
+ float* distances_tmp_buffer,
727
+ const float* x,
728
+ const float* y,
729
+ size_t ny) {
730
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
+ }
732
+ #endif
733
+
734
+ size_t fvec_L2sqr_ny_nearest(
735
+ float* distances_tmp_buffer,
736
+ const float* x,
737
+ const float* y,
738
+ size_t d,
739
+ size_t ny) {
740
+ // optimized for a few special cases
741
+ #define DISPATCH(dval) \
742
+ case dval: \
743
+ return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
+
745
+ switch (d) {
746
+ DISPATCH(4)
747
+ default:
748
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
+ }
750
+ #undef DISPATCH
751
+ }
752
+
753
+ #ifdef __AVX2__
754
+ template <size_t DIM>
755
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
756
+ float* distances_tmp_buffer,
757
+ const float* x,
758
+ const float* y,
759
+ const float* y_sqlen,
760
+ const size_t d_offset,
761
+ size_t ny) {
762
+ // this implementation does not use distances_tmp_buffer.
763
+
764
+ // current index being processed
765
+ size_t i = 0;
766
+
767
+ // min distance and the index of the closest vector so far
768
+ float current_min_distance = HUGE_VALF;
769
+ size_t current_min_index = 0;
770
+
771
+ // process 8 vectors per loop.
772
+ const size_t ny8 = ny / 8;
773
+
774
+ if (ny8 > 0) {
775
+ // track min distance and the closest vector independently
776
+ // for each of 8 AVX2 components.
777
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
778
+ __m256i min_indices = _mm256_set1_epi32(0);
779
+
780
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
781
+ const __m256i indices_increment = _mm256_set1_epi32(8);
782
+
783
+ // m[i] = (2 * x[i], ... 2 * x[i])
784
+ __m256 m[DIM];
785
+ for (size_t j = 0; j < DIM; j++) {
786
+ m[j] = _mm256_set1_ps(x[j]);
787
+ m[j] = _mm256_add_ps(m[j], m[j]);
788
+ }
789
+
790
+ for (; i < ny8 * 8; i += 8) {
791
+ // collect dim 0 for 8 D4-vectors.
792
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
793
+ // compute dot products
794
+ __m256 dp = _mm256_mul_ps(m[0], v0);
795
+
796
+ for (size_t j = 1; j < DIM; j++) {
797
+ // collect dim j for 8 D4-vectors.
798
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
799
+ dp = _mm256_fmadd_ps(m[j], vj, dp);
800
+ }
801
+
802
+ // compute y^2 - (2 * x, y), which is sufficient for looking for the
803
+ // lowest distance.
804
+ // x^2 is the constant that can be avoided.
805
+ const __m256 distances =
806
+ _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
807
+
808
+ // compare the new distances to the min distances
809
+ // for each of 8 AVX2 components.
810
+ const __m256 comparison =
811
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
812
+
813
+ // update min distances and indices with closest vectors if needed.
814
+ min_distances =
815
+ _mm256_blendv_ps(distances, min_distances, comparison);
816
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
817
+ _mm256_castsi256_ps(current_indices),
818
+ _mm256_castsi256_ps(min_indices),
819
+ comparison));
820
+
821
+ // update current indices values. Basically, +8 to each of the
822
+ // 8 AVX2 components.
823
+ current_indices =
824
+ _mm256_add_epi32(current_indices, indices_increment);
825
+
826
+ // scroll y and y_sqlen forward.
827
+ y += 8;
828
+ y_sqlen += 8;
829
+ }
830
+
831
+ // dump values and find the minimum distance / minimum index
832
+ float min_distances_scalar[8];
833
+ uint32_t min_indices_scalar[8];
834
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
835
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
836
+
837
+ for (size_t j = 0; j < 8; j++) {
838
+ if (current_min_distance > min_distances_scalar[j]) {
839
+ current_min_distance = min_distances_scalar[j];
840
+ current_min_index = min_indices_scalar[j];
841
+ }
842
+ }
843
+ }
844
+
845
+ if (i < ny) {
846
+ // process leftovers
847
+ for (; i < ny; i++) {
848
+ float dp = 0;
849
+ for (size_t j = 0; j < DIM; j++) {
850
+ dp += x[j] * y[j * d_offset];
851
+ }
852
+
853
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
854
+ // lowest distance.
855
+ const float distance = y_sqlen[0] - 2 * dp;
856
+
857
+ if (current_min_distance > distance) {
858
+ current_min_distance = distance;
859
+ current_min_index = i;
860
+ }
861
+
862
+ y += 1;
863
+ y_sqlen += 1;
864
+ }
865
+ }
866
+
867
+ return current_min_index;
868
+ }
869
+ #endif
870
+
871
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
872
+ float* distances_tmp_buffer,
873
+ const float* x,
874
+ const float* y,
875
+ const float* y_sqlen,
876
+ size_t d,
877
+ size_t d_offset,
878
+ size_t ny) {
879
+ // optimized for a few special cases
880
+ #ifdef __AVX2__
881
+ #define DISPATCH(dval) \
882
+ case dval: \
883
+ return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
884
+ distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
885
+
886
+ switch (d) {
887
+ DISPATCH(1)
888
+ DISPATCH(2)
889
+ DISPATCH(4)
890
+ DISPATCH(8)
891
+ default:
892
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
893
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
894
+ }
895
+ #undef DISPATCH
896
+ #else
897
+ // non-AVX2 case
898
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
899
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
900
+ #endif
901
+ }
902
+
348
903
  #endif
349
904
 
350
905
  #ifdef USE_AVX
@@ -590,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
590
1145
  float32x4_t sq = vsubq_f32(xi, yi);
591
1146
  accux4 = vfmaq_f32(accux4, sq, sq);
592
1147
  }
593
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
594
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1148
+ float32_t accux1 = vaddvq_f32(accux4);
595
1149
  for (; i < d; ++i) {
596
1150
  float32_t xi = x[i];
597
1151
  float32_t yi = y[i];
@@ -610,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
610
1164
  float32x4_t yi = vld1q_f32(y + i);
611
1165
  accux4 = vfmaq_f32(accux4, xi, yi);
612
1166
  }
613
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
614
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1167
+ float32_t accux1 = vaddvq_f32(accux4);
615
1168
  for (; i < d; ++i) {
616
1169
  float32_t xi = x[i];
617
1170
  float32_t yi = y[i];
@@ -628,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
628
1181
  float32x4_t xi = vld1q_f32(x + i);
629
1182
  accux4 = vfmaq_f32(accux4, xi, xi);
630
1183
  }
631
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
632
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1184
+ float32_t accux1 = vaddvq_f32(accux4);
633
1185
  for (; i < d; ++i) {
634
1186
  float32_t xi = x[i];
635
1187
  accux1 += xi * xi;
@@ -647,6 +1199,27 @@ void fvec_L2sqr_ny(
647
1199
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
648
1200
  }
649
1201
 
1202
+ size_t fvec_L2sqr_ny_nearest(
1203
+ float* distances_tmp_buffer,
1204
+ const float* x,
1205
+ const float* y,
1206
+ size_t d,
1207
+ size_t ny) {
1208
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1209
+ }
1210
+
1211
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1212
+ float* distances_tmp_buffer,
1213
+ const float* x,
1214
+ const float* y,
1215
+ const float* y_sqlen,
1216
+ size_t d,
1217
+ size_t d_offset,
1218
+ size_t ny) {
1219
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1220
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1221
+ }
1222
+
650
1223
  float fvec_L1(const float* x, const float* y, size_t d) {
651
1224
  return fvec_L1_ref(x, y, d);
652
1225
  }
@@ -696,6 +1269,27 @@ void fvec_L2sqr_ny(
696
1269
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
697
1270
  }
698
1271
 
1272
+ size_t fvec_L2sqr_ny_nearest(
1273
+ float* distances_tmp_buffer,
1274
+ const float* x,
1275
+ const float* y,
1276
+ size_t d,
1277
+ size_t ny) {
1278
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1279
+ }
1280
+
1281
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1282
+ float* distances_tmp_buffer,
1283
+ const float* x,
1284
+ const float* y,
1285
+ const float* y_sqlen,
1286
+ size_t d,
1287
+ size_t d_offset,
1288
+ size_t ny) {
1289
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1290
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1291
+ }
1292
+
699
1293
  void fvec_inner_products_ny(
700
1294
  float* dis,
701
1295
  const float* x,
@@ -721,6 +1315,61 @@ static inline void fvec_madd_ref(
721
1315
  c[i] = a[i] + bf * b[i];
722
1316
  }
723
1317
 
1318
+ #ifdef __AVX2__
1319
+ static inline void fvec_madd_avx2(
1320
+ const size_t n,
1321
+ const float* __restrict a,
1322
+ const float bf,
1323
+ const float* __restrict b,
1324
+ float* __restrict c) {
1325
+ //
1326
+ const size_t n8 = n / 8;
1327
+ const size_t n_for_masking = n % 8;
1328
+
1329
+ const __m256 bfmm = _mm256_set1_ps(bf);
1330
+
1331
+ size_t idx = 0;
1332
+ for (idx = 0; idx < n8 * 8; idx += 8) {
1333
+ const __m256 ax = _mm256_loadu_ps(a + idx);
1334
+ const __m256 bx = _mm256_loadu_ps(b + idx);
1335
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1336
+ _mm256_storeu_ps(c + idx, abmul);
1337
+ }
1338
+
1339
+ if (n_for_masking > 0) {
1340
+ __m256i mask;
1341
+ switch (n_for_masking) {
1342
+ case 1:
1343
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
1344
+ break;
1345
+ case 2:
1346
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
1347
+ break;
1348
+ case 3:
1349
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
1350
+ break;
1351
+ case 4:
1352
+ mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
1353
+ break;
1354
+ case 5:
1355
+ mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
1356
+ break;
1357
+ case 6:
1358
+ mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
1359
+ break;
1360
+ case 7:
1361
+ mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
1362
+ break;
1363
+ }
1364
+
1365
+ const __m256 ax = _mm256_maskload_ps(a + idx, mask);
1366
+ const __m256 bx = _mm256_maskload_ps(b + idx, mask);
1367
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1368
+ _mm256_maskstore_ps(c + idx, mask, abmul);
1369
+ }
1370
+ }
1371
+ #endif
1372
+
724
1373
  #ifdef __SSE3__
725
1374
 
726
1375
  static inline void fvec_madd_sse(
@@ -744,10 +1393,30 @@ static inline void fvec_madd_sse(
744
1393
  }
745
1394
 
746
1395
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1396
+ #ifdef __AVX2__
1397
+ fvec_madd_avx2(n, a, bf, b, c);
1398
+ #else
747
1399
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
748
1400
  fvec_madd_sse(n, a, bf, b, c);
749
1401
  else
750
1402
  fvec_madd_ref(n, a, bf, b, c);
1403
+ #endif
1404
+ }
1405
+
1406
+ #elif defined(__aarch64__)
1407
+
1408
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1409
+ const size_t n_simd = n - (n & 3);
1410
+ const float32x4_t bfv = vdupq_n_f32(bf);
1411
+ size_t i;
1412
+ for (i = 0; i < n_simd; i += 4) {
1413
+ const float32x4_t ai = vld1q_f32(a + i);
1414
+ const float32x4_t bi = vld1q_f32(b + i);
1415
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1416
+ vst1q_f32(c + i, ci);
1417
+ }
1418
+ for (; i < n; ++i)
1419
+ c[i] = a[i] + bf * b[i];
751
1420
  }
752
1421
 
753
1422
  #else
@@ -843,6 +1512,57 @@ int fvec_madd_and_argmin(
843
1512
  return fvec_madd_and_argmin_ref(n, a, bf, b, c);
844
1513
  }
845
1514
 
1515
+ #elif defined(__aarch64__)
1516
+
1517
+ int fvec_madd_and_argmin(
1518
+ size_t n,
1519
+ const float* a,
1520
+ float bf,
1521
+ const float* b,
1522
+ float* c) {
1523
+ float32x4_t vminv = vdupq_n_f32(1e20);
1524
+ uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
1525
+ size_t i;
1526
+ {
1527
+ const size_t n_simd = n - (n & 3);
1528
+ const uint32_t iota[] = {0, 1, 2, 3};
1529
+ uint32x4_t iv = vld1q_u32(iota);
1530
+ const uint32x4_t incv = vdupq_n_u32(4);
1531
+ const float32x4_t bfv = vdupq_n_f32(bf);
1532
+ for (i = 0; i < n_simd; i += 4) {
1533
+ const float32x4_t ai = vld1q_f32(a + i);
1534
+ const float32x4_t bi = vld1q_f32(b + i);
1535
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1536
+ vst1q_f32(c + i, ci);
1537
+ const uint32x4_t less_than = vcltq_f32(ci, vminv);
1538
+ vminv = vminq_f32(ci, vminv);
1539
+ iminv = vorrq_u32(
1540
+ vandq_u32(less_than, iv),
1541
+ vandq_u32(vmvnq_u32(less_than), iminv));
1542
+ iv = vaddq_u32(iv, incv);
1543
+ }
1544
+ }
1545
+ float vmin = vminvq_f32(vminv);
1546
+ uint32_t imin;
1547
+ {
1548
+ const float32x4_t vminy = vdupq_n_f32(vmin);
1549
+ const uint32x4_t equals = vceqq_f32(vminv, vminy);
1550
+ imin = vminvq_u32(vorrq_u32(
1551
+ vandq_u32(equals, iminv),
1552
+ vandq_u32(
1553
+ vmvnq_u32(equals),
1554
+ vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
1555
+ }
1556
+ for (; i < n; ++i) {
1557
+ c[i] = a[i] + bf * b[i];
1558
+ if (c[i] < vmin) {
1559
+ vmin = c[i];
1560
+ imin = static_cast<uint32_t>(i);
1561
+ }
1562
+ }
1563
+ return static_cast<int>(imin);
1564
+ }
1565
+
846
1566
  #else
847
1567
 
848
1568
  int fvec_madd_and_argmin(