faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,296 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstdint>
11
+ #include <string>
12
+
13
+ #include <immintrin.h>
14
+
15
+ #include <faiss/impl/platform_macros.h>
16
+
17
+ #include <faiss/utils/simdlib_avx2.h>
18
+
19
+ namespace faiss {
20
+
21
+ /** Simple wrapper around the AVX 512-bit registers
22
+ *
23
+ * The objective is to separate the different interpretations of the same
24
+ * registers (as a vector of uint8, uint16 or uint32), to provide printing
25
+ * functions, and to give more readable names to the AVX intrinsics. It does not
26
+ * pretend to be exhausitve, functions are added as needed.
27
+ */
28
+
29
+ /// 512-bit representation without interpretation as a vector
30
+ struct simd512bit {
31
+ union {
32
+ __m512i i;
33
+ __m512 f;
34
+ };
35
+
36
+ simd512bit() {}
37
+
38
+ explicit simd512bit(__m512i i) : i(i) {}
39
+
40
+ explicit simd512bit(__m512 f) : f(f) {}
41
+
42
+ explicit simd512bit(const void* x)
43
+ : i(_mm512_loadu_si512((__m512i const*)x)) {}
44
+
45
+ // sets up a lower half of the register while keeping upper one as zero
46
+ explicit simd512bit(simd256bit lo)
47
+ : simd512bit(_mm512_inserti32x8(
48
+ _mm512_castsi256_si512(lo.i),
49
+ _mm256_setzero_si256(),
50
+ 1)) {}
51
+
52
+ // constructs from lower and upper halves
53
+ explicit simd512bit(simd256bit lo, simd256bit hi)
54
+ : simd512bit(_mm512_inserti32x8(
55
+ _mm512_castsi256_si512(lo.i),
56
+ hi.i,
57
+ 1)) {}
58
+
59
+ void clear() {
60
+ i = _mm512_setzero_si512();
61
+ }
62
+
63
+ void storeu(void* ptr) const {
64
+ _mm512_storeu_si512((__m512i*)ptr, i);
65
+ }
66
+
67
+ void loadu(const void* ptr) {
68
+ i = _mm512_loadu_si512((__m512i*)ptr);
69
+ }
70
+
71
+ void store(void* ptr) const {
72
+ _mm512_storeu_si512((__m512i*)ptr, i);
73
+ }
74
+
75
+ void bin(char bits[513]) const {
76
+ char bytes[64];
77
+ storeu((void*)bytes);
78
+ for (int i = 0; i < 512; i++) {
79
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
80
+ }
81
+ bits[512] = 0;
82
+ }
83
+
84
+ std::string bin() const {
85
+ char bits[257];
86
+ bin(bits);
87
+ return std::string(bits);
88
+ }
89
+ };
90
+
91
+ /// vector of 32 elements in uint16
92
+ struct simd32uint16 : simd512bit {
93
+ simd32uint16() {}
94
+
95
+ explicit simd32uint16(__m512i i) : simd512bit(i) {}
96
+
97
+ explicit simd32uint16(int x) : simd512bit(_mm512_set1_epi16(x)) {}
98
+
99
+ explicit simd32uint16(uint16_t x) : simd512bit(_mm512_set1_epi16(x)) {}
100
+
101
+ explicit simd32uint16(simd512bit x) : simd512bit(x) {}
102
+
103
+ explicit simd32uint16(const uint16_t* x) : simd512bit((const void*)x) {}
104
+
105
+ // sets up a lower half of the register
106
+ explicit simd32uint16(simd256bit lo) : simd512bit(lo) {}
107
+
108
+ // constructs from lower and upper halves
109
+ explicit simd32uint16(simd256bit lo, simd256bit hi) : simd512bit(lo, hi) {}
110
+
111
+ std::string elements_to_string(const char* fmt) const {
112
+ uint16_t bytes[32];
113
+ storeu((void*)bytes);
114
+ char res[2000];
115
+ char* ptr = res;
116
+ for (int i = 0; i < 32; i++) {
117
+ ptr += sprintf(ptr, fmt, bytes[i]);
118
+ }
119
+ // strip last ,
120
+ ptr[-1] = 0;
121
+ return std::string(res);
122
+ }
123
+
124
+ std::string hex() const {
125
+ return elements_to_string("%02x,");
126
+ }
127
+
128
+ std::string dec() const {
129
+ return elements_to_string("%3d,");
130
+ }
131
+
132
+ void set1(uint16_t x) {
133
+ i = _mm512_set1_epi16((short)x);
134
+ }
135
+
136
+ simd32uint16 operator*(const simd32uint16& other) const {
137
+ return simd32uint16(_mm512_mullo_epi16(i, other.i));
138
+ }
139
+
140
+ // shift must be known at compile time
141
+ simd32uint16 operator>>(const int shift) const {
142
+ return simd32uint16(_mm512_srli_epi16(i, shift));
143
+ }
144
+
145
+ // shift must be known at compile time
146
+ simd32uint16 operator<<(const int shift) const {
147
+ return simd32uint16(_mm512_slli_epi16(i, shift));
148
+ }
149
+
150
+ simd32uint16 operator+=(simd32uint16 other) {
151
+ i = _mm512_add_epi16(i, other.i);
152
+ return *this;
153
+ }
154
+
155
+ simd32uint16 operator-=(simd32uint16 other) {
156
+ i = _mm512_sub_epi16(i, other.i);
157
+ return *this;
158
+ }
159
+
160
+ simd32uint16 operator+(simd32uint16 other) const {
161
+ return simd32uint16(_mm512_add_epi16(i, other.i));
162
+ }
163
+
164
+ simd32uint16 operator-(simd32uint16 other) const {
165
+ return simd32uint16(_mm512_sub_epi16(i, other.i));
166
+ }
167
+
168
+ simd32uint16 operator&(simd512bit other) const {
169
+ return simd32uint16(_mm512_and_si512(i, other.i));
170
+ }
171
+
172
+ simd32uint16 operator|(simd512bit other) const {
173
+ return simd32uint16(_mm512_or_si512(i, other.i));
174
+ }
175
+
176
+ simd32uint16 operator^(simd512bit other) const {
177
+ return simd32uint16(_mm512_xor_si512(i, other.i));
178
+ }
179
+
180
+ simd32uint16 operator~() const {
181
+ return simd32uint16(_mm512_xor_si512(i, _mm512_set1_epi32(-1)));
182
+ }
183
+
184
+ simd16uint16 low() const {
185
+ return simd16uint16(_mm512_castsi512_si256(i));
186
+ }
187
+
188
+ simd16uint16 high() const {
189
+ return simd16uint16(_mm512_extracti32x8_epi32(i, 1));
190
+ }
191
+
192
+ // for debugging only
193
+ uint16_t operator[](int i) const {
194
+ ALIGNED(64) uint16_t tab[32];
195
+ store(tab);
196
+ return tab[i];
197
+ }
198
+
199
+ void accu_min(simd32uint16 incoming) {
200
+ i = _mm512_min_epu16(i, incoming.i);
201
+ }
202
+
203
+ void accu_max(simd32uint16 incoming) {
204
+ i = _mm512_max_epu16(i, incoming.i);
205
+ }
206
+ };
207
+
208
+ // decompose in 128-lanes: a = (a0, a1, a2, a3), b = (b0, b1, b2, b3)
209
+ // return (a0 + a1 + a2 + a3, b0 + b1 + b2 + b3)
210
+ inline simd16uint16 combine4x2(simd32uint16 a, simd32uint16 b) {
211
+ return combine2x2(a.low(), b.low()) + combine2x2(a.high(), b.high());
212
+ }
213
+
214
+ // vector of 32 unsigned 8-bit integers
215
+ struct simd64uint8 : simd512bit {
216
+ simd64uint8() {}
217
+
218
+ explicit simd64uint8(__m512i i) : simd512bit(i) {}
219
+
220
+ explicit simd64uint8(int x) : simd512bit(_mm512_set1_epi8(x)) {}
221
+
222
+ explicit simd64uint8(uint8_t x) : simd512bit(_mm512_set1_epi8(x)) {}
223
+
224
+ // sets up a lower half of the register
225
+ explicit simd64uint8(simd256bit lo) : simd512bit(lo) {}
226
+
227
+ // constructs from lower and upper halves
228
+ explicit simd64uint8(simd256bit lo, simd256bit hi) : simd512bit(lo, hi) {}
229
+
230
+ explicit simd64uint8(simd512bit x) : simd512bit(x) {}
231
+
232
+ explicit simd64uint8(const uint8_t* x) : simd512bit((const void*)x) {}
233
+
234
+ std::string elements_to_string(const char* fmt) const {
235
+ uint8_t bytes[64];
236
+ storeu((void*)bytes);
237
+ char res[2000];
238
+ char* ptr = res;
239
+ for (int i = 0; i < 64; i++) {
240
+ ptr += sprintf(ptr, fmt, bytes[i]);
241
+ }
242
+ // strip last ,
243
+ ptr[-1] = 0;
244
+ return std::string(res);
245
+ }
246
+
247
+ std::string hex() const {
248
+ return elements_to_string("%02x,");
249
+ }
250
+
251
+ std::string dec() const {
252
+ return elements_to_string("%3d,");
253
+ }
254
+
255
+ void set1(uint8_t x) {
256
+ i = _mm512_set1_epi8((char)x);
257
+ }
258
+
259
+ simd64uint8 operator&(simd512bit other) const {
260
+ return simd64uint8(_mm512_and_si512(i, other.i));
261
+ }
262
+
263
+ simd64uint8 operator+(simd64uint8 other) const {
264
+ return simd64uint8(_mm512_add_epi8(i, other.i));
265
+ }
266
+
267
+ simd64uint8 lookup_4_lanes(simd64uint8 idx) const {
268
+ return simd64uint8(_mm512_shuffle_epi8(i, idx.i));
269
+ }
270
+
271
+ // extract + 0-extend lane
272
+ // this operation is slow (3 cycles)
273
+ simd32uint16 lane0_as_uint16() const {
274
+ __m256i x = _mm512_extracti32x8_epi32(i, 0);
275
+ return simd32uint16(_mm512_cvtepu8_epi16(x));
276
+ }
277
+
278
+ simd32uint16 lane1_as_uint16() const {
279
+ __m256i x = _mm512_extracti32x8_epi32(i, 1);
280
+ return simd32uint16(_mm512_cvtepu8_epi16(x));
281
+ }
282
+
283
+ simd64uint8 operator+=(simd64uint8 other) {
284
+ i = _mm512_add_epi8(i, other.i);
285
+ return *this;
286
+ }
287
+
288
+ // for debugging only
289
+ uint8_t operator[](int i) const {
290
+ ALIGNED(64) uint8_t tab[64];
291
+ store(tab);
292
+ return tab[i];
293
+ }
294
+ };
295
+
296
+ } // namespace faiss
@@ -168,9 +168,12 @@ static inline std::string elements_to_string(const char* fmt, const S& simd) {
168
168
  simd.store(bytes);
169
169
  char res[1000], *ptr = res;
170
170
  for (size_t i = 0; i < N; ++i) {
171
- ptr += sprintf(ptr, fmt, bytes[i]);
171
+ int bytesWritten =
172
+ snprintf(ptr, sizeof(res) - (ptr - res), fmt, bytes[i]);
173
+ ptr += bytesWritten;
172
174
  }
173
- // strip last ,
175
+ // The format usually contains a ',' separator so this is to remove the last
176
+ // separator.
174
177
  ptr[-1] = 0;
175
178
  return std::string(res);
176
179
  }
