faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -817,7 +817,7 @@ template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
817
817
  * Histogram subroutines
818
818
  ******************************************************************/
819
819
 
820
- #ifdef __AVX2__
820
+ #if defined(__AVX2__) || defined(__aarch64__)
821
821
  /// FIXME when MSB of uint16 is set
822
822
  // this code does not compile properly with GCC 7.4.0
823
823
 
@@ -833,7 +833,7 @@ simd32uint8 accu4to8(simd16uint16 a4) {
833
833
  simd16uint16 a8_0 = a4 & mask4;
834
834
  simd16uint16 a8_1 = (a4 >> 4) & mask4;
835
835
 
836
- return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
836
+ return simd32uint8(hadd(a8_0, a8_1));
837
837
  }
838
838
 
839
839
  simd16uint16 accu8to16(simd32uint8 a8) {
@@ -842,10 +842,10 @@ simd16uint16 accu8to16(simd32uint8 a8) {
842
842
  simd16uint16 a8_0 = simd16uint16(a8) & mask8;
843
843
  simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
844
844
 
845
- return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
845
+ return hadd(a8_0, a8_1);
846
846
  }
847
847
 
848
- static const simd32uint8 shifts(_mm256_setr_epi8(
848
+ static const simd32uint8 shifts = simd32uint8::create<
849
849
  1,
850
850
  16,
851
851
  0,
@@ -877,7 +877,7 @@ static const simd32uint8 shifts(_mm256_setr_epi8(
877
877
  0,
878
878
  0,
879
879
  4,
880
- 64));
880
+ 64>();
881
881
 
882
882
  // 2-bit accumulator: we can add only up to 3 elements
883
883
  // on output we return 2*4-bit results
@@ -937,7 +937,7 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
937
937
  simd16uint16 a16lo = accu8to16(a8lo);
938
938
  simd16uint16 a16hi = accu8to16(a8hi);
939
939
 
940
- simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
940
+ simd16uint16 a16 = hadd(a16lo, a16hi);
941
941
 
942
942
  // the 2 lanes must still be combined
943
943
  return a16;
@@ -947,7 +947,7 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
947
947
  * 16 bins
948
948
  ************************************************************/
949
949
 
950
- static const simd32uint8 shifts2(_mm256_setr_epi8(
950
+ static const simd32uint8 shifts2 = simd32uint8::create<
951
951
  1,
952
952
  2,
953
953
  4,
@@ -955,7 +955,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
955
955
  16,
956
956
  32,
957
957
  64,
958
- (char)128,
958
+ 128,
959
959
  1,
960
960
  2,
961
961
  4,
@@ -963,7 +963,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
963
963
  16,
964
964
  32,
965
965
  64,
966
- (char)128,
966
+ 128,
967
967
  1,
968
968
  2,
969
969
  4,
@@ -971,7 +971,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
971
971
  16,
972
972
  32,
973
973
  64,
974
- (char)128,
974
+ 128,
975
975
  1,
976
976
  2,
977
977
  4,
@@ -979,19 +979,12 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
979
979
  16,
980
980
  32,
981
981
  64,
982
- (char)128));
982
+ 128>();
983
983
 
984
984
  simd32uint8 shiftr_16(simd32uint8 x, int n) {
985
985
  return simd32uint8(simd16uint16(x) >> n);
986
986
  }
987
987
 
988
- inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
989
- __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
990
- __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
991
-
992
- return simd32uint8(a1b0) + simd32uint8(a0b1);
993
- }
994
-
995
988
  // 2-bit accumulator: we can add only up to 3 elements
996
989
  // on output we return 2*4-bit results
997
990
  template <int N, class Preproc>
@@ -1018,7 +1011,7 @@ void compute_accu2_16(
1018
1011
  // contains 0s for out-of-bounds elements
1019
1012
 
1020
1013
  simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
1021
- lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
1014
+ lt8 = lt8 ^ simd16uint16(0xff00);
1022
1015
 
1023
1016
  a1 = a1 & lt8;
1024
1017
 
@@ -1036,11 +1029,15 @@ void compute_accu2_16(
1036
1029
  simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
1037
1030
  simd32uint8 mask4(0x0f);
1038
1031
 
1039
- simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
1032
+ simd16uint16 a8_0 = combine2x2(
1033
+ (simd16uint16)(a4_0 & mask4),
1034
+ (simd16uint16)(shiftr_16(a4_0, 4) & mask4));
1040
1035
 
1041
- simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
1036
+ simd16uint16 a8_1 = combine2x2(
1037
+ (simd16uint16)(a4_1 & mask4),
1038
+ (simd16uint16)(shiftr_16(a4_1, 4) & mask4));
1042
1039
 
1043
- return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
1040
+ return simd32uint8(hadd(a8_0, a8_1));
1044
1041
  }
1045
1042
 
1046
1043
  template <class Preproc>
@@ -1079,10 +1076,9 @@ simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
1079
1076
  simd16uint16 a16lo = accu8to16(a8lo);
1080
1077
  simd16uint16 a16hi = accu8to16(a8hi);
1081
1078
 
1082
- simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
1079
+ simd16uint16 a16 = hadd(a16lo, a16hi);
1083
1080
 
1084
- __m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
1085
- a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
1081
+ a16 = simd16uint16{simd8uint32{a16}.unzip()};
1086
1082
 
1087
1083
  return a16;
1088
1084
  }
@@ -70,6 +70,13 @@ struct simd256bit {
70
70
  bin(bits);
71
71
  return std::string(bits);
72
72
  }
73
+
74
+ // Checks whether the other holds exactly the same bytes.
75
+ bool is_same_as(simd256bit other) const {
76
+ const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
77
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
78
+ return (bitmask == 0xffffffffU);
79
+ }
73
80
  };
74
81
 
75
82
  /// vector of 16 elements in uint16
@@ -86,6 +93,41 @@ struct simd16uint16 : simd256bit {
86
93
 
87
94
  explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
88
95
 
96
+ explicit simd16uint16(
97
+ uint16_t u0,
98
+ uint16_t u1,
99
+ uint16_t u2,
100
+ uint16_t u3,
101
+ uint16_t u4,
102
+ uint16_t u5,
103
+ uint16_t u6,
104
+ uint16_t u7,
105
+ uint16_t u8,
106
+ uint16_t u9,
107
+ uint16_t u10,
108
+ uint16_t u11,
109
+ uint16_t u12,
110
+ uint16_t u13,
111
+ uint16_t u14,
112
+ uint16_t u15)
113
+ : simd256bit(_mm256_setr_epi16(
114
+ u0,
115
+ u1,
116
+ u2,
117
+ u3,
118
+ u4,
119
+ u5,
120
+ u6,
121
+ u7,
122
+ u8,
123
+ u9,
124
+ u10,
125
+ u11,
126
+ u12,
127
+ u13,
128
+ u14,
129
+ u15)) {}
130
+
89
131
  std::string elements_to_string(const char* fmt) const {
90
132
  uint16_t bytes[16];
91
133
  storeu((void*)bytes);
@@ -151,9 +193,19 @@ struct simd16uint16 : simd256bit {
151
193
  return simd16uint16(_mm256_or_si256(i, other.i));
152
194
  }
153
195
 
196
+ simd16uint16 operator^(simd256bit other) const {
197
+ return simd16uint16(_mm256_xor_si256(i, other.i));
198
+ }
199
+
154
200
  // returns binary masks
155
- simd16uint16 operator==(simd256bit other) const {
156
- return simd16uint16(_mm256_cmpeq_epi16(i, other.i));
201
+ friend simd16uint16 operator==(const simd256bit lhs, const simd256bit rhs) {
202
+ return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
203
+ }
204
+
205
+ bool is_same(simd16uint16 other) const {
206
+ const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i);
207
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
208
+ return (bitmask == 0xffffffffU);
157
209
  }
158
210
 
159
211
  simd16uint16 operator~() const {
@@ -255,6 +307,45 @@ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
255
307
  return ge;
256
308
  }
257
309
 
310
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
311
+ return simd16uint16(_mm256_hadd_epi16(a.i, b.i));
312
+ }
313
+
314
+ // Vectorized version of the following code:
315
+ // for (size_t i = 0; i < n; i++) {
316
+ // bool flag = (candidateValues[i] < currentValues[i]);
317
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
318
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
319
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
320
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
321
+ // }
322
+ // Max indices evaluation is inaccurate in case of equal values (the index of
323
+ // the last equal value is saved instead of the first one), but this behavior
324
+ // saves instructions.
325
+ //
326
+ // Works in i16 mode in order to save instructions. One may
327
+ // switch from i16 to u16.
328
+ inline void cmplt_min_max_fast(
329
+ const simd16uint16 candidateValues,
330
+ const simd16uint16 candidateIndices,
331
+ const simd16uint16 currentValues,
332
+ const simd16uint16 currentIndices,
333
+ simd16uint16& minValues,
334
+ simd16uint16& minIndices,
335
+ simd16uint16& maxValues,
336
+ simd16uint16& maxIndices) {
337
+ // there's no lt instruction, so we'll need to emulate one
338
+ __m256i comparison = _mm256_cmpgt_epi16(currentValues.i, candidateValues.i);
339
+ comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi16(-1));
340
+
341
+ minValues.i = _mm256_min_epi16(candidateValues.i, currentValues.i);
342
+ minIndices.i = _mm256_blendv_epi8(
343
+ candidateIndices.i, currentIndices.i, comparison);
344
+ maxValues.i = _mm256_max_epi16(candidateValues.i, currentValues.i);
345
+ maxIndices.i = _mm256_blendv_epi8(
346
+ currentIndices.i, candidateIndices.i, comparison);
347
+ }
348
+
258
349
  // vector of 32 unsigned 8-bit integers
259
350
  struct simd32uint8 : simd256bit {
260
351
  simd32uint8() {}
@@ -265,6 +356,75 @@ struct simd32uint8 : simd256bit {
265
356
 
266
357
  explicit simd32uint8(uint8_t x) : simd256bit(_mm256_set1_epi8(x)) {}
267
358
 
359
+ template <
360
+ uint8_t _0,
361
+ uint8_t _1,
362
+ uint8_t _2,
363
+ uint8_t _3,
364
+ uint8_t _4,
365
+ uint8_t _5,
366
+ uint8_t _6,
367
+ uint8_t _7,
368
+ uint8_t _8,
369
+ uint8_t _9,
370
+ uint8_t _10,
371
+ uint8_t _11,
372
+ uint8_t _12,
373
+ uint8_t _13,
374
+ uint8_t _14,
375
+ uint8_t _15,
376
+ uint8_t _16,
377
+ uint8_t _17,
378
+ uint8_t _18,
379
+ uint8_t _19,
380
+ uint8_t _20,
381
+ uint8_t _21,
382
+ uint8_t _22,
383
+ uint8_t _23,
384
+ uint8_t _24,
385
+ uint8_t _25,
386
+ uint8_t _26,
387
+ uint8_t _27,
388
+ uint8_t _28,
389
+ uint8_t _29,
390
+ uint8_t _30,
391
+ uint8_t _31>
392
+ static simd32uint8 create() {
393
+ return simd32uint8(_mm256_setr_epi8(
394
+ (char)_0,
395
+ (char)_1,
396
+ (char)_2,
397
+ (char)_3,
398
+ (char)_4,
399
+ (char)_5,
400
+ (char)_6,
401
+ (char)_7,
402
+ (char)_8,
403
+ (char)_9,
404
+ (char)_10,
405
+ (char)_11,
406
+ (char)_12,
407
+ (char)_13,
408
+ (char)_14,
409
+ (char)_15,
410
+ (char)_16,
411
+ (char)_17,
412
+ (char)_18,
413
+ (char)_19,
414
+ (char)_20,
415
+ (char)_21,
416
+ (char)_22,
417
+ (char)_23,
418
+ (char)_24,
419
+ (char)_25,
420
+ (char)_26,
421
+ (char)_27,
422
+ (char)_28,
423
+ (char)_29,
424
+ (char)_30,
425
+ (char)_31));
426
+ }
427
+
268
428
  explicit simd32uint8(simd256bit x) : simd256bit(x) {}
269
429
 
270
430
  explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
@@ -359,6 +519,40 @@ struct simd8uint32 : simd256bit {
359
519
 
360
520
  explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
361
521
 
522
+ explicit simd8uint32(
523
+ uint32_t u0,
524
+ uint32_t u1,
525
+ uint32_t u2,
526
+ uint32_t u3,
527
+ uint32_t u4,
528
+ uint32_t u5,
529
+ uint32_t u6,
530
+ uint32_t u7)
531
+ : simd256bit(_mm256_setr_epi32(u0, u1, u2, u3, u4, u5, u6, u7)) {}
532
+
533
+ simd8uint32 operator+(simd8uint32 other) const {
534
+ return simd8uint32(_mm256_add_epi32(i, other.i));
535
+ }
536
+
537
+ simd8uint32 operator-(simd8uint32 other) const {
538
+ return simd8uint32(_mm256_sub_epi32(i, other.i));
539
+ }
540
+
541
+ simd8uint32& operator+=(const simd8uint32& other) {
542
+ i = _mm256_add_epi32(i, other.i);
543
+ return *this;
544
+ }
545
+
546
+ bool operator==(simd8uint32 other) const {
547
+ const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
548
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
549
+ return (bitmask == 0xffffffffU);
550
+ }
551
+
552
+ bool operator!=(simd8uint32 other) const {
553
+ return !(*this == other);
554
+ }
555
+
362
556
  std::string elements_to_string(const char* fmt) const {
363
557
  uint32_t bytes[8];
364
558
  storeu((void*)bytes);
@@ -383,8 +577,49 @@ struct simd8uint32 : simd256bit {
383
577
  void set1(uint32_t x) {
384
578
  i = _mm256_set1_epi32((int)x);
385
579
  }
580
+
581
+ simd8uint32 unzip() const {
582
+ return simd8uint32(_mm256_permutevar8x32_epi32(
583
+ i, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)));
584
+ }
386
585
  };
387
586
 
587
+ // Vectorized version of the following code:
588
+ // for (size_t i = 0; i < n; i++) {
589
+ // bool flag = (candidateValues[i] < currentValues[i]);
590
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
591
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
592
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
593
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
594
+ // }
595
+ // Max indices evaluation is inaccurate in case of equal values (the index of
596
+ // the last equal value is saved instead of the first one), but this behavior
597
+ // saves instructions.
598
+ inline void cmplt_min_max_fast(
599
+ const simd8uint32 candidateValues,
600
+ const simd8uint32 candidateIndices,
601
+ const simd8uint32 currentValues,
602
+ const simd8uint32 currentIndices,
603
+ simd8uint32& minValues,
604
+ simd8uint32& minIndices,
605
+ simd8uint32& maxValues,
606
+ simd8uint32& maxIndices) {
607
+ // there's no lt instruction, so we'll need to emulate one
608
+ __m256i comparison = _mm256_cmpgt_epi32(currentValues.i, candidateValues.i);
609
+ comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi32(-1));
610
+
611
+ minValues.i = _mm256_min_epi32(candidateValues.i, currentValues.i);
612
+ minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
613
+ _mm256_castsi256_ps(candidateIndices.i),
614
+ _mm256_castsi256_ps(currentIndices.i),
615
+ _mm256_castsi256_ps(comparison)));
616
+ maxValues.i = _mm256_max_epi32(candidateValues.i, currentValues.i);
617
+ maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
618
+ _mm256_castsi256_ps(currentIndices.i),
619
+ _mm256_castsi256_ps(candidateIndices.i),
620
+ _mm256_castsi256_ps(comparison)));
621
+ }
622
+
388
623
  struct simd8float32 : simd256bit {
389
624
  simd8float32() {}
390
625
 
@@ -394,7 +629,18 @@ struct simd8float32 : simd256bit {
394
629
 
395
630
  explicit simd8float32(float x) : simd256bit(_mm256_set1_ps(x)) {}
396
631
 
397
- explicit simd8float32(const float* x) : simd256bit(_mm256_load_ps(x)) {}
632
+ explicit simd8float32(const float* x) : simd256bit(_mm256_loadu_ps(x)) {}
633
+
634
+ explicit simd8float32(
635
+ float f0,
636
+ float f1,
637
+ float f2,
638
+ float f3,
639
+ float f4,
640
+ float f5,
641
+ float f6,
642
+ float f7)
643
+ : simd256bit(_mm256_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7)) {}
398
644
 
399
645
  simd8float32 operator*(simd8float32 other) const {
400
646
  return simd8float32(_mm256_mul_ps(f, other.f));
@@ -408,6 +654,22 @@ struct simd8float32 : simd256bit {
408
654
  return simd8float32(_mm256_sub_ps(f, other.f));
409
655
  }
410
656
 
657
+ simd8float32& operator+=(const simd8float32& other) {
658
+ f = _mm256_add_ps(f, other.f);
659
+ return *this;
660
+ }
661
+
662
+ bool operator==(simd8float32 other) const {
663
+ const __m256i pcmp =
664
+ _mm256_castps_si256(_mm256_cmp_ps(f, other.f, _CMP_EQ_OQ));
665
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
666
+ return (bitmask == 0xffffffffU);
667
+ }
668
+
669
+ bool operator!=(simd8float32 other) const {
670
+ return !(*this == other);
671
+ }
672
+
411
673
  std::string tostring() const {
412
674
  float tab[8];
413
675
  storeu((void*)tab);
@@ -439,6 +701,85 @@ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
439
701
  return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f));
440
702
  }
