faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -65,42 +65,65 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
|
|
65
65
|
*/
|
66
66
|
|
67
67
|
struct Codec8bit {
|
68
|
-
static void encode_component(
|
68
|
+
static FAISS_ALWAYS_INLINE void encode_component(
|
69
|
+
float x,
|
70
|
+
uint8_t* code,
|
71
|
+
int i) {
|
69
72
|
code[i] = (int)(255 * x);
|
70
73
|
}
|
71
74
|
|
72
|
-
static float decode_component(
|
75
|
+
static FAISS_ALWAYS_INLINE float decode_component(
|
76
|
+
const uint8_t* code,
|
77
|
+
int i) {
|
73
78
|
return (code[i] + 0.5f) / 255.0f;
|
74
79
|
}
|
75
80
|
|
76
81
|
#ifdef __AVX2__
|
77
|
-
static __m256
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
__m256i
|
83
|
-
|
84
|
-
__m256
|
85
|
-
__m256
|
86
|
-
|
87
|
-
|
88
|
-
|
82
|
+
static FAISS_ALWAYS_INLINE __m256
|
83
|
+
decode_8_components(const uint8_t* code, int i) {
|
84
|
+
const uint64_t c8 = *(uint64_t*)(code + i);
|
85
|
+
|
86
|
+
const __m128i i8 = _mm_set1_epi64x(c8);
|
87
|
+
const __m256i i32 = _mm256_cvtepu8_epi32(i8);
|
88
|
+
const __m256 f8 = _mm256_cvtepi32_ps(i32);
|
89
|
+
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
|
90
|
+
const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
|
91
|
+
return _mm256_fmadd_ps(f8, one_255, half_one_255);
|
92
|
+
}
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#ifdef __aarch64__
|
96
|
+
static FAISS_ALWAYS_INLINE float32x4x2_t
|
97
|
+
decode_8_components(const uint8_t* code, int i) {
|
98
|
+
float32_t result[8] = {};
|
99
|
+
for (size_t j = 0; j < 8; j++) {
|
100
|
+
result[j] = decode_component(code, i + j);
|
101
|
+
}
|
102
|
+
float32x4_t res1 = vld1q_f32(result);
|
103
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
104
|
+
float32x4x2_t res = vzipq_f32(res1, res2);
|
105
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
89
106
|
}
|
90
107
|
#endif
|
91
108
|
};
|
92
109
|
|
93
110
|
struct Codec4bit {
|
94
|
-
static void encode_component(
|
111
|
+
static FAISS_ALWAYS_INLINE void encode_component(
|
112
|
+
float x,
|
113
|
+
uint8_t* code,
|
114
|
+
int i) {
|
95
115
|
code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
|
96
116
|
}
|
97
117
|
|
98
|
-
static float decode_component(
|
118
|
+
static FAISS_ALWAYS_INLINE float decode_component(
|
119
|
+
const uint8_t* code,
|
120
|
+
int i) {
|
99
121
|
return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
|
100
122
|
}
|
101
123
|
|
102
124
|
#ifdef __AVX2__
|
103
|
-
static __m256
|
125
|
+
static FAISS_ALWAYS_INLINE __m256
|
126
|
+
decode_8_components(const uint8_t* code, int i) {
|
104
127
|
uint32_t c4 = *(uint32_t*)(code + (i >> 1));
|
105
128
|
uint32_t mask = 0x0f0f0f0f;
|
106
129
|
uint32_t c4ev = c4 & mask;
|
@@ -120,10 +143,27 @@ struct Codec4bit {
|
|
120
143
|
return _mm256_mul_ps(f8, one_255);
|
121
144
|
}
|
122
145
|
#endif
|
146
|
+
|
147
|
+
#ifdef __aarch64__
|
148
|
+
static FAISS_ALWAYS_INLINE float32x4x2_t
|
149
|
+
decode_8_components(const uint8_t* code, int i) {
|
150
|
+
float32_t result[8] = {};
|
151
|
+
for (size_t j = 0; j < 8; j++) {
|
152
|
+
result[j] = decode_component(code, i + j);
|
153
|
+
}
|
154
|
+
float32x4_t res1 = vld1q_f32(result);
|
155
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
156
|
+
float32x4x2_t res = vzipq_f32(res1, res2);
|
157
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
158
|
+
}
|
159
|
+
#endif
|
123
160
|
};
|
124
161
|
|
125
162
|
struct Codec6bit {
|
126
|
-
static void encode_component(
|
163
|
+
static FAISS_ALWAYS_INLINE void encode_component(
|
164
|
+
float x,
|
165
|
+
uint8_t* code,
|
166
|
+
int i) {
|
127
167
|
int bits = (int)(x * 63.0);
|
128
168
|
code += (i >> 2) * 3;
|
129
169
|
switch (i & 3) {
|
@@ -144,7 +184,9 @@ struct Codec6bit {
|
|
144
184
|
}
|
145
185
|
}
|
146
186
|
|
147
|
-
static float decode_component(
|
187
|
+
static FAISS_ALWAYS_INLINE float decode_component(
|
188
|
+
const uint8_t* code,
|
189
|
+
int i) {
|
148
190
|
uint8_t bits;
|
149
191
|
code += (i >> 2) * 3;
|
150
192
|
switch (i & 3) {
|
@@ -170,7 +212,7 @@ struct Codec6bit {
|
|
170
212
|
|
171
213
|
/* Load 6 bytes that represent 8 6-bit values, return them as a
|
172
214
|
* 8*32 bit vector register */
|
173
|
-
static __m256i load6(const uint16_t* code16) {
|
215
|
+
static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
|
174
216
|
const __m128i perm = _mm_set_epi8(
|
175
217
|
-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
|
176
218
|
const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
|
@@ -189,18 +231,45 @@ struct Codec6bit {
|
|
189
231
|
return c5;
|
190
232
|
}
|
191
233
|
|
192
|
-
static __m256
|
234
|
+
static FAISS_ALWAYS_INLINE __m256
|
235
|
+
decode_8_components(const uint8_t* code, int i) {
|
236
|
+
// // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
|
237
|
+
// // for the reference, maybe, it becomes used oned day.
|
238
|
+
// const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
|
239
|
+
// const uint32_t* data32 = (const uint32_t*)data16;
|
240
|
+
// const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
|
241
|
+
// const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
|
242
|
+
// const __m128i i8 = _mm_set1_epi64x(vext);
|
243
|
+
// const __m256i i32 = _mm256_cvtepi8_epi32(i8);
|
244
|
+
// const __m256 f8 = _mm256_cvtepi32_ps(i32);
|
245
|
+
// const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
|
246
|
+
// const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
|
247
|
+
// return _mm256_fmadd_ps(f8, one_255, half_one_255);
|
248
|
+
|
193
249
|
__m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
|
194
250
|
__m256 f8 = _mm256_cvtepi32_ps(i8);
|
195
251
|
// this could also be done with bit manipulations but it is
|
196
252
|
// not obviously faster
|
197
|
-
__m256
|
198
|
-
|
199
|
-
|
200
|
-
return _mm256_mul_ps(f8, one_63);
|
253
|
+
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
|
254
|
+
const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
|
255
|
+
return _mm256_fmadd_ps(f8, one_255, half_one_255);
|
201
256
|
}
|
202
257
|
|
203
258
|
#endif
|
259
|
+
|
260
|
+
#ifdef __aarch64__
|
261
|
+
static FAISS_ALWAYS_INLINE float32x4x2_t
|
262
|
+
decode_8_components(const uint8_t* code, int i) {
|
263
|
+
float32_t result[8] = {};
|
264
|
+
for (size_t j = 0; j < 8; j++) {
|
265
|
+
result[j] = decode_component(code, i + j);
|
266
|
+
}
|
267
|
+
float32x4_t res1 = vld1q_f32(result);
|
268
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
269
|
+
float32x4x2_t res = vzipq_f32(res1, res2);
|
270
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
271
|
+
}
|
272
|
+
#endif
|
204
273
|
};
|
205
274
|
|
206
275
|
/*******************************************************************
|
@@ -242,7 +311,8 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
|
|
242
311
|
}
|
243
312
|
}
|
244
313
|
|
245
|
-
float reconstruct_component(const uint8_t* code, int i)
|
314
|
+
FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
|
315
|
+
const {
|
246
316
|
float xi = Codec::decode_component(code, i);
|
247
317
|
return vmin + xi * vdiff;
|
248
318
|
}
|
@@ -255,11 +325,36 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
|
|
255
325
|
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
256
326
|
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
|
257
327
|
|
258
|
-
__m256
|
328
|
+
FAISS_ALWAYS_INLINE __m256
|
329
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
259
330
|
__m256 xi = Codec::decode_8_components(code, i);
|
260
|
-
return
|
261
|
-
_mm256_set1_ps(this->vmin)
|
262
|
-
|
331
|
+
return _mm256_fmadd_ps(
|
332
|
+
xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin));
|
333
|
+
}
|
334
|
+
};
|
335
|
+
|
336
|
+
#endif
|
337
|
+
|
338
|
+
#ifdef __aarch64__
|
339
|
+
|
340
|
+
template <class Codec>
|
341
|
+
struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
|
342
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
343
|
+
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
|
344
|
+
|
345
|
+
FAISS_ALWAYS_INLINE float32x4x2_t
|
346
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
347
|
+
float32x4x2_t xi = Codec::decode_8_components(code, i);
|
348
|
+
float32x4x2_t res = vzipq_f32(
|
349
|
+
vfmaq_f32(
|
350
|
+
vdupq_n_f32(this->vmin),
|
351
|
+
xi.val[0],
|
352
|
+
vdupq_n_f32(this->vdiff)),
|
353
|
+
vfmaq_f32(
|
354
|
+
vdupq_n_f32(this->vmin),
|
355
|
+
xi.val[1],
|
356
|
+
vdupq_n_f32(this->vdiff)));
|
357
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
263
358
|
}
|
264
359
|
};
|
265
360
|
|
@@ -296,7 +391,8 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
|
|
296
391
|
}
|
297
392
|
}
|
298
393
|
|
299
|
-
float reconstruct_component(const uint8_t* code, int i)
|
394
|
+
FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
|
395
|
+
const {
|
300
396
|
float xi = Codec::decode_component(code, i);
|
301
397
|
return vmin[i] + xi * vdiff[i];
|
302
398
|
}
|
@@ -309,11 +405,36 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
|
|
309
405
|
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
310
406
|
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
|
311
407
|
|
312
|
-
__m256
|
408
|
+
FAISS_ALWAYS_INLINE __m256
|
409
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
313
410
|
__m256 xi = Codec::decode_8_components(code, i);
|
314
|
-
return
|
315
|
-
|
316
|
-
|
411
|
+
return _mm256_fmadd_ps(
|
412
|
+
xi,
|
413
|
+
_mm256_loadu_ps(this->vdiff + i),
|
414
|
+
_mm256_loadu_ps(this->vmin + i));
|
415
|
+
}
|
416
|
+
};
|
417
|
+
|
418
|
+
#endif
|
419
|
+
|
420
|
+
#ifdef __aarch64__
|
421
|
+
|
422
|
+
template <class Codec>
|
423
|
+
struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
|
424
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
425
|
+
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
|
426
|
+
|
427
|
+
FAISS_ALWAYS_INLINE float32x4x2_t
|
428
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
429
|
+
float32x4x2_t xi = Codec::decode_8_components(code, i);
|
430
|
+
|
431
|
+
float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
|
432
|
+
float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
|
433
|
+
|
434
|
+
float32x4x2_t res = vzipq_f32(
|
435
|
+
vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
|
436
|
+
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
|
437
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
317
438
|
}
|
318
439
|
};
|
319
440
|
|
@@ -344,7 +465,8 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
|
|
344
465
|
}
|
345
466
|
}
|
346
467
|
|
347
|
-
float reconstruct_component(const uint8_t* code, int i)
|
468
|
+
FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
|
469
|
+
const {
|
348
470
|
return decode_fp16(((uint16_t*)code)[i]);
|
349
471
|
}
|
350
472
|
};
|
@@ -356,7 +478,8 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
|
|
356
478
|
QuantizerFP16(size_t d, const std::vector<float>& trained)
|
357
479
|
: QuantizerFP16<1>(d, trained) {}
|
358
480
|
|
359
|
-
__m256
|
481
|
+
FAISS_ALWAYS_INLINE __m256
|
482
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
360
483
|
__m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
|
361
484
|
return _mm256_cvtph_ps(codei);
|
362
485
|
}
|
@@ -364,6 +487,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
|
|
364
487
|
|
365
488
|
#endif
|
366
489
|
|
490
|
+
#ifdef __aarch64__
|
491
|
+
|
492
|
+
template <>
|
493
|
+
struct QuantizerFP16<8> : QuantizerFP16<1> {
|
494
|
+
QuantizerFP16(size_t d, const std::vector<float>& trained)
|
495
|
+
: QuantizerFP16<1>(d, trained) {}
|
496
|
+
|
497
|
+
FAISS_ALWAYS_INLINE float32x4x2_t
|
498
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
499
|
+
uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
|
500
|
+
return vzipq_f32(
|
501
|
+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
|
502
|
+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
|
503
|
+
}
|
504
|
+
};
|
505
|
+
#endif
|
506
|
+
|
367
507
|
/*******************************************************************
|
368
508
|
* 8bit_direct quantizer
|
369
509
|
*******************************************************************/
|
@@ -390,7 +530,8 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
|
|
390
530
|
}
|
391
531
|
}
|
392
532
|
|
393
|
-
float reconstruct_component(const uint8_t* code, int i)
|
533
|
+
FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
|
534
|
+
const {
|
394
535
|
return code[i];
|
395
536
|
}
|
396
537
|
};
|
@@ -402,7 +543,8 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
|
402
543
|
Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
|
403
544
|
: Quantizer8bitDirect<1>(d, trained) {}
|
404
545
|
|
405
|
-
__m256
|
546
|
+
FAISS_ALWAYS_INLINE __m256
|
547
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
406
548
|
__m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
|
407
549
|
__m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
|
408
550
|
return _mm256_cvtepi32_ps(y8); // 8 * float32
|
@@ -411,6 +553,28 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
|
411
553
|
|
412
554
|
#endif
|
413
555
|
|
556
|
+
#ifdef __aarch64__
|
557
|
+
|
558
|
+
template <>
|
559
|
+
struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
560
|
+
Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
|
561
|
+
: Quantizer8bitDirect<1>(d, trained) {}
|
562
|
+
|
563
|
+
FAISS_ALWAYS_INLINE float32x4x2_t
|
564
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
565
|
+
float32_t result[8] = {};
|
566
|
+
for (size_t j = 0; j < 8; j++) {
|
567
|
+
result[j] = code[i + j];
|
568
|
+
}
|
569
|
+
float32x4_t res1 = vld1q_f32(result);
|
570
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
571
|
+
float32x4x2_t res = vzipq_f32(res1, res2);
|
572
|
+
return vuzpq_f32(res.val[0], res.val[1]);
|
573
|
+
}
|
574
|
+
};
|
575
|
+
|
576
|
+
#endif
|
577
|
+
|
414
578
|
template <int SIMDWIDTH>
|
415
579
|
ScalarQuantizer::SQuantizer* select_quantizer_1(
|
416
580
|
QuantizerType qtype,
|
@@ -486,7 +650,7 @@ void train_Uniform(
|
|
486
650
|
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
487
651
|
std::vector<float> x_copy(n);
|
488
652
|
memcpy(x_copy.data(), x, n * sizeof(*x));
|
489
|
-
// TODO just do a
|
653
|
+
// TODO just do a quickselect
|
490
654
|
std::sort(x_copy.begin(), x_copy.end());
|
491
655
|
int o = int(rs_arg * n);
|
492
656
|
if (o < 0)
|
@@ -632,22 +796,22 @@ struct SimilarityL2<1> {
|
|
632
796
|
|
633
797
|
float accu;
|
634
798
|
|
635
|
-
void begin() {
|
799
|
+
FAISS_ALWAYS_INLINE void begin() {
|
636
800
|
accu = 0;
|
637
801
|
yi = y;
|
638
802
|
}
|
639
803
|
|
640
|
-
void add_component(float x) {
|
804
|
+
FAISS_ALWAYS_INLINE void add_component(float x) {
|
641
805
|
float tmp = *yi++ - x;
|
642
806
|
accu += tmp * tmp;
|
643
807
|
}
|
644
808
|
|
645
|
-
void add_component_2(float x1, float x2) {
|
809
|
+
FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
|
646
810
|
float tmp = x1 - x2;
|
647
811
|
accu += tmp * tmp;
|
648
812
|
}
|
649
813
|
|
650
|
-
float result() {
|
814
|
+
FAISS_ALWAYS_INLINE float result() {
|
651
815
|
return accu;
|
652
816
|
}
|
653
817
|
};
|
@@ -663,34 +827,89 @@ struct SimilarityL2<8> {
|
|
663
827
|
explicit SimilarityL2(const float* y) : y(y) {}
|
664
828
|
__m256 accu8;
|
665
829
|
|
666
|
-
void begin_8() {
|
830
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
667
831
|
accu8 = _mm256_setzero_ps();
|
668
832
|
yi = y;
|
669
833
|
}
|
670
834
|
|
671
|
-
void add_8_components(__m256 x) {
|
835
|
+
FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
|
672
836
|
__m256 yiv = _mm256_loadu_ps(yi);
|
673
837
|
yi += 8;
|
674
838
|
__m256 tmp = _mm256_sub_ps(yiv, x);
|
675
|
-
accu8 =
|
839
|
+
accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
|
676
840
|
}
|
677
841
|
|
678
|
-
void add_8_components_2(__m256 x, __m256
|
679
|
-
__m256 tmp = _mm256_sub_ps(
|
680
|
-
accu8 =
|
842
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) {
|
843
|
+
__m256 tmp = _mm256_sub_ps(y_2, x);
|
844
|
+
accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
|
681
845
|
}
|
682
846
|
|
683
|
-
float result_8() {
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
847
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
848
|
+
const __m128 sum = _mm_add_ps(
|
849
|
+
_mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
|
850
|
+
const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
|
851
|
+
const __m128 v1 = _mm_add_ps(sum, v0);
|
852
|
+
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
|
853
|
+
const __m128 v3 = _mm_add_ps(v1, v2);
|
854
|
+
return _mm_cvtss_f32(v3);
|
689
855
|
}
|
690
856
|
};
|
691
857
|
|
692
858
|
#endif
|
693
859
|
|
860
|
+
#ifdef __aarch64__
|
861
|
+
template <>
|
862
|
+
struct SimilarityL2<8> {
|
863
|
+
static constexpr int simdwidth = 8;
|
864
|
+
static constexpr MetricType metric_type = METRIC_L2;
|
865
|
+
|
866
|
+
const float *y, *yi;
|
867
|
+
explicit SimilarityL2(const float* y) : y(y) {}
|
868
|
+
float32x4x2_t accu8;
|
869
|
+
|
870
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
871
|
+
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
|
872
|
+
yi = y;
|
873
|
+
}
|
874
|
+
|
875
|
+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
|
876
|
+
float32x4x2_t yiv = vld1q_f32_x2(yi);
|
877
|
+
yi += 8;
|
878
|
+
|
879
|
+
float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
|
880
|
+
float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
|
881
|
+
|
882
|
+
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
|
883
|
+
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
|
884
|
+
|
885
|
+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
|
886
|
+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
|
887
|
+
}
|
888
|
+
|
889
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(
|
890
|
+
float32x4x2_t x,
|
891
|
+
float32x4x2_t y) {
|
892
|
+
float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
|
893
|
+
float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
|
894
|
+
|
895
|
+
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
|
896
|
+
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
|
897
|
+
|
898
|
+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
|
899
|
+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
|
900
|
+
}
|
901
|
+
|
902
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
903
|
+
float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
|
904
|
+
float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
|
905
|
+
|
906
|
+
float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
|
907
|
+
float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
|
908
|
+
return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
|
909
|
+
}
|
910
|
+
};
|
911
|
+
#endif
|
912
|
+
|
694
913
|
template <int SIMDWIDTH>
|
695
914
|
struct SimilarityIP {};
|
696
915
|
|
@@ -704,20 +923,20 @@ struct SimilarityIP<1> {
|
|
704
923
|
|
705
924
|
explicit SimilarityIP(const float* y) : y(y) {}
|
706
925
|
|
707
|
-
void begin() {
|
926
|
+
FAISS_ALWAYS_INLINE void begin() {
|
708
927
|
accu = 0;
|
709
928
|
yi = y;
|
710
929
|
}
|
711
930
|
|
712
|
-
void add_component(float x) {
|
931
|
+
FAISS_ALWAYS_INLINE void add_component(float x) {
|
713
932
|
accu += *yi++ * x;
|
714
933
|
}
|
715
934
|
|
716
|
-
void add_component_2(float x1, float x2) {
|
935
|
+
FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
|
717
936
|
accu += x1 * x2;
|
718
937
|
}
|
719
938
|
|
720
|
-
float result() {
|
939
|
+
FAISS_ALWAYS_INLINE float result() {
|
721
940
|
return accu;
|
722
941
|
}
|
723
942
|
};
|
@@ -737,27 +956,79 @@ struct SimilarityIP<8> {
|
|
737
956
|
|
738
957
|
__m256 accu8;
|
739
958
|
|
740
|
-
void begin_8() {
|
959
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
741
960
|
accu8 = _mm256_setzero_ps();
|
742
961
|
yi = y;
|
743
962
|
}
|
744
963
|
|
745
|
-
void add_8_components(__m256 x) {
|
964
|
+
FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
|
746
965
|
__m256 yiv = _mm256_loadu_ps(yi);
|
747
966
|
yi += 8;
|
748
|
-
accu8 =
|
967
|
+
accu8 = _mm256_fmadd_ps(yiv, x, accu8);
|
968
|
+
}
|
969
|
+
|
970
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) {
|
971
|
+
accu8 = _mm256_fmadd_ps(x1, x2, accu8);
|
972
|
+
}
|
973
|
+
|
974
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
975
|
+
const __m128 sum = _mm_add_ps(
|
976
|
+
_mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
|
977
|
+
const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
|
978
|
+
const __m128 v1 = _mm_add_ps(sum, v0);
|
979
|
+
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
|
980
|
+
const __m128 v3 = _mm_add_ps(v1, v2);
|
981
|
+
return _mm_cvtss_f32(v3);
|
982
|
+
}
|
983
|
+
};
|
984
|
+
#endif
|
985
|
+
|
986
|
+
#ifdef __aarch64__
|
987
|
+
|
988
|
+
template <>
|
989
|
+
struct SimilarityIP<8> {
|
990
|
+
static constexpr int simdwidth = 8;
|
991
|
+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
992
|
+
|
993
|
+
const float *y, *yi;
|
994
|
+
|
995
|
+
explicit SimilarityIP(const float* y) : y(y) {}
|
996
|
+
float32x4x2_t accu8;
|
997
|
+
|
998
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
999
|
+
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
|
1000
|
+
yi = y;
|
1001
|
+
}
|
1002
|
+
|
1003
|
+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
|
1004
|
+
float32x4x2_t yiv = vld1q_f32_x2(yi);
|
1005
|
+
yi += 8;
|
1006
|
+
|
1007
|
+
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
|
1008
|
+
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
|
1009
|
+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
|
1010
|
+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
|
749
1011
|
}
|
750
1012
|
|
751
|
-
void add_8_components_2(
|
752
|
-
|
1013
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(
|
1014
|
+
float32x4x2_t x1,
|
1015
|
+
float32x4x2_t x2) {
|
1016
|
+
float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
|
1017
|
+
float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
|
1018
|
+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
|
1019
|
+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
|
753
1020
|
}
|
754
1021
|
|
755
|
-
float result_8() {
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
1022
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
1023
|
+
float32x4x2_t sum_tmp = vzipq_f32(
|
1024
|
+
vpaddq_f32(accu8.val[0], accu8.val[0]),
|
1025
|
+
vpaddq_f32(accu8.val[1], accu8.val[1]));
|
1026
|
+
float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
|
1027
|
+
float32x4x2_t sum2_tmp = vzipq_f32(
|
1028
|
+
vpaddq_f32(sum.val[0], sum.val[0]),
|
1029
|
+
vpaddq_f32(sum.val[1], sum.val[1]));
|
1030
|
+
float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
|
1031
|
+
return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
|
761
1032
|
}
|
762
1033
|
};
|
763
1034
|
#endif
|
@@ -864,6 +1135,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
|
864
1135
|
|
865
1136
|
#endif
|
866
1137
|
|
1138
|
+
#ifdef __aarch64__
|
1139
|
+
|
1140
|
+
template <class Quantizer, class Similarity>
|
1141
|
+
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
1142
|
+
using Sim = Similarity;
|
1143
|
+
|
1144
|
+
Quantizer quant;
|
1145
|
+
|
1146
|
+
DCTemplate(size_t d, const std::vector<float>& trained)
|
1147
|
+
: quant(d, trained) {}
|
1148
|
+
float compute_distance(const float* x, const uint8_t* code) const {
|
1149
|
+
Similarity sim(x);
|
1150
|
+
sim.begin_8();
|
1151
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
1152
|
+
float32x4x2_t xi = quant.reconstruct_8_components(code, i);
|
1153
|
+
sim.add_8_components(xi);
|
1154
|
+
}
|
1155
|
+
return sim.result_8();
|
1156
|
+
}
|
1157
|
+
|
1158
|
+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
1159
|
+
const {
|
1160
|
+
Similarity sim(nullptr);
|
1161
|
+
sim.begin_8();
|
1162
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
1163
|
+
float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
|
1164
|
+
float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
|
1165
|
+
sim.add_8_components_2(x1, x2);
|
1166
|
+
}
|
1167
|
+
return sim.result_8();
|
1168
|
+
}
|
1169
|
+
|
1170
|
+
void set_query(const float* x) final {
|
1171
|
+
q = x;
|
1172
|
+
}
|
1173
|
+
|
1174
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
1175
|
+
return compute_code_distance(
|
1176
|
+
codes + i * code_size, codes + j * code_size);
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
float query_to_code(const uint8_t* code) const final {
|
1180
|
+
return compute_distance(q, code);
|
1181
|
+
}
|
1182
|
+
};
|
1183
|
+
#endif
|
1184
|
+
|
867
1185
|
/*******************************************************************
|
868
1186
|
* DistanceComputerByte: computes distances in the integer domain
|
869
1187
|
*******************************************************************/
|
@@ -980,6 +1298,54 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
980
1298
|
|
981
1299
|
#endif
|
982
1300
|
|
1301
|
+
#ifdef __aarch64__
|
1302
|
+
|
1303
|
+
template <class Similarity>
|
1304
|
+
struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
1305
|
+
using Sim = Similarity;
|
1306
|
+
|
1307
|
+
int d;
|
1308
|
+
std::vector<uint8_t> tmp;
|
1309
|
+
|
1310
|
+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
|
1311
|
+
|
1312
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
1313
|
+
const {
|
1314
|
+
int accu = 0;
|
1315
|
+
for (int i = 0; i < d; i++) {
|
1316
|
+
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
1317
|
+
accu += int(code1[i]) * code2[i];
|
1318
|
+
} else {
|
1319
|
+
int diff = int(code1[i]) - code2[i];
|
1320
|
+
accu += diff * diff;
|
1321
|
+
}
|
1322
|
+
}
|
1323
|
+
return accu;
|
1324
|
+
}
|
1325
|
+
|
1326
|
+
void set_query(const float* x) final {
|
1327
|
+
for (int i = 0; i < d; i++) {
|
1328
|
+
tmp[i] = int(x[i]);
|
1329
|
+
}
|
1330
|
+
}
|
1331
|
+
|
1332
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
1333
|
+
set_query(x);
|
1334
|
+
return compute_code_distance(tmp.data(), code);
|
1335
|
+
}
|
1336
|
+
|
1337
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
1338
|
+
return compute_code_distance(
|
1339
|
+
codes + i * code_size, codes + j * code_size);
|
1340
|
+
}
|
1341
|
+
|
1342
|
+
float query_to_code(const uint8_t* code) const final {
|
1343
|
+
return compute_code_distance(tmp.data(), code);
|
1344
|
+
}
|
1345
|
+
};
|
1346
|
+
|
1347
|
+
#endif
|
1348
|
+
|
983
1349
|
/*******************************************************************
|
984
1350
|
* select_distance_computer: runtime selection of template
|
985
1351
|
* specialization
|
@@ -1115,34 +1481,8 @@ void ScalarQuantizer::train(size_t n, const float* x) {
|
|
1115
1481
|
}
|
1116
1482
|
}
|
1117
1483
|
|
1118
|
-
void ScalarQuantizer::train_residual(
|
1119
|
-
size_t n,
|
1120
|
-
const float* x,
|
1121
|
-
Index* quantizer,
|
1122
|
-
bool by_residual,
|
1123
|
-
bool verbose) {
|
1124
|
-
const float* x_in = x;
|
1125
|
-
|
1126
|
-
// 100k points more than enough
|
1127
|
-
x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
|
1128
|
-
|
1129
|
-
ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
|
1130
|
-
|
1131
|
-
if (by_residual) {
|
1132
|
-
std::vector<idx_t> idx(n);
|
1133
|
-
quantizer->assign(n, x, idx.data());
|
1134
|
-
|
1135
|
-
std::vector<float> residuals(n * d);
|
1136
|
-
quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
|
1137
|
-
|
1138
|
-
train(n, residuals.data());
|
1139
|
-
} else {
|
1140
|
-
train(n, x);
|
1141
|
-
}
|
1142
|
-
}
|
1143
|
-
|
1144
1484
|
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
|
1145
|
-
#
|
1485
|
+
#if defined(USE_F16C) || defined(__aarch64__)
|
1146
1486
|
if (d % 8 == 0) {
|
1147
1487
|
return select_quantizer_1<8>(qtype, d, trained);
|
1148
1488
|
} else
|
@@ -1173,7 +1513,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
|
|
1173
1513
|
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
|
1174
1514
|
MetricType metric) const {
|
1175
1515
|
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
1176
|
-
#
|
1516
|
+
#if defined(USE_F16C) || defined(__aarch64__)
|
1177
1517
|
if (d % 8 == 0) {
|
1178
1518
|
if (metric == METRIC_L2) {
|
1179
1519
|
return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
|
@@ -1204,7 +1544,6 @@ template <class DCClass, int use_sel>
|
|
1204
1544
|
struct IVFSQScannerIP : InvertedListScanner {
|
1205
1545
|
DCClass dc;
|
1206
1546
|
bool by_residual;
|
1207
|
-
const IDSelector* sel;
|
1208
1547
|
|
1209
1548
|
float accu0; /// added to all distances
|
1210
1549
|
|
@@ -1215,9 +1554,11 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
1215
1554
|
bool store_pairs,
|
1216
1555
|
const IDSelector* sel,
|
1217
1556
|
bool by_residual)
|
1218
|
-
: dc(d, trained), by_residual(by_residual),
|
1557
|
+
: dc(d, trained), by_residual(by_residual), accu0(0) {
|
1219
1558
|
this->store_pairs = store_pairs;
|
1559
|
+
this->sel = sel;
|
1220
1560
|
this->code_size = code_size;
|
1561
|
+
this->keep_max = true;
|
1221
1562
|
}
|
1222
1563
|
|
1223
1564
|
void set_query(const float* query) override {
|
@@ -1288,7 +1629,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
1288
1629
|
|
1289
1630
|
bool by_residual;
|
1290
1631
|
const Index* quantizer;
|
1291
|
-
const IDSelector* sel;
|
1292
1632
|
const float* x; /// current query
|
1293
1633
|
|
1294
1634
|
std::vector<float> tmp;
|
@@ -1304,10 +1644,10 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
1304
1644
|
: dc(d, trained),
|
1305
1645
|
by_residual(by_residual),
|
1306
1646
|
quantizer(quantizer),
|
1307
|
-
sel(sel),
|
1308
1647
|
x(nullptr),
|
1309
1648
|
tmp(d) {
|
1310
1649
|
this->store_pairs = store_pairs;
|
1650
|
+
this->sel = sel;
|
1311
1651
|
this->code_size = code_size;
|
1312
1652
|
}
|
1313
1653
|
|
@@ -1509,7 +1849,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
|
|
1509
1849
|
bool store_pairs,
|
1510
1850
|
const IDSelector* sel,
|
1511
1851
|
bool by_residual) const {
|
1512
|
-
#
|
1852
|
+
#if defined(USE_F16C) || defined(__aarch64__)
|
1513
1853
|
if (d % 8 == 0) {
|
1514
1854
|
return sel0_InvertedListScanner<8>(
|
1515
1855
|
mt, this, quantizer, store_pairs, sel, by_residual);
|