faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
+ #include <algorithm>
12
13
  #include <cassert>
13
14
  #include <cmath>
14
15
  #include <cstdio>
@@ -112,6 +113,74 @@ void fvec_L2sqr_ny_ref(
112
113
  }
113
114
  }
114
115
 
116
+ void fvec_L2sqr_ny_y_transposed_ref(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const float* y_sqlen,
121
+ size_t d,
122
+ size_t d_offset,
123
+ size_t ny) {
124
+ float x_sqlen = 0;
125
+ for (size_t j = 0; j < d; j++) {
126
+ x_sqlen += x[j] * x[j];
127
+ }
128
+
129
+ for (size_t i = 0; i < ny; i++) {
130
+ float dp = 0;
131
+ for (size_t j = 0; j < d; j++) {
132
+ dp += x[j] * y[i + j * d_offset];
133
+ }
134
+
135
+ dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
136
+ }
137
+ }
138
+
139
+ size_t fvec_L2sqr_ny_nearest_ref(
140
+ float* distances_tmp_buffer,
141
+ const float* x,
142
+ const float* y,
143
+ size_t d,
144
+ size_t ny) {
145
+ fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
146
+
147
+ size_t nearest_idx = 0;
148
+ float min_dis = HUGE_VALF;
149
+
150
+ for (size_t i = 0; i < ny; i++) {
151
+ if (distances_tmp_buffer[i] < min_dis) {
152
+ min_dis = distances_tmp_buffer[i];
153
+ nearest_idx = i;
154
+ }
155
+ }
156
+
157
+ return nearest_idx;
158
+ }
159
+
160
+ size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
161
+ float* distances_tmp_buffer,
162
+ const float* x,
163
+ const float* y,
164
+ const float* y_sqlen,
165
+ size_t d,
166
+ size_t d_offset,
167
+ size_t ny) {
168
+ fvec_L2sqr_ny_y_transposed_ref(
169
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
170
+
171
+ size_t nearest_idx = 0;
172
+ float min_dis = HUGE_VALF;
173
+
174
+ for (size_t i = 0; i < ny; i++) {
175
+ if (distances_tmp_buffer[i] < min_dis) {
176
+ min_dis = distances_tmp_buffer[i];
177
+ nearest_idx = i;
178
+ }
179
+ }
180
+
181
+ return nearest_idx;
182
+ }
183
+
115
184
  void fvec_inner_products_ny_ref(
116
185
  float* ip,
117
186
  const float* x,
@@ -257,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
257
326
  }
258
327
  }
259
328
 
