faiss 0.2.6 → 0.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +29 -3
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -23,6 +23,10 @@
|
|
23
23
|
#include <immintrin.h>
|
24
24
|
#endif
|
25
25
|
|
26
|
+
#ifdef __AVX2__
|
27
|
+
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
28
|
+
#endif
|
29
|
+
|
26
30
|
#ifdef __aarch64__
|
27
31
|
#include <arm_neon.h>
|
28
32
|
#endif
|
@@ -56,16 +60,6 @@ namespace faiss {
|
|
56
60
|
* Reference implementations
|
57
61
|
*/
|
58
62
|
|
59
|
-
float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
|
60
|
-
size_t i;
|
61
|
-
float res = 0;
|
62
|
-
for (i = 0; i < d; i++) {
|
63
|
-
const float tmp = x[i] - y[i];
|
64
|
-
res += tmp * tmp;
|
65
|
-
}
|
66
|
-
return res;
|
67
|
-
}
|
68
|
-
|
69
63
|
float fvec_L1_ref(const float* x, const float* y, size_t d) {
|
70
64
|
size_t i;
|
71
65
|
float res = 0;
|
@@ -85,22 +79,6 @@ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
|
|
85
79
|
return res;
|
86
80
|
}
|
87
81
|
|
88
|
-
float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
|
89
|
-
size_t i;
|
90
|
-
float res = 0;
|
91
|
-
for (i = 0; i < d; i++)
|
92
|
-
res += x[i] * y[i];
|
93
|
-
return res;
|
94
|
-
}
|
95
|
-
|
96
|
-
float fvec_norm_L2sqr_ref(const float* x, size_t d) {
|
97
|
-
size_t i;
|
98
|
-
double res = 0;
|
99
|
-
for (i = 0; i < d; i++)
|
100
|
-
res += x[i] * x[i];
|
101
|
-
return res;
|
102
|
-
}
|
103
|
-
|
104
82
|
void fvec_L2sqr_ny_ref(
|
105
83
|
float* dis,
|
106
84
|
const float* x,
|
@@ -203,6 +181,48 @@ void fvec_inner_products_ny_ref(
|
|
203
181
|
}
|
204
182
|
}
|
205
183
|
|
184
|
+
/*********************************************************
|
185
|
+
* Autovectorized implementations
|
186
|
+
*/
|
187
|
+
|
188
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
189
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
190
|
+
float res = 0.F;
|
191
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
192
|
+
for (size_t i = 0; i != d; ++i) {
|
193
|
+
res += x[i] * y[i];
|
194
|
+
}
|
195
|
+
return res;
|
196
|
+
}
|
197
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
198
|
+
|
199
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
200
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
201
|
+
// the double in the _ref is suspected to be a typo. Some of the manual
|
202
|
+
// implementations this replaces used float.
|
203
|
+
float res = 0;
|
204
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
205
|
+
for (size_t i = 0; i != d; ++i) {
|
206
|
+
res += x[i] * x[i];
|
207
|
+
}
|
208
|
+
|
209
|
+
return res;
|
210
|
+
}
|
211
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
212
|
+
|
213
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
214
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
215
|
+
size_t i;
|
216
|
+
float res = 0;
|
217
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
218
|
+
for (i = 0; i < d; i++) {
|
219
|
+
const float tmp = x[i] - y[i];
|
220
|
+
res += tmp * tmp;
|
221
|
+
}
|
222
|
+
return res;
|
223
|
+
}
|
224
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
225
|
+
|
206
226
|
/*********************************************************
|
207
227
|
* SSE and AVX implementations
|
208
228
|
*/
|
@@ -225,25 +245,6 @@ static inline __m128 masked_read(int d, const float* x) {
|
|
225
245
|
// cannot use AVX2 _mm_mask_set1_epi32
|
226
246
|
}
|
227
247
|
|
228
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
229
|
-
__m128 mx;
|
230
|
-
__m128 msum1 = _mm_setzero_ps();
|
231
|
-
|
232
|
-
while (d >= 4) {
|
233
|
-
mx = _mm_loadu_ps(x);
|
234
|
-
x += 4;
|
235
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
236
|
-
d -= 4;
|
237
|
-
}
|
238
|
-
|
239
|
-
mx = masked_read(d, x);
|
240
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
241
|
-
|
242
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
243
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
244
|
-
return _mm_cvtss_f32(msum1);
|
245
|
-
}
|
246
|
-
|
247
248
|
namespace {
|
248
249
|
|
249
250
|
/// Function that does a component-wise operation between x and y
|
@@ -354,25 +355,25 @@ void fvec_op_ny_D4<ElementOpIP>(
|
|
354
355
|
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
355
356
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
356
357
|
|
357
|
-
const __m256i indices0 =
|
358
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
359
|
-
|
360
358
|
for (i = 0; i < ny8 * 8; i += 8) {
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
359
|
+
// load 8x4 matrix and transpose it in registers.
|
360
|
+
// the typical bottleneck is memory access, so
|
361
|
+
// let's trade instructions for the bandwidth.
|
362
|
+
|
363
|
+
__m256 v0;
|
364
|
+
__m256 v1;
|
365
|
+
__m256 v2;
|
366
|
+
__m256 v3;
|
367
|
+
|
368
|
+
transpose_8x4(
|
369
|
+
_mm256_loadu_ps(y + 0 * 8),
|
370
|
+
_mm256_loadu_ps(y + 1 * 8),
|
371
|
+
_mm256_loadu_ps(y + 2 * 8),
|
372
|
+
_mm256_loadu_ps(y + 3 * 8),
|
373
|
+
v0,
|
374
|
+
v1,
|
375
|
+
v2,
|
376
|
+
v3);
|
376
377
|
|
377
378
|
// compute distances
|
378
379
|
__m256 distances = _mm256_mul_ps(m0, v0);
|
@@ -380,15 +381,7 @@ void fvec_op_ny_D4<ElementOpIP>(
|
|
380
381
|
distances = _mm256_fmadd_ps(m2, v2, distances);
|
381
382
|
distances = _mm256_fmadd_ps(m3, v3, distances);
|
382
383
|
|
383
|
-
//
|
384
|
-
// (x[1] * y[(i * 8 + 0) * 4 + 1]) +
|
385
|
-
// (x[2] * y[(i * 8 + 0) * 4 + 2]) +
|
386
|
-
// (x[3] * y[(i * 8 + 0) * 4 + 3])
|
387
|
-
// ...
|
388
|
-
// distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
|
389
|
-
// (x[1] * y[(i * 8 + 7) * 4 + 1]) +
|
390
|
-
// (x[2] * y[(i * 8 + 7) * 4 + 2]) +
|
391
|
-
// (x[3] * y[(i * 8 + 7) * 4 + 3])
|
384
|
+
// store
|
392
385
|
_mm256_storeu_ps(dis + i, distances);
|
393
386
|
|
394
387
|
y += 32;
|
@@ -432,25 +425,25 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
432
425
|
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
433
426
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
434
427
|
|
435
|
-
const __m256i indices0 =
|
436
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
437
|
-
|
438
428
|
for (i = 0; i < ny8 * 8; i += 8) {
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
429
|
+
// load 8x4 matrix and transpose it in registers.
|
430
|
+
// the typical bottleneck is memory access, so
|
431
|
+
// let's trade instructions for the bandwidth.
|
432
|
+
|
433
|
+
__m256 v0;
|
434
|
+
__m256 v1;
|
435
|
+
__m256 v2;
|
436
|
+
__m256 v3;
|
437
|
+
|
438
|
+
transpose_8x4(
|
439
|
+
_mm256_loadu_ps(y + 0 * 8),
|
440
|
+
_mm256_loadu_ps(y + 1 * 8),
|
441
|
+
_mm256_loadu_ps(y + 2 * 8),
|
442
|
+
_mm256_loadu_ps(y + 3 * 8),
|
443
|
+
v0,
|
444
|
+
v1,
|
445
|
+
v2,
|
446
|
+
v3);
|
454
447
|
|
455
448
|
// compute differences
|
456
449
|
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
@@ -464,15 +457,7 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
464
457
|
distances = _mm256_fmadd_ps(d2, d2, distances);
|
465
458
|
distances = _mm256_fmadd_ps(d3, d3, distances);
|
466
459
|
|
467
|
-
//
|
468
|
-
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
469
|
-
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
470
|
-
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
471
|
-
// ...
|
472
|
-
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
473
|
-
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
474
|
-
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
475
|
-
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
460
|
+
// store
|
476
461
|
_mm256_storeu_ps(dis + i, distances);
|
477
462
|
|
478
463
|
y += 32;
|
@@ -583,6 +568,228 @@ void fvec_inner_products_ny(
|
|
583
568
|
}
|
584
569
|
|
585
570
|
#ifdef __AVX2__
|
571
|
+
template <size_t DIM>
|
572
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
573
|
+
float* distances,
|
574
|
+
const float* x,
|
575
|
+
const float* y,
|
576
|
+
const float* y_sqlen,
|
577
|
+
const size_t d_offset,
|
578
|
+
size_t ny) {
|
579
|
+
// current index being processed
|
580
|
+
size_t i = 0;
|
581
|
+
|
582
|
+
// squared length of x
|
583
|
+
float x_sqlen = 0;
|
584
|
+
;
|
585
|
+
for (size_t j = 0; j < DIM; j++) {
|
586
|
+
x_sqlen += x[j] * x[j];
|
587
|
+
}
|
588
|
+
|
589
|
+
// process 8 vectors per loop.
|
590
|
+
const size_t ny8 = ny / 8;
|
591
|
+
|
592
|
+
if (ny8 > 0) {
|
593
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
594
|
+
__m256 m[DIM];
|
595
|
+
for (size_t j = 0; j < DIM; j++) {
|
596
|
+
m[j] = _mm256_set1_ps(x[j]);
|
597
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
598
|
+
}
|
599
|
+
|
600
|
+
__m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
|
601
|
+
|
602
|
+
for (; i < ny8 * 8; i += 8) {
|
603
|
+
// collect dim 0 for 8 D4-vectors.
|
604
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
605
|
+
|
606
|
+
// compute dot products
|
607
|
+
// this is x^2 - 2x[0]*y[0]
|
608
|
+
__m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
|
609
|
+
|
610
|
+
for (size_t j = 1; j < DIM; j++) {
|
611
|
+
// collect dim j for 8 D4-vectors.
|
612
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
613
|
+
dp = _mm256_fnmadd_ps(m[j], vj, dp);
|
614
|
+
}
|
615
|
+
|
616
|
+
// we've got x^2 - (2x, y) at this point
|
617
|
+
|
618
|
+
// y^2 - (2x, y) + x^2
|
619
|
+
__m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
|
620
|
+
|
621
|
+
_mm256_storeu_ps(distances + i, distances_v);
|
622
|
+
|
623
|
+
// scroll y and y_sqlen forward.
|
624
|
+
y += 8;
|
625
|
+
y_sqlen += 8;
|
626
|
+
}
|
627
|
+
}
|
628
|
+
|
629
|
+
if (i < ny) {
|
630
|
+
// process leftovers
|
631
|
+
for (; i < ny; i++) {
|
632
|
+
float dp = 0;
|
633
|
+
for (size_t j = 0; j < DIM; j++) {
|
634
|
+
dp += x[j] * y[j * d_offset];
|
635
|
+
}
|
636
|
+
|
637
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
638
|
+
// lowest distance.
|
639
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
640
|
+
distances[i] = distance;
|
641
|
+
|
642
|
+
y += 1;
|
643
|
+
y_sqlen += 1;
|
644
|
+
}
|
645
|
+
}
|
646
|
+
}
|
647
|
+
#endif
|
648
|
+
|
649
|
+
void fvec_L2sqr_ny_transposed(
|
650
|
+
float* dis,
|
651
|
+
const float* x,
|
652
|
+
const float* y,
|
653
|
+
const float* y_sqlen,
|
654
|
+
size_t d,
|
655
|
+
size_t d_offset,
|
656
|
+
size_t ny) {
|
657
|
+
// optimized for a few special cases
|
658
|
+
|
659
|
+
#ifdef __AVX2__
|
660
|
+
#define DISPATCH(dval) \
|
661
|
+
case dval: \
|
662
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
663
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
664
|
+
|
665
|
+
switch (d) {
|
666
|
+
DISPATCH(1)
|
667
|
+
DISPATCH(2)
|
668
|
+
DISPATCH(4)
|
669
|
+
DISPATCH(8)
|
670
|
+
default:
|
671
|
+
return fvec_L2sqr_ny_y_transposed_ref(
|
672
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
673
|
+
}
|
674
|
+
#undef DISPATCH
|
675
|
+
#else
|
676
|
+
// non-AVX2 case
|
677
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
678
|
+
#endif
|
679
|
+
}
|
680
|
+
|
681
|
+
#ifdef __AVX2__
|
682
|
+
|
683
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
684
|
+
float* distances_tmp_buffer,
|
685
|
+
const float* x,
|
686
|
+
const float* y,
|
687
|
+
size_t ny) {
|
688
|
+
// this implementation does not use distances_tmp_buffer.
|
689
|
+
|
690
|
+
// current index being processed
|
691
|
+
size_t i = 0;
|
692
|
+
|
693
|
+
// min distance and the index of the closest vector so far
|
694
|
+
float current_min_distance = HUGE_VALF;
|
695
|
+
size_t current_min_index = 0;
|
696
|
+
|
697
|
+
// process 8 D2-vectors per loop.
|
698
|
+
const size_t ny8 = ny / 8;
|
699
|
+
if (ny8 > 0) {
|
700
|
+
_mm_prefetch(y, _MM_HINT_T0);
|
701
|
+
_mm_prefetch(y + 16, _MM_HINT_T0);
|
702
|
+
|
703
|
+
// track min distance and the closest vector independently
|
704
|
+
// for each of 8 AVX2 components.
|
705
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
706
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
707
|
+
|
708
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
709
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
710
|
+
|
711
|
+
// 1 value per register
|
712
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
713
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
714
|
+
|
715
|
+
for (; i < ny8 * 8; i += 8) {
|
716
|
+
_mm_prefetch(y + 32, _MM_HINT_T0);
|
717
|
+
|
718
|
+
__m256 v0;
|
719
|
+
__m256 v1;
|
720
|
+
|
721
|
+
transpose_8x2(
|
722
|
+
_mm256_loadu_ps(y + 0 * 8),
|
723
|
+
_mm256_loadu_ps(y + 1 * 8),
|
724
|
+
v0,
|
725
|
+
v1);
|
726
|
+
|
727
|
+
// compute differences
|
728
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
729
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
730
|
+
|
731
|
+
// compute squares of differences
|
732
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
733
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
734
|
+
|
735
|
+
// compare the new distances to the min distances
|
736
|
+
// for each of 8 AVX2 components.
|
737
|
+
__m256 comparison =
|
738
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
739
|
+
|
740
|
+
// update min distances and indices with closest vectors if needed.
|
741
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
742
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
743
|
+
_mm256_castsi256_ps(current_indices),
|
744
|
+
_mm256_castsi256_ps(min_indices),
|
745
|
+
comparison));
|
746
|
+
|
747
|
+
// update current indices values. Basically, +8 to each of the
|
748
|
+
// 8 AVX2 components.
|
749
|
+
current_indices =
|
750
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
751
|
+
|
752
|
+
// scroll y forward (8 vectors 2 DIM each).
|
753
|
+
y += 16;
|
754
|
+
}
|
755
|
+
|
756
|
+
// dump values and find the minimum distance / minimum index
|
757
|
+
float min_distances_scalar[8];
|
758
|
+
uint32_t min_indices_scalar[8];
|
759
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
760
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
761
|
+
|
762
|
+
for (size_t j = 0; j < 8; j++) {
|
763
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
764
|
+
current_min_distance = min_distances_scalar[j];
|
765
|
+
current_min_index = min_indices_scalar[j];
|
766
|
+
}
|
767
|
+
}
|
768
|
+
}
|
769
|
+
|
770
|
+
if (i < ny) {
|
771
|
+
// process leftovers.
|
772
|
+
// the following code is not optimal, but it is rarely invoked.
|
773
|
+
float x0 = x[0];
|
774
|
+
float x1 = x[1];
|
775
|
+
|
776
|
+
for (; i < ny; i++) {
|
777
|
+
float sub0 = x0 - y[0];
|
778
|
+
float sub1 = x1 - y[1];
|
779
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
780
|
+
|
781
|
+
y += 2;
|
782
|
+
|
783
|
+
if (current_min_distance > distance) {
|
784
|
+
current_min_distance = distance;
|
785
|
+
current_min_index = i;
|
786
|
+
}
|
787
|
+
}
|
788
|
+
}
|
789
|
+
|
790
|
+
return current_min_index;
|
791
|
+
}
|
792
|
+
|
586
793
|
size_t fvec_L2sqr_ny_nearest_D4(
|
587
794
|
float* distances_tmp_buffer,
|
588
795
|
const float* x,
|
@@ -609,38 +816,27 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
609
816
|
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
610
817
|
const __m256i indices_increment = _mm256_set1_epi32(8);
|
611
818
|
|
612
|
-
//
|
613
|
-
_mm_prefetch(y, _MM_HINT_NTA);
|
614
|
-
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
615
|
-
|
616
|
-
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
819
|
+
// 1 value per register
|
617
820
|
const __m256 m0 = _mm256_set1_ps(x[0]);
|
618
|
-
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
619
821
|
const __m256 m1 = _mm256_set1_ps(x[1]);
|
620
|
-
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
621
822
|
const __m256 m2 = _mm256_set1_ps(x[2]);
|
622
|
-
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
623
823
|
const __m256 m3 = _mm256_set1_ps(x[3]);
|
624
824
|
|
625
|
-
const __m256i indices0 =
|
626
|
-
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
627
|
-
|
628
825
|
for (; i < ny8 * 8; i += 8) {
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
826
|
+
__m256 v0;
|
827
|
+
__m256 v1;
|
828
|
+
__m256 v2;
|
829
|
+
__m256 v3;
|
830
|
+
|
831
|
+
transpose_8x4(
|
832
|
+
_mm256_loadu_ps(y + 0 * 8),
|
833
|
+
_mm256_loadu_ps(y + 1 * 8),
|
834
|
+
_mm256_loadu_ps(y + 2 * 8),
|
835
|
+
_mm256_loadu_ps(y + 3 * 8),
|
836
|
+
v0,
|
837
|
+
v1,
|
838
|
+
v2,
|
839
|
+
v3);
|
644
840
|
|
645
841
|
// compute differences
|
646
842
|
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
@@ -654,24 +850,13 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
654
850
|
distances = _mm256_fmadd_ps(d2, d2, distances);
|
655
851
|
distances = _mm256_fmadd_ps(d3, d3, distances);
|
656
852
|
|
657
|
-
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
658
|
-
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
659
|
-
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
660
|
-
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
661
|
-
// ...
|
662
|
-
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
663
|
-
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
664
|
-
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
665
|
-
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
666
|
-
|
667
853
|
// compare the new distances to the min distances
|
668
854
|
// for each of 8 AVX2 components.
|
669
855
|
__m256 comparison =
|
670
856
|
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
671
857
|
|
672
858
|
// update min distances and indices with closest vectors if needed.
|
673
|
-
min_distances =
|
674
|
-
_mm256_blendv_ps(distances, min_distances, comparison);
|
859
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
675
860
|
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
676
861
|
_mm256_castsi256_ps(current_indices),
|
677
862
|
_mm256_castsi256_ps(min_indices),
|
@@ -721,7 +906,168 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
721
906
|
|
722
907
|
return current_min_index;
|
723
908
|
}
|
909
|
+
|
910
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
911
|
+
float* distances_tmp_buffer,
|
912
|
+
const float* x,
|
913
|
+
const float* y,
|
914
|
+
size_t ny) {
|
915
|
+
// this implementation does not use distances_tmp_buffer.
|
916
|
+
|
917
|
+
// current index being processed
|
918
|
+
size_t i = 0;
|
919
|
+
|
920
|
+
// min distance and the index of the closest vector so far
|
921
|
+
float current_min_distance = HUGE_VALF;
|
922
|
+
size_t current_min_index = 0;
|
923
|
+
|
924
|
+
// process 8 D8-vectors per loop.
|
925
|
+
const size_t ny8 = ny / 8;
|
926
|
+
if (ny8 > 0) {
|
927
|
+
// track min distance and the closest vector independently
|
928
|
+
// for each of 8 AVX2 components.
|
929
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
930
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
931
|
+
|
932
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
933
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
934
|
+
|
935
|
+
// 1 value per register
|
936
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
937
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
938
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
939
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
940
|
+
|
941
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
942
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
943
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
944
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
945
|
+
|
946
|
+
for (; i < ny8 * 8; i += 8) {
|
947
|
+
__m256 v0;
|
948
|
+
__m256 v1;
|
949
|
+
__m256 v2;
|
950
|
+
__m256 v3;
|
951
|
+
__m256 v4;
|
952
|
+
__m256 v5;
|
953
|
+
__m256 v6;
|
954
|
+
__m256 v7;
|
955
|
+
|
956
|
+
transpose_8x8(
|
957
|
+
_mm256_loadu_ps(y + 0 * 8),
|
958
|
+
_mm256_loadu_ps(y + 1 * 8),
|
959
|
+
_mm256_loadu_ps(y + 2 * 8),
|
960
|
+
_mm256_loadu_ps(y + 3 * 8),
|
961
|
+
_mm256_loadu_ps(y + 4 * 8),
|
962
|
+
_mm256_loadu_ps(y + 5 * 8),
|
963
|
+
_mm256_loadu_ps(y + 6 * 8),
|
964
|
+
_mm256_loadu_ps(y + 7 * 8),
|
965
|
+
v0,
|
966
|
+
v1,
|
967
|
+
v2,
|
968
|
+
v3,
|
969
|
+
v4,
|
970
|
+
v5,
|
971
|
+
v6,
|
972
|
+
v7);
|
973
|
+
|
974
|
+
// compute differences
|
975
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
976
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
977
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
978
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
979
|
+
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
980
|
+
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
981
|
+
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
982
|
+
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
983
|
+
|
984
|
+
// compute squares of differences
|
985
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
986
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
987
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
988
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
989
|
+
distances = _mm256_fmadd_ps(d4, d4, distances);
|
990
|
+
distances = _mm256_fmadd_ps(d5, d5, distances);
|
991
|
+
distances = _mm256_fmadd_ps(d6, d6, distances);
|
992
|
+
distances = _mm256_fmadd_ps(d7, d7, distances);
|
993
|
+
|
994
|
+
// compare the new distances to the min distances
|
995
|
+
// for each of 8 AVX2 components.
|
996
|
+
__m256 comparison =
|
997
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
998
|
+
|
999
|
+
// update min distances and indices with closest vectors if needed.
|
1000
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
1001
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
1002
|
+
_mm256_castsi256_ps(current_indices),
|
1003
|
+
_mm256_castsi256_ps(min_indices),
|
1004
|
+
comparison));
|
1005
|
+
|
1006
|
+
// update current indices values. Basically, +8 to each of the
|
1007
|
+
// 8 AVX2 components.
|
1008
|
+
current_indices =
|
1009
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
1010
|
+
|
1011
|
+
// scroll y forward (8 vectors 8 DIM each).
|
1012
|
+
y += 64;
|
1013
|
+
}
|
1014
|
+
|
1015
|
+
// dump values and find the minimum distance / minimum index
|
1016
|
+
float min_distances_scalar[8];
|
1017
|
+
uint32_t min_indices_scalar[8];
|
1018
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
1019
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
1020
|
+
|
1021
|
+
for (size_t j = 0; j < 8; j++) {
|
1022
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
1023
|
+
current_min_distance = min_distances_scalar[j];
|
1024
|
+
current_min_index = min_indices_scalar[j];
|
1025
|
+
}
|
1026
|
+
}
|
1027
|
+
}
|
1028
|
+
|
1029
|
+
if (i < ny) {
|
1030
|
+
// process leftovers
|
1031
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
1032
|
+
|
1033
|
+
for (; i < ny; i++) {
|
1034
|
+
__m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
|
1035
|
+
__m256 accu = _mm256_mul_ps(sub, sub);
|
1036
|
+
y += 8;
|
1037
|
+
|
1038
|
+
// horitontal sum
|
1039
|
+
const __m256 h0 = _mm256_hadd_ps(accu, accu);
|
1040
|
+
const __m256 h1 = _mm256_hadd_ps(h0, h0);
|
1041
|
+
|
1042
|
+
// extract high and low __m128 regs from __m256
|
1043
|
+
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
|
1044
|
+
const __m128 h3 = _mm256_castps256_ps128(h1);
|
1045
|
+
|
1046
|
+
// get a final hsum into all 4 regs
|
1047
|
+
const __m128 h4 = _mm_add_ss(h2, h3);
|
1048
|
+
|
1049
|
+
// extract f[0] from __m128
|
1050
|
+
const float distance = _mm_cvtss_f32(h4);
|
1051
|
+
|
1052
|
+
if (current_min_distance > distance) {
|
1053
|
+
current_min_distance = distance;
|
1054
|
+
current_min_index = i;
|
1055
|
+
}
|
1056
|
+
}
|
1057
|
+
}
|
1058
|
+
|
1059
|
+
return current_min_index;
|
1060
|
+
}
|
1061
|
+
|
724
1062
|
#else
|
1063
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
1064
|
+
float* distances_tmp_buffer,
|
1065
|
+
const float* x,
|
1066
|
+
const float* y,
|
1067
|
+
size_t ny) {
|
1068
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny);
|
1069
|
+
}
|
1070
|
+
|
725
1071
|
size_t fvec_L2sqr_ny_nearest_D4(
|
726
1072
|
float* distances_tmp_buffer,
|
727
1073
|
const float* x,
|
@@ -729,6 +1075,14 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
729
1075
|
size_t ny) {
|
730
1076
|
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
731
1077
|
}
|
1078
|
+
|
1079
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
1080
|
+
float* distances_tmp_buffer,
|
1081
|
+
const float* x,
|
1082
|
+
const float* y,
|
1083
|
+
size_t ny) {
|
1084
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny);
|
1085
|
+
}
|
732
1086
|
#endif
|
733
1087
|
|
734
1088
|
size_t fvec_L2sqr_ny_nearest(
|
@@ -743,7 +1097,9 @@ size_t fvec_L2sqr_ny_nearest(
|
|
743
1097
|
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
744
1098
|
|
745
1099
|
switch (d) {
|
1100
|
+
DISPATCH(2)
|
746
1101
|
DISPATCH(4)
|
1102
|
+
DISPATCH(8)
|
747
1103
|
default:
|
748
1104
|
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
749
1105
|
}
|
@@ -919,79 +1275,6 @@ static inline __m256 masked_read_8(int d, const float* x) {
|
|
919
1275
|
}
|
920
1276
|
}
|
921
1277
|
|
922
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
923
|
-
__m256 msum1 = _mm256_setzero_ps();
|
924
|
-
|
925
|
-
while (d >= 8) {
|
926
|
-
__m256 mx = _mm256_loadu_ps(x);
|
927
|
-
x += 8;
|
928
|
-
__m256 my = _mm256_loadu_ps(y);
|
929
|
-
y += 8;
|
930
|
-
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
|
931
|
-
d -= 8;
|
932
|
-
}
|
933
|
-
|
934
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
935
|
-
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
936
|
-
|
937
|
-
if (d >= 4) {
|
938
|
-
__m128 mx = _mm_loadu_ps(x);
|
939
|
-
x += 4;
|
940
|
-
__m128 my = _mm_loadu_ps(y);
|
941
|
-
y += 4;
|
942
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
943
|
-
d -= 4;
|
944
|
-
}
|
945
|
-
|
946
|
-
if (d > 0) {
|
947
|
-
__m128 mx = masked_read(d, x);
|
948
|
-
__m128 my = masked_read(d, y);
|
949
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
950
|
-
}
|
951
|
-
|
952
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
953
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
954
|
-
return _mm_cvtss_f32(msum2);
|
955
|
-
}
|
956
|
-
|
957
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
958
|
-
__m256 msum1 = _mm256_setzero_ps();
|
959
|
-
|
960
|
-
while (d >= 8) {
|
961
|
-
__m256 mx = _mm256_loadu_ps(x);
|
962
|
-
x += 8;
|
963
|
-
__m256 my = _mm256_loadu_ps(y);
|
964
|
-
y += 8;
|
965
|
-
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
|
966
|
-
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
|
967
|
-
d -= 8;
|
968
|
-
}
|
969
|
-
|
970
|
-
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
971
|
-
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
972
|
-
|
973
|
-
if (d >= 4) {
|
974
|
-
__m128 mx = _mm_loadu_ps(x);
|
975
|
-
x += 4;
|
976
|
-
__m128 my = _mm_loadu_ps(y);
|
977
|
-
y += 4;
|
978
|
-
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
979
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
980
|
-
d -= 4;
|
981
|
-
}
|
982
|
-
|
983
|
-
if (d > 0) {
|
984
|
-
__m128 mx = masked_read(d, x);
|
985
|
-
__m128 my = masked_read(d, y);
|
986
|
-
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
987
|
-
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
988
|
-
}
|
989
|
-
|
990
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
991
|
-
msum2 = _mm_hadd_ps(msum2, msum2);
|
992
|
-
return _mm_cvtss_f32(msum2);
|
993
|
-
}
|
994
|
-
|
995
1278
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
996
1279
|
__m256 msum1 = _mm256_setzero_ps();
|
997
1280
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
@@ -1082,113 +1365,8 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1082
1365
|
return fvec_Linf_ref(x, y, d);
|
1083
1366
|
}
|
1084
1367
|
|
1085
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
1086
|
-
__m128 msum1 = _mm_setzero_ps();
|
1087
|
-
|
1088
|
-
while (d >= 4) {
|
1089
|
-
__m128 mx = _mm_loadu_ps(x);
|
1090
|
-
x += 4;
|
1091
|
-
__m128 my = _mm_loadu_ps(y);
|
1092
|
-
y += 4;
|
1093
|
-
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
1094
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
1095
|
-
d -= 4;
|
1096
|
-
}
|
1097
|
-
|
1098
|
-
if (d > 0) {
|
1099
|
-
// add the last 1, 2 or 3 values
|
1100
|
-
__m128 mx = masked_read(d, x);
|
1101
|
-
__m128 my = masked_read(d, y);
|
1102
|
-
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
1103
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
1104
|
-
}
|
1105
|
-
|
1106
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
1107
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
1108
|
-
return _mm_cvtss_f32(msum1);
|
1109
|
-
}
|
1110
|
-
|
1111
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
1112
|
-
__m128 mx, my;
|
1113
|
-
__m128 msum1 = _mm_setzero_ps();
|
1114
|
-
|
1115
|
-
while (d >= 4) {
|
1116
|
-
mx = _mm_loadu_ps(x);
|
1117
|
-
x += 4;
|
1118
|
-
my = _mm_loadu_ps(y);
|
1119
|
-
y += 4;
|
1120
|
-
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
|
1121
|
-
d -= 4;
|
1122
|
-
}
|
1123
|
-
|
1124
|
-
// add the last 1, 2, or 3 values
|
1125
|
-
mx = masked_read(d, x);
|
1126
|
-
my = masked_read(d, y);
|
1127
|
-
__m128 prod = _mm_mul_ps(mx, my);
|
1128
|
-
|
1129
|
-
msum1 = _mm_add_ps(msum1, prod);
|
1130
|
-
|
1131
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
1132
|
-
msum1 = _mm_hadd_ps(msum1, msum1);
|
1133
|
-
return _mm_cvtss_f32(msum1);
|
1134
|
-
}
|
1135
|
-
|
1136
1368
|
#elif defined(__aarch64__)
|
1137
1369
|
|
1138
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
1139
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
1140
|
-
const size_t d_simd = d - (d & 3);
|
1141
|
-
size_t i;
|
1142
|
-
for (i = 0; i < d_simd; i += 4) {
|
1143
|
-
float32x4_t xi = vld1q_f32(x + i);
|
1144
|
-
float32x4_t yi = vld1q_f32(y + i);
|
1145
|
-
float32x4_t sq = vsubq_f32(xi, yi);
|
1146
|
-
accux4 = vfmaq_f32(accux4, sq, sq);
|
1147
|
-
}
|
1148
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
1149
|
-
for (; i < d; ++i) {
|
1150
|
-
float32_t xi = x[i];
|
1151
|
-
float32_t yi = y[i];
|
1152
|
-
float32_t sq = xi - yi;
|
1153
|
-
accux1 += sq * sq;
|
1154
|
-
}
|
1155
|
-
return accux1;
|
1156
|
-
}
|
1157
|
-
|
1158
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
1159
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
1160
|
-
const size_t d_simd = d - (d & 3);
|
1161
|
-
size_t i;
|
1162
|
-
for (i = 0; i < d_simd; i += 4) {
|
1163
|
-
float32x4_t xi = vld1q_f32(x + i);
|
1164
|
-
float32x4_t yi = vld1q_f32(y + i);
|
1165
|
-
accux4 = vfmaq_f32(accux4, xi, yi);
|
1166
|
-
}
|
1167
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
1168
|
-
for (; i < d; ++i) {
|
1169
|
-
float32_t xi = x[i];
|
1170
|
-
float32_t yi = y[i];
|
1171
|
-
accux1 += xi * yi;
|
1172
|
-
}
|
1173
|
-
return accux1;
|
1174
|
-
}
|
1175
|
-
|
1176
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
1177
|
-
float32x4_t accux4 = vdupq_n_f32(0);
|
1178
|
-
const size_t d_simd = d - (d & 3);
|
1179
|
-
size_t i;
|
1180
|
-
for (i = 0; i < d_simd; i += 4) {
|
1181
|
-
float32x4_t xi = vld1q_f32(x + i);
|
1182
|
-
accux4 = vfmaq_f32(accux4, xi, xi);
|
1183
|
-
}
|
1184
|
-
float32_t accux1 = vaddvq_f32(accux4);
|
1185
|
-
for (; i < d; ++i) {
|
1186
|
-
float32_t xi = x[i];
|
1187
|
-
accux1 += xi * xi;
|
1188
|
-
}
|
1189
|
-
return accux1;
|
1190
|
-
}
|
1191
|
-
|
1192
1370
|
// not optimized for ARM
|
1193
1371
|
void fvec_L2sqr_ny(
|
1194
1372
|
float* dis,
|
@@ -1199,6 +1377,17 @@ void fvec_L2sqr_ny(
|
|
1199
1377
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
1200
1378
|
}
|
1201
1379
|
|
1380
|
+
void fvec_L2sqr_ny_transposed(
|
1381
|
+
float* dis,
|
1382
|
+
const float* x,
|
1383
|
+
const float* y,
|
1384
|
+
const float* y_sqlen,
|
1385
|
+
size_t d,
|
1386
|
+
size_t d_offset,
|
1387
|
+
size_t ny) {
|
1388
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1389
|
+
}
|
1390
|
+
|
1202
1391
|
size_t fvec_L2sqr_ny_nearest(
|
1203
1392
|
float* distances_tmp_buffer,
|
1204
1393
|
const float* x,
|
@@ -1240,10 +1429,6 @@ void fvec_inner_products_ny(
|
|
1240
1429
|
#else
|
1241
1430
|
// scalar implementation
|
1242
1431
|
|
1243
|
-
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
1244
|
-
return fvec_L2sqr_ref(x, y, d);
|
1245
|
-
}
|
1246
|
-
|
1247
1432
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
1248
1433
|
return fvec_L1_ref(x, y, d);
|
1249
1434
|
}
|
@@ -1252,14 +1437,6 @@ float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
1252
1437
|
return fvec_Linf_ref(x, y, d);
|
1253
1438
|
}
|
1254
1439
|
|
1255
|
-
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
1256
|
-
return fvec_inner_product_ref(x, y, d);
|
1257
|
-
}
|
1258
|
-
|
1259
|
-
float fvec_norm_L2sqr(const float* x, size_t d) {
|
1260
|
-
return fvec_norm_L2sqr_ref(x, d);
|
1261
|
-
}
|
1262
|
-
|
1263
1440
|
void fvec_L2sqr_ny(
|
1264
1441
|
float* dis,
|
1265
1442
|
const float* x,
|
@@ -1269,6 +1446,17 @@ void fvec_L2sqr_ny(
|
|
1269
1446
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
1270
1447
|
}
|
1271
1448
|
|
1449
|
+
void fvec_L2sqr_ny_transposed(
|
1450
|
+
float* dis,
|
1451
|
+
const float* x,
|
1452
|
+
const float* y,
|
1453
|
+
const float* y_sqlen,
|
1454
|
+
size_t d,
|
1455
|
+
size_t d_offset,
|
1456
|
+
size_t ny) {
|
1457
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
1458
|
+
}
|
1459
|
+
|
1272
1460
|
size_t fvec_L2sqr_ny_nearest(
|
1273
1461
|
float* distances_tmp_buffer,
|
1274
1462
|
const float* x,
|