faiss 0.2.5 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (191) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/extconf.rb +1 -1
  5. data/ext/faiss/index.cpp +13 -0
  6. data/lib/faiss/version.rb +1 -1
  7. data/lib/faiss.rb +2 -2
  8. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  9. data/vendor/faiss/faiss/AutoTune.h +0 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  11. data/vendor/faiss/faiss/Clustering.h +0 -2
  12. data/vendor/faiss/faiss/IVFlib.h +0 -2
  13. data/vendor/faiss/faiss/Index.h +1 -2
  14. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  16. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  17. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  19. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  20. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  21. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  22. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  23. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  24. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  25. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  26. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  27. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  29. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  30. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  31. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  32. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  33. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  34. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  35. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  36. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  37. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  38. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  39. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  41. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  43. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  44. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  45. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  46. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  47. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  48. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  49. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  50. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  51. data/vendor/faiss/faiss/IndexShards.h +2 -3
  52. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  53. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  54. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  55. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  56. data/vendor/faiss/faiss/MetricType.h +14 -0
  57. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  58. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  59. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  60. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  61. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  62. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  69. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  70. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  71. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  72. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  73. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  74. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  75. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  76. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  77. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  78. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  81. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  82. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  83. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  84. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  85. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  86. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  87. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  91. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  92. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  93. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  94. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  95. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  96. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  97. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  98. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  99. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  100. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  101. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  102. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  104. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  105. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  106. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  111. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  112. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  113. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  114. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  115. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  116. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  117. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  118. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  119. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  121. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  125. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  128. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  129. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  130. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  131. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  132. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  133. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  134. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  139. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  140. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  141. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  142. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  143. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  144. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  145. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  146. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  147. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  148. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  149. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  150. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  151. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  152. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  153. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  155. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  156. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  157. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  158. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  159. data/vendor/faiss/faiss/utils/distances.h +11 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  164. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  165. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  166. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  167. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  168. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  169. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  170. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  171. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  172. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  173. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  174. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  175. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  176. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  179. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  180. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  181. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  182. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  183. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  184. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  185. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  186. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  187. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  188. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  189. data/vendor/faiss/faiss/utils/utils.h +2 -9
  190. metadata +30 -4
  191. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -817,7 +817,7 @@ template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
817
817
  * Histogram subroutines
818
818
  ******************************************************************/
819
819
 
820
- #ifdef __AVX2__
820
+ #if defined(__AVX2__) || defined(__aarch64__)
821
821
  /// FIXME when MSB of uint16 is set
822
822
  // this code does not compile properly with GCC 7.4.0
823
823
 
@@ -833,7 +833,7 @@ simd32uint8 accu4to8(simd16uint16 a4) {
833
833
  simd16uint16 a8_0 = a4 & mask4;
834
834
  simd16uint16 a8_1 = (a4 >> 4) & mask4;
835
835
 
836
- return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
836
+ return simd32uint8(hadd(a8_0, a8_1));
837
837
  }
838
838
 