329
+ #ifdef __AVX2__
330
+
331
+ // Specialized versions for AVX2 for any CPUs that support gather/scatter.
332
+ // Todo: implement fvec_op_ny_Dxxx in the same way.
333
+
334
+ template <>
335
+ void fvec_op_ny_D4<ElementOpIP>(
336
+ float* dis,
337
+ const float* x,
338
+ const float* y,
339
+ size_t ny) {
340
+ const size_t ny8 = ny / 8;
341
+ size_t i = 0;
342
+
343
+ if (ny8 > 0) {
344
+ // process 8 D4-vectors per loop.
345
+ _mm_prefetch(y, _MM_HINT_NTA);
346
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
347
+
348
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
349
+ const __m256 m0 = _mm256_set1_ps(x[0]);
350
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
351
+ const __m256 m1 = _mm256_set1_ps(x[1]);
352
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
353
+ const __m256 m2 = _mm256_set1_ps(x[2]);
354
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
355
+ const __m256 m3 = _mm256_set1_ps(x[3]);
356
+
357
+ const __m256i indices0 =
358
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
359
+
360
+ for (i = 0; i < ny8 * 8; i += 8) {
361
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
362
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
363
+
364
+ // collect dim 0 for 8 D4-vectors.
365
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
366
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
367
+ // collect dim 1 for 8 D4-vectors.
368
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
369
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
370
+ // collect dim 2 for 8 D4-vectors.
371
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
372
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
373
+ // collect dim 3 for 8 D4-vectors.
374
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
375
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
376
+
377
+ // compute distances
378
+ __m256 distances = _mm256_mul_ps(m0, v0);
379
+ distances = _mm256_fmadd_ps(m1, v1, distances);
380
+ distances = _mm256_fmadd_ps(m2, v2, distances);
381
+ distances = _mm256_fmadd_ps(m3, v3, distances);
382
+
383
+ // distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
384
+ // (x[1] * y[(i * 8 + 0) * 4 + 1]) +
385
+ // (x[2] * y[(i * 8 + 0) * 4 + 2]) +
386
+ // (x[3] * y[(i * 8 + 0) * 4 + 3])
387
+ // ...
388
+ // distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
389
+ // (x[1] * y[(i * 8 + 7) * 4 + 1]) +
390
+ // (x[2] * y[(i * 8 + 7) * 4 + 2]) +
391
+ // (x[3] * y[(i * 8 + 7) * 4 + 3])
392
+ _mm256_storeu_ps(dis + i, distances);
393
+
394
+ y += 32;
395
+ }
396
+ }
397
+
398
+ if (i < ny) {
399
+ // process leftovers
400
+ __m128 x0 = _mm_loadu_ps(x);
401
+
402
+ for (; i < ny; i++) {
403
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
404
+ y += 4;
405
+ accu = _mm_hadd_ps(accu, accu);
406
+ accu = _mm_hadd_ps(accu, accu);
407
+ dis[i] = _mm_cvtss_f32(accu);
408
+ }
409
+ }
410
+ }
411
+
412
+ template <>
413
+ void fvec_op_ny_D4<ElementOpL2>(
414
+ float* dis,
415
+ const float* x,
416
+ const float* y,
417
+ size_t ny) {
418
+ const size_t ny8 = ny / 8;
419
+ size_t i = 0;
420
+
421
+ if (ny8 > 0) {
422
+ // process 8 D4-vectors per loop.
423
+ _mm_prefetch(y, _MM_HINT_NTA);
424
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
425
+
426
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
427
+ const __m256 m0 = _mm256_set1_ps(x[0]);
428
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
429
+ const __m256 m1 = _mm256_set1_ps(x[1]);
430
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
431
+ const __m256 m2 = _mm256_set1_ps(x[2]);
432
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
433
+ const __m256 m3 = _mm256_set1_ps(x[3]);
434
+
435
+ const __m256i indices0 =
436
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
437
+
438
+ for (i = 0; i < ny8 * 8; i += 8) {
439
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
440
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
441
+
442
+ // collect dim 0 for 8 D4-vectors.
443
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
444
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
445
+ // collect dim 1 for 8 D4-vectors.
446
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
447
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
448
+ // collect dim 2 for 8 D4-vectors.
449
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
450
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
451
+ // collect dim 3 for 8 D4-vectors.
452
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
453
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
454
+
455
+ // compute differences
456
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
457
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
458
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
459
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
460
+
461
+ // compute squares of differences
462
+ __m256 distances = _mm256_mul_ps(d0, d0);
463
+ distances = _mm256_fmadd_ps(d1, d1, distances);
464
+ distances = _mm256_fmadd_ps(d2, d2, distances);
465
+ distances = _mm256_fmadd_ps(d3, d3, distances);
466
+
467
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
468
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
469
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
470
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
471
+ // ...
472
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
473
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
474
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
475
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
476
+ _mm256_storeu_ps(dis + i, distances);
477
+
478
+ y += 32;
479
+ }
480
+ }
481
+
482
+ if (i < ny) {
483
+ // process leftovers
484
+ __m128 x0 = _mm_loadu_ps(x);
485
+
486
+ for (; i < ny; i++) {
487
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
488
+ y += 4;
489
+ accu = _mm_hadd_ps(accu, accu);
490
+ accu = _mm_hadd_ps(accu, accu);
491
+ dis[i] = _mm_cvtss_f32(accu);
492
+ }
493
+ }
494
+ }
495
+
496
+ #endif
497
+
260
498
  template <class ElementOp>
261
499
  void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
262
500
  __m128 x0 = _mm_loadu_ps(x);
@@ -344,6 +582,324 @@ void fvec_inner_products_ny(
344
582
  #undef DISPATCH
345
583
  }
346
584
 
