faiss 0.6.1 → 0.6.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
- data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
- data/vendor/faiss/faiss/factory_tools.cpp +4 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
- data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
- data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
- data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
- data/vendor/faiss/faiss/impl/HNSW.h +51 -13
- data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
- data/vendor/faiss/faiss/impl/Panorama.h +11 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
- data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
- data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
- data/vendor/faiss/faiss/impl/io_macros.h +25 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
- data/vendor/faiss/faiss/index_factory.cpp +5 -1
- data/vendor/faiss/faiss/index_io.h +16 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
- data/vendor/faiss/faiss/utils/bf16.h +34 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
- metadata +12 -2
|
@@ -0,0 +1,559 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#ifdef COMPILE_SIMD_AVX512_SPR
|
|
9
|
+
|
|
10
|
+
#include <immintrin.h>
|
|
11
|
+
|
|
12
|
+
#include <faiss/impl/scalar_quantizer/codecs.h>
|
|
13
|
+
#include <faiss/impl/scalar_quantizer/distance_computers.h>
|
|
14
|
+
#include <faiss/impl/scalar_quantizer/quantizers.h>
|
|
15
|
+
#include <faiss/impl/scalar_quantizer/scanners.h>
|
|
16
|
+
#include <faiss/impl/scalar_quantizer/similarities.h>
|
|
17
|
+
#include <faiss/impl/simdlib/simdlib_avx512.h>
|
|
18
|
+
|
|
19
|
+
#include <faiss/impl/scalar_quantizer/sq-avx512-impl.h>
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
namespace scalar_quantizer {
|
|
23
|
+
|
|
24
|
+
/**********************************************************
|
|
25
|
+
* Codecs — inherit AVX512 implementations
|
|
26
|
+
**********************************************************/
|
|
27
|
+
|
|
28
|
+
template <>
|
|
29
|
+
struct Codec8bit<SIMDLevel::AVX512_SPR> : Codec8bit<SIMDLevel::AVX512> {};
|
|
30
|
+
|
|
31
|
+
template <>
|
|
32
|
+
struct Codec4bit<SIMDLevel::AVX512_SPR> : Codec4bit<SIMDLevel::AVX512> {};
|
|
33
|
+
|
|
34
|
+
template <>
|
|
35
|
+
struct Codec6bit<SIMDLevel::AVX512_SPR> : Codec6bit<SIMDLevel::AVX512> {};
|
|
36
|
+
|
|
37
|
+
/**********************************************************
|
|
38
|
+
* Quantizers — inherit AVX512 implementations
|
|
39
|
+
**********************************************************/
|
|
40
|
+
|
|
41
|
+
template <class Codec>
|
|
42
|
+
struct QuantizerTemplate<
|
|
43
|
+
Codec,
|
|
44
|
+
QuantizerTemplateScaling::UNIFORM,
|
|
45
|
+
SIMDLevel::AVX512_SPR>
|
|
46
|
+
: QuantizerTemplate<
|
|
47
|
+
Codec,
|
|
48
|
+
QuantizerTemplateScaling::UNIFORM,
|
|
49
|
+
SIMDLevel::AVX512> {
|
|
50
|
+
using QuantizerTemplate<
|
|
51
|
+
Codec,
|
|
52
|
+
QuantizerTemplateScaling::UNIFORM,
|
|
53
|
+
SIMDLevel::AVX512>::QuantizerTemplate;
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
template <class Codec>
|
|
57
|
+
struct QuantizerTemplate<
|
|
58
|
+
Codec,
|
|
59
|
+
QuantizerTemplateScaling::NON_UNIFORM,
|
|
60
|
+
SIMDLevel::AVX512_SPR>
|
|
61
|
+
: QuantizerTemplate<
|
|
62
|
+
Codec,
|
|
63
|
+
QuantizerTemplateScaling::NON_UNIFORM,
|
|
64
|
+
SIMDLevel::AVX512> {
|
|
65
|
+
using QuantizerTemplate<
|
|
66
|
+
Codec,
|
|
67
|
+
QuantizerTemplateScaling::NON_UNIFORM,
|
|
68
|
+
SIMDLevel::AVX512>::QuantizerTemplate;
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
template <>
|
|
72
|
+
struct QuantizerFP16<SIMDLevel::AVX512_SPR> : QuantizerFP16<SIMDLevel::AVX512> {
|
|
73
|
+
using QuantizerFP16<SIMDLevel::AVX512>::QuantizerFP16;
|
|
74
|
+
};
|
|
75
|
+
|
|
76
|
+
template <>
|
|
77
|
+
struct QuantizerBF16<SIMDLevel::AVX512_SPR> : QuantizerBF16<SIMDLevel::AVX512> {
|
|
78
|
+
using QuantizerBF16<SIMDLevel::AVX512>::QuantizerBF16;
|
|
79
|
+
|
|
80
|
+
void encode_vector(const float* x, uint8_t* code) const override {
|
|
81
|
+
encode_bf16_simd(x, (uint16_t*)code, this->d);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void decode_vector(const uint8_t* code, float* x) const override {
|
|
85
|
+
decode_bf16_simd((const uint16_t*)code, x, this->d);
|
|
86
|
+
}
|
|
87
|
+
};
|
|
88
|
+
|
|
89
|
+
template <>
|
|
90
|
+
struct Quantizer8bitDirect<SIMDLevel::AVX512_SPR>
|
|
91
|
+
: Quantizer8bitDirect<SIMDLevel::AVX512> {
|
|
92
|
+
using Quantizer8bitDirect<SIMDLevel::AVX512>::Quantizer8bitDirect;
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
template <>
|
|
96
|
+
struct Quantizer8bitDirectSigned<SIMDLevel::AVX512_SPR>
|
|
97
|
+
: Quantizer8bitDirectSigned<SIMDLevel::AVX512> {
|
|
98
|
+
using Quantizer8bitDirectSigned<
|
|
99
|
+
SIMDLevel::AVX512>::Quantizer8bitDirectSigned;
|
|
100
|
+
};
|
|
101
|
+
|
|
102
|
+
/**********************************************************
|
|
103
|
+
* TurboQuant MSE — inherit AVX512 implementations
|
|
104
|
+
**********************************************************/
|
|
105
|
+
|
|
106
|
+
template <int NBits>
|
|
107
|
+
struct QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512_SPR>
|
|
108
|
+
: QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512> {
|
|
109
|
+
using QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512>::
|
|
110
|
+
QuantizerTurboQuantMSE;
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
/**********************************************************
|
|
114
|
+
* Similarities — inherit AVX512 implementations
|
|
115
|
+
**********************************************************/
|
|
116
|
+
|
|
117
|
+
template <>
|
|
118
|
+
struct SimilarityL2<SIMDLevel::AVX512_SPR> : SimilarityL2<SIMDLevel::AVX512> {
|
|
119
|
+
using SimilarityL2<SIMDLevel::AVX512>::SimilarityL2;
|
|
120
|
+
static constexpr SIMDLevel simd_level = SIMDLevel::AVX512_SPR;
|
|
121
|
+
};
|
|
122
|
+
|
|
123
|
+
template <>
|
|
124
|
+
struct SimilarityIP<SIMDLevel::AVX512_SPR> : SimilarityIP<SIMDLevel::AVX512> {
|
|
125
|
+
using SimilarityIP<SIMDLevel::AVX512>::SimilarityIP;
|
|
126
|
+
static constexpr SIMDLevel simd_level = SIMDLevel::AVX512_SPR;
|
|
127
|
+
};
|
|
128
|
+
|
|
129
|
+
/**********************************************************
|
|
130
|
+
* Generic DCTemplate — delegate to AVX512 implementations
|
|
131
|
+
**********************************************************/
|
|
132
|
+
|
|
133
|
+
template <class Quantizer, class Similarity>
|
|
134
|
+
struct DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512_SPR>
|
|
135
|
+
: DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512> {
|
|
136
|
+
using DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512>::DCTemplate;
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
/**********************************************************
|
|
140
|
+
* DistanceComputerByte: AVX512-VNNI
|
|
141
|
+
*
|
|
142
|
+
* Uses _mm512_dpbusd_epi32 to compute dot products of uint8 vectors
|
|
143
|
+
* at 64 bytes per instruction (4x throughput vs generic AVX512).
|
|
144
|
+
**********************************************************/
|
|
145
|
+
|
|
146
|
+
template <class Similarity>
|
|
147
|
+
struct DistanceComputerByte<Similarity, SIMDLevel::AVX512_SPR>
|
|
148
|
+
: SQDistanceComputer {
|
|
149
|
+
using Sim = Similarity;
|
|
150
|
+
|
|
151
|
+
int d;
|
|
152
|
+
std::vector<uint8_t> tmp;
|
|
153
|
+
|
|
154
|
+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
|
|
155
|
+
|
|
156
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
157
|
+
const {
|
|
158
|
+
if constexpr (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
159
|
+
__m512i accu = _mm512_setzero_si512();
|
|
160
|
+
int i = 0;
|
|
161
|
+
for (; i + 64 <= d; i += 64) {
|
|
162
|
+
__m512i c1 = _mm512_loadu_si512(code1 + i);
|
|
163
|
+
__m512i c2 = _mm512_loadu_si512(code2 + i);
|
|
164
|
+
|
|
165
|
+
__m512i c2_signed = _mm512_sub_epi8(c2, _mm512_set1_epi8(-128));
|
|
166
|
+
accu = _mm512_dpbusd_epi32(accu, c1, c2_signed);
|
|
167
|
+
}
|
|
168
|
+
int32_t sum_c1 = 0;
|
|
169
|
+
for (int j = 0; j < i; j++) {
|
|
170
|
+
sum_c1 += code1[j];
|
|
171
|
+
}
|
|
172
|
+
int32_t result = _mm512_reduce_add_epi32(accu) + 128 * sum_c1;
|
|
173
|
+
|
|
174
|
+
for (; i < d; i++) {
|
|
175
|
+
result += int(code1[i]) * code2[i];
|
|
176
|
+
}
|
|
177
|
+
return result;
|
|
178
|
+
} else {
|
|
179
|
+
__m512i accu = _mm512_setzero_si512();
|
|
180
|
+
int i = 0;
|
|
181
|
+
for (; i + 64 <= d; i += 64) {
|
|
182
|
+
__m256i c1_lo = _mm256_loadu_si256((const __m256i*)(code1 + i));
|
|
183
|
+
__m256i c2_lo = _mm256_loadu_si256((const __m256i*)(code2 + i));
|
|
184
|
+
__m256i c1_hi =
|
|
185
|
+
_mm256_loadu_si256((const __m256i*)(code1 + i + 32));
|
|
186
|
+
__m256i c2_hi =
|
|
187
|
+
_mm256_loadu_si256((const __m256i*)(code2 + i + 32));
|
|
188
|
+
|
|
189
|
+
__m512i c1_16_lo = _mm512_cvtepu8_epi16(c1_lo);
|
|
190
|
+
__m512i c2_16_lo = _mm512_cvtepu8_epi16(c2_lo);
|
|
191
|
+
__m512i diff_lo = _mm512_sub_epi16(c1_16_lo, c2_16_lo);
|
|
192
|
+
|
|
193
|
+
__m512i c1_16_hi = _mm512_cvtepu8_epi16(c1_hi);
|
|
194
|
+
__m512i c2_16_hi = _mm512_cvtepu8_epi16(c2_hi);
|
|
195
|
+
__m512i diff_hi = _mm512_sub_epi16(c1_16_hi, c2_16_hi);
|
|
196
|
+
|
|
197
|
+
accu = _mm512_add_epi32(
|
|
198
|
+
accu, _mm512_madd_epi16(diff_lo, diff_lo));
|
|
199
|
+
accu = _mm512_add_epi32(
|
|
200
|
+
accu, _mm512_madd_epi16(diff_hi, diff_hi));
|
|
201
|
+
}
|
|
202
|
+
for (; i + 32 <= d; i += 32) {
|
|
203
|
+
__m256i c1v = _mm256_loadu_si256((const __m256i*)(code1 + i));
|
|
204
|
+
__m256i c2v = _mm256_loadu_si256((const __m256i*)(code2 + i));
|
|
205
|
+
__m512i c1_16 = _mm512_cvtepu8_epi16(c1v);
|
|
206
|
+
__m512i c2_16 = _mm512_cvtepu8_epi16(c2v);
|
|
207
|
+
__m512i diff = _mm512_sub_epi16(c1_16, c2_16);
|
|
208
|
+
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
|
|
209
|
+
}
|
|
210
|
+
int32_t result = _mm512_reduce_add_epi32(accu);
|
|
211
|
+
|
|
212
|
+
for (; i < d; i++) {
|
|
213
|
+
int diff = int(code1[i]) - code2[i];
|
|
214
|
+
result += diff * diff;
|
|
215
|
+
}
|
|
216
|
+
return result;
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
void set_query(const float* x) final {
|
|
221
|
+
for (int i = 0; i < d; i++) {
|
|
222
|
+
tmp[i] = int(x[i]);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
|
227
|
+
set_query(x);
|
|
228
|
+
return compute_code_distance(tmp.data(), code);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
232
|
+
return compute_code_distance(
|
|
233
|
+
codes + i * code_size, codes + j * code_size);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
float query_to_code(const uint8_t* code) const final {
|
|
237
|
+
return compute_code_distance(tmp.data(), code);
|
|
238
|
+
}
|
|
239
|
+
};
|
|
240
|
+
|
|
241
|
+
/**********************************************************
|
|
242
|
+
* DistanceComputerByteSigned: AVX512_SPR specialization for
|
|
243
|
+
* QT_8bit_direct_signed.
|
|
244
|
+
*
|
|
245
|
+
* Storage convention (see Quantizer8bitDirectSigned):
|
|
246
|
+
* stored_byte = value + 128, i.e. value = stored_byte - 128
|
|
247
|
+
*
|
|
248
|
+
* L2: (s_a - 128) - (s_b - 128) == s_a - s_b, so the unsigned
|
|
249
|
+
* widened-madd kernel is bit-exact for the signed variant.
|
|
250
|
+
*
|
|
251
|
+
* IP: (s_a - 128) * (s_b - 128)
|
|
252
|
+
* = s_a*s_b - 128*(s_a + s_b) + 16384
|
|
253
|
+
* summed over d components:
|
|
254
|
+
* sum_ip_signed = sum_ip_unsigned
|
|
255
|
+
* - 128 * (sum(s_a) + sum(s_b))
|
|
256
|
+
* + 16384 * d
|
|
257
|
+
* sum(s_a), sum(s_b) are cheap via _mm512_sad_epu8 against zero.
|
|
258
|
+
**********************************************************/
|
|
259
|
+
|
|
260
|
+
template <class Similarity>
|
|
261
|
+
struct DistanceComputerByteSigned<Similarity, SIMDLevel::AVX512_SPR>
|
|
262
|
+
: SQDistanceComputer {
|
|
263
|
+
using Sim = Similarity;
|
|
264
|
+
|
|
265
|
+
int d;
|
|
266
|
+
std::vector<uint8_t> tmp;
|
|
267
|
+
|
|
268
|
+
DistanceComputerByteSigned(int d, const std::vector<float>&)
|
|
269
|
+
: d(d), tmp(d) {}
|
|
270
|
+
|
|
271
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
272
|
+
const {
|
|
273
|
+
if constexpr (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
274
|
+
__m512i accu = _mm512_setzero_si512();
|
|
275
|
+
__m512i sum_a = _mm512_setzero_si512();
|
|
276
|
+
__m512i sum_b = _mm512_setzero_si512();
|
|
277
|
+
const __m512i zero = _mm512_setzero_si512();
|
|
278
|
+
const __m512i bias = _mm512_set1_epi8(-128);
|
|
279
|
+
|
|
280
|
+
int i = 0;
|
|
281
|
+
for (; i + 64 <= d; i += 64) {
|
|
282
|
+
__m512i c1 = _mm512_loadu_si512(code1 + i);
|
|
283
|
+
__m512i c2 = _mm512_loadu_si512(code2 + i);
|
|
284
|
+
|
|
285
|
+
sum_a = _mm512_add_epi64(sum_a, _mm512_sad_epu8(c1, zero));
|
|
286
|
+
sum_b = _mm512_add_epi64(sum_b, _mm512_sad_epu8(c2, zero));
|
|
287
|
+
|
|
288
|
+
__m512i c2_signed = _mm512_sub_epi8(c2, bias);
|
|
289
|
+
accu = _mm512_dpbusd_epi32(accu, c1, c2_signed);
|
|
290
|
+
}
|
|
291
|
+
int32_t sum_c1_for_bias = int32_t(_mm512_reduce_add_epi64(sum_a));
|
|
292
|
+
int32_t result =
|
|
293
|
+
_mm512_reduce_add_epi32(accu) + 128 * sum_c1_for_bias;
|
|
294
|
+
|
|
295
|
+
int32_t tail_sum_a = 0, tail_sum_b = 0;
|
|
296
|
+
for (; i < d; ++i) {
|
|
297
|
+
result += int32_t(code1[i]) * int32_t(code2[i]);
|
|
298
|
+
tail_sum_a += code1[i];
|
|
299
|
+
tail_sum_b += code2[i];
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
int32_t total_sum_a = sum_c1_for_bias + tail_sum_a;
|
|
303
|
+
int32_t total_sum_b =
|
|
304
|
+
int32_t(_mm512_reduce_add_epi64(sum_b)) + tail_sum_b;
|
|
305
|
+
result -= 128 * (total_sum_a + total_sum_b);
|
|
306
|
+
result += 16384 * d;
|
|
307
|
+
return result;
|
|
308
|
+
} else {
|
|
309
|
+
__m512i accu = _mm512_setzero_si512();
|
|
310
|
+
int i = 0;
|
|
311
|
+
for (; i + 64 <= d; i += 64) {
|
|
312
|
+
__m256i c1_lo = _mm256_loadu_si256((const __m256i*)(code1 + i));
|
|
313
|
+
__m256i c2_lo = _mm256_loadu_si256((const __m256i*)(code2 + i));
|
|
314
|
+
__m256i c1_hi =
|
|
315
|
+
_mm256_loadu_si256((const __m256i*)(code1 + i + 32));
|
|
316
|
+
__m256i c2_hi =
|
|
317
|
+
_mm256_loadu_si256((const __m256i*)(code2 + i + 32));
|
|
318
|
+
__m512i diff_lo = _mm512_sub_epi16(
|
|
319
|
+
_mm512_cvtepu8_epi16(c1_lo),
|
|
320
|
+
_mm512_cvtepu8_epi16(c2_lo));
|
|
321
|
+
__m512i diff_hi = _mm512_sub_epi16(
|
|
322
|
+
_mm512_cvtepu8_epi16(c1_hi),
|
|
323
|
+
_mm512_cvtepu8_epi16(c2_hi));
|
|
324
|
+
accu = _mm512_add_epi32(
|
|
325
|
+
accu, _mm512_madd_epi16(diff_lo, diff_lo));
|
|
326
|
+
accu = _mm512_add_epi32(
|
|
327
|
+
accu, _mm512_madd_epi16(diff_hi, diff_hi));
|
|
328
|
+
}
|
|
329
|
+
for (; i + 32 <= d; i += 32) {
|
|
330
|
+
__m256i c1v = _mm256_loadu_si256((const __m256i*)(code1 + i));
|
|
331
|
+
__m256i c2v = _mm256_loadu_si256((const __m256i*)(code2 + i));
|
|
332
|
+
__m512i diff = _mm512_sub_epi16(
|
|
333
|
+
_mm512_cvtepu8_epi16(c1v), _mm512_cvtepu8_epi16(c2v));
|
|
334
|
+
accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
|
|
335
|
+
}
|
|
336
|
+
int32_t result = _mm512_reduce_add_epi32(accu);
|
|
337
|
+
for (; i < d; ++i) {
|
|
338
|
+
int32_t diff = int32_t(code1[i]) - int32_t(code2[i]);
|
|
339
|
+
result += diff * diff;
|
|
340
|
+
}
|
|
341
|
+
return result;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
void set_query(const float* x) final {
|
|
346
|
+
for (int i = 0; i < d; ++i) {
|
|
347
|
+
tmp[i] = uint8_t(int(x[i]) + 128);
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
|
352
|
+
set_query(x);
|
|
353
|
+
return compute_code_distance(tmp.data(), code);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
357
|
+
return compute_code_distance(
|
|
358
|
+
codes + i * code_size, codes + j * code_size);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
float query_to_code(const uint8_t* code) const final {
|
|
362
|
+
return compute_code_distance(tmp.data(), code);
|
|
363
|
+
}
|
|
364
|
+
};
|
|
365
|
+
|
|
366
|
+
/**********************************************************
|
|
367
|
+
* BF16 native distance helpers using VDPBF16PS
|
|
368
|
+
**********************************************************/
|
|
369
|
+
|
|
370
|
+
static FAISS_ALWAYS_INLINE float bf16_vdpbf16ps(
|
|
371
|
+
const uint16_t* a,
|
|
372
|
+
const uint16_t* b,
|
|
373
|
+
size_t d) {
|
|
374
|
+
__m512 acc = _mm512_setzero_ps();
|
|
375
|
+
size_t i = 0;
|
|
376
|
+
for (; i + 32 <= d; i += 32) {
|
|
377
|
+
__m512bh va = (__m512bh)_mm512_loadu_epi16(a + i);
|
|
378
|
+
__m512bh vb = (__m512bh)_mm512_loadu_epi16(b + i);
|
|
379
|
+
acc = _mm512_dpbf16_ps(acc, va, vb);
|
|
380
|
+
}
|
|
381
|
+
// Remainder: 16 elements (d % 16 == 0 but may not be % 32)
|
|
382
|
+
if (i < d) {
|
|
383
|
+
__m256i a_lo = _mm256_loadu_epi16(a + i);
|
|
384
|
+
__m256i b_lo = _mm256_loadu_epi16(b + i);
|
|
385
|
+
__m512bh va =
|
|
386
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), a_lo, 0);
|
|
387
|
+
__m512bh vb =
|
|
388
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), b_lo, 0);
|
|
389
|
+
acc = _mm512_dpbf16_ps(acc, va, vb);
|
|
390
|
+
}
|
|
391
|
+
return _mm512_reduce_add_ps(acc);
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
static FAISS_ALWAYS_INLINE float bf16_L2_asymmetric(
|
|
395
|
+
const uint16_t* query_bf16,
|
|
396
|
+
const uint16_t* code,
|
|
397
|
+
size_t d) {
|
|
398
|
+
__m512 acc_qc = _mm512_setzero_ps();
|
|
399
|
+
__m512 acc_cc = _mm512_setzero_ps();
|
|
400
|
+
size_t i = 0;
|
|
401
|
+
for (; i + 32 <= d; i += 32) {
|
|
402
|
+
__m512bh vq = (__m512bh)_mm512_loadu_epi16(query_bf16 + i);
|
|
403
|
+
__m512bh vc = (__m512bh)_mm512_loadu_epi16(code + i);
|
|
404
|
+
acc_qc = _mm512_dpbf16_ps(acc_qc, vq, vc);
|
|
405
|
+
acc_cc = _mm512_dpbf16_ps(acc_cc, vc, vc);
|
|
406
|
+
}
|
|
407
|
+
if (i < d) {
|
|
408
|
+
__m256i q_lo = _mm256_loadu_epi16(query_bf16 + i);
|
|
409
|
+
__m256i c_lo = _mm256_loadu_epi16(code + i);
|
|
410
|
+
__m512bh vq =
|
|
411
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), q_lo, 0);
|
|
412
|
+
__m512bh vc =
|
|
413
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), c_lo, 0);
|
|
414
|
+
acc_qc = _mm512_dpbf16_ps(acc_qc, vq, vc);
|
|
415
|
+
acc_cc = _mm512_dpbf16_ps(acc_cc, vc, vc);
|
|
416
|
+
}
|
|
417
|
+
float dot_qc = _mm512_reduce_add_ps(acc_qc);
|
|
418
|
+
float norm_c = _mm512_reduce_add_ps(acc_cc);
|
|
419
|
+
return -2.0f * dot_qc + norm_c;
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
static FAISS_ALWAYS_INLINE float bf16_L2_symmetric(
|
|
423
|
+
const uint16_t* a,
|
|
424
|
+
const uint16_t* b,
|
|
425
|
+
size_t d) {
|
|
426
|
+
__m512 acc_ab = _mm512_setzero_ps();
|
|
427
|
+
__m512 acc_aa = _mm512_setzero_ps();
|
|
428
|
+
__m512 acc_bb = _mm512_setzero_ps();
|
|
429
|
+
size_t i = 0;
|
|
430
|
+
for (; i + 32 <= d; i += 32) {
|
|
431
|
+
__m512bh va = (__m512bh)_mm512_loadu_epi16(a + i);
|
|
432
|
+
__m512bh vb = (__m512bh)_mm512_loadu_epi16(b + i);
|
|
433
|
+
acc_ab = _mm512_dpbf16_ps(acc_ab, va, vb);
|
|
434
|
+
acc_aa = _mm512_dpbf16_ps(acc_aa, va, va);
|
|
435
|
+
acc_bb = _mm512_dpbf16_ps(acc_bb, vb, vb);
|
|
436
|
+
}
|
|
437
|
+
if (i < d) {
|
|
438
|
+
__m256i a_lo = _mm256_loadu_epi16(a + i);
|
|
439
|
+
__m256i b_lo = _mm256_loadu_epi16(b + i);
|
|
440
|
+
__m512bh va =
|
|
441
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), a_lo, 0);
|
|
442
|
+
__m512bh vb =
|
|
443
|
+
(__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), b_lo, 0);
|
|
444
|
+
acc_ab = _mm512_dpbf16_ps(acc_ab, va, vb);
|
|
445
|
+
acc_aa = _mm512_dpbf16_ps(acc_aa, va, va);
|
|
446
|
+
acc_bb = _mm512_dpbf16_ps(acc_bb, vb, vb);
|
|
447
|
+
}
|
|
448
|
+
return _mm512_reduce_add_ps(acc_aa) - 2.0f * _mm512_reduce_add_ps(acc_ab) +
|
|
449
|
+
_mm512_reduce_add_ps(acc_bb);
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
/**********************************************************
|
|
453
|
+
* BF16 + Inner Product distance computer (SPR)
|
|
454
|
+
**********************************************************/
|
|
455
|
+
|
|
456
|
+
struct DCBF16_IP : SQDistanceComputer {
|
|
457
|
+
using Sim = SimilarityIP<SIMDLevel::AVX512_SPR>;
|
|
458
|
+
|
|
459
|
+
size_t d;
|
|
460
|
+
std::vector<uint16_t> query_bf16;
|
|
461
|
+
|
|
462
|
+
DCBF16_IP(size_t d, const std::vector<float>&) : d(d), query_bf16(d) {}
|
|
463
|
+
|
|
464
|
+
void set_query(const float* x) final {
|
|
465
|
+
q = x;
|
|
466
|
+
encode_bf16_simd(x, query_bf16.data(), d);
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
float query_to_code(const uint8_t* code) const final {
|
|
470
|
+
return bf16_vdpbf16ps(query_bf16.data(), (const uint16_t*)code, d);
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
474
|
+
return bf16_vdpbf16ps(
|
|
475
|
+
(const uint16_t*)(codes + i * code_size),
|
|
476
|
+
(const uint16_t*)(codes + j * code_size),
|
|
477
|
+
d);
|
|
478
|
+
}
|
|
479
|
+
};
|
|
480
|
+
|
|
481
|
+
/**********************************************************
|
|
482
|
+
* BF16 + L2 distance computer (SPR)
|
|
483
|
+
**********************************************************/
|
|
484
|
+
|
|
485
|
+
struct DCBF16_L2 : SQDistanceComputer {
|
|
486
|
+
using Sim = SimilarityL2<SIMDLevel::AVX512_SPR>;
|
|
487
|
+
|
|
488
|
+
size_t d;
|
|
489
|
+
std::vector<uint16_t> query_bf16;
|
|
490
|
+
float query_norm_sq;
|
|
491
|
+
|
|
492
|
+
DCBF16_L2(size_t d, const std::vector<float>&)
|
|
493
|
+
: d(d), query_bf16(d), query_norm_sq(0) {}
|
|
494
|
+
|
|
495
|
+
void set_query(const float* x) final {
|
|
496
|
+
q = x;
|
|
497
|
+
encode_bf16_simd(x, query_bf16.data(), d);
|
|
498
|
+
query_norm_sq = bf16_vdpbf16ps(query_bf16.data(), query_bf16.data(), d);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
float query_to_code(const uint8_t* code) const final {
|
|
502
|
+
return query_norm_sq +
|
|
503
|
+
bf16_L2_asymmetric(query_bf16.data(), (const uint16_t*)code, d);
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
507
|
+
return bf16_L2_symmetric(
|
|
508
|
+
(const uint16_t*)(codes + i * code_size),
|
|
509
|
+
(const uint16_t*)(codes + j * code_size),
|
|
510
|
+
d);
|
|
511
|
+
}
|
|
512
|
+
};
|
|
513
|
+
|
|
514
|
+
template <>
|
|
515
|
+
struct DCTemplate<
|
|
516
|
+
QuantizerBF16<SIMDLevel::AVX512_SPR>,
|
|
517
|
+
SimilarityIP<SIMDLevel::AVX512_SPR>,
|
|
518
|
+
SIMDLevel::AVX512_SPR> : DCBF16_IP {
|
|
519
|
+
using Sim = SimilarityIP<SIMDLevel::AVX512_SPR>;
|
|
520
|
+
using DCBF16_IP::DCBF16_IP;
|
|
521
|
+
};
|
|
522
|
+
|
|
523
|
+
template <>
|
|
524
|
+
struct DCTemplate<
|
|
525
|
+
QuantizerBF16<SIMDLevel::AVX512_SPR>,
|
|
526
|
+
SimilarityL2<SIMDLevel::AVX512_SPR>,
|
|
527
|
+
SIMDLevel::AVX512_SPR> : DCBF16_L2 {
|
|
528
|
+
using Sim = SimilarityL2<SIMDLevel::AVX512_SPR>;
|
|
529
|
+
using DCBF16_L2::DCBF16_L2;
|
|
530
|
+
};
|
|
531
|
+
|
|
532
|
+
/**********************************************************
|
|
533
|
+
* turboq_masked_sum — delegate to AVX512 implementation
|
|
534
|
+
**********************************************************/
|
|
535
|
+
|
|
536
|
+
template <SIMDLevel SL0>
|
|
537
|
+
float turboq_masked_sum(const float* arr, const uint8_t* bits, size_t d);
|
|
538
|
+
|
|
539
|
+
template <>
|
|
540
|
+
float turboq_masked_sum<SIMDLevel::AVX512>(
|
|
541
|
+
const float* arr,
|
|
542
|
+
const uint8_t* bits,
|
|
543
|
+
size_t d);
|
|
544
|
+
|
|
545
|
+
template <>
|
|
546
|
+
float turboq_masked_sum<SIMDLevel::AVX512_SPR>(
|
|
547
|
+
const float* arr,
|
|
548
|
+
const uint8_t* bits,
|
|
549
|
+
size_t d) {
|
|
550
|
+
return turboq_masked_sum<SIMDLevel::AVX512>(arr, bits, d);
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
} // namespace scalar_quantizer
|
|
554
|
+
} // namespace faiss
|
|
555
|
+
|
|
556
|
+
#define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX512_SPR
|
|
557
|
+
#include <faiss/impl/scalar_quantizer/sq-dispatch.h>
|
|
558
|
+
|
|
559
|
+
#endif // COMPILE_SIMD_AVX512_SPR
|