faiss 0.2.7 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -62,7 +62,7 @@ void kernel(
62
62
  const float* const __restrict y,
63
63
  const float* const __restrict y_transposed,
64
64
  const size_t ny,
65
- SingleBestResultHandler<CMax<float, int64_t>>& res,
65
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
66
66
  const float* __restrict y_norms,
67
67
  const size_t i) {
68
68
  const size_t ny_p =
@@ -73,7 +73,7 @@ void kernel(
73
73
 
74
74
  // prefetch the next point
75
75
  #if defined(__AVX2__)
76
- _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA);
76
+ _mm_prefetch((const char*)(xd_0 + DIM * sizeof(float)), _MM_HINT_NTA);
77
77
  #endif
78
78
 
79
79
  // load a single point from x
@@ -226,7 +226,7 @@ void exhaustive_L2sqr_fused_cmax(
226
226
  const float* const __restrict y,
227
227
  size_t nx,
228
228
  size_t ny,
229
- SingleBestResultHandler<CMax<float, int64_t>>& res,
229
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
230
230
  const float* __restrict y_norms) {
231
231
  // BLAS does not like empty matrices
232
232
  if (nx == 0 || ny == 0) {
@@ -270,7 +270,7 @@ void exhaustive_L2sqr_fused_cmax(
270
270
  x, y, y_transposed.data(), ny, res, y_norms, i);
271
271
  }
272
272
 
273
- // Does nothing for SingleBestResultHandler, but
273
+ // Does nothing for Top1BlockResultHandler, but
274
274
  // keeping the call for the consistency.
275
275
  res.end_multiple();
276
276
  InterruptCallback::check();
@@ -284,7 +284,7 @@ bool exhaustive_L2sqr_fused_cmax_simdlib(
284
284
  size_t d,
285
285
  size_t nx,
286
286
  size_t ny,
287
- SingleBestResultHandler<CMax<float, int64_t>>& res,
287
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
288
288
  const float* y_norms) {
289
289
  // Process only cases with certain dimensionalities.
290
290
  // An acceptable dimensionality value is limited by the number of
@@ -24,7 +24,7 @@ bool exhaustive_L2sqr_fused_cmax_simdlib(
24
24
  size_t d,
25
25
  size_t nx,
26
26
  size_t ny,
27
- SingleBestResultHandler<CMax<float, int64_t>>& res,
27
+ Top1BlockResultHandler<CMax<float, int64_t>>& res,
28
28
  const float* y_norms);
29
29
 
30
30
  } // namespace faiss
@@ -223,6 +223,76 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
223
223
  }
224
224
  FAISS_PRAGMA_IMPRECISE_FUNCTION_END
225
225
 
226
+ /// Special version of inner product that computes 4 distances
227
+ /// between x and yi
228
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
229
+ void fvec_inner_product_batch_4(
230
+ const float* __restrict x,
231
+ const float* __restrict y0,
232
+ const float* __restrict y1,
233
+ const float* __restrict y2,
234
+ const float* __restrict y3,
235
+ const size_t d,
236
+ float& dis0,
237
+ float& dis1,
238
+ float& dis2,
239
+ float& dis3) {
240
+ float d0 = 0;
241
+ float d1 = 0;
242
+ float d2 = 0;
243
+ float d3 = 0;
244
+ FAISS_PRAGMA_IMPRECISE_LOOP
245
+ for (size_t i = 0; i < d; ++i) {
246
+ d0 += x[i] * y0[i];
247
+ d1 += x[i] * y1[i];
248
+ d2 += x[i] * y2[i];
249
+ d3 += x[i] * y3[i];
250
+ }
251
+
252
+ dis0 = d0;
253
+ dis1 = d1;
254
+ dis2 = d2;
255
+ dis3 = d3;
256
+ }
257
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
258
+
259
+ /// Special version of L2sqr that computes 4 distances
260
+ /// between x and yi, which is performance oriented.
261
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
262
+ void fvec_L2sqr_batch_4(
263
+ const float* x,
264
+ const float* y0,
265
+ const float* y1,
266
+ const float* y2,
267
+ const float* y3,
268
+ const size_t d,
269
+ float& dis0,
270
+ float& dis1,
271
+ float& dis2,
272
+ float& dis3) {
273
+ float d0 = 0;
274
+ float d1 = 0;
275
+ float d2 = 0;
276
+ float d3 = 0;
277
+ FAISS_PRAGMA_IMPRECISE_LOOP
278
+ for (size_t i = 0; i < d; ++i) {
279
+ const float q0 = x[i] - y0[i];
280
+ const float q1 = x[i] - y1[i];
281
+ const float q2 = x[i] - y2[i];
282
+ const float q3 = x[i] - y3[i];
283
+ d0 += q0 * q0;
284
+ d1 += q1 * q1;
285
+ d2 += q2 * q2;
286
+ d3 += q3 * q3;
287
+ }
288
+
289
+ dis0 = d0;
290
+ dis1 = d1;
291
+ dis2 = d2;
292
+ dis3 = d3;
293
+ }
294
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
295
+
226
296
  /*********************************************************
227
297
  * SSE and AVX implementations
228
298
  */
@@ -236,8 +306,10 @@ static inline __m128 masked_read(int d, const float* x) {
236
306
  switch (d) {
237
307
  case 3:
238
308
  buf[2] = x[2];
309
+ [[fallthrough]];
239
310
  case 2:
240
311
  buf[1] = x[1];
312
+ [[fallthrough]];
241
313
  case 1:
242
314
  buf[0] = x[0];
243
315
  }
@@ -247,6 +319,33 @@ static inline __m128 masked_read(int d, const float* x) {
247
319
 
248
320
  namespace {
249
321
 
322
+ /// helper function
323
+ inline float horizontal_sum(const __m128 v) {
324
+ // say, v is [x0, x1, x2, x3]
325
+
326
+ // v0 is [x2, x3, ..., ...]
327
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
328
+ // v1 is [x0 + x2, x1 + x3, ..., ...]
329
+ const __m128 v1 = _mm_add_ps(v, v0);
330
+ // v2 is [x1 + x3, ..., .... ,...]
331
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
332
+ // v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
333
+ const __m128 v3 = _mm_add_ps(v1, v2);
334
+ // return v3[0]
335
+ return _mm_cvtss_f32(v3);
336
+ }
337
+
338
+ #ifdef __AVX2__
339
+ /// helper function for AVX2
340
+ inline float horizontal_sum(const __m256 v) {
341
+ // add high and low parts
342
+ const __m128 v0 =
343
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
344
+ // perform horizontal sum on v0
345
+ return horizontal_sum(v0);
346
+ }
347
+ #endif
348
+
250
349
  /// Function that does a component-wise operation between x and y
251
350
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
252
351
  /// functions below
@@ -260,6 +359,13 @@ struct ElementOpL2 {
260
359
  __m128 tmp = _mm_sub_ps(x, y);
261
360
  return _mm_mul_ps(tmp, tmp);
262
361
  }
362
+
363
+ #ifdef __AVX2__
364
+ static __m256 op(__m256 x, __m256 y) {
365
+ __m256 tmp = _mm256_sub_ps(x, y);
366
+ return _mm256_mul_ps(tmp, tmp);
367
+ }
368
+ #endif
263
369
  };
264
370
 
265
371
  /// Function that does a component-wise operation between x and y
@@ -272,6 +378,12 @@ struct ElementOpIP {
272
378
  static __m128 op(__m128 x, __m128 y) {
273
379
  return _mm_mul_ps(x, y);
274
380
  }
381
+
382
+ #ifdef __AVX2__
383
+ static __m256 op(__m256 x, __m256 y) {
384
+ return _mm256_mul_ps(x, y);
385
+ }
386
+ #endif
275
387
  };
276
388
 
277
389
  template <class ElementOp>
@@ -314,6 +426,131 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
314
426
  }
315
427
  }
316
428
 
429
+ #ifdef __AVX2__
430
+
431
+ template <>
432
+ void fvec_op_ny_D2<ElementOpIP>(
433
+ float* dis,
434
+ const float* x,
435
+ const float* y,
436
+ size_t ny) {
437
+ const size_t ny8 = ny / 8;
438
+ size_t i = 0;
439
+
440
+ if (ny8 > 0) {
441
+ // process 8 D2-vectors per loop.
442
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
443
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
444
+
445
+ const __m256 m0 = _mm256_set1_ps(x[0]);
446
+ const __m256 m1 = _mm256_set1_ps(x[1]);
447
+
448
+ for (i = 0; i < ny8 * 8; i += 8) {
449
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
450
+
451
+ // load 8x2 matrix and transpose it in registers.
452
+ // the typical bottleneck is memory access, so
453
+ // let's trade instructions for the bandwidth.
454
+
455
+ __m256 v0;
456
+ __m256 v1;
457
+
458
+ transpose_8x2(
459
+ _mm256_loadu_ps(y + 0 * 8),
460
+ _mm256_loadu_ps(y + 1 * 8),
461
+ v0,
462
+ v1);
463
+
464
+ // compute distances
465
+ __m256 distances = _mm256_mul_ps(m0, v0);
466
+ distances = _mm256_fmadd_ps(m1, v1, distances);
467
+
468
+ // store
469
+ _mm256_storeu_ps(dis + i, distances);
470
+
471
+ y += 16;
472
+ }
473
+ }
474
+
475
+ if (i < ny) {
476
+ // process leftovers
477
+ float x0 = x[0];
478
+ float x1 = x[1];
479
+
480
+ for (; i < ny; i++) {
481
+ float distance = x0 * y[0] + x1 * y[1];
482
+ y += 2;
483
+ dis[i] = distance;
484
+ }
485
+ }
486
+ }
487
+
488
+ template <>
489
+ void fvec_op_ny_D2<ElementOpL2>(
490
+ float* dis,
491
+ const float* x,
492
+ const float* y,
493
+ size_t ny) {
494
+ const size_t ny8 = ny / 8;
495
+ size_t i = 0;
496
+
497
+ if (ny8 > 0) {
498
+ // process 8 D2-vectors per loop.
499
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
500
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
501
+
502
+ const __m256 m0 = _mm256_set1_ps(x[0]);
503
+ const __m256 m1 = _mm256_set1_ps(x[1]);
504
+
505
+ for (i = 0; i < ny8 * 8; i += 8) {
506
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
507
+
508
+ // load 8x2 matrix and transpose it in registers.
509
+ // the typical bottleneck is memory access, so
510
+ // let's trade instructions for the bandwidth.
511
+
512
+ __m256 v0;
513
+ __m256 v1;
514
+
515
+ transpose_8x2(
516
+ _mm256_loadu_ps(y + 0 * 8),
517
+ _mm256_loadu_ps(y + 1 * 8),
518
+ v0,
519
+ v1);
520
+
521
+ // compute differences
522
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
523
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
524
+
525
+ // compute squares of differences
526
+ __m256 distances = _mm256_mul_ps(d0, d0);
527
+ distances = _mm256_fmadd_ps(d1, d1, distances);
528
+
529
+ // store
530
+ _mm256_storeu_ps(dis + i, distances);
531
+
532
+ y += 16;
533
+ }
534
+ }
535
+
536
+ if (i < ny) {
537
+ // process leftovers
538
+ float x0 = x[0];
539
+ float x1 = x[1];
540
+
541
+ for (; i < ny; i++) {
542
+ float sub0 = x0 - y[0];
543
+ float sub1 = x1 - y[1];
544
+ float distance = sub0 * sub0 + sub1 * sub1;
545
+
546
+ y += 2;
547
+ dis[i] = distance;
548
+ }
549
+ }
550
+ }
551
+
552
+ #endif
553
+
317
554
  template <class ElementOp>
318
555
  void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
319
556
  __m128 x0 = _mm_loadu_ps(x);
@@ -321,17 +558,12 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
321
558
  for (size_t i = 0; i < ny; i++) {
322
559
  __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
323
560
  y += 4;
324
- accu = _mm_hadd_ps(accu, accu);
325
- accu = _mm_hadd_ps(accu, accu);
326
- dis[i] = _mm_cvtss_f32(accu);
561
+ dis[i] = horizontal_sum(accu);
327
562
  }
328
563
  }
