faiss 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +29 -3
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
|
@@ -23,6 +23,10 @@
|
|
|
23
23
|
#include <immintrin.h>
|
|
24
24
|
#endif
|
|
25
25
|
|
|
26
|
+
#ifdef __AVX2__
|
|
27
|
+
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
|
28
|
+
#endif
|
|
29
|
+
|
|
26
30
|
#ifdef __aarch64__
|
|
27
31
|
#include <arm_neon.h>
|
|
28
32
|
#endif
|
|
@@ -56,16 +60,6 @@ namespace faiss {
|
|
|
56
60
|
* Reference implementations
|
|
57
61
|
*/
|
|
58
62
|
|
|
59
|
-
float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
|
|
60
|
-
size_t i;
|
|
61
|
-
float res = 0;
|
|
62
|
-
for (i = 0; i < d; i++) {
|
|
63
|
-
const float tmp = x[i] - y[i];
|
|
64
|
-
res += tmp * tmp;
|
|
65
|
-
}
|
|
66
|
-
return res;
|
|
67
|
-
}
|
|
68
|
-
|
|
69
63
|
float fvec_L1_ref(const float* x, const float* y, size_t d) {
|
|
70
64
|
size_t i;
|
|
71
65
|
float res = 0;
|
|
@@ -85,22 +79,6 @@ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
|
|
|
85
79
|
return res;
|
|
86
80
|
}
|
|
87
81
|
|
|
88
|
-
float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
|
|
89
|
-
size_t i;
|
|
90
|
-
float res = 0;
|
|
91
|
-
for (i = 0; i < d; i++)
|
|
92
|
-
res += x[i] * y[i];
|
|
93
|
-
return res;
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
float fvec_norm_L2sqr_ref(const float* x, size_t d) {
|
|
97
|
-
size_t i;
|
|
98
|
-
double res = 0;
|
|
99
|
-
for (i = 0; i < d; i++)
|
|
100
|
-
res += x[i] * x[i];
|
|
101
|
-
return res;
|
|
102
|
-
}
|
|
103
|
-
|
|
104
82
|
void fvec_L2sqr_ny_ref(
|
|
105
83
|
float* dis,
|
|
106
84
|
const float* x,
|
|
@@ -203,6 +181,48 @@ void fvec_inner_products_ny_ref(
|
|
|
203
181
|
}
|
|
204
182
|
}
|
|
205
183
|
|
|
184
|
+
/*********************************************************
|
|
185
|
+
* Autovectorized implementations
|
|
186
|
+
*/
|
|
187
|
+
|
|
188
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
189
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
190
|
+
float res = 0.F;
|
|
191
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
192
|
+
for (size_t i = 0; i != d; ++i) {
|
|
193
|
+
res += x[i] * y[i];
|
|
194
|
+
}
|
|
195
|
+
return res;
|
|
196
|
+
}
|
|
197
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
198
|
+
|
|
199
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
200
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
201
|
+
// the double in the _ref is suspected to be a typo. Some of the manual
|
|
202
|
+
// implementations this replaces used float.
|
|
203
|
+
float res = 0;
|
|
204
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
205
|
+
for (size_t i = 0; i != d; ++i) {
|
|
206
|
+
res += x[i] * x[i];
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
return res;
|
|
210
|
+
}
|
|
211
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
212
|
+
|
|
213
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
214
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
215
|
+
size_t i;
|
|
216
|
+
float res = 0;
|
|
217
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
218
|
+
for (i = 0; i < d; i++) {
|
|
219
|
+
const float tmp = x[i] - y[i];
|
|
220
|
+
res += tmp * tmp;
|
|
221
|
+
}
|
|
222
|
+
return res;
|
|
223
|
+
}
|
|
224
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
225
|
+
|
|
206
226
|
/*********************************************************
|
|
207
227
|
* SSE and AVX implementations
|
|
208
228
|
*/
|
|
@@ -225,25 +245,6 @@ static inline __m128 masked_read(int d, const float* x) {
|
|
|
225
245
|
// cannot use AVX2 _mm_mask_set1_epi32
|
|
226
246
|
}
|
|
227
247
|
|
|
228
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
229
|
-
__m128 mx;
|
|
230
|
-
__m128 msum1 = _mm_setzero_ps();
|
|
231
|
-
|
|
232
|
-
while (d >= 4) {
|
|
233
|
-
mx = _mm_loadu_ps(x);
|
|
234
|
-
x += 4;
|
|
235
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
236
|
-
d -= 4;
|
|
237
|
-
}
|
|
238
|
-
|
|
239
|
-
mx = masked_read(d, x);
|
|
240
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
241
|
-
|
|
242
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
243
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
244
|
-
return _mm_cvtss_f32(msum1);
|
|
245
|
-
}
|
|
246
|
-
|
|
247
248
|
namespace {
|
|
248
249
|
|
|
249
250
|
/// Function that does a component-wise operation between x and y
|
|
@@ -354,25 +355,25 @@ void fvec_op_ny_D4<ElementOpIP>(
|
|
|
354
355
|
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
355
356
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
356
357
|
|
|
357
|
-
const __m256i indices0 =
|
|
358
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
359
|
-
|
|
360
358
|
for (i = 0; i < ny8 * 8; i += 8) {
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
359
|
+
// load 8x4 matrix and transpose it in registers.
|
|
360
|
+
// the typical bottleneck is memory access, so
|
|
361
|
+
// let's trade instructions for the bandwidth.
|
|
362
|
+
|
|
363
|
+
__m256 v0;
|
|
364
|
+
__m256 v1;
|
|
365
|
+
__m256 v2;
|
|
366
|
+
__m256 v3;
|
|
367
|
+
|
|
368
|
+
transpose_8x4(
|
|
369
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
370
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
371
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
372
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
373
|
+
v0,
|
|
374
|
+
v1,
|
|
375
|
+
v2,
|
|
376
|
+
v3);
|
|
376
377
|
|
|
377
378
|
// compute distances
|
|
378
379
|
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
@@ -380,15 +381,7 @@ void fvec_op_ny_D4<ElementOpIP>(
|
|
|
380
381
|
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
381
382
|
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
382
383
|
|
|
383
|
-
//
|
|
384
|
-
// (x[1] * y[(i * 8 + 0) * 4 + 1]) +
|
|
385
|
-
// (x[2] * y[(i * 8 + 0) * 4 + 2]) +
|
|
386
|
-
// (x[3] * y[(i * 8 + 0) * 4 + 3])
|
|
387
|
-
// ...
|
|
388
|
-
// distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
|
|
389
|
-
// (x[1] * y[(i * 8 + 7) * 4 + 1]) +
|
|
390
|
-
// (x[2] * y[(i * 8 + 7) * 4 + 2]) +
|
|
391
|
-
// (x[3] * y[(i * 8 + 7) * 4 + 3])
|
|
384
|
+
// store
|
|
392
385
|
_mm256_storeu_ps(dis + i, distances);
|
|
393
386
|
|
|
394
387
|
y += 32;
|
|
@@ -432,25 +425,25 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
|
432
425
|
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
433
426
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
434
427
|
|
|
435
|
-
const __m256i indices0 =
|
|
436
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
437
|
-
|
|
438
428
|
for (i = 0; i < ny8 * 8; i += 8) {
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
429
|
+
// load 8x4 matrix and transpose it in registers.
|
|
430
|
+
// the typical bottleneck is memory access, so
|
|
431
|
+
// let's trade instructions for the bandwidth.
|
|
432
|
+
|
|
433
|
+
__m256 v0;
|
|
434
|
+
__m256 v1;
|
|
435
|
+
__m256 v2;
|
|
436
|
+
__m256 v3;
|
|
437
|
+
|
|
438
|
+
transpose_8x4(
|
|
439
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
440
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
441
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
442
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
443
|
+
v0,
|
|
444
|
+
v1,
|
|
445
|
+
v2,
|
|
446
|
+
v3);
|
|
454
447
|
|
|
455
448
|
// compute differences
|
|
456
449
|
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
@@ -464,15 +457,7 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
|
464
457
|
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
465
458
|
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
466
459
|
|
|
467
|
-
//
|
|
468
|
-
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
469
|
-
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
470
|
-
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
471
|
-
// ...
|
|
472
|
-
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
473
|
-
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
474
|
-
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
475
|
-
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
460
|
+
// store
|
|
476
461
|
_mm256_storeu_ps(dis + i, distances);
|
|
477
462
|
|
|
478
463
|
y += 32;
|
|
@@ -583,6 +568,228 @@ void fvec_inner_products_ny(
|
|
|
583
568
|
}
|
|
584
569
|
|
|
585
570
|
#ifdef __AVX2__
|
|
571
|
+
template <size_t DIM>
|
|
572
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
|
573
|
+
float* distances,
|
|
574
|
+
const float* x,
|
|
575
|
+
const float* y,
|
|
576
|
+
const float* y_sqlen,
|
|
577
|
+
const size_t d_offset,
|
|
578
|
+
size_t ny) {
|
|
579
|
+
// current index being processed
|
|
580
|
+
size_t i = 0;
|
|
581
|
+
|
|
582
|
+
// squared length of x
|
|
583
|
+
float x_sqlen = 0;
|
|
584
|
+
;
|
|
585
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
586
|
+
x_sqlen += x[j] * x[j];
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
// process 8 vectors per loop.
|
|
590
|
+
const size_t ny8 = ny / 8;
|
|
591
|
+
|
|
592
|
+
if (ny8 > 0) {
|
|
593
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
594
|
+
__m256 m[DIM];
|
|
595
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
596
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
597
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
__m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
|
|
601
|
+
|
|
602
|
+
for (; i < ny8 * 8; i += 8) {
|
|
603
|
+
// collect dim 0 for 8 D4-vectors.
|
|
604
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
605
|
+
|
|
606
|
+
// compute dot products
|
|
607
|
+
// this is x^2 - 2x[0]*y[0]
|
|
608
|
+
__m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
|
|
609
|
+
|
|
610
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
611
|
+
// collect dim j for 8 D4-vectors.
|
|
612
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
613
|
+
dp = _mm256_fnmadd_ps(m[j], vj, dp);
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
// we've got x^2 - (2x, y) at this point
|
|
617
|
+
|
|
618
|
+
// y^2 - (2x, y) + x^2
|
|
619
|
+
__m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
620
|
+
|
|
621
|
+
_mm256_storeu_ps(distances + i, distances_v);
|
|
622
|
+
|
|
623
|
+
// scroll y and y_sqlen forward.
|
|
624
|
+
y += 8;
|
|
625
|
+
y_sqlen += 8;
|
|
626
|
+
}
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
if (i < ny) {
|
|
630
|
+
// process leftovers
|
|
631
|
+
for (; i < ny; i++) {
|
|
632
|
+
float dp = 0;
|
|
633
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
634
|
+
dp += x[j] * y[j * d_offset];
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
638
|
+
// lowest distance.
|
|
639
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
640
|
+
distances[i] = distance;
|
|
641
|
+
|
|
642
|
+
y += 1;
|
|
643
|
+
y_sqlen += 1;
|
|
644
|
+
}
|
|
645
|
+
}
|
|
646
|
+
}
|
|
647
|
+
#endif
|
|
648
|
+
|
|
649
|
+
void fvec_L2sqr_ny_transposed(
|
|
650
|
+
float* dis,
|
|
651
|
+
const float* x,
|
|
652
|
+
const float* y,
|
|
653
|
+
const float* y_sqlen,
|
|
654
|
+
size_t d,
|
|
655
|
+
size_t d_offset,
|
|
656
|
+
size_t ny) {
|
|
657
|
+
// optimized for a few special cases
|
|
658
|
+
|
|
659
|
+
#ifdef __AVX2__
|
|
660
|
+
#define DISPATCH(dval) \
|
|
661
|
+
case dval: \
|
|
662
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
|
663
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
|
664
|
+
|
|
665
|
+
switch (d) {
|
|
666
|
+
DISPATCH(1)
|
|
667
|
+
DISPATCH(2)
|
|
668
|
+
DISPATCH(4)
|
|
669
|
+
DISPATCH(8)
|
|
670
|
+
default:
|
|
671
|
+
return fvec_L2sqr_ny_y_transposed_ref(
|
|
672
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
|
673
|
+
}
|
|
674
|
+
#undef DISPATCH
|
|
675
|
+
#else
|
|
676
|
+
// non-AVX2 case
|
|
677
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
678
|
+
#endif
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
#ifdef __AVX2__
|
|
682
|
+
|
|
683
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
|
684
|
+
float* distances_tmp_buffer,
|
|
685
|
+
const float* x,
|
|
686
|
+
const float* y,
|
|
687
|
+
size_t ny) {
|
|
688
|
+
// this implementation does not use distances_tmp_buffer.
|
|
689
|
+
|
|
690
|
+
// current index being processed
|
|
691
|
+
size_t i = 0;
|
|
692
|
+
|
|
693
|
+
// min distance and the index of the closest vector so far
|
|
694
|
+
float current_min_distance = HUGE_VALF;
|
|
695
|
+
size_t current_min_index = 0;
|
|
696
|
+
|
|
697
|
+
// process 8 D2-vectors per loop.
|
|
698
|
+
const size_t ny8 = ny / 8;
|
|
699
|
+
if (ny8 > 0) {
|
|
700
|
+
_mm_prefetch(y, _MM_HINT_T0);
|
|
701
|
+
_mm_prefetch(y + 16, _MM_HINT_T0);
|
|
702
|
+
|
|
703
|
+
// track min distance and the closest vector independently
|
|
704
|
+
// for each of 8 AVX2 components.
|
|
705
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
706
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
707
|
+
|
|
708
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
709
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
710
|
+
|
|
711
|
+
// 1 value per register
|
|
712
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
713
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
714
|
+
|
|
715
|
+
for (; i < ny8 * 8; i += 8) {
|
|
716
|
+
_mm_prefetch(y + 32, _MM_HINT_T0);
|
|
717
|
+
|
|
718
|
+
__m256 v0;
|
|
719
|
+
__m256 v1;
|
|
720
|
+
|
|
721
|
+
transpose_8x2(
|
|
722
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
723
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
724
|
+
v0,
|
|
725
|
+
v1);
|
|
726
|
+
|
|
727
|
+
// compute differences
|
|
728
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
729
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
730
|
+
|
|
731
|
+
// compute squares of differences
|
|
732
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
733
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
734
|
+
|
|
735
|
+
// compare the new distances to the min distances
|
|
736
|
+
// for each of 8 AVX2 components.
|
|
737
|
+
__m256 comparison =
|
|
738
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
739
|
+
|
|
740
|
+
// update min distances and indices with closest vectors if needed.
|
|
741
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
742
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
743
|
+
_mm256_castsi256_ps(current_indices),
|
|
744
|
+
_mm256_castsi256_ps(min_indices),
|
|
745
|
+
comparison));
|
|
746
|
+
|
|
747
|
+
// update current indices values. Basically, +8 to each of the
|
|
748
|
+
// 8 AVX2 components.
|
|
749
|
+
current_indices =
|
|
750
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
751
|
+
|
|
752
|
+
// scroll y forward (8 vectors 2 DIM each).
|
|
753
|
+
y += 16;
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
// dump values and find the minimum distance / minimum index
|
|
757
|
+
float min_distances_scalar[8];
|
|
758
|
+
uint32_t min_indices_scalar[8];
|
|
759
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
760
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
761
|
+
|
|
762
|
+
for (size_t j = 0; j < 8; j++) {
|
|
763
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
764
|
+
current_min_distance = min_distances_scalar[j];
|
|
765
|
+
current_min_index = min_indices_scalar[j];
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
if (i < ny) {
|
|
771
|
+
// process leftovers.
|
|
772
|
+
// the following code is not optimal, but it is rarely invoked.
|
|
773
|
+
float x0 = x[0];
|
|
774
|
+
float x1 = x[1];
|
|
775
|
+
|
|
776
|
+
for (; i < ny; i++) {
|
|
777
|
+
float sub0 = x0 - y[0];
|
|
778
|
+
float sub1 = x1 - y[1];
|
|
779
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
780
|
+
|
|
781
|
+
y += 2;
|
|
782
|
+
|
|
783
|
+
if (current_min_distance > distance) {
|
|
784
|
+
current_min_distance = distance;
|
|
785
|
+
current_min_index = i;
|
|
786
|
+
}
|
|
787
|
+
}
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
return current_min_index;
|
|
791
|
+
}
|
|
792
|
+
|
|
586
793
|
size_t fvec_L2sqr_ny_nearest_D4(
|
|
587
794
|
float* distances_tmp_buffer,
|
|
588
795
|
const float* x,
|
|
@@ -609,38 +816,27 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
|
609
816
|
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
610
817
|
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
611
818
|
|
|
612
|
-
//
|
|
613
|
-
_mm_prefetch(y, _MM_HINT_NTA);
|
|
614
|
-
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
615
|
-
|
|
616
|
-
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
819
|
+
// 1 value per register
|
|
617
820
|
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
618
|
-
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
619
821
|
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
620
|
-
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
621
822
|
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
622
|
-
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
623
823
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
624
824
|
|
|
625
|
-
const __m256i indices0 =
|
|
626
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
627
|
-
|
|
628
825
|
for (; i < ny8 * 8; i += 8) {
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
826
|
+
__m256 v0;
|
|
827
|
+
__m256 v1;
|
|
828
|
+
__m256 v2;
|
|
829
|
+
__m256 v3;
|
|
830
|
+
|
|
831
|
+
transpose_8x4(
|
|
832
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
833
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
834
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
835
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
836
|
+
v0,
|
|
837
|
+
v1,
|
|
838
|
+
v2,
|
|
839
|
+
v3);
|
|
644
840
|
|
|
645
841
|
// compute differences
|
|
646
842
|
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
@@ -654,24 +850,13 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
|
654
850
|
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
655
851
|
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
656
852
|
|
|
657
|
-
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
|
658
|
-
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
659
|
-
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
660
|
-
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
661
|
-
// ...
|
|
662
|
-
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
663
|
-
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
664
|
-
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
665
|
-
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
666
|
-
|
|
667
853
|
// compare the new distances to the min distances
|
|
668
854
|
// for each of 8 AVX2 components.
|
|
669
855
|
__m256 comparison =
|
|
670
856
|
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
671
857
|
|
|
672
858
|
// update min distances and indices with closest vectors if needed.
|
|
673
|
-
min_distances =
|
|
674
|
-
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
859
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
675
860
|
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
676
861
|
_mm256_castsi256_ps(current_indices),
|
|
677
862
|
_mm256_castsi256_ps(min_indices),
|
|
@@ -721,7 +906,168 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
|
721
906
|
|
|
722
907
|
return current_min_index;
|
|
723
908
|
}
|
|
909
|
+
|
|
910
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
|
911
|
+
float* distances_tmp_buffer,
|
|
912
|
+
const float* x,
|
|
913
|
+
const float* y,
|
|
914
|
+
size_t ny) {
|
|
915
|
+
// this implementation does not use distances_tmp_buffer.
|
|
916
|
+
|
|
917
|
+
// current index being processed
|
|
918
|
+
size_t i = 0;
|
|
919
|
+
|
|
920
|
+
// min distance and the index of the closest vector so far
|
|
921
|
+
float current_min_distance = HUGE_VALF;
|
|
922
|
+
size_t current_min_index = 0;
|
|
923
|
+
|
|
924
|
+
// process 8 D8-vectors per loop.
|
|
925
|
+
const size_t ny8 = ny / 8;
|
|
926
|
+
if (ny8 > 0) {
|
|
927
|
+
// track min distance and the closest vector independently
|
|
928
|
+
// for each of 8 AVX2 components.
|
|
929
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
930
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
931
|
+
|
|
932
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
933
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
934
|
+
|
|
935
|
+
// 1 value per register
|
|
936
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
937
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
938
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
939
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
940
|
+
|
|
941
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
942
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
943
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
944
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
945
|
+
|
|
946
|
+
for (; i < ny8 * 8; i += 8) {
|
|
947
|
+
__m256 v0;
|
|
948
|
+
__m256 v1;
|
|
949
|
+
__m256 v2;
|
|
950
|
+
__m256 v3;
|
|
951
|
+
__m256 v4;
|
|
952
|
+
__m256 v5;
|
|
953
|
+
__m256 v6;
|
|
954
|
+
__m256 v7;
|
|
955
|
+
|
|
956
|
+
transpose_8x8(
|
|
957
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
958
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
959
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
960
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
961
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
962
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
963
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
964
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
965
|
+
v0,
|
|
966
|
+
v1,
|
|
967
|
+
v2,
|
|
968
|
+
v3,
|
|
969
|
+
v4,
|
|
970
|
+
v5,
|
|
971
|
+
v6,
|
|
972
|
+
v7);
|
|
973
|
+
|
|
974
|
+
// compute differences
|
|
975
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
976
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
977
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
978
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
979
|
+
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
980
|
+
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
981
|
+
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
982
|
+
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
983
|
+
|
|
984
|
+
// compute squares of differences
|
|
985
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
986
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
987
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
988
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
989
|
+
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
990
|
+
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
991
|
+
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
992
|
+
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
993
|
+
|
|
994
|
+
// compare the new distances to the min distances
|
|
995
|
+
// for each of 8 AVX2 components.
|
|
996
|
+
__m256 comparison =
|
|
997
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
998
|
+
|
|
999
|
+
// update min distances and indices with closest vectors if needed.
|
|
1000
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
1001
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
1002
|
+
_mm256_castsi256_ps(current_indices),
|
|
1003
|
+
_mm256_castsi256_ps(min_indices),
|
|
1004
|
+
comparison));
|
|
1005
|
+
|
|
1006
|
+
// update current indices values. Basically, +8 to each of the
|
|
1007
|
+
// 8 AVX2 components.
|
|
1008
|
+
current_indices =
|
|
1009
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
1010
|
+
|
|
1011
|
+
// scroll y forward (8 vectors 8 DIM each).
|
|
1012
|
+
y += 64;
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
// dump values and find the minimum distance / minimum index
|
|
1016
|
+
float min_distances_scalar[8];
|
|
1017
|
+
uint32_t min_indices_scalar[8];
|
|
1018
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
1019
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
1020
|
+
|
|
1021
|
+
for (size_t j = 0; j < 8; j++) {
|
|
1022
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
1023
|
+
current_min_distance = min_distances_scalar[j];
|
|
1024
|
+
current_min_index = min_indices_scalar[j];
|
|
1025
|
+
}
|
|
1026
|
+
}
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
if (i < ny) {
|
|
1030
|
+
// process leftovers
|
|
1031
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1032
|
+
|
|
1033
|
+
for (; i < ny; i++) {
|
|
1034
|
+
__m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
|
|
1035
|
+
__m256 accu = _mm256_mul_ps(sub, sub);
|
|
1036
|
+
y += 8;
|
|
1037
|
+
|
|
1038
|
+
// horitontal sum
|
|
1039
|
+
const __m256 h0 = _mm256_hadd_ps(accu, accu);
|
|
1040
|
+
const __m256 h1 = _mm256_hadd_ps(h0, h0);
|
|
1041
|
+
|
|
1042
|
+
// extract high and low __m128 regs from __m256
|
|
1043
|
+
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
|
|
1044
|
+
const __m128 h3 = _mm256_castps256_ps128(h1);
|
|
1045
|
+
|
|
1046
|
+
// get a final hsum into all 4 regs
|
|
1047
|
+
const __m128 h4 = _mm_add_ss(h2, h3);
|
|
1048
|
+
|
|
1049
|
+
// extract f[0] from __m128
|
|
1050
|
+
const float distance = _mm_cvtss_f32(h4);
|
|
1051
|
+
|
|
1052
|
+
if (current_min_distance > distance) {
|
|
1053
|
+
current_min_distance = distance;
|
|
1054
|
+
current_min_index = i;
|
|
1055
|
+
}
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
return current_min_index;
|
|
1060
|
+
}
|
|
1061
|
+
|
|
724
1062
|
#else
|
|
1063
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
|
1064
|
+
float* distances_tmp_buffer,
|
|
1065
|
+
const float* x,
|
|
1066
|
+
const float* y,
|
|
1067
|
+
size_t ny) {
|
|
1068
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
|
|
1069
|
+
}
|
|
1070
|
+
|
|
725
1071
|
size_t fvec_L2sqr_ny_nearest_D4(
|
|
726
1072
|
float* distances_tmp_buffer,
|
|
727
1073
|
const float* x,
|
|
@@ -729,6 +1075,14 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
|
729
1075
|
size_t ny) {
|
|
730
1076
|
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
|
731
1077
|
}
|
|
1078
|
+
|
|
1079
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
|
1080
|
+
float* distances_tmp_buffer,
|
|
1081
|
+
const float* x,
|
|
1082
|
+
const float* y,
|
|
1083
|
+
size_t ny) {
|
|
1084
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
|
|
1085
|
+
}
|
|
732
1086
|
#endif
|
|
733
1087
|
|
|
734
1088
|
size_t fvec_L2sqr_ny_nearest(
|
|
@@ -743,7 +1097,9 @@ size_t fvec_L2sqr_ny_nearest(
|
|
|
743
1097
|
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
|
744
1098
|
|
|
745
1099
|
switch (d) {
|
|
1100
|
+
DISPATCH(2)
|
|
746
1101
|
DISPATCH(4)
|
|
1102
|
+
DISPATCH(8)
|
|
747
1103
|
default:
|
|
748
1104
|
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
749
1105
|
}
|
|
@@ -919,79 +1275,6 @@ static inline __m256 masked_read_8(int d, const float* x) {
|
|
|
919
1275
|
}
|
|
920
1276
|
}
|
|
921
1277
|
|
|
922
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
923
|
-
__m256 msum1 = _mm256_setzero_ps();
|
|
924
|
-
|
|
925
|
-
while (d >= 8) {
|
|
926
|
-
__m256 mx = _mm256_loadu_ps(x);
|
|
927
|
-
x += 8;
|
|
928
|
-
__m256 my = _mm256_loadu_ps(y);
|
|
929
|
-
y += 8;
|
|
930
|
-
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
|
|
931
|
-
d -= 8;
|
|
932
|
-
}
|
|
933
|
-
|
|
934
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
935
|
-
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
936
|
-
|
|
937
|
-
if (d >= 4) {
|
|
938
|
-
__m128 mx = _mm_loadu_ps(x);
|
|
939
|
-
x += 4;
|
|
940
|
-
__m128 my = _mm_loadu_ps(y);
|
|
941
|
-
y += 4;
|
|
942
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
943
|
-
d -= 4;
|
|
944
|
-
}
|
|
945
|
-
|
|
946
|
-
if (d > 0) {
|
|
947
|
-
__m128 mx = masked_read(d, x);
|
|
948
|
-
__m128 my = masked_read(d, y);
|
|
949
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
950
|
-
}
|
|
951
|
-
|
|
952
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
953
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
954
|
-
return _mm_cvtss_f32(msum2);
|
|
955
|
-
}
|
|
956
|
-
|
|
957
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
958
|
-
__m256 msum1 = _mm256_setzero_ps();
|
|
959
|
-
|
|
960
|
-
while (d >= 8) {
|
|
961
|
-
__m256 mx = _mm256_loadu_ps(x);
|
|
962
|
-
x += 8;
|
|
963
|
-
__m256 my = _mm256_loadu_ps(y);
|
|
964
|
-
y += 8;
|
|
965
|
-
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
|
|
966
|
-
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
|
|
967
|
-
d -= 8;
|
|
968
|
-
}
|
|
969
|
-
|
|
970
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
971
|
-
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
972
|
-
|
|
973
|
-
if (d >= 4) {
|
|
974
|
-
__m128 mx = _mm_loadu_ps(x);
|
|
975
|
-
x += 4;
|
|
976
|
-
__m128 my = _mm_loadu_ps(y);
|
|
977
|
-
y += 4;
|
|
978
|
-
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
979
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
980
|
-
d -= 4;
|
|
981
|
-
}
|
|
982
|
-
|
|
983
|
-
if (d > 0) {
|
|
984
|
-
__m128 mx = masked_read(d, x);
|
|
985
|
-
__m128 my = masked_read(d, y);
|
|
986
|
-
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
987
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
988
|
-
}
|
|
989
|
-
|
|
990
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
991
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
992
|
-
return _mm_cvtss_f32(msum2);
|
|
993
|
-
}
|
|
994
|
-
|
|
995
1278
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
996
1279
|
__m256 msum1 = _mm256_setzero_ps();
|
|
997
1280
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
@@ -1082,113 +1365,8 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
|
1082
1365
|
return fvec_Linf_ref(x, y, d);
|
|
1083
1366
|
}
|
|
1084
1367
|
|
|
1085
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
1086
|
-
__m128 msum1 = _mm_setzero_ps();
|
|
1087
|
-
|
|
1088
|
-
while (d >= 4) {
|
|
1089
|
-
__m128 mx = _mm_loadu_ps(x);
|
|
1090
|
-
x += 4;
|
|
1091
|
-
__m128 my = _mm_loadu_ps(y);
|
|
1092
|
-
y += 4;
|
|
1093
|
-
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
1094
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
1095
|
-
d -= 4;
|
|
1096
|
-
}
|
|
1097
|
-
|
|
1098
|
-
if (d > 0) {
|
|
1099
|
-
// add the last 1, 2 or 3 values
|
|
1100
|
-
__m128 mx = masked_read(d, x);
|
|
1101
|
-
__m128 my = masked_read(d, y);
|
|
1102
|
-
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
1103
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
1104
|
-
}
|
|
1105
|
-
|
|
1106
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
1107
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
1108
|
-
return _mm_cvtss_f32(msum1);
|
|
1109
|
-
}
|
|
1110
|
-
|
|
1111
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
1112
|
-
__m128 mx, my;
|
|
1113
|
-
__m128 msum1 = _mm_setzero_ps();
|
|
1114
|
-
|
|
1115
|
-
while (d >= 4) {
|
|
1116
|
-
mx = _mm_loadu_ps(x);
|
|
1117
|
-
x += 4;
|
|
1118
|
-
my = _mm_loadu_ps(y);
|
|
1119
|
-
y += 4;
|
|
1120
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
|
|
1121
|
-
d -= 4;
|
|
1122
|
-
}
|
|
1123
|
-
|
|
1124
|
-
// add the last 1, 2, or 3 values
|
|
1125
|
-
mx = masked_read(d, x);
|
|
1126
|
-
my = masked_read(d, y);
|
|
1127
|
-
__m128 prod = _mm_mul_ps(mx, my);
|
|
1128
|
-
|
|
1129
|
-
msum1 = _mm_add_ps(msum1, prod);
|
|
1130
|
-
|
|
1131
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
1132
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
1133
|
-
return _mm_cvtss_f32(msum1);
|
|
1134
|
-
}
|
|
1135
|
-
|
|
1136
1368
|
#elif defined(__aarch64__)
|
|
1137
1369
|
|
|
1138
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
1139
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
|
1140
|
-
const size_t d_simd = d - (d & 3);
|
|
1141
|
-
size_t i;
|
|
1142
|
-
for (i = 0; i < d_simd; i += 4) {
|
|
1143
|
-
float32x4_t xi = vld1q_f32(x + i);
|
|
1144
|
-
float32x4_t yi = vld1q_f32(y + i);
|
|
1145
|
-
float32x4_t sq = vsubq_f32(xi, yi);
|
|
1146
|
-
accux4 = vfmaq_f32(accux4, sq, sq);
|
|
1147
|
-
}
|
|
1148
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
|
1149
|
-
for (; i < d; ++i) {
|
|
1150
|
-
float32_t xi = x[i];
|
|
1151
|
-
float32_t yi = y[i];
|
|
1152
|
-
float32_t sq = xi - yi;
|
|
1153
|
-
accux1 += sq * sq;
|
|
1154
|
-
}
|
|
1155
|
-
return accux1;
|
|
1156
|
-
}
|
|
1157
|
-
|
|
1158
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
1159
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
|
1160
|
-
const size_t d_simd = d - (d & 3);
|
|
1161
|
-
size_t i;
|
|
1162
|
-
for (i = 0; i < d_simd; i += 4) {
|
|
1163
|
-
float32x4_t xi = vld1q_f32(x + i);
|
|
1164
|
-
float32x4_t yi = vld1q_f32(y + i);
|
|
1165
|
-
accux4 = vfmaq_f32(accux4, xi, yi);
|
|
1166
|
-
}
|
|
1167
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
|
1168
|
-
for (; i < d; ++i) {
|
|
1169
|
-
float32_t xi = x[i];
|
|
1170
|
-
float32_t yi = y[i];
|
|
1171
|
-
accux1 += xi * yi;
|
|
1172
|
-
}
|
|
1173
|
-
return accux1;
|
|
1174
|
-
}
|
|
1175
|
-
|
|
1176
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
1177
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
|
1178
|
-
const size_t d_simd = d - (d & 3);
|
|
1179
|
-
size_t i;
|
|
1180
|
-
for (i = 0; i < d_simd; i += 4) {
|
|
1181
|
-
float32x4_t xi = vld1q_f32(x + i);
|
|
1182
|
-
accux4 = vfmaq_f32(accux4, xi, xi);
|
|
1183
|
-
}
|
|
1184
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
|
1185
|
-
for (; i < d; ++i) {
|
|
1186
|
-
float32_t xi = x[i];
|
|
1187
|
-
accux1 += xi * xi;
|
|
1188
|
-
}
|
|
1189
|
-
return accux1;
|
|
1190
|
-
}
|
|
1191
|
-
|
|
1192
1370
|
// not optimized for ARM
|
|
1193
1371
|
void fvec_L2sqr_ny(
|
|
1194
1372
|
float* dis,
|
|
@@ -1199,6 +1377,17 @@ void fvec_L2sqr_ny(
|
|
|
1199
1377
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
1200
1378
|
}
|
|
1201
1379
|
|
|
1380
|
+
void fvec_L2sqr_ny_transposed(
|
|
1381
|
+
float* dis,
|
|
1382
|
+
const float* x,
|
|
1383
|
+
const float* y,
|
|
1384
|
+
const float* y_sqlen,
|
|
1385
|
+
size_t d,
|
|
1386
|
+
size_t d_offset,
|
|
1387
|
+
size_t ny) {
|
|
1388
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1389
|
+
}
|
|
1390
|
+
|
|
1202
1391
|
size_t fvec_L2sqr_ny_nearest(
|
|
1203
1392
|
float* distances_tmp_buffer,
|
|
1204
1393
|
const float* x,
|
|
@@ -1240,10 +1429,6 @@ void fvec_inner_products_ny(
|
|
|
1240
1429
|
#else
|
|
1241
1430
|
// scalar implementation
|
|
1242
1431
|
|
|
1243
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
1244
|
-
return fvec_L2sqr_ref(x, y, d);
|
|
1245
|
-
}
|
|
1246
|
-
|
|
1247
1432
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1248
1433
|
return fvec_L1_ref(x, y, d);
|
|
1249
1434
|
}
|
|
@@ -1252,14 +1437,6 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
|
1252
1437
|
return fvec_Linf_ref(x, y, d);
|
|
1253
1438
|
}
|
|
1254
1439
|
|
|
1255
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
1256
|
-
return fvec_inner_product_ref(x, y, d);
|
|
1257
|
-
}
|
|
1258
|
-
|
|
1259
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
1260
|
-
return fvec_norm_L2sqr_ref(x, d);
|
|
1261
|
-
}
|
|
1262
|
-
|
|
1263
1440
|
void fvec_L2sqr_ny(
|
|
1264
1441
|
float* dis,
|
|
1265
1442
|
const float* x,
|
|
@@ -1269,6 +1446,17 @@ void fvec_L2sqr_ny(
|
|
|
1269
1446
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
1270
1447
|
}
|
|
1271
1448
|
|
|
1449
|
+
void fvec_L2sqr_ny_transposed(
|
|
1450
|
+
float* dis,
|
|
1451
|
+
const float* x,
|
|
1452
|
+
const float* y,
|
|
1453
|
+
const float* y_sqlen,
|
|
1454
|
+
size_t d,
|
|
1455
|
+
size_t d_offset,
|
|
1456
|
+
size_t ny) {
|
|
1457
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1458
|
+
}
|
|
1459
|
+
|
|
1272
1460
|
size_t fvec_L2sqr_ny_nearest(
|
|
1273
1461
|
float* distances_tmp_buffer,
|
|
1274
1462
|
const float* x,
|