441
703
 
704
+ // The following primitive is a vectorized version of the following code
705
+ // snippet:
706
+ // float lowestValue = HUGE_VAL;
707
+ // uint lowestIndex = 0;
708
+ // for (size_t i = 0; i < n; i++) {
709
+ // if (values[i] < lowestValue) {
710
+ // lowestValue = values[i];
711
+ // lowestIndex = i;
712
+ // }
713
+ // }
714
+ // Vectorized version can be implemented via two operations: cmp and blend
715
+ // with something like this:
716
+ // lowestValues = [HUGE_VAL; 8];
717
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
718
+ // for (size_t i = 0; i < n; i += 8) {
719
+ // auto comparison = cmp(values + i, lowestValues);
720
+ // lowestValues = blend(
721
+ // comparison,
722
+ // values + i,
723
+ // lowestValues);
724
+ // lowestIndices = blend(
725
+ // comparison,
726
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
727
+ // lowestIndices);
728
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
729
+ // }
730
+ // The problem is that blend primitive needs very different instruction
731
+ // order for AVX and ARM.
732
+ // So, let's introduce a combination of these two in order to avoid
733
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
734
+ // these two ops (cmp and blend) are very often used together.
735
+ inline void cmplt_and_blend_inplace(
736
+ const simd8float32 candidateValues,
737
+ const simd8uint32 candidateIndices,
738
+ simd8float32& lowestValues,
739
+ simd8uint32& lowestIndices) {
740
+ const __m256 comparison =
741
+ _mm256_cmp_ps(lowestValues.f, candidateValues.f, _CMP_LE_OS);
742
+ lowestValues.f = _mm256_min_ps(candidateValues.f, lowestValues.f);
743
+ lowestIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
744
+ _mm256_castsi256_ps(candidateIndices.i),
745
+ _mm256_castsi256_ps(lowestIndices.i),
746
+ comparison));
747
+ }
748
+
749
+ // Vectorized version of the following code:
750
+ // for (size_t i = 0; i < n; i++) {
751
+ // bool flag = (candidateValues[i] < currentValues[i]);
752
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
753
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
754
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
755
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
756
+ // }
757
+ // Max indices evaluation is inaccurate in case of equal values (the index of
758
+ // the last equal value is saved instead of the first one), but this behavior
759
+ // saves instructions.
760
+ inline void cmplt_min_max_fast(
761
+ const simd8float32 candidateValues,
762
+ const simd8uint32 candidateIndices,
763
+ const simd8float32 currentValues,
764
+ const simd8uint32 currentIndices,
765
+ simd8float32& minValues,
766
+ simd8uint32& minIndices,
767
+ simd8float32& maxValues,
768
+ simd8uint32& maxIndices) {
769
+ const __m256 comparison =
770
+ _mm256_cmp_ps(currentValues.f, candidateValues.f, _CMP_LE_OS);
771
+ minValues.f = _mm256_min_ps(candidateValues.f, currentValues.f);
772
+ minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
773
+ _mm256_castsi256_ps(candidateIndices.i),
774
+ _mm256_castsi256_ps(currentIndices.i),
775
+ comparison));
776
+ maxValues.f = _mm256_max_ps(candidateValues.f, currentValues.f);
777
+ maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
778
+ _mm256_castsi256_ps(currentIndices.i),
779
+ _mm256_castsi256_ps(candidateIndices.i),
780
+ comparison));
781
+ }
782
+
442
783
  namespace {
443
784
 
444
785
  // get even float32's of a and b, interleaved