585
+ #ifdef __AVX2__
586
+ size_t fvec_L2sqr_ny_nearest_D4(
587
+ float* distances_tmp_buffer,
588
+ const float* x,
589
+ const float* y,
590
+ size_t ny) {
591
+ // this implementation does not use distances_tmp_buffer.
592
+
593
+ // current index being processed
594
+ size_t i = 0;
595
+
596
+ // min distance and the index of the closest vector so far
597
+ float current_min_distance = HUGE_VALF;
598
+ size_t current_min_index = 0;
599
+
600
+ // process 8 D4-vectors per loop.
601
+ const size_t ny8 = ny / 8;
602
+
603
+ if (ny8 > 0) {
604
+ // track min distance and the closest vector independently
605
+ // for each of 8 AVX2 components.
606
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
607
+ __m256i min_indices = _mm256_set1_epi32(0);
608
+
609
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
610
+ const __m256i indices_increment = _mm256_set1_epi32(8);
611
+
612
+ //
613
+ _mm_prefetch(y, _MM_HINT_NTA);
614
+ _mm_prefetch(y + 16, _MM_HINT_NTA);
615
+
616
+ // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
617
+ const __m256 m0 = _mm256_set1_ps(x[0]);
618
+ // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
619
+ const __m256 m1 = _mm256_set1_ps(x[1]);
620
+ // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
621
+ const __m256 m2 = _mm256_set1_ps(x[2]);
622
+ // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
623
+ const __m256 m3 = _mm256_set1_ps(x[3]);
624
+
625
+ const __m256i indices0 =
626
+ _mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
627
+
628
+ for (; i < ny8 * 8; i += 8) {
629
+ _mm_prefetch(y + 32, _MM_HINT_NTA);
630
+ _mm_prefetch(y + 48, _MM_HINT_NTA);
631
+
632
+ // collect dim 0 for 8 D4-vectors.
633
+ // v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
634
+ const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
635
+ // collect dim 1 for 8 D4-vectors.
636
+ // v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
637
+ const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
638
+ // collect dim 2 for 8 D4-vectors.
639
+ // v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
640
+ const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
641
+ // collect dim 3 for 8 D4-vectors.
642
+ // v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
643
+ const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
644
+
645
+ // compute differences
646
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
647
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
648
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
649
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
650
+
651
+ // compute squares of differences
652
+ __m256 distances = _mm256_mul_ps(d0, d0);
653
+ distances = _mm256_fmadd_ps(d1, d1, distances);
654
+ distances = _mm256_fmadd_ps(d2, d2, distances);
655
+ distances = _mm256_fmadd_ps(d3, d3, distances);
656
+
657
+ // distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
658
+ // (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
659
+ // (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
660
+ // (x[3] - y[(i * 8 + 0) * 4 + 3])
661
+ // ...
662
+ // distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
663
+ // (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
664
+ // (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
665
+ // (x[3] - y[(i * 8 + 7) * 4 + 3])
666
+
667
+ // compare the new distances to the min distances
668
+ // for each of 8 AVX2 components.
669
+ __m256 comparison =
670
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
671
+
672
+ // update min distances and indices with closest vectors if needed.
673
+ min_distances =
674
+ _mm256_blendv_ps(distances, min_distances, comparison);
675
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
676
+ _mm256_castsi256_ps(current_indices),
677
+ _mm256_castsi256_ps(min_indices),
678
+ comparison));
679
+
680
+ // update current indices values. Basically, +8 to each of the
681
+ // 8 AVX2 components.
682
+ current_indices =
683
+ _mm256_add_epi32(current_indices, indices_increment);
684
+
685
+ // scroll y forward (8 vectors 4 DIM each).
686
+ y += 32;
687
+ }
688
+
689
+ // dump values and find the minimum distance / minimum index
690
+ float min_distances_scalar[8];
691
+ uint32_t min_indices_scalar[8];
692
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
693
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
694
+
695
+ for (size_t j = 0; j < 8; j++) {
696
+ if (current_min_distance > min_distances_scalar[j]) {
697
+ current_min_distance = min_distances_scalar[j];
698
+ current_min_index = min_indices_scalar[j];
699
+ }
700
+ }
701
+ }
702
+
703
+ if (i < ny) {
704
+ // process leftovers
705
+ __m128 x0 = _mm_loadu_ps(x);
706
+
707
+ for (; i < ny; i++) {
708
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
709
+ y += 4;
710
+ accu = _mm_hadd_ps(accu, accu);
711
+ accu = _mm_hadd_ps(accu, accu);
712
+
713
+ const auto distance = _mm_cvtss_f32(accu);
714
+
715
+ if (current_min_distance > distance) {
716
+ current_min_distance = distance;
717
+ current_min_index = i;
718
+ }
719
+ }
720
+ }
721
+
722
+ return current_min_index;
723
+ }
724
+ #else
725
+ size_t fvec_L2sqr_ny_nearest_D4(
726
+ float* distances_tmp_buffer,
727
+ const float* x,
728
+ const float* y,
729
+ size_t ny) {
730
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
731
+ }
732
+ #endif
733
+
734
+ size_t fvec_L2sqr_ny_nearest(
735
+ float* distances_tmp_buffer,
736
+ const float* x,
737
+ const float* y,
738
+ size_t d,
739
+ size_t ny) {
740
+ // optimized for a few special cases
741
+ #define DISPATCH(dval) \
742
+ case dval: \
743
+ return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
744
+
745
+ switch (d) {
746
+ DISPATCH(4)
747
+ default:
748
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
749
+ }
750
+ #undef DISPATCH
751
+ }
752
+
753
+ #ifdef __AVX2__
754
+ template <size_t DIM>
755
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
756
+ float* distances_tmp_buffer,
757
+ const float* x,
758
+ const float* y,
759
+ const float* y_sqlen,
760
+ const size_t d_offset,
761
+ size_t ny) {
762
+ // this implementation does not use distances_tmp_buffer.
763
+
764
+ // current index being processed
765
+ size_t i = 0;
766
+
767
+ // min distance and the index of the closest vector so far
768
+ float current_min_distance = HUGE_VALF;
769
+ size_t current_min_index = 0;
770
+
771
+ // process 8 vectors per loop.
772
+ const size_t ny8 = ny / 8;
773
+
774
+ if (ny8 > 0) {
775
+ // track min distance and the closest vector independently
776
+ // for each of 8 AVX2 components.
777
+ __m256 min_distances = _mm256_set1_ps(HUGE_VALF);
778
+ __m256i min_indices = _mm256_set1_epi32(0);
779
+
780
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
781
+ const __m256i indices_increment = _mm256_set1_epi32(8);
782
+
783
+ // m[i] = (2 * x[i], ... 2 * x[i])
784
+ __m256 m[DIM];
785
+ for (size_t j = 0; j < DIM; j++) {
786
+ m[j] = _mm256_set1_ps(x[j]);
787
+ m[j] = _mm256_add_ps(m[j], m[j]);
788
+ }
789
+
790
+ for (; i < ny8 * 8; i += 8) {
791
+ // collect dim 0 for 8 D4-vectors.
792
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
793
+ // compute dot products
794
+ __m256 dp = _mm256_mul_ps(m[0], v0);
795
+
796
+ for (size_t j = 1; j < DIM; j++) {
797
+ // collect dim j for 8 D4-vectors.
798
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
799
+ dp = _mm256_fmadd_ps(m[j], vj, dp);
800
+ }
801
+
802
+ // compute y^2 - (2 * x, y), which is sufficient for looking for the
803
+ // lowest distance.
804
+ // x^2 is the constant that can be avoided.
805
+ const __m256 distances =
806
+ _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
807
+
808
+ // compare the new distances to the min distances
809
+ // for each of 8 AVX2 components.
810
+ const __m256 comparison =
811
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
812
+
813
+ // update min distances and indices with closest vectors if needed.
814
+ min_distances =
815
+ _mm256_blendv_ps(distances, min_distances, comparison);
816
+ min_indices = _mm256_castps_si256(_mm256_blendv_ps(
817
+ _mm256_castsi256_ps(current_indices),
818
+ _mm256_castsi256_ps(min_indices),
819
+ comparison));
820
+
821
+ // update current indices values. Basically, +8 to each of the
822
+ // 8 AVX2 components.
823
+ current_indices =
824
+ _mm256_add_epi32(current_indices, indices_increment);
825
+
826
+ // scroll y and y_sqlen forward.
827
+ y += 8;
828
+ y_sqlen += 8;
829
+ }
830
+
831
+ // dump values and find the minimum distance / minimum index
832
+ float min_distances_scalar[8];
833
+ uint32_t min_indices_scalar[8];
834
+ _mm256_storeu_ps(min_distances_scalar, min_distances);
835
+ _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
836
+
837
+ for (size_t j = 0; j < 8; j++) {
838
+ if (current_min_distance > min_distances_scalar[j]) {
839
+ current_min_distance = min_distances_scalar[j];
840
+ current_min_index = min_indices_scalar[j];
841
+ }
842
+ }
843
+ }
844
+
845
+ if (i < ny) {
846
+ // process leftovers
847
+ for (; i < ny; i++) {
848
+ float dp = 0;
849
+ for (size_t j = 0; j < DIM; j++) {
850
+ dp += x[j] * y[j * d_offset];
851
+ }
852
+
853
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
854
+ // lowest distance.
855
+ const float distance = y_sqlen[0] - 2 * dp;
856
+
857
+ if (current_min_distance > distance) {
858
+ current_min_distance = distance;
859
+ current_min_index = i;
860
+ }
861
+
862
+ y += 1;
863
+ y_sqlen += 1;
864
+ }
865
+ }
866
+
867
+ return current_min_index;
868
+ }
869
+ #endif
870
+
871
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
872
+ float* distances_tmp_buffer,
873
+ const float* x,
874
+ const float* y,
875
+ const float* y_sqlen,
876
+ size_t d,
877
+ size_t d_offset,
878
+ size_t ny) {
879
+ // optimized for a few special cases
880
+ #ifdef __AVX2__
881
+ #define DISPATCH(dval) \
882
+ case dval: \
883
+ return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
884
+ distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
885
+
886
+ switch (d) {
887
+ DISPATCH(1)
888
+ DISPATCH(2)
889
+ DISPATCH(4)
890
+ DISPATCH(8)
891
+ default:
892
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
893
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
894
+ }
895
+ #undef DISPATCH
896
+ #else
897
+ // non-AVX2 case
898
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
899
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
900
+ #endif
901
+ }
902
+
347
903
  #endif