839
839
  simd16uint16 accu8to16(simd32uint8 a8) {
@@ -842,10 +842,10 @@ simd16uint16 accu8to16(simd32uint8 a8) {
842
842
  simd16uint16 a8_0 = simd16uint16(a8) & mask8;
843
843
  simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
844
844
 
845
- return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
845
+ return hadd(a8_0, a8_1);
846
846
  }
847
847
 
848
- static const simd32uint8 shifts(_mm256_setr_epi8(
848
+ static const simd32uint8 shifts = simd32uint8::create<
849
849
  1,
850
850
  16,
851
851
  0,
@@ -877,7 +877,7 @@ static const simd32uint8 shifts(_mm256_setr_epi8(
877
877
  0,
878
878
  0,
879
879
  4,
880
- 64));
880
+ 64>();
881
881
 
882
882
  // 2-bit accumulator: we can add only up to 3 elements
883
883
  // on output we return 2*4-bit results
@@ -937,7 +937,7 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
937
937
  simd16uint16 a16lo = accu8to16(a8lo);
938
938
  simd16uint16 a16hi = accu8to16(a8hi);
939
939
 
940
- simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
940
+ simd16uint16 a16 = hadd(a16lo, a16hi);
941
941
 
942
942
  // the 2 lanes must still be combined
943
943
  return a16;
@@ -947,7 +947,7 @@ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
947
947
  * 16 bins
948
948
  ************************************************************/
949
949
 
950
- static const simd32uint8 shifts2(_mm256_setr_epi8(
950
+ static const simd32uint8 shifts2 = simd32uint8::create<
951
951
  1,
952
952
  2,
953
953
  4,
@@ -955,7 +955,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
955
955
  16,
956
956
  32,
957
957
  64,
958
- (char)128,
958
+ 128,
959
959
  1,
960
960
  2,
961
961
  4,
@@ -963,7 +963,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
963
963
  16,
964
964
  32,
965
965
  64,
966
- (char)128,
966
+ 128,
967
967
  1,
968
968
  2,
969
969
  4,
@@ -971,7 +971,7 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
971
971
  16,
972
972
  32,
973
973
  64,
974
- (char)128,
974
+ 128,
975
975
  1,
976
976
  2,
977
977
  4,
@@ -979,19 +979,12 @@ static const simd32uint8 shifts2(_mm256_setr_epi8(
979
979
  16,
980
980
  32,
981
981
  64,
982
- (char)128));
982
+ 128>();
983
983
 
984
984
  simd32uint8 shiftr_16(simd32uint8 x, int n) {
985
985
  return simd32uint8(simd16uint16(x) >> n);
986
986
  }
987
987
 
988
- inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
989
- __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
990
- __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
991
-
992
- return simd32uint8(a1b0) + simd32uint8(a0b1);
993
- }
994
-
995
988
  // 2-bit accumulator: we can add only up to 3 elements
996
989
  // on output we return 2*4-bit results
997
990
  template <int N, class Preproc>
@@ -1018,7 +1011,7 @@ void compute_accu2_16(
1018
1011
  // contains 0s for out-of-bounds elements
1019
1012
 
1020
1013
  simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
1021
- lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
1014
+ lt8 = lt8 ^ simd16uint16(0xff00);
1022
1015
 
1023
1016
  a1 = a1 & lt8;
1024
1017
 
@@ -1036,11 +1029,15 @@ void compute_accu2_16(
1036
1029
  simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
1037
1030
  simd32uint8 mask4(0x0f);
1038
1031
 
1039
- simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
1032
+ simd16uint16 a8_0 = combine2x2(
1033
+ (simd16uint16)(a4_0 & mask4),
1034
+ (simd16uint16)(shiftr_16(a4_0, 4) & mask4));
1040
1035
 
1041
- simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
1036
+ simd16uint16 a8_1 = combine2x2(
1037
+ (simd16uint16)(a4_1 & mask4),
1038
+ (simd16uint16)(shiftr_16(a4_1, 4) & mask4));
1042
1039
 
1043
- return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
1040
+ return simd32uint8(hadd(a8_0, a8_1));
1044
1041
  }
1045
1042
 
1046
1043
  template <class Preproc>
@@ -1079,10 +1076,9 @@ simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
1079
1076
  simd16uint16 a16lo = accu8to16(a8lo);
1080
1077
  simd16uint16 a16hi = accu8to16(a8hi);
1081
1078
 
1082
- simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
1079
+ simd16uint16 a16 = hadd(a16lo, a16hi);
1083
1080
 
1084
- __m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
1085
- a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
1081
+ a16 = simd16uint16{simd8uint32{a16}.unzip()};
1086
1082
 
1087
1083
  return a16;
1088
1084
  }
@@ -70,6 +70,13 @@ struct simd256bit {
70
70
  bin(bits);
71
71
  return std::string(bits);
72
72
  }
73
+
74
+ // Checks whether the other holds exactly the same bytes.
75
+ bool is_same_as(simd256bit other) const {
76
+ const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
77
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
78
+ return (bitmask == 0xffffffffU);
79
+ }
73
80
  };
74
81
 
75
82
  /// vector of 16 elements in uint16
@@ -86,6 +93,41 @@ struct simd16uint16 : simd256bit {
86
93
 
87
94
  explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
88
95
 
96
+ explicit simd16uint16(
97
+ uint16_t u0,
98
+ uint16_t u1,
99
+ uint16_t u2,
100
+ uint16_t u3,
101
+ uint16_t u4,
102
+ uint16_t u5,
103
+ uint16_t u6,
104
+ uint16_t u7,
105
+ uint16_t u8,
106
+ uint16_t u9,
107
+ uint16_t u10,
108
+ uint16_t u11,
109
+ uint16_t u12,
110
+ uint16_t u13,
111
+ uint16_t u14,
112
+ uint16_t u15)
113
+ : simd256bit(_mm256_setr_epi16(
114
+ u0,
115
+ u1,
116
+ u2,
117
+ u3,
118
+ u4,
119
+ u5,
120
+ u6,
121
+ u7,
122
+ u8,
123
+ u9,
124
+ u10,
125
+ u11,
126
+ u12,
127
+ u13,
128
+ u14,
129
+ u15)) {}
130
+
89
131
  std::string elements_to_string(const char* fmt) const {
90
132
  uint16_t bytes[16];
91
133
  storeu((void*)bytes);
@@ -151,9 +193,19 @@ struct simd16uint16 : simd256bit {
151
193
  return simd16uint16(_mm256_or_si256(i, other.i));
152
194
  }
153
195
 
196
+ simd16uint16 operator^(simd256bit other) const {
197
+ return simd16uint16(_mm256_xor_si256(i, other.i));
198
+ }
199
+
154
200
  // returns binary masks
155
- simd16uint16 operator==(simd256bit other) const {
156
- return simd16uint16(_mm256_cmpeq_epi16(i, other.i));
201
+ friend simd16uint16 operator==(const simd256bit lhs, const simd256bit rhs) {
202
+ return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
203
+ }
204
+
205
+ bool is_same(simd16uint16 other) const {
206
+ const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i);
207
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
208
+ return (bitmask == 0xffffffffU);
157
209
  }
158
210
 
159
211
  simd16uint16 operator~() const {
@@ -255,6 +307,45 @@ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
255
307
  return ge;
256
308
  }
257
309
 
310
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
311
+ return simd16uint16(_mm256_hadd_epi16(a.i, b.i));
312
+ }
313
+
314
+ // Vectorized version of the following code:
315
+ // for (size_t i = 0; i < n; i++) {
316
+ // bool flag = (candidateValues[i] < currentValues[i]);
317
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
318
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
319
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
320
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
321
+ // }
322
+ // Max indices evaluation is inaccurate in case of equal values (the index of
323
+ // the last equal value is saved instead of the first one), but this behavior
324
+ // saves instructions.
325
+ //
326
+ // Works in i16 mode in order to save instructions. One may
327
+ // switch from i16 to u16.
328
+ inline void cmplt_min_max_fast(
329
+ const simd16uint16 candidateValues,
330
+ const simd16uint16 candidateIndices,
331
+ const simd16uint16 currentValues,
332
+ const simd16uint16 currentIndices,
333
+ simd16uint16& minValues,
334
+ simd16uint16& minIndices,
335
+ simd16uint16& maxValues,
336
+ simd16uint16& maxIndices) {
337
+ // there's no lt instruction, so we'll need to emulate one
338
+ __m256i comparison = _mm256_cmpgt_epi16(currentValues.i, candidateValues.i);
339
+ comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi16(-1));
340
+
341
+ minValues.i = _mm256_min_epi16(candidateValues.i, currentValues.i);
342
+ minIndices.i = _mm256_blendv_epi8(
343
+ candidateIndices.i, currentIndices.i, comparison);
344
+ maxValues.i = _mm256_max_epi16(candidateValues.i, currentValues.i);
345
+ maxIndices.i = _mm256_blendv_epi8(
346
+ currentIndices.i, candidateIndices.i, comparison);
347
+ }
348
+
258
349
  // vector of 32 unsigned 8-bit integers
259
350
  struct simd32uint8 : simd256bit {
260
351
  simd32uint8() {}
@@ -265,6 +356,75 @@ struct simd32uint8 : simd256bit {
265
356
 
266
357
  explicit simd32uint8(uint8_t x) : simd256bit(_mm256_set1_epi8(x)) {}
267
358
 
359
+ template <
360
+ uint8_t _0,
361
+ uint8_t _1,
362
+ uint8_t _2,
363
+ uint8_t _3,
364
+ uint8_t _4,
365
+ uint8_t _5,
366
+ uint8_t _6,
367
+ uint8_t _7,
368
+ uint8_t _8,
369
+ uint8_t _9,
370
+ uint8_t _10,
371
+ uint8_t _11,
372
+ uint8_t _12,
373
+ uint8_t _13,
374
+ uint8_t _14,
375
+ uint8_t _15,
376
+ uint8_t _16,
377
+ uint8_t _17,
378
+ uint8_t _18,
379
+ uint8_t _19,
380
+ uint8_t _20,
381
+ uint8_t _21,
382
+ uint8_t _22,
383
+ uint8_t _23,
384
+ uint8_t _24,
385
+ uint8_t _25,
386
+ uint8_t _26,
387
+ uint8_t _27,
388
+ uint8_t _28,
389
+ uint8_t _29,
390
+ uint8_t _30,
391
+ uint8_t _31>
392
+ static simd32uint8 create() {
393
+ return simd32uint8(_mm256_setr_epi8(
394
+ (char)_0,
395
+ (char)_1,
396
+ (char)_2,
397
+ (char)_3,
398
+ (char)_4,
399
+ (char)_5,
400
+ (char)_6,
401
+ (char)_7,
402
+ (char)_8,
403
+ (char)_9,
404
+ (char)_10,
405
+ (char)_11,
406
+ (char)_12,
407
+ (char)_13,
408
+ (char)_14,
409
+ (char)_15,
410
+ (char)_16,
411
+ (char)_17,
412
+ (char)_18,
413
+ (char)_19,
414
+ (char)_20,
415
+ (char)_21,
416
+ (char)_22,
417
+ (char)_23,
418
+ (char)_24,
419
+ (char)_25,
420
+ (char)_26,
421
+ (char)_27,
422
+ (char)_28,
423
+ (char)_29,
424
+ (char)_30,
425
+ (char)_31));
426
+ }
427
+
268
428
  explicit simd32uint8(simd256bit x) : simd256bit(x) {}
269
429
 
270
430
  explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
@@ -359,6 +519,40 @@ struct simd8uint32 : simd256bit {
359
519
 
360
520
  explicit simd8uint32(const uint8_t* x) : simd256bit((const void*)x) {}
361
521
 
522
+ explicit simd8uint32(
523
+ uint32_t u0,
524
+ uint32_t u1,
525
+ uint32_t u2,
526
+ uint32_t u3,
527
+ uint32_t u4,
528
+ uint32_t u5,
529
+ uint32_t u6,
530
+ uint32_t u7)
531
+ : simd256bit(_mm256_setr_epi32(u0, u1, u2, u3, u4, u5, u6, u7)) {}
532
+
533
+ simd8uint32 operator+(simd8uint32 other) const {
534
+ return simd8uint32(_mm256_add_epi32(i, other.i));
535
+ }
536
+
537
+ simd8uint32 operator-(simd8uint32 other) const {
538
+ return simd8uint32(_mm256_sub_epi32(i, other.i));
539
+ }
540
+
541
+ simd8uint32& operator+=(const simd8uint32& other) {
542
+ i = _mm256_add_epi32(i, other.i);
543
+ return *this;
544
+ }
545
+
546
+ bool operator==(simd8uint32 other) const {
547
+ const __m256i pcmp = _mm256_cmpeq_epi32(i, other.i);
548
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
549
+ return (bitmask == 0xffffffffU);
550
+ }
551
+
552
+ bool operator!=(simd8uint32 other) const {
553
+ return !(*this == other);
554
+ }
555
+
362
556
  std::string elements_to_string(const char* fmt) const {
363
557
  uint32_t bytes[8];
364
558
  storeu((void*)bytes);
@@ -383,8 +577,49 @@ struct simd8uint32 : simd256bit {
383
577
  void set1(uint32_t x) {
384
578
  i = _mm256_set1_epi32((int)x);
385
579
  }
580
+
581
+ simd8uint32 unzip() const {
582
+ return simd8uint32(_mm256_permutevar8x32_epi32(
583
+ i, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)));
584
+ }
386
585
  };
387
586
 
587
+ // Vectorized version of the following code:
588
+ // for (size_t i = 0; i < n; i++) {
589
+ // bool flag = (candidateValues[i] < currentValues[i]);
590
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
591
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
592
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
593
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
594
+ // }
595
+ // Max indices evaluation is inaccurate in case of equal values (the index of
596
+ // the last equal value is saved instead of the first one), but this behavior
597
+ // saves instructions.
598
+ inline void cmplt_min_max_fast(
599
+ const simd8uint32 candidateValues,
600
+ const simd8uint32 candidateIndices,
601
+ const simd8uint32 currentValues,
602
+ const simd8uint32 currentIndices,
603
+ simd8uint32& minValues,
604
+ simd8uint32& minIndices,
605
+ simd8uint32& maxValues,
606
+ simd8uint32& maxIndices) {
607
+ // there's no lt instruction, so we'll need to emulate one
608
+ __m256i comparison = _mm256_cmpgt_epi32(currentValues.i, candidateValues.i);
609
+ comparison = _mm256_andnot_si256(comparison, _mm256_set1_epi32(-1));
610
+
611
+ minValues.i = _mm256_min_epi32(candidateValues.i, currentValues.i);
612
+ minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
613
+ _mm256_castsi256_ps(candidateIndices.i),
614
+ _mm256_castsi256_ps(currentIndices.i),
615
+ _mm256_castsi256_ps(comparison)));
616
+ maxValues.i = _mm256_max_epi32(candidateValues.i, currentValues.i);
617
+ maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
618
+ _mm256_castsi256_ps(currentIndices.i),
619
+ _mm256_castsi256_ps(candidateIndices.i),
620
+ _mm256_castsi256_ps(comparison)));
621
+ }
622
+
388
623
  struct simd8float32 : simd256bit {
389
624
  simd8float32() {}
390
625
 
@@ -394,7 +629,18 @@ struct simd8float32 : simd256bit {
394
629
 
395
630
  explicit simd8float32(float x) : simd256bit(_mm256_set1_ps(x)) {}
396
631
 
397
- explicit simd8float32(const float* x) : simd256bit(_mm256_load_ps(x)) {}
632
+ explicit simd8float32(const float* x) : simd256bit(_mm256_loadu_ps(x)) {}
633
+
634
+ explicit simd8float32(
635
+ float f0,
636
+ float f1,
637
+ float f2,
638
+ float f3,
639
+ float f4,
640
+ float f5,
641
+ float f6,
642
+ float f7)
643
+ : simd256bit(_mm256_setr_ps(f0, f1, f2, f3, f4, f5, f6, f7)) {}
398
644
 
399
645
  simd8float32 operator*(simd8float32 other) const {
400
646
  return simd8float32(_mm256_mul_ps(f, other.f));
@@ -408,6 +654,22 @@ struct simd8float32 : simd256bit {
408
654
  return simd8float32(_mm256_sub_ps(f, other.f));
409
655
  }
410
656
 
657
+ simd8float32& operator+=(const simd8float32& other) {
658
+ f = _mm256_add_ps(f, other.f);
659
+ return *this;
660
+ }
661
+
662
+ bool operator==(simd8float32 other) const {
663
+ const __m256i pcmp =
664
+ _mm256_castps_si256(_mm256_cmp_ps(f, other.f, _CMP_EQ_OQ));
665
+ unsigned bitmask = _mm256_movemask_epi8(pcmp);
666
+ return (bitmask == 0xffffffffU);
667
+ }
668
+
669
+ bool operator!=(simd8float32 other) const {
670
+ return !(*this == other);
671
+ }
672
+
411
673
  std::string tostring() const {
412
674
  float tab[8];
413
675
  storeu((void*)tab);
@@ -439,6 +701,85 @@ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
439
701
  return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f));
440
702
  }
441
703
 
704
+ // The following primitive is a vectorized version of the following code
705
+ // snippet:
706
+ // float lowestValue = HUGE_VAL;
707
+ // uint lowestIndex = 0;
708
+ // for (size_t i = 0; i < n; i++) {
709
+ // if (values[i] < lowestValue) {
710
+ // lowestValue = values[i];
711
+ // lowestIndex = i;
712
+ // }
713
+ // }
714
+ // Vectorized version can be implemented via two operations: cmp and blend
715
+ // with something like this:
716
+ // lowestValues = [HUGE_VAL; 8];
717
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
718
+ // for (size_t i = 0; i < n; i += 8) {
719
+ // auto comparison = cmp(values + i, lowestValues);
720
+ // lowestValues = blend(
721
+ // comparison,
722
+ // values + i,
723
+ // lowestValues);
724
+ // lowestIndices = blend(
725
+ // comparison,
726
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
727
+ // lowestIndices);
728
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
729
+ // }
730
+ // The problem is that blend primitive needs very different instruction
731
+ // order for AVX and ARM.
732
+ // So, let's introduce a combination of these two in order to avoid
733
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
734
+ // these two ops (cmp and blend) are very often used together.
735
+ inline void cmplt_and_blend_inplace(
736
+ const simd8float32 candidateValues,
737
+ const simd8uint32 candidateIndices,
738
+ simd8float32& lowestValues,
739
+ simd8uint32& lowestIndices) {
740
+ const __m256 comparison =
741
+ _mm256_cmp_ps(lowestValues.f, candidateValues.f, _CMP_LE_OS);
742
+ lowestValues.f = _mm256_min_ps(candidateValues.f, lowestValues.f);
743
+ lowestIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
744
+ _mm256_castsi256_ps(candidateIndices.i),
745
+ _mm256_castsi256_ps(lowestIndices.i),
746
+ comparison));
747
+ }
748
+
749
+ // Vectorized version of the following code:
750
+ // for (size_t i = 0; i < n; i++) {
751
+ // bool flag = (candidateValues[i] < currentValues[i]);
752
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
753
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
754
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
755
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
756
+ // }
757
+ // Max indices evaluation is inaccurate in case of equal values (the index of
758
+ // the last equal value is saved instead of the first one), but this behavior
759
+ // saves instructions.
760
+ inline void cmplt_min_max_fast(
761
+ const simd8float32 candidateValues,
762
+ const simd8uint32 candidateIndices,
763
+ const simd8float32 currentValues,
764
+ const simd8uint32 currentIndices,
765
+ simd8float32& minValues,
766
+ simd8uint32& minIndices,
767
+ simd8float32& maxValues,
768
+ simd8uint32& maxIndices) {
769
+ const __m256 comparison =
770
+ _mm256_cmp_ps(currentValues.f, candidateValues.f, _CMP_LE_OS);
771
+ minValues.f = _mm256_min_ps(candidateValues.f, currentValues.f);
772
+ minIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
773
+ _mm256_castsi256_ps(candidateIndices.i),
774
+ _mm256_castsi256_ps(currentIndices.i),
775
+ comparison));
776
+ maxValues.f = _mm256_max_ps(candidateValues.f, currentValues.f);
777
+ maxIndices.i = _mm256_castps_si256(_mm256_blendv_ps(
778
+ _mm256_castsi256_ps(currentIndices.i),
779
+ _mm256_castsi256_ps(candidateIndices.i),
780
+ comparison));
781
+ }
782
+
442
783
  namespace {
443
784
 
444
785
  // get even float32's of a and b, interleaved