faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/utils/distances.h>
|
|
11
11
|
|
|
12
|
+
#include <algorithm>
|
|
12
13
|
#include <cassert>
|
|
13
14
|
#include <cmath>
|
|
14
15
|
#include <cstdio>
|
|
@@ -112,6 +113,74 @@ void fvec_L2sqr_ny_ref(
|
|
|
112
113
|
}
|
|
113
114
|
}
|
|
114
115
|
|
|
116
|
+
void fvec_L2sqr_ny_y_transposed_ref(
|
|
117
|
+
float* dis,
|
|
118
|
+
const float* x,
|
|
119
|
+
const float* y,
|
|
120
|
+
const float* y_sqlen,
|
|
121
|
+
size_t d,
|
|
122
|
+
size_t d_offset,
|
|
123
|
+
size_t ny) {
|
|
124
|
+
float x_sqlen = 0;
|
|
125
|
+
for (size_t j = 0; j < d; j++) {
|
|
126
|
+
x_sqlen += x[j] * x[j];
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
for (size_t i = 0; i < ny; i++) {
|
|
130
|
+
float dp = 0;
|
|
131
|
+
for (size_t j = 0; j < d; j++) {
|
|
132
|
+
dp += x[j] * y[i + j * d_offset];
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
size_t fvec_L2sqr_ny_nearest_ref(
|
|
140
|
+
float* distances_tmp_buffer,
|
|
141
|
+
const float* x,
|
|
142
|
+
const float* y,
|
|
143
|
+
size_t d,
|
|
144
|
+
size_t ny) {
|
|
145
|
+
fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
|
|
146
|
+
|
|
147
|
+
size_t nearest_idx = 0;
|
|
148
|
+
float min_dis = HUGE_VALF;
|
|
149
|
+
|
|
150
|
+
for (size_t i = 0; i < ny; i++) {
|
|
151
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
152
|
+
min_dis = distances_tmp_buffer[i];
|
|
153
|
+
nearest_idx = i;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return nearest_idx;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
161
|
+
float* distances_tmp_buffer,
|
|
162
|
+
const float* x,
|
|
163
|
+
const float* y,
|
|
164
|
+
const float* y_sqlen,
|
|
165
|
+
size_t d,
|
|
166
|
+
size_t d_offset,
|
|
167
|
+
size_t ny) {
|
|
168
|
+
fvec_L2sqr_ny_y_transposed_ref(
|
|
169
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
170
|
+
|
|
171
|
+
size_t nearest_idx = 0;
|
|
172
|
+
float min_dis = HUGE_VALF;
|
|
173
|
+
|
|
174
|
+
for (size_t i = 0; i < ny; i++) {
|
|
175
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
176
|
+
min_dis = distances_tmp_buffer[i];
|
|
177
|
+
nearest_idx = i;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return nearest_idx;
|
|
182
|
+
}
|
|
183
|
+
|
|
115
184
|
void fvec_inner_products_ny_ref(
|
|
116
185
|
float* ip,
|
|
117
186
|
const float* x,
|
|
@@ -257,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
|
257
326
|
}
|
|
258
327
|
}
|
|
259
328
|
|
|
329
|
+
#ifdef __AVX2__
|
|
330
|
+
|
|
331
|
+
// Specialized versions for AVX2 for any CPUs that support gather/scatter.
|
|
332
|
+
// Todo: implement fvec_op_ny_Dxxx in the same way.
|
|
333
|
+
|
|
334
|
+
template <>
|
|
335
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
|
336
|
+
float* dis,
|
|
337
|
+
const float* x,
|
|
338
|
+
const float* y,
|
|
339
|
+
size_t ny) {
|
|
340
|
+
const size_t ny8 = ny / 8;
|
|
341
|
+
size_t i = 0;
|
|
342
|
+
|
|
343
|
+
if (ny8 > 0) {
|
|
344
|
+
// process 8 D4-vectors per loop.
|
|
345
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
346
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
347
|
+
|
|
348
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
349
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
350
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
351
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
352
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
353
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
354
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
355
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
356
|
+
|
|
357
|
+
const __m256i indices0 =
|
|
358
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
359
|
+
|
|
360
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
361
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
362
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
363
|
+
|
|
364
|
+
// collect dim 0 for 8 D4-vectors.
|
|
365
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
366
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
367
|
+
// collect dim 1 for 8 D4-vectors.
|
|
368
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
369
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
370
|
+
// collect dim 2 for 8 D4-vectors.
|
|
371
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
372
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
373
|
+
// collect dim 3 for 8 D4-vectors.
|
|
374
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
375
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
376
|
+
|
|
377
|
+
// compute distances
|
|
378
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
379
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
380
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
381
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
382
|
+
|
|
383
|
+
// distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
|
|
384
|
+
// (x[1] * y[(i * 8 + 0) * 4 + 1]) +
|
|
385
|
+
// (x[2] * y[(i * 8 + 0) * 4 + 2]) +
|
|
386
|
+
// (x[3] * y[(i * 8 + 0) * 4 + 3])
|
|
387
|
+
// ...
|
|
388
|
+
// distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
|
|
389
|
+
// (x[1] * y[(i * 8 + 7) * 4 + 1]) +
|
|
390
|
+
// (x[2] * y[(i * 8 + 7) * 4 + 2]) +
|
|
391
|
+
// (x[3] * y[(i * 8 + 7) * 4 + 3])
|
|
392
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
393
|
+
|
|
394
|
+
y += 32;
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
if (i < ny) {
|
|
399
|
+
// process leftovers
|
|
400
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
401
|
+
|
|
402
|
+
for (; i < ny; i++) {
|
|
403
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
404
|
+
y += 4;
|
|
405
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
406
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
407
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
template <>
|
|
413
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
|
414
|
+
float* dis,
|
|
415
|
+
const float* x,
|
|
416
|
+
const float* y,
|
|
417
|
+
size_t ny) {
|
|
418
|
+
const size_t ny8 = ny / 8;
|
|
419
|
+
size_t i = 0;
|
|
420
|
+
|
|
421
|
+
if (ny8 > 0) {
|
|
422
|
+
// process 8 D4-vectors per loop.
|
|
423
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
424
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
425
|
+
|
|
426
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
427
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
428
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
429
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
430
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
431
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
432
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
433
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
434
|
+
|
|
435
|
+
const __m256i indices0 =
|
|
436
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
437
|
+
|
|
438
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
439
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
440
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
441
|
+
|
|
442
|
+
// collect dim 0 for 8 D4-vectors.
|
|
443
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
444
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
445
|
+
// collect dim 1 for 8 D4-vectors.
|
|
446
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
447
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
448
|
+
// collect dim 2 for 8 D4-vectors.
|
|
449
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
450
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
451
|
+
// collect dim 3 for 8 D4-vectors.
|
|
452
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
453
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
454
|
+
|
|
455
|
+
// compute differences
|
|
456
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
457
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
458
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
459
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
460
|
+
|
|
461
|
+
// compute squares of differences
|
|
462
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
463
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
464
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
465
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
466
|
+
|
|
467
|
+
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
|
468
|
+
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
469
|
+
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
470
|
+
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
471
|
+
// ...
|
|
472
|
+
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
473
|
+
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
474
|
+
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
475
|
+
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
476
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
477
|
+
|
|
478
|
+
y += 32;
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
if (i < ny) {
|
|
483
|
+
// process leftovers
|
|
484
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
485
|
+
|
|
486
|
+
for (; i < ny; i++) {
|
|
487
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
488
|
+
y += 4;
|
|
489
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
490
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
491
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
#endif
|
|
497
|
+
|
|
260
498
|
template <class ElementOp>
|
|
261
499
|
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
262
500
|
__m128 x0 = _mm_loadu_ps(x);
|
|
@@ -344,6 +582,324 @@ void fvec_inner_products_ny(
|
|
|
344
582
|
#undef DISPATCH
|
|
345
583
|
}
|
|
346
584
|
|
|
585
|
+
#ifdef __AVX2__
|
|
586
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
|
587
|
+
float* distances_tmp_buffer,
|
|
588
|
+
const float* x,
|
|
589
|
+
const float* y,
|
|
590
|
+
size_t ny) {
|
|
591
|
+
// this implementation does not use distances_tmp_buffer.
|
|
592
|
+
|
|
593
|
+
// current index being processed
|
|
594
|
+
size_t i = 0;
|
|
595
|
+
|
|
596
|
+
// min distance and the index of the closest vector so far
|
|
597
|
+
float current_min_distance = HUGE_VALF;
|
|
598
|
+
size_t current_min_index = 0;
|
|
599
|
+
|
|
600
|
+
// process 8 D4-vectors per loop.
|
|
601
|
+
const size_t ny8 = ny / 8;
|
|
602
|
+
|
|
603
|
+
if (ny8 > 0) {
|
|
604
|
+
// track min distance and the closest vector independently
|
|
605
|
+
// for each of 8 AVX2 components.
|
|
606
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
607
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
608
|
+
|
|
609
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
610
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
611
|
+
|
|
612
|
+
//
|
|
613
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
614
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
615
|
+
|
|
616
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
617
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
618
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
619
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
620
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
621
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
622
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
623
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
624
|
+
|
|
625
|
+
const __m256i indices0 =
|
|
626
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
627
|
+
|
|
628
|
+
for (; i < ny8 * 8; i += 8) {
|
|
629
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
630
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
631
|
+
|
|
632
|
+
// collect dim 0 for 8 D4-vectors.
|
|
633
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
634
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
635
|
+
// collect dim 1 for 8 D4-vectors.
|
|
636
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
637
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
638
|
+
// collect dim 2 for 8 D4-vectors.
|
|
639
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
640
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
641
|
+
// collect dim 3 for 8 D4-vectors.
|
|
642
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
643
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
644
|
+
|
|
645
|
+
// compute differences
|
|
646
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
647
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
648
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
649
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
650
|
+
|
|
651
|
+
// compute squares of differences
|
|
652
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
653
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
654
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
655
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
656
|
+
|
|
657
|
+
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
|
658
|
+
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
659
|
+
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
660
|
+
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
661
|
+
// ...
|
|
662
|
+
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
663
|
+
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
664
|
+
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
665
|
+
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
666
|
+
|
|
667
|
+
// compare the new distances to the min distances
|
|
668
|
+
// for each of 8 AVX2 components.
|
|
669
|
+
__m256 comparison =
|
|
670
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
671
|
+
|
|
672
|
+
// update min distances and indices with closest vectors if needed.
|
|
673
|
+
min_distances =
|
|
674
|
+
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
675
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
676
|
+
_mm256_castsi256_ps(current_indices),
|
|
677
|
+
_mm256_castsi256_ps(min_indices),
|
|
678
|
+
comparison));
|
|
679
|
+
|
|
680
|
+
// update current indices values. Basically, +8 to each of the
|
|
681
|
+
// 8 AVX2 components.
|
|
682
|
+
current_indices =
|
|
683
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
684
|
+
|
|
685
|
+
// scroll y forward (8 vectors 4 DIM each).
|
|
686
|
+
y += 32;
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
// dump values and find the minimum distance / minimum index
|
|
690
|
+
float min_distances_scalar[8];
|
|
691
|
+
uint32_t min_indices_scalar[8];
|
|
692
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
693
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
694
|
+
|
|
695
|
+
for (size_t j = 0; j < 8; j++) {
|
|
696
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
697
|
+
current_min_distance = min_distances_scalar[j];
|
|
698
|
+
current_min_index = min_indices_scalar[j];
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
if (i < ny) {
|
|
704
|
+
// process leftovers
|
|
705
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
706
|
+
|
|
707
|
+
for (; i < ny; i++) {
|
|
708
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
709
|
+
y += 4;
|
|
710
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
711
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
712
|
+
|
|
713
|
+
const auto distance = _mm_cvtss_f32(accu);
|
|
714
|
+
|
|
715
|
+
if (current_min_distance > distance) {
|
|
716
|
+
current_min_distance = distance;
|
|
717
|
+
current_min_index = i;
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
return current_min_index;
|
|
723
|
+
}
|
|
724
|
+
#else
|
|
725
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
|
726
|
+
float* distances_tmp_buffer,
|
|
727
|
+
const float* x,
|
|
728
|
+
const float* y,
|
|
729
|
+
size_t ny) {
|
|
730
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
|
731
|
+
}
|
|
732
|
+
#endif
|
|
733
|
+
|
|
734
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
735
|
+
float* distances_tmp_buffer,
|
|
736
|
+
const float* x,
|
|
737
|
+
const float* y,
|
|
738
|
+
size_t d,
|
|
739
|
+
size_t ny) {
|
|
740
|
+
// optimized for a few special cases
|
|
741
|
+
#define DISPATCH(dval) \
|
|
742
|
+
case dval: \
|
|
743
|
+
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
|
744
|
+
|
|
745
|
+
switch (d) {
|
|
746
|
+
DISPATCH(4)
|
|
747
|
+
default:
|
|
748
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
749
|
+
}
|
|
750
|
+
#undef DISPATCH
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
#ifdef __AVX2__
|
|
754
|
+
template <size_t DIM>
|
|
755
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
756
|
+
float* distances_tmp_buffer,
|
|
757
|
+
const float* x,
|
|
758
|
+
const float* y,
|
|
759
|
+
const float* y_sqlen,
|
|
760
|
+
const size_t d_offset,
|
|
761
|
+
size_t ny) {
|
|
762
|
+
// this implementation does not use distances_tmp_buffer.
|
|
763
|
+
|
|
764
|
+
// current index being processed
|
|
765
|
+
size_t i = 0;
|
|
766
|
+
|
|
767
|
+
// min distance and the index of the closest vector so far
|
|
768
|
+
float current_min_distance = HUGE_VALF;
|
|
769
|
+
size_t current_min_index = 0;
|
|
770
|
+
|
|
771
|
+
// process 8 vectors per loop.
|
|
772
|
+
const size_t ny8 = ny / 8;
|
|
773
|
+
|
|
774
|
+
if (ny8 > 0) {
|
|
775
|
+
// track min distance and the closest vector independently
|
|
776
|
+
// for each of 8 AVX2 components.
|
|
777
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
778
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
779
|
+
|
|
780
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
781
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
782
|
+
|
|
783
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
784
|
+
__m256 m[DIM];
|
|
785
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
786
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
787
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
for (; i < ny8 * 8; i += 8) {
|
|
791
|
+
// collect dim 0 for 8 D4-vectors.
|
|
792
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
793
|
+
// compute dot products
|
|
794
|
+
__m256 dp = _mm256_mul_ps(m[0], v0);
|
|
795
|
+
|
|
796
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
797
|
+
// collect dim j for 8 D4-vectors.
|
|
798
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
799
|
+
dp = _mm256_fmadd_ps(m[j], vj, dp);
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
// compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
803
|
+
// lowest distance.
|
|
804
|
+
// x^2 is the constant that can be avoided.
|
|
805
|
+
const __m256 distances =
|
|
806
|
+
_mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
807
|
+
|
|
808
|
+
// compare the new distances to the min distances
|
|
809
|
+
// for each of 8 AVX2 components.
|
|
810
|
+
const __m256 comparison =
|
|
811
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
812
|
+
|
|
813
|
+
// update min distances and indices with closest vectors if needed.
|
|
814
|
+
min_distances =
|
|
815
|
+
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
816
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
817
|
+
_mm256_castsi256_ps(current_indices),
|
|
818
|
+
_mm256_castsi256_ps(min_indices),
|
|
819
|
+
comparison));
|
|
820
|
+
|
|
821
|
+
// update current indices values. Basically, +8 to each of the
|
|
822
|
+
// 8 AVX2 components.
|
|
823
|
+
current_indices =
|
|
824
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
825
|
+
|
|
826
|
+
// scroll y and y_sqlen forward.
|
|
827
|
+
y += 8;
|
|
828
|
+
y_sqlen += 8;
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
// dump values and find the minimum distance / minimum index
|
|
832
|
+
float min_distances_scalar[8];
|
|
833
|
+
uint32_t min_indices_scalar[8];
|
|
834
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
835
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
836
|
+
|
|
837
|
+
for (size_t j = 0; j < 8; j++) {
|
|
838
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
839
|
+
current_min_distance = min_distances_scalar[j];
|
|
840
|
+
current_min_index = min_indices_scalar[j];
|
|
841
|
+
}
|
|
842
|
+
}
|
|
843
|
+
}
|
|
844
|
+
|
|
845
|
+
if (i < ny) {
|
|
846
|
+
// process leftovers
|
|
847
|
+
for (; i < ny; i++) {
|
|
848
|
+
float dp = 0;
|
|
849
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
850
|
+
dp += x[j] * y[j * d_offset];
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
854
|
+
// lowest distance.
|
|
855
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
|
856
|
+
|
|
857
|
+
if (current_min_distance > distance) {
|
|
858
|
+
current_min_distance = distance;
|
|
859
|
+
current_min_index = i;
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
y += 1;
|
|
863
|
+
y_sqlen += 1;
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
return current_min_index;
|
|
868
|
+
}
|
|
869
|
+
#endif
|
|
870
|
+
|
|
871
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
872
|
+
float* distances_tmp_buffer,
|
|
873
|
+
const float* x,
|
|
874
|
+
const float* y,
|
|
875
|
+
const float* y_sqlen,
|
|
876
|
+
size_t d,
|
|
877
|
+
size_t d_offset,
|
|
878
|
+
size_t ny) {
|
|
879
|
+
// optimized for a few special cases
|
|
880
|
+
#ifdef __AVX2__
|
|
881
|
+
#define DISPATCH(dval) \
|
|
882
|
+
case dval: \
|
|
883
|
+
return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
|
|
884
|
+
distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
|
|
885
|
+
|
|
886
|
+
switch (d) {
|
|
887
|
+
DISPATCH(1)
|
|
888
|
+
DISPATCH(2)
|
|
889
|
+
DISPATCH(4)
|
|
890
|
+
DISPATCH(8)
|
|
891
|
+
default:
|
|
892
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
893
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
894
|
+
}
|
|
895
|
+
#undef DISPATCH
|
|
896
|
+
#else
|
|
897
|
+
// non-AVX2 case
|
|
898
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
899
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
900
|
+
#endif
|
|
901
|
+
}
|
|
902
|
+
|
|
347
903
|
#endif
|
|
348
904
|
|
|
349
905
|
#ifdef USE_AVX
|
|
@@ -589,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
|
589
1145
|
float32x4_t sq = vsubq_f32(xi, yi);
|
|
590
1146
|
accux4 = vfmaq_f32(accux4, sq, sq);
|
|
591
1147
|
}
|
|
592
|
-
|
|
593
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1148
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
594
1149
|
for (; i < d; ++i) {
|
|
595
1150
|
float32_t xi = x[i];
|
|
596
1151
|
float32_t yi = y[i];
|
|
@@ -609,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
|
609
1164
|
float32x4_t yi = vld1q_f32(y + i);
|
|
610
1165
|
accux4 = vfmaq_f32(accux4, xi, yi);
|
|
611
1166
|
}
|
|
612
|
-
|
|
613
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1167
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
614
1168
|
for (; i < d; ++i) {
|
|
615
1169
|
float32_t xi = x[i];
|
|
616
1170
|
float32_t yi = y[i];
|
|
@@ -627,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
|
627
1181
|
float32x4_t xi = vld1q_f32(x + i);
|
|
628
1182
|
accux4 = vfmaq_f32(accux4, xi, xi);
|
|
629
1183
|
}
|
|
630
|
-
|
|
631
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1184
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
632
1185
|
for (; i < d; ++i) {
|
|
633
1186
|
float32_t xi = x[i];
|
|
634
1187
|
accux1 += xi * xi;
|
|
@@ -646,6 +1199,27 @@ void fvec_L2sqr_ny(
|
|
|
646
1199
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
647
1200
|
}
|
|
648
1201
|
|
|
1202
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
1203
|
+
float* distances_tmp_buffer,
|
|
1204
|
+
const float* x,
|
|
1205
|
+
const float* y,
|
|
1206
|
+
size_t d,
|
|
1207
|
+
size_t ny) {
|
|
1208
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1212
|
+
float* distances_tmp_buffer,
|
|
1213
|
+
const float* x,
|
|
1214
|
+
const float* y,
|
|
1215
|
+
const float* y_sqlen,
|
|
1216
|
+
size_t d,
|
|
1217
|
+
size_t d_offset,
|
|
1218
|
+
size_t ny) {
|
|
1219
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
1220
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
1221
|
+
}
|
|
1222
|
+
|
|
649
1223
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
650
1224
|
return fvec_L1_ref(x, y, d);
|
|
651
1225
|
}
|
|
@@ -695,6 +1269,27 @@ void fvec_L2sqr_ny(
|
|
|
695
1269
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
696
1270
|
}
|
|
697
1271
|
|
|
1272
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
1273
|
+
float* distances_tmp_buffer,
|
|
1274
|
+
const float* x,
|
|
1275
|
+
const float* y,
|
|
1276
|
+
size_t d,
|
|
1277
|
+
size_t ny) {
|
|
1278
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1282
|
+
float* distances_tmp_buffer,
|
|
1283
|
+
const float* x,
|
|
1284
|
+
const float* y,
|
|
1285
|
+
const float* y_sqlen,
|
|
1286
|
+
size_t d,
|
|
1287
|
+
size_t d_offset,
|
|
1288
|
+
size_t ny) {
|
|
1289
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
1290
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
1291
|
+
}
|
|
1292
|
+
|
|
698
1293
|
void fvec_inner_products_ny(
|
|
699
1294
|
float* dis,
|
|
700
1295
|
const float* x,
|
|
@@ -720,6 +1315,61 @@ static inline void fvec_madd_ref(
|
|
|
720
1315
|
c[i] = a[i] + bf * b[i];
|
|
721
1316
|
}
|
|
722
1317
|
|
|
1318
|
+
#ifdef __AVX2__
|
|
1319
|
+
static inline void fvec_madd_avx2(
|
|
1320
|
+
const size_t n,
|
|
1321
|
+
const float* __restrict a,
|
|
1322
|
+
const float bf,
|
|
1323
|
+
const float* __restrict b,
|
|
1324
|
+
float* __restrict c) {
|
|
1325
|
+
//
|
|
1326
|
+
const size_t n8 = n / 8;
|
|
1327
|
+
const size_t n_for_masking = n % 8;
|
|
1328
|
+
|
|
1329
|
+
const __m256 bfmm = _mm256_set1_ps(bf);
|
|
1330
|
+
|
|
1331
|
+
size_t idx = 0;
|
|
1332
|
+
for (idx = 0; idx < n8 * 8; idx += 8) {
|
|
1333
|
+
const __m256 ax = _mm256_loadu_ps(a + idx);
|
|
1334
|
+
const __m256 bx = _mm256_loadu_ps(b + idx);
|
|
1335
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
1336
|
+
_mm256_storeu_ps(c + idx, abmul);
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1339
|
+
if (n_for_masking > 0) {
|
|
1340
|
+
__m256i mask;
|
|
1341
|
+
switch (n_for_masking) {
|
|
1342
|
+
case 1:
|
|
1343
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
|
|
1344
|
+
break;
|
|
1345
|
+
case 2:
|
|
1346
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
|
|
1347
|
+
break;
|
|
1348
|
+
case 3:
|
|
1349
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
|
|
1350
|
+
break;
|
|
1351
|
+
case 4:
|
|
1352
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
|
|
1353
|
+
break;
|
|
1354
|
+
case 5:
|
|
1355
|
+
mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
|
|
1356
|
+
break;
|
|
1357
|
+
case 6:
|
|
1358
|
+
mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
|
|
1359
|
+
break;
|
|
1360
|
+
case 7:
|
|
1361
|
+
mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
|
|
1362
|
+
break;
|
|
1363
|
+
}
|
|
1364
|
+
|
|
1365
|
+
const __m256 ax = _mm256_maskload_ps(a + idx, mask);
|
|
1366
|
+
const __m256 bx = _mm256_maskload_ps(b + idx, mask);
|
|
1367
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
1368
|
+
_mm256_maskstore_ps(c + idx, mask, abmul);
|
|
1369
|
+
}
|
|
1370
|
+
}
|
|
1371
|
+
#endif
|
|
1372
|
+
|
|
723
1373
|
#ifdef __SSE3__
|
|
724
1374
|
|
|
725
1375
|
static inline void fvec_madd_sse(
|
|
@@ -743,10 +1393,30 @@ static inline void fvec_madd_sse(
|
|
|
743
1393
|
}
|
|
744
1394
|
|
|
745
1395
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1396
|
+
#ifdef __AVX2__
|
|
1397
|
+
fvec_madd_avx2(n, a, bf, b, c);
|
|
1398
|
+
#else
|
|
746
1399
|
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
747
1400
|
fvec_madd_sse(n, a, bf, b, c);
|
|
748
1401
|
else
|
|
749
1402
|
fvec_madd_ref(n, a, bf, b, c);
|
|
1403
|
+
#endif
|
|
1404
|
+
}
|
|
1405
|
+
|
|
1406
|
+
#elif defined(__aarch64__)
|
|
1407
|
+
|
|
1408
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1409
|
+
const size_t n_simd = n - (n & 3);
|
|
1410
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
1411
|
+
size_t i;
|
|
1412
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
1413
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
1414
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
1415
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
1416
|
+
vst1q_f32(c + i, ci);
|
|
1417
|
+
}
|
|
1418
|
+
for (; i < n; ++i)
|
|
1419
|
+
c[i] = a[i] + bf * b[i];
|
|
750
1420
|
}
|
|
751
1421
|
|
|
752
1422
|
#else
|
|
@@ -842,6 +1512,57 @@ int fvec_madd_and_argmin(
|
|
|
842
1512
|
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
843
1513
|
}
|
|
844
1514
|
|
|
1515
|
+
#elif defined(__aarch64__)
|
|
1516
|
+
|
|
1517
|
+
int fvec_madd_and_argmin(
|
|
1518
|
+
size_t n,
|
|
1519
|
+
const float* a,
|
|
1520
|
+
float bf,
|
|
1521
|
+
const float* b,
|
|
1522
|
+
float* c) {
|
|
1523
|
+
float32x4_t vminv = vdupq_n_f32(1e20);
|
|
1524
|
+
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
|
|
1525
|
+
size_t i;
|
|
1526
|
+
{
|
|
1527
|
+
const size_t n_simd = n - (n & 3);
|
|
1528
|
+
const uint32_t iota[] = {0, 1, 2, 3};
|
|
1529
|
+
uint32x4_t iv = vld1q_u32(iota);
|
|
1530
|
+
const uint32x4_t incv = vdupq_n_u32(4);
|
|
1531
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
1532
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
1533
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
1534
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
1535
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
1536
|
+
vst1q_f32(c + i, ci);
|
|
1537
|
+
const uint32x4_t less_than = vcltq_f32(ci, vminv);
|
|
1538
|
+
vminv = vminq_f32(ci, vminv);
|
|
1539
|
+
iminv = vorrq_u32(
|
|
1540
|
+
vandq_u32(less_than, iv),
|
|
1541
|
+
vandq_u32(vmvnq_u32(less_than), iminv));
|
|
1542
|
+
iv = vaddq_u32(iv, incv);
|
|
1543
|
+
}
|
|
1544
|
+
}
|
|
1545
|
+
float vmin = vminvq_f32(vminv);
|
|
1546
|
+
uint32_t imin;
|
|
1547
|
+
{
|
|
1548
|
+
const float32x4_t vminy = vdupq_n_f32(vmin);
|
|
1549
|
+
const uint32x4_t equals = vceqq_f32(vminv, vminy);
|
|
1550
|
+
imin = vminvq_u32(vorrq_u32(
|
|
1551
|
+
vandq_u32(equals, iminv),
|
|
1552
|
+
vandq_u32(
|
|
1553
|
+
vmvnq_u32(equals),
|
|
1554
|
+
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
|
|
1555
|
+
}
|
|
1556
|
+
for (; i < n; ++i) {
|
|
1557
|
+
c[i] = a[i] + bf * b[i];
|
|
1558
|
+
if (c[i] < vmin) {
|
|
1559
|
+
vmin = c[i];
|
|
1560
|
+
imin = static_cast<uint32_t>(i);
|
|
1561
|
+
}
|
|
1562
|
+
}
|
|
1563
|
+
return static_cast<int>(imin);
|
|
1564
|
+
}
|
|
1565
|
+
|
|
845
1566
|
#else
|
|
846
1567
|
|
|
847
1568
|
int fvec_madd_and_argmin(
|
|
@@ -973,4 +1694,53 @@ void compute_PQ_dis_tables_dsub2(
|
|
|
973
1694
|
}
|
|
974
1695
|
}
|
|
975
1696
|
|
|
1697
|
+
/*********************************************************
|
|
1698
|
+
* Vector to vector functions
|
|
1699
|
+
*********************************************************/
|
|
1700
|
+
|
|
1701
|
+
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
|
|
1702
|
+
size_t i;
|
|
1703
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
1704
|
+
simd8float32 ci, ai, bi;
|
|
1705
|
+
ai.loadu(a + i);
|
|
1706
|
+
bi.loadu(b + i);
|
|
1707
|
+
ci = ai - bi;
|
|
1708
|
+
ci.storeu(c + i);
|
|
1709
|
+
}
|
|
1710
|
+
// finish non-multiple of 8 remainder
|
|
1711
|
+
for (; i < d; i++) {
|
|
1712
|
+
c[i] = a[i] - b[i];
|
|
1713
|
+
}
|
|
1714
|
+
}
|
|
1715
|
+
|
|
1716
|
+
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
1717
|
+
size_t i;
|
|
1718
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
1719
|
+
simd8float32 ci, ai, bi;
|
|
1720
|
+
ai.loadu(a + i);
|
|
1721
|
+
bi.loadu(b + i);
|
|
1722
|
+
ci = ai + bi;
|
|
1723
|
+
ci.storeu(c + i);
|
|
1724
|
+
}
|
|
1725
|
+
// finish non-multiple of 8 remainder
|
|
1726
|
+
for (; i < d; i++) {
|
|
1727
|
+
c[i] = a[i] + b[i];
|
|
1728
|
+
}
|
|
1729
|
+
}
|
|
1730
|
+
|
|
1731
|
+
void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
1732
|
+
size_t i;
|
|
1733
|
+
simd8float32 bv(b);
|
|
1734
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
1735
|
+
simd8float32 ci, ai, bi;
|
|
1736
|
+
ai.loadu(a + i);
|
|
1737
|
+
ci = ai + bv;
|
|
1738
|
+
ci.storeu(c + i);
|
|
1739
|
+
}
|
|
1740
|
+
// finish non-multiple of 8 remainder
|
|
1741
|
+
for (; i < d; i++) {
|
|
1742
|
+
c[i] = a[i] + b;
|
|
1743
|
+
}
|
|
1744
|
+
}
|
|
1745
|
+
|
|
976
1746
|
} // namespace faiss
|