348
904
 
349
905
  #ifdef USE_AVX
@@ -589,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
589
1145
  float32x4_t sq = vsubq_f32(xi, yi);
590
1146
  accux4 = vfmaq_f32(accux4, sq, sq);
591
1147
  }
592
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
593
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1148
+ float32_t accux1 = vaddvq_f32(accux4);
594
1149
  for (; i < d; ++i) {
595
1150
  float32_t xi = x[i];
596
1151
  float32_t yi = y[i];
@@ -609,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
609
1164
  float32x4_t yi = vld1q_f32(y + i);
610
1165
  accux4 = vfmaq_f32(accux4, xi, yi);
611
1166
  }
612
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
613
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1167
+ float32_t accux1 = vaddvq_f32(accux4);
614
1168
  for (; i < d; ++i) {
615
1169
  float32_t xi = x[i];
616
1170
  float32_t yi = y[i];
@@ -627,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
627
1181
  float32x4_t xi = vld1q_f32(x + i);
628
1182
  accux4 = vfmaq_f32(accux4, xi, xi);
629
1183
  }
630
- float32x4_t accux2 = vpaddq_f32(accux4, accux4);
631
- float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
1184
+ float32_t accux1 = vaddvq_f32(accux4);
632
1185
  for (; i < d; ++i) {
633
1186
  float32_t xi = x[i];
634
1187
  accux1 += xi * xi;
@@ -646,6 +1199,27 @@ void fvec_L2sqr_ny(
646
1199
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
647
1200
  }
648
1201
 
1202
+ size_t fvec_L2sqr_ny_nearest(
1203
+ float* distances_tmp_buffer,
1204
+ const float* x,
1205
+ const float* y,
1206
+ size_t d,
1207
+ size_t ny) {
1208
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1209
+ }
1210
+
1211
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1212
+ float* distances_tmp_buffer,
1213
+ const float* x,
1214
+ const float* y,
1215
+ const float* y_sqlen,
1216
+ size_t d,
1217
+ size_t d_offset,
1218
+ size_t ny) {
1219
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1220
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1221
+ }
1222
+
649
1223
  float fvec_L1(const float* x, const float* y, size_t d) {
650
1224
  return fvec_L1_ref(x, y, d);
651
1225
  }
