faiss 0.3.0 → 0.3.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,1084 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <algorithm>
|
11
|
+
#include <cstdint>
|
12
|
+
#include <cstring>
|
13
|
+
#include <string>
|
14
|
+
|
15
|
+
namespace faiss {
|
16
|
+
|
17
|
+
struct simd256bit {
|
18
|
+
union {
|
19
|
+
uint8_t u8[32];
|
20
|
+
uint16_t u16[16];
|
21
|
+
uint32_t u32[8];
|
22
|
+
float f32[8];
|
23
|
+
};
|
24
|
+
|
25
|
+
simd256bit() {}
|
26
|
+
|
27
|
+
explicit simd256bit(const void* x) {
|
28
|
+
memcpy(u8, x, 32);
|
29
|
+
}
|
30
|
+
|
31
|
+
void clear() {
|
32
|
+
memset(u8, 0, 32);
|
33
|
+
}
|
34
|
+
|
35
|
+
void storeu(void* ptr) const {
|
36
|
+
memcpy(ptr, u8, 32);
|
37
|
+
}
|
38
|
+
|
39
|
+
void loadu(const void* ptr) {
|
40
|
+
memcpy(u8, ptr, 32);
|
41
|
+
}
|
42
|
+
|
43
|
+
void store(void* ptr) const {
|
44
|
+
storeu(ptr);
|
45
|
+
}
|
46
|
+
|
47
|
+
void bin(char bits[257]) const {
|
48
|
+
const char* bytes = (char*)this->u8;
|
49
|
+
for (int i = 0; i < 256; i++) {
|
50
|
+
bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
|
51
|
+
}
|
52
|
+
bits[256] = 0;
|
53
|
+
}
|
54
|
+
|
55
|
+
std::string bin() const {
|
56
|
+
char bits[257];
|
57
|
+
bin(bits);
|
58
|
+
return std::string(bits);
|
59
|
+
}
|
60
|
+
|
61
|
+
// Checks whether the other holds exactly the same bytes.
|
62
|
+
bool is_same_as(simd256bit other) const {
|
63
|
+
for (size_t i = 0; i < 8; i++) {
|
64
|
+
if (u32[i] != other.u32[i]) {
|
65
|
+
return false;
|
66
|
+
}
|
67
|
+
}
|
68
|
+
|
69
|
+
return true;
|
70
|
+
}
|
71
|
+
};
|
72
|
+
|
73
|
+
/// vector of 16 elements in uint16
|
74
|
+
struct simd16uint16 : simd256bit {
|
75
|
+
simd16uint16() {}
|
76
|
+
|
77
|
+
explicit simd16uint16(int x) {
|
78
|
+
set1(x);
|
79
|
+
}
|
80
|
+
|
81
|
+
explicit simd16uint16(uint16_t x) {
|
82
|
+
set1(x);
|
83
|
+
}
|
84
|
+
|
85
|
+
explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
|
86
|
+
|
87
|
+
explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
|
88
|
+
|
89
|
+
explicit simd16uint16(
|
90
|
+
uint16_t u0,
|
91
|
+
uint16_t u1,
|
92
|
+
uint16_t u2,
|
93
|
+
uint16_t u3,
|
94
|
+
uint16_t u4,
|
95
|
+
uint16_t u5,
|
96
|
+
uint16_t u6,
|
97
|
+
uint16_t u7,
|
98
|
+
uint16_t u8,
|
99
|
+
uint16_t u9,
|
100
|
+
uint16_t u10,
|
101
|
+
uint16_t u11,
|
102
|
+
uint16_t u12,
|
103
|
+
uint16_t u13,
|
104
|
+
uint16_t u14,
|
105
|
+
uint16_t u15) {
|
106
|
+
this->u16[0] = u0;
|
107
|
+
this->u16[1] = u1;
|
108
|
+
this->u16[2] = u2;
|
109
|
+
this->u16[3] = u3;
|
110
|
+
this->u16[4] = u4;
|
111
|
+
this->u16[5] = u5;
|
112
|
+
this->u16[6] = u6;
|
113
|
+
this->u16[7] = u7;
|
114
|
+
this->u16[8] = u8;
|
115
|
+
this->u16[9] = u9;
|
116
|
+
this->u16[10] = u10;
|
117
|
+
this->u16[11] = u11;
|
118
|
+
this->u16[12] = u12;
|
119
|
+
this->u16[13] = u13;
|
120
|
+
this->u16[14] = u14;
|
121
|
+
this->u16[15] = u15;
|
122
|
+
}
|
123
|
+
|
124
|
+
std::string elements_to_string(const char* fmt) const {
|
125
|
+
char res[1000], *ptr = res;
|
126
|
+
for (int i = 0; i < 16; i++) {
|
127
|
+
ptr += sprintf(ptr, fmt, u16[i]);
|
128
|
+
}
|
129
|
+
// strip last ,
|
130
|
+
ptr[-1] = 0;
|
131
|
+
return std::string(res);
|
132
|
+
}
|
133
|
+
|
134
|
+
std::string hex() const {
|
135
|
+
return elements_to_string("%02x,");
|
136
|
+
}
|
137
|
+
|
138
|
+
std::string dec() const {
|
139
|
+
return elements_to_string("%3d,");
|
140
|
+
}
|
141
|
+
|
142
|
+
template <typename F>
|
143
|
+
static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
|
144
|
+
simd16uint16 c;
|
145
|
+
for (int j = 0; j < 16; j++) {
|
146
|
+
c.u16[j] = f(a.u16[j]);
|
147
|
+
}
|
148
|
+
return c;
|
149
|
+
}
|
150
|
+
|
151
|
+
template <typename F>
|
152
|
+
static simd16uint16 binary_func(
|
153
|
+
const simd16uint16& a,
|
154
|
+
const simd16uint16& b,
|
155
|
+
F&& f) {
|
156
|
+
simd16uint16 c;
|
157
|
+
for (int j = 0; j < 16; j++) {
|
158
|
+
c.u16[j] = f(a.u16[j], b.u16[j]);
|
159
|
+
}
|
160
|
+
return c;
|
161
|
+
}
|
162
|
+
|
163
|
+
void set1(uint16_t x) {
|
164
|
+
for (int i = 0; i < 16; i++) {
|
165
|
+
u16[i] = x;
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
simd16uint16 operator*(const simd16uint16& other) const {
|
170
|
+
return binary_func(
|
171
|
+
*this, other, [](uint16_t a, uint16_t b) { return a * b; });
|
172
|
+
}
|
173
|
+
|
174
|
+
// shift must be known at compile time
|
175
|
+
simd16uint16 operator>>(const int shift) const {
|
176
|
+
return unary_func(*this, [shift](uint16_t a) { return a >> shift; });
|
177
|
+
}
|
178
|
+
|
179
|
+
// shift must be known at compile time
|
180
|
+
simd16uint16 operator<<(const int shift) const {
|
181
|
+
return unary_func(*this, [shift](uint16_t a) { return a << shift; });
|
182
|
+
}
|
183
|
+
|
184
|
+
simd16uint16 operator+=(const simd16uint16& other) {
|
185
|
+
*this = *this + other;
|
186
|
+
return *this;
|
187
|
+
}
|
188
|
+
|
189
|
+
simd16uint16 operator-=(const simd16uint16& other) {
|
190
|
+
*this = *this - other;
|
191
|
+
return *this;
|
192
|
+
}
|
193
|
+
|
194
|
+
simd16uint16 operator+(const simd16uint16& other) const {
|
195
|
+
return binary_func(
|
196
|
+
*this, other, [](uint16_t a, uint16_t b) { return a + b; });
|
197
|
+
}
|
198
|
+
|
199
|
+
simd16uint16 operator-(const simd16uint16& other) const {
|
200
|
+
return binary_func(
|
201
|
+
*this, other, [](uint16_t a, uint16_t b) { return a - b; });
|
202
|
+
}
|
203
|
+
|
204
|
+
simd16uint16 operator&(const simd256bit& other) const {
|
205
|
+
return binary_func(
|
206
|
+
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
|
207
|
+
return a & b;
|
208
|
+
});
|
209
|
+
}
|
210
|
+
|
211
|
+
simd16uint16 operator|(const simd256bit& other) const {
|
212
|
+
return binary_func(
|
213
|
+
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
|
214
|
+
return a | b;
|
215
|
+
});
|
216
|
+
}
|
217
|
+
|
218
|
+
simd16uint16 operator^(const simd256bit& other) const {
|
219
|
+
return binary_func(
|
220
|
+
*this, simd16uint16(other), [](uint16_t a, uint16_t b) {
|
221
|
+
return a ^ b;
|
222
|
+
});
|
223
|
+
}
|
224
|
+
|
225
|
+
// returns binary masks
|
226
|
+
simd16uint16 operator==(const simd16uint16& other) const {
|
227
|
+
return binary_func(*this, other, [](uint16_t a, uint16_t b) {
|
228
|
+
return a == b ? 0xffff : 0;
|
229
|
+
});
|
230
|
+
}
|
231
|
+
|
232
|
+
simd16uint16 operator~() const {
|
233
|
+
return unary_func(*this, [](uint16_t a) { return ~a; });
|
234
|
+
}
|
235
|
+
|
236
|
+
// get scalar at index 0
|
237
|
+
uint16_t get_scalar_0() const {
|
238
|
+
return u16[0];
|
239
|
+
}
|
240
|
+
|
241
|
+
// mask of elements where this >= thresh
|
242
|
+
// 2 bit per component: 16 * 2 = 32 bit
|
243
|
+
uint32_t ge_mask(const simd16uint16& thresh) const {
|
244
|
+
uint32_t gem = 0;
|
245
|
+
for (int j = 0; j < 16; j++) {
|
246
|
+
if (u16[j] >= thresh.u16[j]) {
|
247
|
+
gem |= 3 << (j * 2);
|
248
|
+
}
|
249
|
+
}
|
250
|
+
return gem;
|
251
|
+
}
|
252
|
+
|
253
|
+
uint32_t le_mask(const simd16uint16& thresh) const {
|
254
|
+
return thresh.ge_mask(*this);
|
255
|
+
}
|
256
|
+
|
257
|
+
uint32_t gt_mask(const simd16uint16& thresh) const {
|
258
|
+
return ~le_mask(thresh);
|
259
|
+
}
|
260
|
+
|
261
|
+
bool all_gt(const simd16uint16& thresh) const {
|
262
|
+
return le_mask(thresh) == 0;
|
263
|
+
}
|
264
|
+
|
265
|
+
// for debugging only
|
266
|
+
uint16_t operator[](int i) const {
|
267
|
+
return u16[i];
|
268
|
+
}
|
269
|
+
|
270
|
+
void accu_min(const simd16uint16& incoming) {
|
271
|
+
for (int j = 0; j < 16; j++) {
|
272
|
+
if (incoming.u16[j] < u16[j]) {
|
273
|
+
u16[j] = incoming.u16[j];
|
274
|
+
}
|
275
|
+
}
|
276
|
+
}
|
277
|
+
|
278
|
+
void accu_max(const simd16uint16& incoming) {
|
279
|
+
for (int j = 0; j < 16; j++) {
|
280
|
+
if (incoming.u16[j] > u16[j]) {
|
281
|
+
u16[j] = incoming.u16[j];
|
282
|
+
}
|
283
|
+
}
|
284
|
+
}
|
285
|
+
};
|
286
|
+
|
287
|
+
// not really a std::min because it returns an elementwise min
|
288
|
+
inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
|
289
|
+
return simd16uint16::binary_func(
|
290
|
+
av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
|
291
|
+
}
|
292
|
+
|
293
|
+
inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
|
294
|
+
return simd16uint16::binary_func(
|
295
|
+
av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
|
296
|
+
}
|
297
|
+
|
298
|
+
// decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
|
299
|
+
// return (a0 + a1, b0 + b1)
|
300
|
+
// TODO find a better name
|
301
|
+
inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
|
302
|
+
simd16uint16 c;
|
303
|
+
for (int j = 0; j < 8; j++) {
|
304
|
+
c.u16[j] = a.u16[j] + a.u16[j + 8];
|
305
|
+
c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
|
306
|
+
}
|
307
|
+
return c;
|
308
|
+
}
|
309
|
+
|
310
|
+
// compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
|
311
|
+
// of d0 and d1 with thr
|
312
|
+
inline uint32_t cmp_ge32(
|
313
|
+
const simd16uint16& d0,
|
314
|
+
const simd16uint16& d1,
|
315
|
+
const simd16uint16& thr) {
|
316
|
+
uint32_t gem = 0;
|
317
|
+
for (int j = 0; j < 16; j++) {
|
318
|
+
if (d0.u16[j] >= thr.u16[j]) {
|
319
|
+
gem |= 1 << j;
|
320
|
+
}
|
321
|
+
if (d1.u16[j] >= thr.u16[j]) {
|
322
|
+
gem |= 1 << (j + 16);
|
323
|
+
}
|
324
|
+
}
|
325
|
+
return gem;
|
326
|
+
}
|
327
|
+
|
328
|
+
inline uint32_t cmp_le32(
|
329
|
+
const simd16uint16& d0,
|
330
|
+
const simd16uint16& d1,
|
331
|
+
const simd16uint16& thr) {
|
332
|
+
uint32_t gem = 0;
|
333
|
+
for (int j = 0; j < 16; j++) {
|
334
|
+
if (d0.u16[j] <= thr.u16[j]) {
|
335
|
+
gem |= 1 << j;
|
336
|
+
}
|
337
|
+
if (d1.u16[j] <= thr.u16[j]) {
|
338
|
+
gem |= 1 << (j + 16);
|
339
|
+
}
|
340
|
+
}
|
341
|
+
return gem;
|
342
|
+
}
|
343
|
+
|
344
|
+
// hadd does not cross lanes
|
345
|
+
inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
|
346
|
+
simd16uint16 c;
|
347
|
+
c.u16[0] = a.u16[0] + a.u16[1];
|
348
|
+
c.u16[1] = a.u16[2] + a.u16[3];
|
349
|
+
c.u16[2] = a.u16[4] + a.u16[5];
|
350
|
+
c.u16[3] = a.u16[6] + a.u16[7];
|
351
|
+
c.u16[4] = b.u16[0] + b.u16[1];
|
352
|
+
c.u16[5] = b.u16[2] + b.u16[3];
|
353
|
+
c.u16[6] = b.u16[4] + b.u16[5];
|
354
|
+
c.u16[7] = b.u16[6] + b.u16[7];
|
355
|
+
|
356
|
+
c.u16[8] = a.u16[8] + a.u16[9];
|
357
|
+
c.u16[9] = a.u16[10] + a.u16[11];
|
358
|
+
c.u16[10] = a.u16[12] + a.u16[13];
|
359
|
+
c.u16[11] = a.u16[14] + a.u16[15];
|
360
|
+
c.u16[12] = b.u16[8] + b.u16[9];
|
361
|
+
c.u16[13] = b.u16[10] + b.u16[11];
|
362
|
+
c.u16[14] = b.u16[12] + b.u16[13];
|
363
|
+
c.u16[15] = b.u16[14] + b.u16[15];
|
364
|
+
|
365
|
+
return c;
|
366
|
+
}
|
367
|
+
|
368
|
+
// Vectorized version of the following code:
|
369
|
+
// for (size_t i = 0; i < n; i++) {
|
370
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
371
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
372
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
373
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
374
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
375
|
+
// }
|
376
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
377
|
+
// the last equal value is saved instead of the first one), but this behavior
|
378
|
+
// saves instructions.
|
379
|
+
inline void cmplt_min_max_fast(
|
380
|
+
const simd16uint16 candidateValues,
|
381
|
+
const simd16uint16 candidateIndices,
|
382
|
+
const simd16uint16 currentValues,
|
383
|
+
const simd16uint16 currentIndices,
|
384
|
+
simd16uint16& minValues,
|
385
|
+
simd16uint16& minIndices,
|
386
|
+
simd16uint16& maxValues,
|
387
|
+
simd16uint16& maxIndices) {
|
388
|
+
for (size_t i = 0; i < 16; i++) {
|
389
|
+
bool flag = (candidateValues.u16[i] < currentValues.u16[i]);
|
390
|
+
minValues.u16[i] = flag ? candidateValues.u16[i] : currentValues.u16[i];
|
391
|
+
minIndices.u16[i] =
|
392
|
+
flag ? candidateIndices.u16[i] : currentIndices.u16[i];
|
393
|
+
maxValues.u16[i] =
|
394
|
+
!flag ? candidateValues.u16[i] : currentValues.u16[i];
|
395
|
+
maxIndices.u16[i] =
|
396
|
+
!flag ? candidateIndices.u16[i] : currentIndices.u16[i];
|
397
|
+
}
|
398
|
+
}
|
399
|
+
|
400
|
+
// vector of 32 unsigned 8-bit integers
|
401
|
+
struct simd32uint8 : simd256bit {
|
402
|
+
simd32uint8() {}
|
403
|
+
|
404
|
+
explicit simd32uint8(int x) {
|
405
|
+
set1(x);
|
406
|
+
}
|
407
|
+
|
408
|
+
explicit simd32uint8(uint8_t x) {
|
409
|
+
set1(x);
|
410
|
+
}
|
411
|
+
template <
|
412
|
+
uint8_t _0,
|
413
|
+
uint8_t _1,
|
414
|
+
uint8_t _2,
|
415
|
+
uint8_t _3,
|
416
|
+
uint8_t _4,
|
417
|
+
uint8_t _5,
|
418
|
+
uint8_t _6,
|
419
|
+
uint8_t _7,
|
420
|
+
uint8_t _8,
|
421
|
+
uint8_t _9,
|
422
|
+
uint8_t _10,
|
423
|
+
uint8_t _11,
|
424
|
+
uint8_t _12,
|
425
|
+
uint8_t _13,
|
426
|
+
uint8_t _14,
|
427
|
+
uint8_t _15,
|
428
|
+
uint8_t _16,
|
429
|
+
uint8_t _17,
|
430
|
+
uint8_t _18,
|
431
|
+
uint8_t _19,
|
432
|
+
uint8_t _20,
|
433
|
+
uint8_t _21,
|
434
|
+
uint8_t _22,
|
435
|
+
uint8_t _23,
|
436
|
+
uint8_t _24,
|
437
|
+
uint8_t _25,
|
438
|
+
uint8_t _26,
|
439
|
+
uint8_t _27,
|
440
|
+
uint8_t _28,
|
441
|
+
uint8_t _29,
|
442
|
+
uint8_t _30,
|
443
|
+
uint8_t _31>
|
444
|
+
static simd32uint8 create() {
|
445
|
+
simd32uint8 ret;
|
446
|
+
ret.u8[0] = _0;
|
447
|
+
ret.u8[1] = _1;
|
448
|
+
ret.u8[2] = _2;
|
449
|
+
ret.u8[3] = _3;
|
450
|
+
ret.u8[4] = _4;
|
451
|
+
ret.u8[5] = _5;
|
452
|
+
ret.u8[6] = _6;
|
453
|
+
ret.u8[7] = _7;
|
454
|
+
ret.u8[8] = _8;
|
455
|
+
ret.u8[9] = _9;
|
456
|
+
ret.u8[10] = _10;
|
457
|
+
ret.u8[11] = _11;
|
458
|
+
ret.u8[12] = _12;
|
459
|
+
ret.u8[13] = _13;
|
460
|
+
ret.u8[14] = _14;
|
461
|
+
ret.u8[15] = _15;
|
462
|
+
ret.u8[16] = _16;
|
463
|
+
ret.u8[17] = _17;
|
464
|
+
ret.u8[18] = _18;
|
465
|
+
ret.u8[19] = _19;
|
466
|
+
ret.u8[20] = _20;
|
467
|
+
ret.u8[21] = _21;
|
468
|
+
ret.u8[22] = _22;
|
469
|
+
ret.u8[23] = _23;
|
470
|
+
ret.u8[24] = _24;
|
471
|
+
ret.u8[25] = _25;
|
472
|
+
ret.u8[26] = _26;
|
473
|
+
ret.u8[27] = _27;
|
474
|
+
ret.u8[28] = _28;
|
475
|
+
ret.u8[29] = _29;
|
476
|
+
ret.u8[30] = _30;
|
477
|
+
ret.u8[31] = _31;
|
478
|
+
return ret;
|
479
|
+
}
|
480
|
+
|
481
|
+
explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
|
482
|
+
|
483
|
+
explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
|
484
|
+
|
485
|
+
std::string elements_to_string(const char* fmt) const {
|
486
|
+
char res[1000], *ptr = res;
|
487
|
+
for (int i = 0; i < 32; i++) {
|
488
|
+
ptr += sprintf(ptr, fmt, u8[i]);
|
489
|
+
}
|
490
|
+
// strip last ,
|
491
|
+
ptr[-1] = 0;
|
492
|
+
return std::string(res);
|
493
|
+
}
|
494
|
+
|
495
|
+
std::string hex() const {
|
496
|
+
return elements_to_string("%02x,");
|
497
|
+
}
|
498
|
+
|
499
|
+
std::string dec() const {
|
500
|
+
return elements_to_string("%3d,");
|
501
|
+
}
|
502
|
+
|
503
|
+
void set1(uint8_t x) {
|
504
|
+
for (int j = 0; j < 32; j++) {
|
505
|
+
u8[j] = x;
|
506
|
+
}
|
507
|
+
}
|
508
|
+
|
509
|
+
template <typename F>
|
510
|
+
static simd32uint8 binary_func(
|
511
|
+
const simd32uint8& a,
|
512
|
+
const simd32uint8& b,
|
513
|
+
F&& f) {
|
514
|
+
simd32uint8 c;
|
515
|
+
for (int j = 0; j < 32; j++) {
|
516
|
+
c.u8[j] = f(a.u8[j], b.u8[j]);
|
517
|
+
}
|
518
|
+
return c;
|
519
|
+
}
|
520
|
+
|
521
|
+
simd32uint8 operator&(const simd256bit& other) const {
|
522
|
+
return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
|
523
|
+
return a & b;
|
524
|
+
});
|
525
|
+
}
|
526
|
+
|
527
|
+
simd32uint8 operator+(const simd32uint8& other) const {
|
528
|
+
return binary_func(
|
529
|
+
*this, other, [](uint8_t a, uint8_t b) { return a + b; });
|
530
|
+
}
|
531
|
+
|
532
|
+
// The very important operation that everything relies on
|
533
|
+
simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
|
534
|
+
simd32uint8 c;
|
535
|
+
// The original for loop:
|
536
|
+
// for (int j = 0; j < 32; j++) {
|
537
|
+
// if (idx.u8[j] & 0x80) {
|
538
|
+
// c.u8[j] = 0;
|
539
|
+
// } else {
|
540
|
+
// uint8_t i = idx.u8[j] & 15;
|
541
|
+
// if (j < 16) {
|
542
|
+
// c.u8[j] = u8[i];
|
543
|
+
// } else {
|
544
|
+
// c.u8[j] = u8[16 + i];
|
545
|
+
// }
|
546
|
+
// }
|
547
|
+
|
548
|
+
// The following function was re-written for Power 10
|
549
|
+
// The loop was unrolled to remove the if (j < 16) statement by doing
|
550
|
+
// the j and j + 16 iterations in parallel. The additional unrolling
|
551
|
+
// for j + 1 and j + 17, reduces the execution time on Power 10 by
|
552
|
+
// about 50% as the instruction scheduling allows on average 2X more
|
553
|
+
// instructions to be issued per cycle.
|
554
|
+
|
555
|
+
for (int j = 0; j < 16; j = j + 2) {
|
556
|
+
// j < 16, unrolled to depth of 2
|
557
|
+
if (idx.u8[j] & 0x80) {
|
558
|
+
c.u8[j] = 0;
|
559
|
+
} else {
|
560
|
+
uint8_t i = idx.u8[j] & 15;
|
561
|
+
c.u8[j] = u8[i];
|
562
|
+
}
|
563
|
+
|
564
|
+
if (idx.u8[j + 1] & 0x80) {
|
565
|
+
c.u8[j + 1] = 0;
|
566
|
+
} else {
|
567
|
+
uint8_t i = idx.u8[j + 1] & 15;
|
568
|
+
c.u8[j + 1] = u8[i];
|
569
|
+
}
|
570
|
+
|
571
|
+
// j >= 16, unrolled to depth of 2
|
572
|
+
if (idx.u8[j + 16] & 0x80) {
|
573
|
+
c.u8[j + 16] = 0;
|
574
|
+
} else {
|
575
|
+
uint8_t i = idx.u8[j + 16] & 15;
|
576
|
+
c.u8[j + 16] = u8[i + 16];
|
577
|
+
}
|
578
|
+
|
579
|
+
if (idx.u8[j + 17] & 0x80) {
|
580
|
+
c.u8[j + 17] = 0;
|
581
|
+
} else {
|
582
|
+
uint8_t i = idx.u8[j + 17] & 15;
|
583
|
+
c.u8[j + 17] = u8[i + 16];
|
584
|
+
}
|
585
|
+
}
|
586
|
+
return c;
|
587
|
+
}
|
588
|
+
|
589
|
+
// extract + 0-extend lane
|
590
|
+
// this operation is slow (3 cycles)
|
591
|
+
|
592
|
+
simd32uint8 operator+=(const simd32uint8& other) {
|
593
|
+
*this = *this + other;
|
594
|
+
return *this;
|
595
|
+
}
|
596
|
+
|
597
|
+
// for debugging only
|
598
|
+
uint8_t operator[](int i) const {
|
599
|
+
return u8[i];
|
600
|
+
}
|
601
|
+
};
|
602
|
+
|
603
|
+
// convert with saturation
|
604
|
+
// careful: this does not cross lanes, so the order is weird
|
605
|
+
inline simd32uint8 uint16_to_uint8_saturate(
|
606
|
+
const simd16uint16& a,
|
607
|
+
const simd16uint16& b) {
|
608
|
+
simd32uint8 c;
|
609
|
+
|
610
|
+
auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
|
611
|
+
|
612
|
+
for (int i = 0; i < 8; i++) {
|
613
|
+
c.u8[i] = saturate_16_to_8(a.u16[i]);
|
614
|
+
c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
|
615
|
+
c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
|
616
|
+
c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
|
617
|
+
}
|
618
|
+
return c;
|
619
|
+
}
|
620
|
+
|
621
|
+
/// get most significant bit of each byte
|
622
|
+
inline uint32_t get_MSBs(const simd32uint8& a) {
|
623
|
+
uint32_t res = 0;
|
624
|
+
for (int i = 0; i < 32; i++) {
|
625
|
+
if (a.u8[i] & 0x80) {
|
626
|
+
res |= 1 << i;
|
627
|
+
}
|
628
|
+
}
|
629
|
+
return res;
|
630
|
+
}
|
631
|
+
|
632
|
+
/// use MSB of each byte of mask to select a byte between a and b
|
633
|
+
inline simd32uint8 blendv(
|
634
|
+
const simd32uint8& a,
|
635
|
+
const simd32uint8& b,
|
636
|
+
const simd32uint8& mask) {
|
637
|
+
simd32uint8 c;
|
638
|
+
for (int i = 0; i < 32; i++) {
|
639
|
+
if (mask.u8[i] & 0x80) {
|
640
|
+
c.u8[i] = b.u8[i];
|
641
|
+
} else {
|
642
|
+
c.u8[i] = a.u8[i];
|
643
|
+
}
|
644
|
+
}
|
645
|
+
return c;
|
646
|
+
}
|
647
|
+
|
648
|
+
/// vector of 8 unsigned 32-bit integers
|
649
|
+
struct simd8uint32 : simd256bit {
|
650
|
+
simd8uint32() {}
|
651
|
+
|
652
|
+
explicit simd8uint32(uint32_t x) {
|
653
|
+
set1(x);
|
654
|
+
}
|
655
|
+
|
656
|
+
explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
|
657
|
+
|
658
|
+
explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {}
|
659
|
+
|
660
|
+
explicit simd8uint32(
|
661
|
+
uint32_t u0,
|
662
|
+
uint32_t u1,
|
663
|
+
uint32_t u2,
|
664
|
+
uint32_t u3,
|
665
|
+
uint32_t u4,
|
666
|
+
uint32_t u5,
|
667
|
+
uint32_t u6,
|
668
|
+
uint32_t u7) {
|
669
|
+
u32[0] = u0;
|
670
|
+
u32[1] = u1;
|
671
|
+
u32[2] = u2;
|
672
|
+
u32[3] = u3;
|
673
|
+
u32[4] = u4;
|
674
|
+
u32[5] = u5;
|
675
|
+
u32[6] = u6;
|
676
|
+
u32[7] = u7;
|
677
|
+
}
|
678
|
+
|
679
|
+
simd8uint32 operator+(simd8uint32 other) const {
|
680
|
+
simd8uint32 result;
|
681
|
+
for (int i = 0; i < 8; i++) {
|
682
|
+
result.u32[i] = u32[i] + other.u32[i];
|
683
|
+
}
|
684
|
+
return result;
|
685
|
+
}
|
686
|
+
|
687
|
+
simd8uint32 operator-(simd8uint32 other) const {
|
688
|
+
simd8uint32 result;
|
689
|
+
for (int i = 0; i < 8; i++) {
|
690
|
+
result.u32[i] = u32[i] - other.u32[i];
|
691
|
+
}
|
692
|
+
return result;
|
693
|
+
}
|
694
|
+
|
695
|
+
simd8uint32& operator+=(const simd8uint32& other) {
|
696
|
+
for (int i = 0; i < 8; i++) {
|
697
|
+
u32[i] += other.u32[i];
|
698
|
+
}
|
699
|
+
return *this;
|
700
|
+
}
|
701
|
+
|
702
|
+
bool operator==(simd8uint32 other) const {
|
703
|
+
for (size_t i = 0; i < 8; i++) {
|
704
|
+
if (u32[i] != other.u32[i]) {
|
705
|
+
return false;
|
706
|
+
}
|
707
|
+
}
|
708
|
+
|
709
|
+
return true;
|
710
|
+
}
|
711
|
+
|
712
|
+
bool operator!=(simd8uint32 other) const {
|
713
|
+
return !(*this == other);
|
714
|
+
}
|
715
|
+
|
716
|
+
std::string elements_to_string(const char* fmt) const {
|
717
|
+
char res[1000], *ptr = res;
|
718
|
+
for (int i = 0; i < 8; i++) {
|
719
|
+
ptr += sprintf(ptr, fmt, u32[i]);
|
720
|
+
}
|
721
|
+
// strip last ,
|
722
|
+
ptr[-1] = 0;
|
723
|
+
return std::string(res);
|
724
|
+
}
|
725
|
+
|
726
|
+
std::string hex() const {
|
727
|
+
return elements_to_string("%08x,");
|
728
|
+
}
|
729
|
+
|
730
|
+
std::string dec() const {
|
731
|
+
return elements_to_string("%10d,");
|
732
|
+
}
|
733
|
+
|
734
|
+
void set1(uint32_t x) {
|
735
|
+
for (int i = 0; i < 8; i++) {
|
736
|
+
u32[i] = x;
|
737
|
+
}
|
738
|
+
}
|
739
|
+
|
740
|
+
simd8uint32 unzip() const {
|
741
|
+
const uint32_t ret[] = {
|
742
|
+
u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]};
|
743
|
+
return simd8uint32{ret};
|
744
|
+
}
|
745
|
+
};
|
746
|
+
|
747
|
+
// Vectorized version of the following code:
|
748
|
+
// for (size_t i = 0; i < n; i++) {
|
749
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
750
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
751
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
752
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
753
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
754
|
+
// }
|
755
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
756
|
+
// the last equal value is saved instead of the first one), but this behavior
|
757
|
+
// saves instructions.
|
758
|
+
inline void cmplt_min_max_fast(
|
759
|
+
const simd8uint32 candidateValues,
|
760
|
+
const simd8uint32 candidateIndices,
|
761
|
+
const simd8uint32 currentValues,
|
762
|
+
const simd8uint32 currentIndices,
|
763
|
+
simd8uint32& minValues,
|
764
|
+
simd8uint32& minIndices,
|
765
|
+
simd8uint32& maxValues,
|
766
|
+
simd8uint32& maxIndices) {
|
767
|
+
for (size_t i = 0; i < 8; i++) {
|
768
|
+
bool flag = (candidateValues.u32[i] < currentValues.u32[i]);
|
769
|
+
minValues.u32[i] = flag ? candidateValues.u32[i] : currentValues.u32[i];
|
770
|
+
minIndices.u32[i] =
|
771
|
+
flag ? candidateIndices.u32[i] : currentIndices.u32[i];
|
772
|
+
maxValues.u32[i] =
|
773
|
+
!flag ? candidateValues.u32[i] : currentValues.u32[i];
|
774
|
+
maxIndices.u32[i] =
|
775
|
+
!flag ? candidateIndices.u32[i] : currentIndices.u32[i];
|
776
|
+
}
|
777
|
+
}
|
778
|
+
|
779
|
+
struct simd8float32 : simd256bit {
|
780
|
+
simd8float32() {}
|
781
|
+
|
782
|
+
explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
|
783
|
+
|
784
|
+
explicit simd8float32(float x) {
|
785
|
+
set1(x);
|
786
|
+
}
|
787
|
+
|
788
|
+
explicit simd8float32(const float* x) {
|
789
|
+
loadu((void*)x);
|
790
|
+
}
|
791
|
+
|
792
|
+
void set1(float x) {
|
793
|
+
for (int i = 0; i < 8; i++) {
|
794
|
+
f32[i] = x;
|
795
|
+
}
|
796
|
+
}
|
797
|
+
|
798
|
+
explicit simd8float32(
|
799
|
+
float f0,
|
800
|
+
float f1,
|
801
|
+
float f2,
|
802
|
+
float f3,
|
803
|
+
float f4,
|
804
|
+
float f5,
|
805
|
+
float f6,
|
806
|
+
float f7) {
|
807
|
+
f32[0] = f0;
|
808
|
+
f32[1] = f1;
|
809
|
+
f32[2] = f2;
|
810
|
+
f32[3] = f3;
|
811
|
+
f32[4] = f4;
|
812
|
+
f32[5] = f5;
|
813
|
+
f32[6] = f6;
|
814
|
+
f32[7] = f7;
|
815
|
+
}
|
816
|
+
|
817
|
+
template <typename F>
|
818
|
+
static simd8float32 binary_func(
|
819
|
+
const simd8float32& a,
|
820
|
+
const simd8float32& b,
|
821
|
+
F&& f) {
|
822
|
+
simd8float32 c;
|
823
|
+
for (int j = 0; j < 8; j++) {
|
824
|
+
c.f32[j] = f(a.f32[j], b.f32[j]);
|
825
|
+
}
|
826
|
+
return c;
|
827
|
+
}
|
828
|
+
|
829
|
+
simd8float32 operator*(const simd8float32& other) const {
|
830
|
+
return binary_func(
|
831
|
+
*this, other, [](float a, float b) { return a * b; });
|
832
|
+
}
|
833
|
+
|
834
|
+
simd8float32 operator+(const simd8float32& other) const {
|
835
|
+
return binary_func(
|
836
|
+
*this, other, [](float a, float b) { return a + b; });
|
837
|
+
}
|
838
|
+
|
839
|
+
simd8float32 operator-(const simd8float32& other) const {
|
840
|
+
return binary_func(
|
841
|
+
*this, other, [](float a, float b) { return a - b; });
|
842
|
+
}
|
843
|
+
|
844
|
+
simd8float32& operator+=(const simd8float32& other) {
|
845
|
+
for (size_t i = 0; i < 8; i++) {
|
846
|
+
f32[i] += other.f32[i];
|
847
|
+
}
|
848
|
+
|
849
|
+
return *this;
|
850
|
+
}
|
851
|
+
|
852
|
+
bool operator==(simd8float32 other) const {
|
853
|
+
for (size_t i = 0; i < 8; i++) {
|
854
|
+
if (f32[i] != other.f32[i]) {
|
855
|
+
return false;
|
856
|
+
}
|
857
|
+
}
|
858
|
+
|
859
|
+
return true;
|
860
|
+
}
|
861
|
+
|
862
|
+
bool operator!=(simd8float32 other) const {
|
863
|
+
return !(*this == other);
|
864
|
+
}
|
865
|
+
|
866
|
+
std::string tostring() const {
|
867
|
+
char res[1000], *ptr = res;
|
868
|
+
for (int i = 0; i < 8; i++) {
|
869
|
+
ptr += sprintf(ptr, "%g,", f32[i]);
|
870
|
+
}
|
871
|
+
// strip last ,
|
872
|
+
ptr[-1] = 0;
|
873
|
+
return std::string(res);
|
874
|
+
}
|
875
|
+
};
|
876
|
+
|
877
|
+
// hadd does not cross lanes
|
878
|
+
inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
|
879
|
+
simd8float32 c;
|
880
|
+
c.f32[0] = a.f32[0] + a.f32[1];
|
881
|
+
c.f32[1] = a.f32[2] + a.f32[3];
|
882
|
+
c.f32[2] = b.f32[0] + b.f32[1];
|
883
|
+
c.f32[3] = b.f32[2] + b.f32[3];
|
884
|
+
|
885
|
+
c.f32[4] = a.f32[4] + a.f32[5];
|
886
|
+
c.f32[5] = a.f32[6] + a.f32[7];
|
887
|
+
c.f32[6] = b.f32[4] + b.f32[5];
|
888
|
+
c.f32[7] = b.f32[6] + b.f32[7];
|
889
|
+
|
890
|
+
return c;
|
891
|
+
}
|
892
|
+
|
893
|
+
inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
|
894
|
+
simd8float32 c;
|
895
|
+
c.f32[0] = a.f32[0];
|
896
|
+
c.f32[1] = b.f32[0];
|
897
|
+
c.f32[2] = a.f32[1];
|
898
|
+
c.f32[3] = b.f32[1];
|
899
|
+
|
900
|
+
c.f32[4] = a.f32[4];
|
901
|
+
c.f32[5] = b.f32[4];
|
902
|
+
c.f32[6] = a.f32[5];
|
903
|
+
c.f32[7] = b.f32[5];
|
904
|
+
|
905
|
+
return c;
|
906
|
+
}
|
907
|
+
|
908
|
+
inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
|
909
|
+
simd8float32 c;
|
910
|
+
c.f32[0] = a.f32[2];
|
911
|
+
c.f32[1] = b.f32[2];
|
912
|
+
c.f32[2] = a.f32[3];
|
913
|
+
c.f32[3] = b.f32[3];
|
914
|
+
|
915
|
+
c.f32[4] = a.f32[6];
|
916
|
+
c.f32[5] = b.f32[6];
|
917
|
+
c.f32[6] = a.f32[7];
|
918
|
+
c.f32[7] = b.f32[7];
|
919
|
+
|
920
|
+
return c;
|
921
|
+
}
|
922
|
+
|
923
|
+
// compute a * b + c
|
924
|
+
inline simd8float32 fmadd(
|
925
|
+
const simd8float32& a,
|
926
|
+
const simd8float32& b,
|
927
|
+
const simd8float32& c) {
|
928
|
+
simd8float32 res;
|
929
|
+
for (int i = 0; i < 8; i++) {
|
930
|
+
res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
|
931
|
+
}
|
932
|
+
return res;
|
933
|
+
}
|
934
|
+
|
935
|
+
namespace {
|
936
|
+
|
937
|
+
// get even float32's of a and b, interleaved
|
938
|
+
simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
|
939
|
+
simd8float32 c;
|
940
|
+
|
941
|
+
c.f32[0] = a.f32[0];
|
942
|
+
c.f32[1] = a.f32[2];
|
943
|
+
c.f32[2] = b.f32[0];
|
944
|
+
c.f32[3] = b.f32[2];
|
945
|
+
|
946
|
+
c.f32[4] = a.f32[4];
|
947
|
+
c.f32[5] = a.f32[6];
|
948
|
+
c.f32[6] = b.f32[4];
|
949
|
+
c.f32[7] = b.f32[6];
|
950
|
+
|
951
|
+
return c;
|
952
|
+
}
|
953
|
+
|
954
|
+
// get odd float32's of a and b, interleaved
|
955
|
+
simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
|
956
|
+
simd8float32 c;
|
957
|
+
|
958
|
+
c.f32[0] = a.f32[1];
|
959
|
+
c.f32[1] = a.f32[3];
|
960
|
+
c.f32[2] = b.f32[1];
|
961
|
+
c.f32[3] = b.f32[3];
|
962
|
+
|
963
|
+
c.f32[4] = a.f32[5];
|
964
|
+
c.f32[5] = a.f32[7];
|
965
|
+
c.f32[6] = b.f32[5];
|
966
|
+
c.f32[7] = b.f32[7];
|
967
|
+
|
968
|
+
return c;
|
969
|
+
}
|
970
|
+
|
971
|
+
// 3 cycles
|
972
|
+
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
|
973
|
+
simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
|
974
|
+
simd8float32 c;
|
975
|
+
|
976
|
+
c.f32[0] = a.f32[0];
|
977
|
+
c.f32[1] = a.f32[1];
|
978
|
+
c.f32[2] = a.f32[2];
|
979
|
+
c.f32[3] = a.f32[3];
|
980
|
+
|
981
|
+
c.f32[4] = b.f32[0];
|
982
|
+
c.f32[5] = b.f32[1];
|
983
|
+
c.f32[6] = b.f32[2];
|
984
|
+
c.f32[7] = b.f32[3];
|
985
|
+
|
986
|
+
return c;
|
987
|
+
}
|
988
|
+
|
989
|
+
simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
|
990
|
+
simd8float32 c;
|
991
|
+
|
992
|
+
c.f32[0] = a.f32[4];
|
993
|
+
c.f32[1] = a.f32[5];
|
994
|
+
c.f32[2] = a.f32[6];
|
995
|
+
c.f32[3] = a.f32[7];
|
996
|
+
|
997
|
+
c.f32[4] = b.f32[4];
|
998
|
+
c.f32[5] = b.f32[5];
|
999
|
+
c.f32[6] = b.f32[6];
|
1000
|
+
c.f32[7] = b.f32[7];
|
1001
|
+
|
1002
|
+
return c;
|
1003
|
+
}
|
1004
|
+
|
1005
|
+
// The following primitive is a vectorized version of the following code
|
1006
|
+
// snippet:
|
1007
|
+
// float lowestValue = HUGE_VAL;
|
1008
|
+
// uint lowestIndex = 0;
|
1009
|
+
// for (size_t i = 0; i < n; i++) {
|
1010
|
+
// if (values[i] < lowestValue) {
|
1011
|
+
// lowestValue = values[i];
|
1012
|
+
// lowestIndex = i;
|
1013
|
+
// }
|
1014
|
+
// }
|
1015
|
+
// Vectorized version can be implemented via two operations: cmp and blend
|
1016
|
+
// with something like this:
|
1017
|
+
// lowestValues = [HUGE_VAL; 8];
|
1018
|
+
// lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
|
1019
|
+
// for (size_t i = 0; i < n; i += 8) {
|
1020
|
+
// auto comparison = cmp(values + i, lowestValues);
|
1021
|
+
// lowestValues = blend(
|
1022
|
+
// comparison,
|
1023
|
+
// values + i,
|
1024
|
+
// lowestValues);
|
1025
|
+
// lowestIndices = blend(
|
1026
|
+
// comparison,
|
1027
|
+
// i + {0, 1, 2, 3, 4, 5, 6, 7},
|
1028
|
+
// lowestIndices);
|
1029
|
+
// lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
|
1030
|
+
// }
|
1031
|
+
// The problem is that blend primitive needs very different instruction
|
1032
|
+
// order for AVX and ARM.
|
1033
|
+
// So, let's introduce a combination of these two in order to avoid
|
1034
|
+
// confusion for ppl who write in low-level SIMD instructions. Additionally,
|
1035
|
+
// these two ops (cmp and blend) are very often used together.
|
1036
|
+
inline void cmplt_and_blend_inplace(
|
1037
|
+
const simd8float32 candidateValues,
|
1038
|
+
const simd8uint32 candidateIndices,
|
1039
|
+
simd8float32& lowestValues,
|
1040
|
+
simd8uint32& lowestIndices) {
|
1041
|
+
for (size_t j = 0; j < 8; j++) {
|
1042
|
+
bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]);
|
1043
|
+
if (comparison) {
|
1044
|
+
lowestValues.f32[j] = candidateValues.f32[j];
|
1045
|
+
lowestIndices.u32[j] = candidateIndices.u32[j];
|
1046
|
+
}
|
1047
|
+
}
|
1048
|
+
}
|
1049
|
+
|
1050
|
+
// Vectorized version of the following code:
|
1051
|
+
// for (size_t i = 0; i < n; i++) {
|
1052
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
1053
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
1054
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
1055
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
1056
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
1057
|
+
// }
|
1058
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
1059
|
+
// the last equal value is saved instead of the first one), but this behavior
|
1060
|
+
// saves instructions.
|
1061
|
+
inline void cmplt_min_max_fast(
|
1062
|
+
const simd8float32 candidateValues,
|
1063
|
+
const simd8uint32 candidateIndices,
|
1064
|
+
const simd8float32 currentValues,
|
1065
|
+
const simd8uint32 currentIndices,
|
1066
|
+
simd8float32& minValues,
|
1067
|
+
simd8uint32& minIndices,
|
1068
|
+
simd8float32& maxValues,
|
1069
|
+
simd8uint32& maxIndices) {
|
1070
|
+
for (size_t i = 0; i < 8; i++) {
|
1071
|
+
bool flag = (candidateValues.f32[i] < currentValues.f32[i]);
|
1072
|
+
minValues.f32[i] = flag ? candidateValues.f32[i] : currentValues.f32[i];
|
1073
|
+
minIndices.u32[i] =
|
1074
|
+
flag ? candidateIndices.u32[i] : currentIndices.u32[i];
|
1075
|
+
maxValues.f32[i] =
|
1076
|
+
!flag ? candidateValues.f32[i] : currentValues.f32[i];
|
1077
|
+
maxIndices.u32[i] =
|
1078
|
+
!flag ? candidateIndices.u32[i] : currentIndices.u32[i];
|
1079
|
+
}
|
1080
|
+
}
|
1081
|
+
|
1082
|
+
} // namespace
|
1083
|
+
|
1084
|
+
} // namespace faiss
|