faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -57,6 +57,17 @@ struct simd256bit {
57
57
  bin(bits);
58
58
  return std::string(bits);
59
59
  }
60
+
61
+ // Checks whether the other holds exactly the same bytes.
62
+ bool is_same_as(simd256bit other) const {
63
+ for (size_t i = 0; i < 8; i++) {
64
+ if (u32[i] != other.u32[i]) {
65
+ return false;
66
+ }
67
+ }
68
+
69
+ return true;
70
+ }
60
71
  };
61
72
 
62
73
  /// vector of 16 elements in uint16
@@ -75,6 +86,41 @@ struct simd16uint16 : simd256bit {
75
86
 
76
87
  explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
77
88
 
89
+ explicit simd16uint16(
90
+ uint16_t u0,
91
+ uint16_t u1,
92
+ uint16_t u2,
93
+ uint16_t u3,
94
+ uint16_t u4,
95
+ uint16_t u5,
96
+ uint16_t u6,
97
+ uint16_t u7,
98
+ uint16_t u8,
99
+ uint16_t u9,
100
+ uint16_t u10,
101
+ uint16_t u11,
102
+ uint16_t u12,
103
+ uint16_t u13,
104
+ uint16_t u14,
105
+ uint16_t u15) {
106
+ this->u16[0] = u0;
107
+ this->u16[1] = u1;
108
+ this->u16[2] = u2;
109
+ this->u16[3] = u3;
110
+ this->u16[4] = u4;
111
+ this->u16[5] = u5;
112
+ this->u16[6] = u6;
113
+ this->u16[7] = u7;
114
+ this->u16[8] = u8;
115
+ this->u16[9] = u9;
116
+ this->u16[10] = u10;
117
+ this->u16[11] = u11;
118
+ this->u16[12] = u12;
119
+ this->u16[13] = u13;
120
+ this->u16[14] = u14;
121
+ this->u16[15] = u15;
122
+ }
123
+
78
124
  std::string elements_to_string(const char* fmt) const {
79
125
  char res[1000], *ptr = res;
80
126
  for (int i = 0; i < 16; i++) {
@@ -169,6 +215,13 @@ struct simd16uint16 : simd256bit {
169
215
  });
170
216
  }
171
217
 
218
+ simd16uint16 operator^(const simd256bit& other) const {
219
+ return binary_func(
220
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
221
+ return a ^ b;
222
+ });
223
+ }
224
+
172
225
  // returns binary masks