329
564
 
330
565
  #ifdef __AVX2__
331
566
 
332
- // Specialized versions for AVX2 for any CPUs that support gather/scatter.
333
- // Todo: implement fvec_op_ny_Dxxx in the same way.
334
-
335
567
  template <>
336
568
  void fvec_op_ny_D4<ElementOpIP>(
337
569
  float* dis,
@@ -343,16 +575,9 @@ void fvec_op_ny_D4<ElementOpIP>(
343
575
 
344
576
  if (ny8 > 0) {
345
577
  // process 8 D4-vectors per loop.
346
- _mm_prefetch(y, _MM_HINT_NTA);
347
- _mm_prefetch(y + 16, _MM_HINT_NTA);
348
-
349
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
350
578
  const __m256 m0 = _mm256_set1_ps(x[0]);
351
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
352
579
  const __m256 m1 = _mm256_set1_ps(x[1]);
353
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
354
580
  const __m256 m2 = _mm256_set1_ps(x[2]);
355
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
356
581
  const __m256 m3 = _mm256_set1_ps(x[3]);
357
582
 
358
583
  for (i = 0; i < ny8 * 8; i += 8) {
@@ -395,9 +620,7 @@ void fvec_op_ny_D4<ElementOpIP>(
395
620
  for (; i < ny; i++) {
396
621
  __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
397
622
  y += 4;
398
- accu = _mm_hadd_ps(accu, accu);
399
- accu = _mm_hadd_ps(accu, accu);
400
- dis[i] = _mm_cvtss_f32(accu);
623
+ dis[i] = horizontal_sum(accu);
401
624
  }
402
625
  }
403
626
  }
@@ -413,16 +636,9 @@ void fvec_op_ny_D4<ElementOpL2>(
413
636
 
414
637
  if (ny8 > 0) {
415
638
  // process 8 D4-vectors per loop.
416
- _mm_prefetch(y, _MM_HINT_NTA);
417
- _mm_prefetch(y + 16, _MM_HINT_NTA);
418
-
419
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
420
639
  const __m256 m0 = _mm256_set1_ps(x[0]);
421
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
422
640
  const __m256 m1 = _mm256_set1_ps(x[1]);
423
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
424
641
  const __m256 m2 = _mm256_set1_ps(x[2]);
425
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
426
642
  const __m256 m3 = _mm256_set1_ps(x[3]);
427
643
 
428
644
  for (i = 0; i < ny8 * 8; i += 8) {
@@ -471,9 +687,7 @@ void fvec_op_ny_D4<ElementOpL2>(
471
687
  for (; i < ny; i++) {
472
688
  __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
473
689
  y += 4;
474
- accu = _mm_hadd_ps(accu, accu);
475
- accu = _mm_hadd_ps(accu, accu);
476
- dis[i] = _mm_cvtss_f32(accu);
690
+ dis[i] = horizontal_sum(accu);
477
691
  }
478
692
  }
479
693
  }
@@ -496,6 +710,182 @@ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
496
710
  }
497
711
  }
498
712
 
713
+ #ifdef __AVX2__
714
+
715
+ template <>
716
+ void fvec_op_ny_D8<ElementOpIP>(
717
+ float* dis,
718
+ const float* x,
719
+ const float* y,
720
+ size_t ny) {
721
+ const size_t ny8 = ny / 8;
722
+ size_t i = 0;
723
+
724
+ if (ny8 > 0) {
725
+ // process 8 D8-vectors per loop.
726
+ const __m256 m0 = _mm256_set1_ps(x[0]);
727
+ const __m256 m1 = _mm256_set1_ps(x[1]);
728
+ const __m256 m2 = _mm256_set1_ps(x[2]);
729
+ const __m256 m3 = _mm256_set1_ps(x[3]);
730
+ const __m256 m4 = _mm256_set1_ps(x[4]);
731
+ const __m256 m5 = _mm256_set1_ps(x[5]);
732
+ const __m256 m6 = _mm256_set1_ps(x[6]);
733
+ const __m256 m7 = _mm256_set1_ps(x[7]);
734
+
735
+ for (i = 0; i < ny8 * 8; i += 8) {
736
+ // load 8x8 matrix and transpose it in registers.
737
+ // the typical bottleneck is memory access, so
738
+ // let's trade instructions for the bandwidth.
739
+
740
+ __m256 v0;
741
+ __m256 v1;
742
+ __m256 v2;
743
+ __m256 v3;
744
+ __m256 v4;
745
+ __m256 v5;
746
+ __m256 v6;
747
+ __m256 v7;
748
+
749
+ transpose_8x8(
750
+ _mm256_loadu_ps(y + 0 * 8),
751
+ _mm256_loadu_ps(y + 1 * 8),
752
+ _mm256_loadu_ps(y + 2 * 8),
753
+ _mm256_loadu_ps(y + 3 * 8),
754
+ _mm256_loadu_ps(y + 4 * 8),
755
+ _mm256_loadu_ps(y + 5 * 8),
756
+ _mm256_loadu_ps(y + 6 * 8),
757
+ _mm256_loadu_ps(y + 7 * 8),
758
+ v0,
759
+ v1,
760
+ v2,
761
+ v3,
762
+ v4,
763
+ v5,
764
+ v6,
765
+ v7);
766
+
767
+ // compute distances
768
+ __m256 distances = _mm256_mul_ps(m0, v0);
769
+ distances = _mm256_fmadd_ps(m1, v1, distances);
770
+ distances = _mm256_fmadd_ps(m2, v2, distances);
771
+ distances = _mm256_fmadd_ps(m3, v3, distances);
772
+ distances = _mm256_fmadd_ps(m4, v4, distances);
773
+ distances = _mm256_fmadd_ps(m5, v5, distances);
774
+ distances = _mm256_fmadd_ps(m6, v6, distances);
775
+ distances = _mm256_fmadd_ps(m7, v7, distances);
776
+
777
+ // store
778
+ _mm256_storeu_ps(dis + i, distances);
779
+
780
+ y += 64;
781
+ }
782
+ }
783
+
784
+ if (i < ny) {
785
+ // process leftovers
786
+ __m256 x0 = _mm256_loadu_ps(x);
787
+
788
+ for (; i < ny; i++) {
789
+ __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
790
+ y += 8;
791
+ dis[i] = horizontal_sum(accu);
792
+ }
793
+ }
794
+ }
795
+
796
+ template <>
797
+ void fvec_op_ny_D8<ElementOpL2>(
798
+ float* dis,
799
+ const float* x,
800
+ const float* y,
801
+ size_t ny) {
802
+ const size_t ny8 = ny / 8;
803
+ size_t i = 0;
804
+
805
+ if (ny8 > 0) {
806
+ // process 8 D8-vectors per loop.
807
+ const __m256 m0 = _mm256_set1_ps(x[0]);
808
+ const __m256 m1 = _mm256_set1_ps(x[1]);
809
+ const __m256 m2 = _mm256_set1_ps(x[2]);
810
+ const __m256 m3 = _mm256_set1_ps(x[3]);
811
+ const __m256 m4 = _mm256_set1_ps(x[4]);
812
+ const __m256 m5 = _mm256_set1_ps(x[5]);
813
+ const __m256 m6 = _mm256_set1_ps(x[6]);
814
+ const __m256 m7 = _mm256_set1_ps(x[7]);
815
+
816
+ for (i = 0; i < ny8 * 8; i += 8) {
817
+ // load 8x8 matrix and transpose it in registers.
818
+ // the typical bottleneck is memory access, so
819
+ // let's trade instructions for the bandwidth.
820
+
821
+ __m256 v0;
822
+ __m256 v1;
823
+ __m256 v2;
824
+ __m256 v3;
825
+ __m256 v4;
826
+ __m256 v5;
827
+ __m256 v6;
828
+ __m256 v7;
829
+
830
+ transpose_8x8(
831
+ _mm256_loadu_ps(y + 0 * 8),
832
+ _mm256_loadu_ps(y + 1 * 8),
833
+ _mm256_loadu_ps(y + 2 * 8),
834
+ _mm256_loadu_ps(y + 3 * 8),
835
+ _mm256_loadu_ps(y + 4 * 8),
836
+ _mm256_loadu_ps(y + 5 * 8),
837
+ _mm256_loadu_ps(y + 6 * 8),
838
+ _mm256_loadu_ps(y + 7 * 8),
839
+ v0,
840
+ v1,
841
+ v2,
842
+ v3,
843
+ v4,
844
+ v5,
845
+ v6,
846
+ v7);
847
+
848
+ // compute differences
849
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
850
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
851
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
852
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
853
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
854
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
855
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
856
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
857
+
858
+ // compute squares of differences
859
+ __m256 distances = _mm256_mul_ps(d0, d0);
860
+ distances = _mm256_fmadd_ps(d1, d1, distances);
861
+ distances = _mm256_fmadd_ps(d2, d2, distances);
862
+ distances = _mm256_fmadd_ps(d3, d3, distances);
863
+ distances = _mm256_fmadd_ps(d4, d4, distances);
864
+ distances = _mm256_fmadd_ps(d5, d5, distances);
865
+ distances = _mm256_fmadd_ps(d6, d6, distances);
866
+ distances = _mm256_fmadd_ps(d7, d7, distances);
867
+
868
+ // store
869
+ _mm256_storeu_ps(dis + i, distances);
870
+
871
+ y += 64;
872
+ }
873
+ }
874
+
875
+ if (i < ny) {
876
+ // process leftovers
877
+ __m256 x0 = _mm256_loadu_ps(x);
878
+
879
+ for (; i < ny; i++) {
880
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
881
+ y += 8;
882
+ dis[i] = horizontal_sum(accu);
883
+ }
884
+ }
885
+ }
886
+
887
+ #endif
888
+
499
889
  template <class ElementOp>
500
890
  void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
501
891
  __m128 x0 = _mm_loadu_ps(x);
@@ -509,9 +899,7 @@ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
509
899
  y += 4;
510
900
  accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
511
901
  y += 4;
512
- accu = _mm_hadd_ps(accu, accu);
513
- accu = _mm_hadd_ps(accu, accu);
514
- dis[i] = _mm_cvtss_f32(accu);
902
+ dis[i] = horizontal_sum(accu);
515
903
  }
516
904
  }
517
905
 
@@ -581,7 +969,6 @@ void fvec_L2sqr_ny_y_transposed_D(
581
969
 
582
970
  // squared length of x
583
971
  float x_sqlen = 0;
584
- ;
585
972
  for (size_t j = 0; j < DIM; j++) {
586
973
  x_sqlen += x[j] * x[j];
587
974
  }
@@ -697,8 +1084,8 @@ size_t fvec_L2sqr_ny_nearest_D2(
697
1084
  // process 8 D2-vectors per loop.
698
1085
  const size_t ny8 = ny / 8;
699
1086
  if (ny8 > 0) {
700
- _mm_prefetch(y, _MM_HINT_T0);
701
- _mm_prefetch(y + 16, _MM_HINT_T0);
1087
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
1088
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
702
1089
 
703
1090
  // track min distance and the closest vector independently
704
1091
  // for each of 8 AVX2 components.
@@ -713,7 +1100,7 @@ size_t fvec_L2sqr_ny_nearest_D2(
713
1100
  const __m256 m1 = _mm256_set1_ps(x[1]);
714
1101
 
715
1102
  for (; i < ny8 * 8; i += 8) {
716
- _mm_prefetch(y + 32, _MM_HINT_T0);
1103
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
717
1104
 
718
1105
  __m256 v0;
719
1106
  __m256 v1;
@@ -892,10 +1279,7 @@ size_t fvec_L2sqr_ny_nearest_D4(
892
1279
  for (; i < ny; i++) {
893
1280
  __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
894
1281
  y += 4;
895
- accu = _mm_hadd_ps(accu, accu);
896
- accu = _mm_hadd_ps(accu, accu);
897
-
898
- const auto distance = _mm_cvtss_f32(accu);
1282
+ const float distance = horizontal_sum(accu);
899
1283
 
900
1284
  if (current_min_distance > distance) {
901
1285
  current_min_distance = distance;
@@ -1031,23 +1415,9 @@ size_t fvec_L2sqr_ny_nearest_D8(
1031
1415
  __m256 x0 = _mm256_loadu_ps(x);
1032
1416
 
1033
1417
  for (; i < ny; i++) {
1034
- __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
1035
- __m256 accu = _mm256_mul_ps(sub, sub);
1418
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1036
1419
  y += 8;
1037
-
1038
- // horitontal sum
1039
- const __m256 h0 = _mm256_hadd_ps(accu, accu);
1040
- const __m256 h1 = _mm256_hadd_ps(h0, h0);
1041
-
1042
- // extract high and low __m128 regs from __m256
1043
- const __m128 h2 = _mm256_extractf128_ps(h1, 1);
1044
- const __m128 h3 = _mm256_castps256_ps128(h1);
1045
-
1046
- // get a final hsum into all 4 regs
1047
- const __m128 h4 = _mm_add_ss(h2, h3);
1048
-
1049
- // extract f[0] from __m128
1050
- const float distance = _mm_cvtss_f32(h4);
1420
+ const float distance = horizontal_sum(accu);
1051
1421
 
1052
1422
  if (current_min_distance > distance) {
1053
1423
  current_min_distance = distance;
@@ -1260,21 +1630,6 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
1260
1630
 
1261
1631
  #ifdef USE_AVX
1262
1632
 
1263
- // reads 0 <= d < 8 floats as __m256
1264
- static inline __m256 masked_read_8(int d, const float* x) {
1265
- assert(0 <= d && d < 8);
1266
- if (d < 4) {
1267
- __m256 res = _mm256_setzero_ps();
1268
- res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
1269
- return res;
1270
- } else {
1271
- __m256 res = _mm256_setzero_ps();
1272
- res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
1273
- res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
1274
- return res;
1275
- }
1276
- }
1277
-
1278
1633
  float fvec_L1(const float* x, const float* y, size_t d) {
1279
1634
  __m256 msum1 = _mm256_setzero_ps();
1280
1635
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
@@ -1493,7 +1848,7 @@ void fvec_inner_products_ny(
1493
1848
  * heavily optimized table computations
1494
1849
  ***************************************************************************/
1495
1850
 
1496
- static inline void fvec_madd_ref(
1851
+ [[maybe_unused]] static inline void fvec_madd_ref(
1497
1852
  size_t n,
1498
1853
  const float* a,
1499
1854
  float bf,
@@ -1560,7 +1915,7 @@ static inline void fvec_madd_avx2(
1560
1915
 
1561
1916
  #ifdef __SSE3__
1562
1917
 
1563
- static inline void fvec_madd_sse(
1918
+ [[maybe_unused]] static inline void fvec_madd_sse(
1564
1919
  size_t n,
1565
1920
  const float* a,
1566
1921
  float bf,
@@ -1807,10 +2162,13 @@ void pq2_8cents_table(
1807
2162
  switch (nout) {
1808
2163
  case 4:
1809
2164
  ip3.storeu(out + 3 * ldo);
2165
+ [[fallthrough]];
1810
2166
  case 3:
1811
2167
  ip2.storeu(out + 2 * ldo);
2168
+ [[fallthrough]];
1812
2169
  case 2:
1813
2170
  ip1.storeu(out + 1 * ldo);
2171
+ [[fallthrough]];
1814
2172
  case 1:
1815
2173
  ip0.storeu(out);
1816
2174
  }