@@ -695,6 +1269,27 @@ void fvec_L2sqr_ny(
695
1269
  fvec_L2sqr_ny_ref(dis, x, y, d, ny);
696
1270
  }
697
1271
 
1272
+ size_t fvec_L2sqr_ny_nearest(
1273
+ float* distances_tmp_buffer,
1274
+ const float* x,
1275
+ const float* y,
1276
+ size_t d,
1277
+ size_t ny) {
1278
+ return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
1279
+ }
1280
+
1281
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
1282
+ float* distances_tmp_buffer,
1283
+ const float* x,
1284
+ const float* y,
1285
+ const float* y_sqlen,
1286
+ size_t d,
1287
+ size_t d_offset,
1288
+ size_t ny) {
1289
+ return fvec_L2sqr_ny_nearest_y_transposed_ref(
1290
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
1291
+ }
1292
+
698
1293
  void fvec_inner_products_ny(
699
1294
  float* dis,
700
1295
  const float* x,
@@ -720,6 +1315,61 @@ static inline void fvec_madd_ref(
720
1315
  c[i] = a[i] + bf * b[i];
721
1316
  }
722
1317
 
1318
+ #ifdef __AVX2__
1319
+ static inline void fvec_madd_avx2(
1320
+ const size_t n,
1321
+ const float* __restrict a,
1322
+ const float bf,
1323
+ const float* __restrict b,
1324
+ float* __restrict c) {
1325
+ //
1326
+ const size_t n8 = n / 8;
1327
+ const size_t n_for_masking = n % 8;
1328
+
1329
+ const __m256 bfmm = _mm256_set1_ps(bf);
1330
+
1331
+ size_t idx = 0;
1332
+ for (idx = 0; idx < n8 * 8; idx += 8) {
1333
+ const __m256 ax = _mm256_loadu_ps(a + idx);
1334
+ const __m256 bx = _mm256_loadu_ps(b + idx);
1335
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1336
+ _mm256_storeu_ps(c + idx, abmul);
1337
+ }
1338
+
1339
+ if (n_for_masking > 0) {
1340
+ __m256i mask;
1341
+ switch (n_for_masking) {
1342
+ case 1:
1343
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
1344
+ break;
1345
+ case 2:
1346
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
1347
+ break;
1348
+ case 3:
1349
+ mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
1350
+ break;
1351
+ case 4:
1352
+ mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
1353
+ break;
1354
+ case 5:
1355
+ mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
1356
+ break;
1357
+ case 6:
1358
+ mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
1359
+ break;
1360
+ case 7:
1361
+ mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
1362
+ break;
1363
+ }
1364
+
1365
+ const __m256 ax = _mm256_maskload_ps(a + idx, mask);
1366
+ const __m256 bx = _mm256_maskload_ps(b + idx, mask);
1367
+ const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
1368
+ _mm256_maskstore_ps(c + idx, mask, abmul);
1369
+ }
1370
+ }
1371
+ #endif
1372
+
723
1373
  #ifdef __SSE3__
724
1374
 
725
1375
  static inline void fvec_madd_sse(
@@ -743,10 +1393,30 @@ static inline void fvec_madd_sse(
743
1393
  }
744
1394
 
745
1395
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1396
+ #ifdef __AVX2__
1397
+ fvec_madd_avx2(n, a, bf, b, c);
1398
+ #else
746
1399
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
747
1400
  fvec_madd_sse(n, a, bf, b, c);
748
1401
  else
749
1402
  fvec_madd_ref(n, a, bf, b, c);
1403
+ #endif
1404
+ }
1405
+
1406
+ #elif defined(__aarch64__)
1407
+
1408
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1409
+ const size_t n_simd = n - (n & 3);
1410
+ const float32x4_t bfv = vdupq_n_f32(bf);
1411
+ size_t i;
1412
+ for (i = 0; i < n_simd; i += 4) {
1413
+ const float32x4_t ai = vld1q_f32(a + i);
1414
+ const float32x4_t bi = vld1q_f32(b + i);
1415
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1416
+ vst1q_f32(c + i, ci);
1417
+ }
1418
+ for (; i < n; ++i)
1419
+ c[i] = a[i] + bf * b[i];
750
1420
  }
751
1421
 
752
1422
  #else
@@ -842,6 +1512,57 @@ int fvec_madd_and_argmin(
842
1512
  return fvec_madd_and_argmin_ref(n, a, bf, b, c);
843
1513
  }
844
1514
 
1515
+ #elif defined(__aarch64__)
1516
+
1517
+ int fvec_madd_and_argmin(
1518
+ size_t n,
1519
+ const float* a,
1520
+ float bf,
1521
+ const float* b,
1522
+ float* c) {
1523
+ float32x4_t vminv = vdupq_n_f32(1e20);
1524
+ uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
1525
+ size_t i;
1526
+ {
1527
+ const size_t n_simd = n - (n & 3);
1528
+ const uint32_t iota[] = {0, 1, 2, 3};
1529
+ uint32x4_t iv = vld1q_u32(iota);
1530
+ const uint32x4_t incv = vdupq_n_u32(4);
1531
+ const float32x4_t bfv = vdupq_n_f32(bf);
1532
+ for (i = 0; i < n_simd; i += 4) {
1533
+ const float32x4_t ai = vld1q_f32(a + i);
1534
+ const float32x4_t bi = vld1q_f32(b + i);
1535
+ const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
1536
+ vst1q_f32(c + i, ci);
1537
+ const uint32x4_t less_than = vcltq_f32(ci, vminv);
1538
+ vminv = vminq_f32(ci, vminv);
1539
+ iminv = vorrq_u32(
1540
+ vandq_u32(less_than, iv),
1541
+ vandq_u32(vmvnq_u32(less_than), iminv));
1542
+ iv = vaddq_u32(iv, incv);
1543
+ }
1544
+ }
1545
+ float vmin = vminvq_f32(vminv);
1546
+ uint32_t imin;
1547
+ {
1548
+ const float32x4_t vminy = vdupq_n_f32(vmin);
1549
+ const uint32x4_t equals = vceqq_f32(vminv, vminy);
1550
+ imin = vminvq_u32(vorrq_u32(
1551
+ vandq_u32(equals, iminv),
1552
+ vandq_u32(
1553
+ vmvnq_u32(equals),
1554
+ vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
1555
+ }
1556
+ for (; i < n; ++i) {
1557
+ c[i] = a[i] + bf * b[i];
1558
+ if (c[i] < vmin) {
1559
+ vmin = c[i];
1560
+ imin = static_cast<uint32_t>(i);
1561
+ }
1562
+ }
1563
+ return static_cast<int>(imin);
1564
+ }
1565
+
845
1566
  #else
846
1567
 
847
1568
  int fvec_madd_and_argmin(
@@ -973,4 +1694,53 @@ void compute_PQ_dis_tables_dsub2(
973
1694
  }
974
1695
  }
975
1696
 
1697
+ /*********************************************************
1698
+ * Vector to vector functions
1699
+ *********************************************************/
1700
+
1701
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
1702
+ size_t i;
1703
+ for (i = 0; i + 7 < d; i += 8) {
1704
+ simd8float32 ci, ai, bi;
1705
+ ai.loadu(a + i);
1706
+ bi.loadu(b + i);
1707
+ ci = ai - bi;
1708
+ ci.storeu(c + i);
1709
+ }
1710
+ // finish non-multiple of 8 remainder
1711
+ for (; i < d; i++) {
1712
+ c[i] = a[i] - b[i];
1713
+ }
1714
+ }
1715
+
1716
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
1717
+ size_t i;
1718
+ for (i = 0; i + 7 < d; i += 8) {
1719
+ simd8float32 ci, ai, bi;
1720
+ ai.loadu(a + i);
1721
+ bi.loadu(b + i);
1722
+ ci = ai + bi;
1723
+ ci.storeu(c + i);
1724
+ }
1725
+ // finish non-multiple of 8 remainder
1726
+ for (; i < d; i++) {
1727
+ c[i] = a[i] + b[i];
1728
+ }
1729
+ }
1730
+
1731
+ void fvec_add(size_t d, const float* a, float b, float* c) {
1732
+ size_t i;
1733
+ simd8float32 bv(b);
1734
+ for (i = 0; i + 7 < d; i += 8) {
1735
+ simd8float32 ci, ai, bi;
1736
+ ai.loadu(a + i);
1737
+ ci = ai + bv;
1738
+ ci.storeu(c + i);
1739
+ }
1740
+ // finish non-multiple of 8 remainder
1741
+ for (; i < d; i++) {
1742
+ c[i] = a[i] + b;
1743
+ }
1744
+ }
1745
+
976
1746
  } // namespace faiss