@@ -559,15 +562,13 @@ struct simd16uint16 {
559
562
  }
560
563
 
561
564
  // Checks whether the other holds exactly the same bytes.
562
- bool is_same_as(simd16uint16 other) const {
563
- const bool equal0 =
564
- (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565
- 0xffff);
566
- const bool equal1 =
567
- (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568
- 0xffff);
569
-
570
- return equal0 && equal1;
565
+ template <typename T>
566
+ bool is_same_as(T other) const {
567
+ const auto o = detail::simdlib::reinterpret_u16(other.data);
568
+ const auto equals = detail::simdlib::binary_func(data, o)
569
+ .template call<&vceqq_u16>();
570
+ const auto equal = vandq_u16(equals.val[0], equals.val[1]);
571
+ return vminvq_u16(equal) == 0xffffu;
571
572
  }
572
573
 
573
574
  simd16uint16 operator~() const {
@@ -689,13 +690,12 @@ inline void cmplt_min_max_fast(
689
690
  simd16uint16& minIndices,
690
691
  simd16uint16& maxValues,
691
692
  simd16uint16& maxIndices) {
692
- const uint16x8x2_t comparison = uint16x8x2_t{
693
- vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694
- vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
693
+ const uint16x8x2_t comparison =
694
+ detail::simdlib::binary_func(
695
+ candidateValues.data, currentValues.data)
696
+ .call<&vcltq_u16>();
695
697
 
696
- minValues.data = uint16x8x2_t{
697
- vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698
- vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
698
+ minValues = min(candidateValues, currentValues);
699
699
  minIndices.data = uint16x8x2_t{
700
700
  vbslq_u16(
701
701
  comparison.val[0],
@@ -706,9 +706,7 @@ inline void cmplt_min_max_fast(
706
706
  candidateIndices.data.val[1],
707
707
  currentIndices.data.val[1])};
708
708
 
709
- maxValues.data = uint16x8x2_t{
710
- vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711
- vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
709
+ maxValues = max(candidateValues, currentValues);
712
710
  maxIndices.data = uint16x8x2_t{
713
711
  vbslq_u16(
714
712
  comparison.val[0],
@@ -869,13 +867,13 @@ struct simd32uint8 {
869
867
  }
870
868
 
871
869
  // Checks whether the other holds exactly the same bytes.
872
- bool is_same_as(simd32uint8 other) const {
873
- const bool equal0 =
874
- (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875
- const bool equal1 =
876
- (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877
-
878
- return equal0 && equal1;
870
+ template <typename T>
871
+ bool is_same_as(T other) const {
872
+ const auto o = detail::simdlib::reinterpret_u8(other.data);
873
+ const auto equals = detail::simdlib::binary_func(data, o)
874
+ .template call<&vceqq_u8>();
875
+ const auto equal = vandq_u8(equals.val[0], equals.val[1]);
876
+ return vminvq_u8(equal) == 0xffu;
879
877
  }
880
878
  };
881
879
 
@@ -960,27 +958,28 @@ struct simd8uint32 {
960
958
  return *this;
961
959
  }
962
960
 
963
- bool operator==(simd8uint32 other) const {
964
- const auto equals = detail::simdlib::binary_func(data, other.data)
965
- .call<&vceqq_u32>();
966
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967
- return vminvq_u32(equal) == 0xffffffff;
961
+ simd8uint32 operator==(simd8uint32 other) const {
962
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
963
+ .call<&vceqq_u32>()};
968
964
  }
969
965
 
970
- bool operator!=(simd8uint32 other) const {
971
- return !(*this == other);
966
+ simd8uint32 operator~() const {
967
+ return simd8uint32{
968
+ detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
972
969
  }
973
970
 
974
- // Checks whether the other holds exactly the same bytes.
975
- bool is_same_as(simd8uint32 other) const {
976
- const bool equal0 =
977
- (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978
- 0xffffffff);
979
- const bool equal1 =
980
- (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981
- 0xffffffff);
971
+ simd8uint32 operator!=(simd8uint32 other) const {
972
+ return ~(*this == other);
973
+ }
982
974
 
983
- return equal0 && equal1;
975
+ // Checks whether the other holds exactly the same bytes.
976
+ template <typename T>
977
+ bool is_same_as(T other) const {
978
+ const auto o = detail::simdlib::reinterpret_u32(other.data);
979
+ const auto equals = detail::simdlib::binary_func(data, o)
980
+ .template call<&vceqq_u32>();
981
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
982
+ return vminvq_u32(equal) == 0xffffffffu;
984
983
  }
985
984
 
986
985
  void clear() {
@@ -1053,13 +1052,14 @@ inline void cmplt_min_max_fast(
1053
1052
  simd8uint32& minIndices,
1054
1053
  simd8uint32& maxValues,
1055
1054
  simd8uint32& maxIndices) {
1056
- const uint32x4x2_t comparison = uint32x4x2_t{
1057
- vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058
- vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059
-
1060
- minValues.data = uint32x4x2_t{
1061
- vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062
- vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1055
+ const uint32x4x2_t comparison =
1056
+ detail::simdlib::binary_func(
1057
+ candidateValues.data, currentValues.data)
1058
+ .call<&vcltq_u32>();
1059
+
1060
+ minValues.data = detail::simdlib::binary_func(
1061
+ candidateValues.data, currentValues.data)
1062
+ .call<&vminq_u32>();
1063
1063
  minIndices.data = uint32x4x2_t{
1064
1064
  vbslq_u32(
1065
1065
  comparison.val[0],
@@ -1070,9 +1070,9 @@ inline void cmplt_min_max_fast(
1070
1070
  candidateIndices.data.val[1],
1071
1071
  currentIndices.data.val[1])};
1072
1072
 
1073
- maxValues.data = uint32x4x2_t{
1074
- vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075
- vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1073
+ maxValues.data = detail::simdlib::binary_func(
1074
+ candidateValues.data, currentValues.data)
1075
+ .call<&vmaxq_u32>();
1076
1076
  maxIndices.data = uint32x4x2_t{
1077
1077
  vbslq_u32(
1078
1078
  comparison.val[0],
@@ -1167,28 +1167,25 @@ struct simd8float32 {
1167
1167
  return *this;
1168
1168
  }
1169
1169
 
1170
- bool operator==(simd8float32 other) const {
1171
- const auto equals =
1170
+ simd8uint32 operator==(simd8float32 other) const {
1171
+ return simd8uint32{
1172
1172
  detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173
- .call<&vceqq_f32>();
1174
- const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175
- return vminvq_u32(equal) == 0xffffffff;
1173
+ .call<&vceqq_f32>()};
1176
1174
  }
1177
1175
 
1178
- bool operator!=(simd8float32 other) const {
1179
- return !(*this == other);
1176
+ simd8uint32 operator!=(simd8float32 other) const {
1177
+ return ~(*this == other);
1180
1178
  }
1181
1179
 
1182
1180
  // Checks whether the other holds exactly the same bytes.
1183
- bool is_same_as(simd8float32 other) const {
1184
- const bool equal0 =
1185
- (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186
- 0xffffffff);
1187
- const bool equal1 =
1188
- (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189
- 0xffffffff);
1190
-
1191
- return equal0 && equal1;
1181
+ template <typename T>
1182
+ bool is_same_as(T other) const {
1183
+ const auto o = detail::simdlib::reinterpret_f32(other.data);
1184
+ const auto equals =
1185
+ detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
1186
+ .template call<&vceqq_f32>();
1187
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1188
+ return vminvq_u32(equal) == 0xffffffffu;
1192
1189
  }
1193
1190
 
1194
1191
  std::string tostring() const {
@@ -1302,13 +1299,14 @@ inline void cmplt_min_max_fast(
1302
1299
  simd8uint32& minIndices,
1303
1300
  simd8float32& maxValues,
1304
1301
  simd8uint32& maxIndices) {
1305
- const uint32x4x2_t comparison = uint32x4x2_t{
1306
- vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307
- vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308
-
1309
- minValues.data = float32x4x2_t{
1310
- vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311
- vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1302
+ const uint32x4x2_t comparison =
1303
+ detail::simdlib::binary_func<::uint32x4x2_t>(
1304
+ candidateValues.data, currentValues.data)
1305
+ .call<&vcltq_f32>();
1306
+
1307
+ minValues.data = detail::simdlib::binary_func(
1308
+ candidateValues.data, currentValues.data)
1309
+ .call<&vminq_f32>();
1312
1310
  minIndices.data = uint32x4x2_t{
1313
1311
  vbslq_u32(
1314
1312
  comparison.val[0],
@@ -1319,9 +1317,9 @@ inline void cmplt_min_max_fast(
1319
1317
  candidateIndices.data.val[1],
1320
1318
  currentIndices.data.val[1])};
1321
1319
 
1322
- maxValues.data = float32x4x2_t{
1323
- vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324
- vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1320
+ maxValues.data = detail::simdlib::binary_func(
1321
+ candidateValues.data, currentValues.data)
1322
+ .call<&vmaxq_f32>();
1325
1323
  maxIndices.data = uint32x4x2_t{
1326
1324
  vbslq_u32(
1327
1325
  comparison.val[0],