173
226
  simd16uint16 operator==(const simd16uint16& other) const {
174
227
  return binary_func(*this, other, [](uint16_t a, uint16_t b) {
@@ -288,6 +341,62 @@ inline uint32_t cmp_le32(
288
341
  return gem;
289
342
  }
290
343
 
344
+ // hadd does not cross lanes
345
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
346
+ simd16uint16 c;
347
+ c.u16[0] = a.u16[0] + a.u16[1];
348
+ c.u16[1] = a.u16[2] + a.u16[3];
349
+ c.u16[2] = a.u16[4] + a.u16[5];
350
+ c.u16[3] = a.u16[6] + a.u16[7];
351
+ c.u16[4] = b.u16[0] + b.u16[1];
352
+ c.u16[5] = b.u16[2] + b.u16[3];
353
+ c.u16[6] = b.u16[4] + b.u16[5];
354
+ c.u16[7] = b.u16[6] + b.u16[7];
355
+
356
+ c.u16[8] = a.u16[8] + a.u16[9];
357
+ c.u16[9] = a.u16[10] + a.u16[11];
358
+ c.u16[10] = a.u16[12] + a.u16[13];
359
+ c.u16[11] = a.u16[14] + a.u16[15];
360
+ c.u16[12] = b.u16[8] + b.u16[9];
361
+ c.u16[13] = b.u16[10] + b.u16[11];
362
+ c.u16[14] = b.u16[12] + b.u16[13];
363
+ c.u16[15] = b.u16[14] + b.u16[15];
364
+
365
+ return c;
366
+ }
367
+
368
+ // Vectorized version of the following code:
369
+ // for (size_t i = 0; i < n; i++) {
370
+ // bool flag = (candidateValues[i] < currentValues[i]);
371
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
372
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
373
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
374
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
375
+ // }
376
+ // Max indices evaluation is inaccurate in case of equal values (the index of
377
+ // the last equal value is saved instead of the first one), but this behavior
378
+ // saves instructions.
379
+ inline void cmplt_min_max_fast(
380
+ const simd16uint16 candidateValues,
381
+ const simd16uint16 candidateIndices,
382
+ const simd16uint16 currentValues,
383
+ const simd16uint16 currentIndices,
384
+ simd16uint16& minValues,
385
+ simd16uint16& minIndices,
386
+ simd16uint16& maxValues,
387
+ simd16uint16& maxIndices) {
388
+ for (size_t i = 0; i < 16; i++) {
389
+ bool flag = (candidateValues.u16[i] < currentValues.u16[i]);
390
+ minValues.u16[i] = flag ? candidateValues.u16[i] : currentValues.u16[i];
391
+ minIndices.u16[i] =
392
+ flag ? candidateIndices.u16[i] : currentIndices.u16[i];
393
+ maxValues.u16[i] =
394
+ !flag ? candidateValues.u16[i] : currentValues.u16[i];
395
+ maxIndices.u16[i] =
396
+ !flag ? candidateIndices.u16[i] : currentIndices.u16[i];
397
+ }
398
+ }
399
+
291
400
  // vector of 32 unsigned 8-bit integers
292
401
  struct simd32uint8 : simd256bit {
293
402
  simd32uint8() {}
@@ -299,6 +408,75 @@ struct simd32uint8 : simd256bit {
299
408
  explicit simd32uint8(uint8_t x) {
300
409
  set1(x);
301
410
  }
411
+ template <
412
+ uint8_t _0,
413
+ uint8_t _1,
414
+ uint8_t _2,
415
+ uint8_t _3,
416
+ uint8_t _4,
417
+ uint8_t _5,
418
+ uint8_t _6,
419
+ uint8_t _7,
420
+ uint8_t _8,
421
+ uint8_t _9,
422
+ uint8_t _10,
423
+ uint8_t _11,
424
+ uint8_t _12,
425
+ uint8_t _13,
426
+ uint8_t _14,
427
+ uint8_t _15,
428
+ uint8_t _16,
429
+ uint8_t _17,
430
+ uint8_t _18,
431
+ uint8_t _19,
432
+ uint8_t _20,
433
+ uint8_t _21,
434
+ uint8_t _22,
435
+ uint8_t _23,
436
+ uint8_t _24,
437
+ uint8_t _25,
438
+ uint8_t _26,
439
+ uint8_t _27,
440
+ uint8_t _28,
441
+ uint8_t _29,
442
+ uint8_t _30,
443
+ uint8_t _31>
444
+ static simd32uint8 create() {
445
+ simd32uint8 ret;
446
+ ret.u8[0] = _0;
447
+ ret.u8[1] = _1;
448
+ ret.u8[2] = _2;
449
+ ret.u8[3] = _3;
450
+ ret.u8[4] = _4;
451
+ ret.u8[5] = _5;
452
+ ret.u8[6] = _6;
453
+ ret.u8[7] = _7;
454
+ ret.u8[8] = _8;
455
+ ret.u8[9] = _9;
456
+ ret.u8[10] = _10;
457
+ ret.u8[11] = _11;
458
+ ret.u8[12] = _12;
459
+ ret.u8[13] = _13;
460
+ ret.u8[14] = _14;
461
+ ret.u8[15] = _15;
462
+ ret.u8[16] = _16;
463
+ ret.u8[17] = _17;
464
+ ret.u8[18] = _18;
465
+ ret.u8[19] = _19;
466
+ ret.u8[20] = _20;
467
+ ret.u8[21] = _21;
468
+ ret.u8[22] = _22;
469
+ ret.u8[23] = _23;
470
+ ret.u8[24] = _24;
471
+ ret.u8[25] = _25;
472
+ ret.u8[26] = _26;
473
+ ret.u8[27] = _27;
474
+ ret.u8[28] = _28;
475
+ ret.u8[29] = _29;
476
+ ret.u8[30] = _30;
477
+ ret.u8[31] = _31;
478
+ return ret;
479
+ }
302
480
 
303
481
  explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
304
482
 
@@ -440,6 +618,62 @@ struct simd8uint32 : simd256bit {
440
618
 
441
619
  explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {}
442
620
 
621
+ explicit simd8uint32(
622
+ uint32_t u0,
623
+ uint32_t u1,
624
+ uint32_t u2,
625
+ uint32_t u3,
626
+ uint32_t u4,
627
+ uint32_t u5,
628
+ uint32_t u6,
629
+ uint32_t u7) {
630
+ u32[0] = u0;
631
+ u32[1] = u1;
632
+ u32[2] = u2;
633
+ u32[3] = u3;
634
+ u32[4] = u4;
635
+ u32[5] = u5;
636
+ u32[6] = u6;
637
+ u32[7] = u7;
638
+ }
639
+
640
+ simd8uint32 operator+(simd8uint32 other) const {
641
+ simd8uint32 result;
642
+ for (int i = 0; i < 8; i++) {
643
+ result.u32[i] = u32[i] + other.u32[i];
644
+ }
645
+ return result;
646
+ }
647
+
648
+ simd8uint32 operator-(simd8uint32 other) const {
649
+ simd8uint32 result;
650
+ for (int i = 0; i < 8; i++) {
651
+ result.u32[i] = u32[i] - other.u32[i];
652
+ }
653
+ return result;
654
+ }
655
+
656
+ simd8uint32& operator+=(const simd8uint32& other) {
657
+ for (int i = 0; i < 8; i++) {
658
+ u32[i] += other.u32[i];
659
+ }
660
+ return *this;
661
+ }
662
+
663
+ bool operator==(simd8uint32 other) const {
664
+ for (size_t i = 0; i < 8; i++) {
665
+ if (u32[i] != other.u32[i]) {
666
+ return false;
667
+ }
668
+ }
669
+
670
+ return true;
671
+ }
672
+
673
+ bool operator!=(simd8uint32 other) const {
674
+ return !(*this == other);
675
+ }
676
+
443
677
  std::string elements_to_string(const char* fmt) const {
444
678
  char res[1000], *ptr = res;
445
679
  for (int i = 0; i < 8; i++) {
@@ -463,8 +697,46 @@ struct simd8uint32 : simd256bit {
463
697
  u32[i] = x;
464
698
  }
465
699
  }
700
+
701
+ simd8uint32 unzip() const {
702
+ const uint32_t ret[] = {
703
+ u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]};
704
+ return simd8uint32{ret};
705
+ }
466
706
  };
467
707
 
708
+ // Vectorized version of the following code:
709
+ // for (size_t i = 0; i < n; i++) {
710
+ // bool flag = (candidateValues[i] < currentValues[i]);
711
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
712
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
713
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
714
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
715
+ // }
716
+ // Max indices evaluation is inaccurate in case of equal values (the index of
717
+ // the last equal value is saved instead of the first one), but this behavior
718
+ // saves instructions.
719
+ inline void cmplt_min_max_fast(
720
+ const simd8uint32 candidateValues,
721
+ const simd8uint32 candidateIndices,
722
+ const simd8uint32 currentValues,
723
+ const simd8uint32 currentIndices,
724
+ simd8uint32& minValues,
725
+ simd8uint32& minIndices,
726
+ simd8uint32& maxValues,
727
+ simd8uint32& maxIndices) {
728
+ for (size_t i = 0; i < 8; i++) {
729
+ bool flag = (candidateValues.u32[i] < currentValues.u32[i]);
730
+ minValues.u32[i] = flag ? candidateValues.u32[i] : currentValues.u32[i];
731
+ minIndices.u32[i] =
732
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
733
+ maxValues.u32[i] =
734
+ !flag ? candidateValues.u32[i] : currentValues.u32[i];
735
+ maxIndices.u32[i] =
736
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
737
+ }
738
+ }
739
+
468
740
  struct simd8float32 : simd256bit {
469
741
  simd8float32() {}
470
742
 
@@ -484,6 +756,25 @@ struct simd8float32 : simd256bit {
484
756
  }
485
757
  }
486
758
 
759
+ explicit simd8float32(
760
+ float f0,
761
+ float f1,
762
+ float f2,
763
+ float f3,
764
+ float f4,
765
+ float f5,
766
+ float f6,
767
+ float f7) {
768
+ f32[0] = f0;
769
+ f32[1] = f1;
770
+ f32[2] = f2;
771
+ f32[3] = f3;
772
+ f32[4] = f4;
773
+ f32[5] = f5;
774
+ f32[6] = f6;
775
+ f32[7] = f7;
776
+ }
777
+
487
778
  template <typename F>
488
779
  static simd8float32 binary_func(
489
780
  const simd8float32& a,
@@ -511,6 +802,28 @@ struct simd8float32 : simd256bit {
511
802
  *this, other, [](float a, float b) { return a - b; });
512
803
  }
513
804
 
805
+ simd8float32& operator+=(const simd8float32& other) {
806
+ for (size_t i = 0; i < 8; i++) {
807
+ f32[i] += other.f32[i];
808
+ }
809
+
810
+ return *this;
811
+ }
812
+
813
+ bool operator==(simd8float32 other) const {
814
+ for (size_t i = 0; i < 8; i++) {
815
+ if (f32[i] != other.f32[i]) {
816
+ return false;
817
+ }
818
+ }
819
+
820
+ return true;
821
+ }
822
+
823
+ bool operator!=(simd8float32 other) const {
824
+ return !(*this == other);
825
+ }
826
+
514
827
  std::string tostring() const {
515
828
  char res[1000], *ptr = res;
516
829
  for (int i = 0; i < 8; i++) {
@@ -650,6 +963,83 @@ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
650
963
  return c;
651
964
  }
652
965
 
966
+ // The following primitive is a vectorized version of the following code
967
+ // snippet:
968
+ // float lowestValue = HUGE_VAL;
969
+ // uint lowestIndex = 0;
970
+ // for (size_t i = 0; i < n; i++) {
971
+ // if (values[i] < lowestValue) {
972
+ // lowestValue = values[i];
973
+ // lowestIndex = i;
974
+ // }
975
+ // }
976
+ // Vectorized version can be implemented via two operations: cmp and blend
977
+ // with something like this:
978
+ // lowestValues = [HUGE_VAL; 8];
979
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
980
+ // for (size_t i = 0; i < n; i += 8) {
981
+ // auto comparison = cmp(values + i, lowestValues);
982
+ // lowestValues = blend(
983
+ // comparison,
984
+ // values + i,
985
+ // lowestValues);
986
+ // lowestIndices = blend(
987
+ // comparison,
988
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
989
+ // lowestIndices);
990
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
991
+ // }
992
+ // The problem is that blend primitive needs very different instruction
993
+ // order for AVX and ARM.
994
+ // So, let's introduce a combination of these two in order to avoid
995
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
996
+ // these two ops (cmp and blend) are very often used together.
997
+ inline void cmplt_and_blend_inplace(
998
+ const simd8float32 candidateValues,
999
+ const simd8uint32 candidateIndices,
1000
+ simd8float32& lowestValues,
1001
+ simd8uint32& lowestIndices) {
1002
+ for (size_t j = 0; j < 8; j++) {
1003
+ bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]);
1004
+ if (comparison) {
1005
+ lowestValues.f32[j] = candidateValues.f32[j];
1006
+ lowestIndices.u32[j] = candidateIndices.u32[j];
1007
+ }
1008
+ }
1009
+ }
1010
+
1011
+ // Vectorized version of the following code:
1012
+ // for (size_t i = 0; i < n; i++) {
1013
+ // bool flag = (candidateValues[i] < currentValues[i]);
1014
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1015
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1016
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1017
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1018
+ // }
1019
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1020
+ // the last equal value is saved instead of the first one), but this behavior
1021
+ // saves instructions.
1022
+ inline void cmplt_min_max_fast(
1023
+ const simd8float32 candidateValues,
1024
+ const simd8uint32 candidateIndices,
1025
+ const simd8float32 currentValues,
1026
+ const simd8uint32 currentIndices,
1027
+ simd8float32& minValues,
1028
+ simd8uint32& minIndices,
1029
+ simd8float32& maxValues,
1030
+ simd8uint32& maxIndices) {
1031
+ for (size_t i = 0; i < 8; i++) {
1032
+ bool flag = (candidateValues.f32[i] < currentValues.f32[i]);
1033
+ minValues.f32[i] = flag ? candidateValues.f32[i] : currentValues.f32[i];
1034
+ minIndices.u32[i] =
1035
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1036
+ maxValues.f32[i] =
1037
+ !flag ? candidateValues.f32[i] : currentValues.f32[i];
1038
+ maxIndices.u32[i] =
1039
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1040
+ }
1041
+ }
1042
+
653
1043
  } // namespace
654
1044
 
655
1045
  } // namespace faiss