faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -113,6 +113,74 @@ void fvec_L2sqr_ny_ref(
|
|
|
113
113
|
}
|
|
114
114
|
}
|
|
115
115
|
|
|
116
|
+
void fvec_L2sqr_ny_y_transposed_ref(
|
|
117
|
+
float* dis,
|
|
118
|
+
const float* x,
|
|
119
|
+
const float* y,
|
|
120
|
+
const float* y_sqlen,
|
|
121
|
+
size_t d,
|
|
122
|
+
size_t d_offset,
|
|
123
|
+
size_t ny) {
|
|
124
|
+
float x_sqlen = 0;
|
|
125
|
+
for (size_t j = 0; j < d; j++) {
|
|
126
|
+
x_sqlen += x[j] * x[j];
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
for (size_t i = 0; i < ny; i++) {
|
|
130
|
+
float dp = 0;
|
|
131
|
+
for (size_t j = 0; j < d; j++) {
|
|
132
|
+
dp += x[j] * y[i + j * d_offset];
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
size_t fvec_L2sqr_ny_nearest_ref(
|
|
140
|
+
float* distances_tmp_buffer,
|
|
141
|
+
const float* x,
|
|
142
|
+
const float* y,
|
|
143
|
+
size_t d,
|
|
144
|
+
size_t ny) {
|
|
145
|
+
fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny);
|
|
146
|
+
|
|
147
|
+
size_t nearest_idx = 0;
|
|
148
|
+
float min_dis = HUGE_VALF;
|
|
149
|
+
|
|
150
|
+
for (size_t i = 0; i < ny; i++) {
|
|
151
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
152
|
+
min_dis = distances_tmp_buffer[i];
|
|
153
|
+
nearest_idx = i;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return nearest_idx;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
161
|
+
float* distances_tmp_buffer,
|
|
162
|
+
const float* x,
|
|
163
|
+
const float* y,
|
|
164
|
+
const float* y_sqlen,
|
|
165
|
+
size_t d,
|
|
166
|
+
size_t d_offset,
|
|
167
|
+
size_t ny) {
|
|
168
|
+
fvec_L2sqr_ny_y_transposed_ref(
|
|
169
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
170
|
+
|
|
171
|
+
size_t nearest_idx = 0;
|
|
172
|
+
float min_dis = HUGE_VALF;
|
|
173
|
+
|
|
174
|
+
for (size_t i = 0; i < ny; i++) {
|
|
175
|
+
if (distances_tmp_buffer[i] < min_dis) {
|
|
176
|
+
min_dis = distances_tmp_buffer[i];
|
|
177
|
+
nearest_idx = i;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return nearest_idx;
|
|
182
|
+
}
|
|
183
|
+
|
|
116
184
|
void fvec_inner_products_ny_ref(
|
|
117
185
|
float* ip,
|
|
118
186
|
const float* x,
|
|
@@ -258,6 +326,175 @@ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
|
258
326
|
}
|
|
259
327
|
}
|
|
260
328
|
|
|
329
|
+
#ifdef __AVX2__
|
|
330
|
+
|
|
331
|
+
// Specialized versions for AVX2 for any CPUs that support gather/scatter.
|
|
332
|
+
// Todo: implement fvec_op_ny_Dxxx in the same way.
|
|
333
|
+
|
|
334
|
+
template <>
|
|
335
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
|
336
|
+
float* dis,
|
|
337
|
+
const float* x,
|
|
338
|
+
const float* y,
|
|
339
|
+
size_t ny) {
|
|
340
|
+
const size_t ny8 = ny / 8;
|
|
341
|
+
size_t i = 0;
|
|
342
|
+
|
|
343
|
+
if (ny8 > 0) {
|
|
344
|
+
// process 8 D4-vectors per loop.
|
|
345
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
346
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
347
|
+
|
|
348
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
349
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
350
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
351
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
352
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
353
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
354
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
355
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
356
|
+
|
|
357
|
+
const __m256i indices0 =
|
|
358
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
359
|
+
|
|
360
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
361
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
362
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
363
|
+
|
|
364
|
+
// collect dim 0 for 8 D4-vectors.
|
|
365
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
366
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
367
|
+
// collect dim 1 for 8 D4-vectors.
|
|
368
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
369
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
370
|
+
// collect dim 2 for 8 D4-vectors.
|
|
371
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
372
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
373
|
+
// collect dim 3 for 8 D4-vectors.
|
|
374
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
375
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
376
|
+
|
|
377
|
+
// compute distances
|
|
378
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
379
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
380
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
381
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
382
|
+
|
|
383
|
+
// distances[0] = (x[0] * y[(i * 8 + 0) * 4 + 0]) +
|
|
384
|
+
// (x[1] * y[(i * 8 + 0) * 4 + 1]) +
|
|
385
|
+
// (x[2] * y[(i * 8 + 0) * 4 + 2]) +
|
|
386
|
+
// (x[3] * y[(i * 8 + 0) * 4 + 3])
|
|
387
|
+
// ...
|
|
388
|
+
// distances[7] = (x[0] * y[(i * 8 + 7) * 4 + 0]) +
|
|
389
|
+
// (x[1] * y[(i * 8 + 7) * 4 + 1]) +
|
|
390
|
+
// (x[2] * y[(i * 8 + 7) * 4 + 2]) +
|
|
391
|
+
// (x[3] * y[(i * 8 + 7) * 4 + 3])
|
|
392
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
393
|
+
|
|
394
|
+
y += 32;
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
if (i < ny) {
|
|
399
|
+
// process leftovers
|
|
400
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
401
|
+
|
|
402
|
+
for (; i < ny; i++) {
|
|
403
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
404
|
+
y += 4;
|
|
405
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
406
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
407
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
template <>
|
|
413
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
|
414
|
+
float* dis,
|
|
415
|
+
const float* x,
|
|
416
|
+
const float* y,
|
|
417
|
+
size_t ny) {
|
|
418
|
+
const size_t ny8 = ny / 8;
|
|
419
|
+
size_t i = 0;
|
|
420
|
+
|
|
421
|
+
if (ny8 > 0) {
|
|
422
|
+
// process 8 D4-vectors per loop.
|
|
423
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
424
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
425
|
+
|
|
426
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
427
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
428
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
429
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
430
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
431
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
432
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
433
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
434
|
+
|
|
435
|
+
const __m256i indices0 =
|
|
436
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
437
|
+
|
|
438
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
439
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
440
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
441
|
+
|
|
442
|
+
// collect dim 0 for 8 D4-vectors.
|
|
443
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
444
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
445
|
+
// collect dim 1 for 8 D4-vectors.
|
|
446
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
447
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
448
|
+
// collect dim 2 for 8 D4-vectors.
|
|
449
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
450
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
451
|
+
// collect dim 3 for 8 D4-vectors.
|
|
452
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
453
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
454
|
+
|
|
455
|
+
// compute differences
|
|
456
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
457
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
458
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
459
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
460
|
+
|
|
461
|
+
// compute squares of differences
|
|
462
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
463
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
464
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
465
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
466
|
+
|
|
467
|
+
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
|
468
|
+
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
469
|
+
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
470
|
+
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
471
|
+
// ...
|
|
472
|
+
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
473
|
+
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
474
|
+
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
475
|
+
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
476
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
477
|
+
|
|
478
|
+
y += 32;
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
if (i < ny) {
|
|
483
|
+
// process leftovers
|
|
484
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
485
|
+
|
|
486
|
+
for (; i < ny; i++) {
|
|
487
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
488
|
+
y += 4;
|
|
489
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
490
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
491
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
#endif
|
|
497
|
+
|
|
261
498
|
template <class ElementOp>
|
|
262
499
|
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
263
500
|
__m128 x0 = _mm_loadu_ps(x);
|
|
@@ -345,6 +582,324 @@ void fvec_inner_products_ny(
|
|
|
345
582
|
#undef DISPATCH
|
|
346
583
|
}
|
|
347
584
|
|
|
585
|
+
#ifdef __AVX2__
|
|
586
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
|
587
|
+
float* distances_tmp_buffer,
|
|
588
|
+
const float* x,
|
|
589
|
+
const float* y,
|
|
590
|
+
size_t ny) {
|
|
591
|
+
// this implementation does not use distances_tmp_buffer.
|
|
592
|
+
|
|
593
|
+
// current index being processed
|
|
594
|
+
size_t i = 0;
|
|
595
|
+
|
|
596
|
+
// min distance and the index of the closest vector so far
|
|
597
|
+
float current_min_distance = HUGE_VALF;
|
|
598
|
+
size_t current_min_index = 0;
|
|
599
|
+
|
|
600
|
+
// process 8 D4-vectors per loop.
|
|
601
|
+
const size_t ny8 = ny / 8;
|
|
602
|
+
|
|
603
|
+
if (ny8 > 0) {
|
|
604
|
+
// track min distance and the closest vector independently
|
|
605
|
+
// for each of 8 AVX2 components.
|
|
606
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
607
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
608
|
+
|
|
609
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
610
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
611
|
+
|
|
612
|
+
//
|
|
613
|
+
_mm_prefetch(y, _MM_HINT_NTA);
|
|
614
|
+
_mm_prefetch(y + 16, _MM_HINT_NTA);
|
|
615
|
+
|
|
616
|
+
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
617
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
618
|
+
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
619
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
620
|
+
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
621
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
622
|
+
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
623
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
624
|
+
|
|
625
|
+
const __m256i indices0 =
|
|
626
|
+
_mm256_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112);
|
|
627
|
+
|
|
628
|
+
for (; i < ny8 * 8; i += 8) {
|
|
629
|
+
_mm_prefetch(y + 32, _MM_HINT_NTA);
|
|
630
|
+
_mm_prefetch(y + 48, _MM_HINT_NTA);
|
|
631
|
+
|
|
632
|
+
// collect dim 0 for 8 D4-vectors.
|
|
633
|
+
// v0 = (y[(i * 8 + 0) * 4 + 0], ..., y[(i * 8 + 7) * 4 + 0])
|
|
634
|
+
const __m256 v0 = _mm256_i32gather_ps(y, indices0, 1);
|
|
635
|
+
// collect dim 1 for 8 D4-vectors.
|
|
636
|
+
// v1 = (y[(i * 8 + 0) * 4 + 1], ..., y[(i * 8 + 7) * 4 + 1])
|
|
637
|
+
const __m256 v1 = _mm256_i32gather_ps(y + 1, indices0, 1);
|
|
638
|
+
// collect dim 2 for 8 D4-vectors.
|
|
639
|
+
// v2 = (y[(i * 8 + 0) * 4 + 2], ..., y[(i * 8 + 7) * 4 + 2])
|
|
640
|
+
const __m256 v2 = _mm256_i32gather_ps(y + 2, indices0, 1);
|
|
641
|
+
// collect dim 3 for 8 D4-vectors.
|
|
642
|
+
// v3 = (y[(i * 8 + 0) * 4 + 3], ..., y[(i * 8 + 7) * 4 + 3])
|
|
643
|
+
const __m256 v3 = _mm256_i32gather_ps(y + 3, indices0, 1);
|
|
644
|
+
|
|
645
|
+
// compute differences
|
|
646
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
647
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
648
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
649
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
650
|
+
|
|
651
|
+
// compute squares of differences
|
|
652
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
653
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
654
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
655
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
656
|
+
|
|
657
|
+
// distances[0] = (x[0] - y[(i * 8 + 0) * 4 + 0]) ^ 2 +
|
|
658
|
+
// (x[1] - y[(i * 8 + 0) * 4 + 1]) ^ 2 +
|
|
659
|
+
// (x[2] - y[(i * 8 + 0) * 4 + 2]) ^ 2 +
|
|
660
|
+
// (x[3] - y[(i * 8 + 0) * 4 + 3])
|
|
661
|
+
// ...
|
|
662
|
+
// distances[7] = (x[0] - y[(i * 8 + 7) * 4 + 0]) ^ 2 +
|
|
663
|
+
// (x[1] - y[(i * 8 + 7) * 4 + 1]) ^ 2 +
|
|
664
|
+
// (x[2] - y[(i * 8 + 7) * 4 + 2]) ^ 2 +
|
|
665
|
+
// (x[3] - y[(i * 8 + 7) * 4 + 3])
|
|
666
|
+
|
|
667
|
+
// compare the new distances to the min distances
|
|
668
|
+
// for each of 8 AVX2 components.
|
|
669
|
+
__m256 comparison =
|
|
670
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
671
|
+
|
|
672
|
+
// update min distances and indices with closest vectors if needed.
|
|
673
|
+
min_distances =
|
|
674
|
+
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
675
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
676
|
+
_mm256_castsi256_ps(current_indices),
|
|
677
|
+
_mm256_castsi256_ps(min_indices),
|
|
678
|
+
comparison));
|
|
679
|
+
|
|
680
|
+
// update current indices values. Basically, +8 to each of the
|
|
681
|
+
// 8 AVX2 components.
|
|
682
|
+
current_indices =
|
|
683
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
684
|
+
|
|
685
|
+
// scroll y forward (8 vectors 4 DIM each).
|
|
686
|
+
y += 32;
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
// dump values and find the minimum distance / minimum index
|
|
690
|
+
float min_distances_scalar[8];
|
|
691
|
+
uint32_t min_indices_scalar[8];
|
|
692
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
693
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
694
|
+
|
|
695
|
+
for (size_t j = 0; j < 8; j++) {
|
|
696
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
697
|
+
current_min_distance = min_distances_scalar[j];
|
|
698
|
+
current_min_index = min_indices_scalar[j];
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
if (i < ny) {
|
|
704
|
+
// process leftovers
|
|
705
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
706
|
+
|
|
707
|
+
for (; i < ny; i++) {
|
|
708
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
709
|
+
y += 4;
|
|
710
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
711
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
712
|
+
|
|
713
|
+
const auto distance = _mm_cvtss_f32(accu);
|
|
714
|
+
|
|
715
|
+
if (current_min_distance > distance) {
|
|
716
|
+
current_min_distance = distance;
|
|
717
|
+
current_min_index = i;
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
return current_min_index;
|
|
723
|
+
}
|
|
724
|
+
#else
|
|
725
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
|
726
|
+
float* distances_tmp_buffer,
|
|
727
|
+
const float* x,
|
|
728
|
+
const float* y,
|
|
729
|
+
size_t ny) {
|
|
730
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny);
|
|
731
|
+
}
|
|
732
|
+
#endif
|
|
733
|
+
|
|
734
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
735
|
+
float* distances_tmp_buffer,
|
|
736
|
+
const float* x,
|
|
737
|
+
const float* y,
|
|
738
|
+
size_t d,
|
|
739
|
+
size_t ny) {
|
|
740
|
+
// optimized for a few special cases
|
|
741
|
+
#define DISPATCH(dval) \
|
|
742
|
+
case dval: \
|
|
743
|
+
return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny);
|
|
744
|
+
|
|
745
|
+
switch (d) {
|
|
746
|
+
DISPATCH(4)
|
|
747
|
+
default:
|
|
748
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
749
|
+
}
|
|
750
|
+
#undef DISPATCH
|
|
751
|
+
}
|
|
752
|
+
|
|
753
|
+
#ifdef __AVX2__
|
|
754
|
+
template <size_t DIM>
|
|
755
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
756
|
+
float* distances_tmp_buffer,
|
|
757
|
+
const float* x,
|
|
758
|
+
const float* y,
|
|
759
|
+
const float* y_sqlen,
|
|
760
|
+
const size_t d_offset,
|
|
761
|
+
size_t ny) {
|
|
762
|
+
// this implementation does not use distances_tmp_buffer.
|
|
763
|
+
|
|
764
|
+
// current index being processed
|
|
765
|
+
size_t i = 0;
|
|
766
|
+
|
|
767
|
+
// min distance and the index of the closest vector so far
|
|
768
|
+
float current_min_distance = HUGE_VALF;
|
|
769
|
+
size_t current_min_index = 0;
|
|
770
|
+
|
|
771
|
+
// process 8 vectors per loop.
|
|
772
|
+
const size_t ny8 = ny / 8;
|
|
773
|
+
|
|
774
|
+
if (ny8 > 0) {
|
|
775
|
+
// track min distance and the closest vector independently
|
|
776
|
+
// for each of 8 AVX2 components.
|
|
777
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
778
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
779
|
+
|
|
780
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
781
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
782
|
+
|
|
783
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
784
|
+
__m256 m[DIM];
|
|
785
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
786
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
787
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
788
|
+
}
|
|
789
|
+
|
|
790
|
+
for (; i < ny8 * 8; i += 8) {
|
|
791
|
+
// collect dim 0 for 8 D4-vectors.
|
|
792
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
793
|
+
// compute dot products
|
|
794
|
+
__m256 dp = _mm256_mul_ps(m[0], v0);
|
|
795
|
+
|
|
796
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
797
|
+
// collect dim j for 8 D4-vectors.
|
|
798
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
799
|
+
dp = _mm256_fmadd_ps(m[j], vj, dp);
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
// compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
803
|
+
// lowest distance.
|
|
804
|
+
// x^2 is the constant that can be avoided.
|
|
805
|
+
const __m256 distances =
|
|
806
|
+
_mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
807
|
+
|
|
808
|
+
// compare the new distances to the min distances
|
|
809
|
+
// for each of 8 AVX2 components.
|
|
810
|
+
const __m256 comparison =
|
|
811
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
812
|
+
|
|
813
|
+
// update min distances and indices with closest vectors if needed.
|
|
814
|
+
min_distances =
|
|
815
|
+
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
816
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
817
|
+
_mm256_castsi256_ps(current_indices),
|
|
818
|
+
_mm256_castsi256_ps(min_indices),
|
|
819
|
+
comparison));
|
|
820
|
+
|
|
821
|
+
// update current indices values. Basically, +8 to each of the
|
|
822
|
+
// 8 AVX2 components.
|
|
823
|
+
current_indices =
|
|
824
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
825
|
+
|
|
826
|
+
// scroll y and y_sqlen forward.
|
|
827
|
+
y += 8;
|
|
828
|
+
y_sqlen += 8;
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
// dump values and find the minimum distance / minimum index
|
|
832
|
+
float min_distances_scalar[8];
|
|
833
|
+
uint32_t min_indices_scalar[8];
|
|
834
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
835
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
836
|
+
|
|
837
|
+
for (size_t j = 0; j < 8; j++) {
|
|
838
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
839
|
+
current_min_distance = min_distances_scalar[j];
|
|
840
|
+
current_min_index = min_indices_scalar[j];
|
|
841
|
+
}
|
|
842
|
+
}
|
|
843
|
+
}
|
|
844
|
+
|
|
845
|
+
if (i < ny) {
|
|
846
|
+
// process leftovers
|
|
847
|
+
for (; i < ny; i++) {
|
|
848
|
+
float dp = 0;
|
|
849
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
850
|
+
dp += x[j] * y[j * d_offset];
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
854
|
+
// lowest distance.
|
|
855
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
|
856
|
+
|
|
857
|
+
if (current_min_distance > distance) {
|
|
858
|
+
current_min_distance = distance;
|
|
859
|
+
current_min_index = i;
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
y += 1;
|
|
863
|
+
y_sqlen += 1;
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
return current_min_index;
|
|
868
|
+
}
|
|
869
|
+
#endif
|
|
870
|
+
|
|
871
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
872
|
+
float* distances_tmp_buffer,
|
|
873
|
+
const float* x,
|
|
874
|
+
const float* y,
|
|
875
|
+
const float* y_sqlen,
|
|
876
|
+
size_t d,
|
|
877
|
+
size_t d_offset,
|
|
878
|
+
size_t ny) {
|
|
879
|
+
// optimized for a few special cases
|
|
880
|
+
#ifdef __AVX2__
|
|
881
|
+
#define DISPATCH(dval) \
|
|
882
|
+
case dval: \
|
|
883
|
+
return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
|
|
884
|
+
distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
|
|
885
|
+
|
|
886
|
+
switch (d) {
|
|
887
|
+
DISPATCH(1)
|
|
888
|
+
DISPATCH(2)
|
|
889
|
+
DISPATCH(4)
|
|
890
|
+
DISPATCH(8)
|
|
891
|
+
default:
|
|
892
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
893
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
894
|
+
}
|
|
895
|
+
#undef DISPATCH
|
|
896
|
+
#else
|
|
897
|
+
// non-AVX2 case
|
|
898
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
899
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
900
|
+
#endif
|
|
901
|
+
}
|
|
902
|
+
|
|
348
903
|
#endif
|
|
349
904
|
|
|
350
905
|
#ifdef USE_AVX
|
|
@@ -590,8 +1145,7 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
|
590
1145
|
float32x4_t sq = vsubq_f32(xi, yi);
|
|
591
1146
|
accux4 = vfmaq_f32(accux4, sq, sq);
|
|
592
1147
|
}
|
|
593
|
-
|
|
594
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1148
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
595
1149
|
for (; i < d; ++i) {
|
|
596
1150
|
float32_t xi = x[i];
|
|
597
1151
|
float32_t yi = y[i];
|
|
@@ -610,8 +1164,7 @@ float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
|
610
1164
|
float32x4_t yi = vld1q_f32(y + i);
|
|
611
1165
|
accux4 = vfmaq_f32(accux4, xi, yi);
|
|
612
1166
|
}
|
|
613
|
-
|
|
614
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1167
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
615
1168
|
for (; i < d; ++i) {
|
|
616
1169
|
float32_t xi = x[i];
|
|
617
1170
|
float32_t yi = y[i];
|
|
@@ -628,8 +1181,7 @@ float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
|
628
1181
|
float32x4_t xi = vld1q_f32(x + i);
|
|
629
1182
|
accux4 = vfmaq_f32(accux4, xi, xi);
|
|
630
1183
|
}
|
|
631
|
-
|
|
632
|
-
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
1184
|
+
float32_t accux1 = vaddvq_f32(accux4);
|
|
633
1185
|
for (; i < d; ++i) {
|
|
634
1186
|
float32_t xi = x[i];
|
|
635
1187
|
accux1 += xi * xi;
|
|
@@ -647,6 +1199,27 @@ void fvec_L2sqr_ny(
|
|
|
647
1199
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
648
1200
|
}
|
|
649
1201
|
|
|
1202
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
1203
|
+
float* distances_tmp_buffer,
|
|
1204
|
+
const float* x,
|
|
1205
|
+
const float* y,
|
|
1206
|
+
size_t d,
|
|
1207
|
+
size_t ny) {
|
|
1208
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
1209
|
+
}
|
|
1210
|
+
|
|
1211
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1212
|
+
float* distances_tmp_buffer,
|
|
1213
|
+
const float* x,
|
|
1214
|
+
const float* y,
|
|
1215
|
+
const float* y_sqlen,
|
|
1216
|
+
size_t d,
|
|
1217
|
+
size_t d_offset,
|
|
1218
|
+
size_t ny) {
|
|
1219
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
1220
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
1221
|
+
}
|
|
1222
|
+
|
|
650
1223
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
651
1224
|
return fvec_L1_ref(x, y, d);
|
|
652
1225
|
}
|
|
@@ -696,6 +1269,27 @@ void fvec_L2sqr_ny(
|
|
|
696
1269
|
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
697
1270
|
}
|
|
698
1271
|
|
|
1272
|
+
size_t fvec_L2sqr_ny_nearest(
|
|
1273
|
+
float* distances_tmp_buffer,
|
|
1274
|
+
const float* x,
|
|
1275
|
+
const float* y,
|
|
1276
|
+
size_t d,
|
|
1277
|
+
size_t ny) {
|
|
1278
|
+
return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny);
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
1282
|
+
float* distances_tmp_buffer,
|
|
1283
|
+
const float* x,
|
|
1284
|
+
const float* y,
|
|
1285
|
+
const float* y_sqlen,
|
|
1286
|
+
size_t d,
|
|
1287
|
+
size_t d_offset,
|
|
1288
|
+
size_t ny) {
|
|
1289
|
+
return fvec_L2sqr_ny_nearest_y_transposed_ref(
|
|
1290
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
1291
|
+
}
|
|
1292
|
+
|
|
699
1293
|
void fvec_inner_products_ny(
|
|
700
1294
|
float* dis,
|
|
701
1295
|
const float* x,
|
|
@@ -721,6 +1315,61 @@ static inline void fvec_madd_ref(
|
|
|
721
1315
|
c[i] = a[i] + bf * b[i];
|
|
722
1316
|
}
|
|
723
1317
|
|
|
1318
|
+
#ifdef __AVX2__
|
|
1319
|
+
static inline void fvec_madd_avx2(
|
|
1320
|
+
const size_t n,
|
|
1321
|
+
const float* __restrict a,
|
|
1322
|
+
const float bf,
|
|
1323
|
+
const float* __restrict b,
|
|
1324
|
+
float* __restrict c) {
|
|
1325
|
+
//
|
|
1326
|
+
const size_t n8 = n / 8;
|
|
1327
|
+
const size_t n_for_masking = n % 8;
|
|
1328
|
+
|
|
1329
|
+
const __m256 bfmm = _mm256_set1_ps(bf);
|
|
1330
|
+
|
|
1331
|
+
size_t idx = 0;
|
|
1332
|
+
for (idx = 0; idx < n8 * 8; idx += 8) {
|
|
1333
|
+
const __m256 ax = _mm256_loadu_ps(a + idx);
|
|
1334
|
+
const __m256 bx = _mm256_loadu_ps(b + idx);
|
|
1335
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
1336
|
+
_mm256_storeu_ps(c + idx, abmul);
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1339
|
+
if (n_for_masking > 0) {
|
|
1340
|
+
__m256i mask;
|
|
1341
|
+
switch (n_for_masking) {
|
|
1342
|
+
case 1:
|
|
1343
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
|
|
1344
|
+
break;
|
|
1345
|
+
case 2:
|
|
1346
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
|
|
1347
|
+
break;
|
|
1348
|
+
case 3:
|
|
1349
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
|
|
1350
|
+
break;
|
|
1351
|
+
case 4:
|
|
1352
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
|
|
1353
|
+
break;
|
|
1354
|
+
case 5:
|
|
1355
|
+
mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
|
|
1356
|
+
break;
|
|
1357
|
+
case 6:
|
|
1358
|
+
mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
|
|
1359
|
+
break;
|
|
1360
|
+
case 7:
|
|
1361
|
+
mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
|
|
1362
|
+
break;
|
|
1363
|
+
}
|
|
1364
|
+
|
|
1365
|
+
const __m256 ax = _mm256_maskload_ps(a + idx, mask);
|
|
1366
|
+
const __m256 bx = _mm256_maskload_ps(b + idx, mask);
|
|
1367
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
1368
|
+
_mm256_maskstore_ps(c + idx, mask, abmul);
|
|
1369
|
+
}
|
|
1370
|
+
}
|
|
1371
|
+
#endif
|
|
1372
|
+
|
|
724
1373
|
#ifdef __SSE3__
|
|
725
1374
|
|
|
726
1375
|
static inline void fvec_madd_sse(
|
|
@@ -744,10 +1393,30 @@ static inline void fvec_madd_sse(
|
|
|
744
1393
|
}
|
|
745
1394
|
|
|
746
1395
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1396
|
+
#ifdef __AVX2__
|
|
1397
|
+
fvec_madd_avx2(n, a, bf, b, c);
|
|
1398
|
+
#else
|
|
747
1399
|
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
748
1400
|
fvec_madd_sse(n, a, bf, b, c);
|
|
749
1401
|
else
|
|
750
1402
|
fvec_madd_ref(n, a, bf, b, c);
|
|
1403
|
+
#endif
|
|
1404
|
+
}
|
|
1405
|
+
|
|
1406
|
+
#elif defined(__aarch64__)
|
|
1407
|
+
|
|
1408
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1409
|
+
const size_t n_simd = n - (n & 3);
|
|
1410
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
1411
|
+
size_t i;
|
|
1412
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
1413
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
1414
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
1415
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
1416
|
+
vst1q_f32(c + i, ci);
|
|
1417
|
+
}
|
|
1418
|
+
for (; i < n; ++i)
|
|
1419
|
+
c[i] = a[i] + bf * b[i];
|
|
751
1420
|
}
|
|
752
1421
|
|
|
753
1422
|
#else
|
|
@@ -843,6 +1512,57 @@ int fvec_madd_and_argmin(
|
|
|
843
1512
|
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
844
1513
|
}
|
|
845
1514
|
|
|
1515
|
+
#elif defined(__aarch64__)
|
|
1516
|
+
|
|
1517
|
+
int fvec_madd_and_argmin(
|
|
1518
|
+
size_t n,
|
|
1519
|
+
const float* a,
|
|
1520
|
+
float bf,
|
|
1521
|
+
const float* b,
|
|
1522
|
+
float* c) {
|
|
1523
|
+
float32x4_t vminv = vdupq_n_f32(1e20);
|
|
1524
|
+
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
|
|
1525
|
+
size_t i;
|
|
1526
|
+
{
|
|
1527
|
+
const size_t n_simd = n - (n & 3);
|
|
1528
|
+
const uint32_t iota[] = {0, 1, 2, 3};
|
|
1529
|
+
uint32x4_t iv = vld1q_u32(iota);
|
|
1530
|
+
const uint32x4_t incv = vdupq_n_u32(4);
|
|
1531
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
1532
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
1533
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
1534
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
1535
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
1536
|
+
vst1q_f32(c + i, ci);
|
|
1537
|
+
const uint32x4_t less_than = vcltq_f32(ci, vminv);
|
|
1538
|
+
vminv = vminq_f32(ci, vminv);
|
|
1539
|
+
iminv = vorrq_u32(
|
|
1540
|
+
vandq_u32(less_than, iv),
|
|
1541
|
+
vandq_u32(vmvnq_u32(less_than), iminv));
|
|
1542
|
+
iv = vaddq_u32(iv, incv);
|
|
1543
|
+
}
|
|
1544
|
+
}
|
|
1545
|
+
float vmin = vminvq_f32(vminv);
|
|
1546
|
+
uint32_t imin;
|
|
1547
|
+
{
|
|
1548
|
+
const float32x4_t vminy = vdupq_n_f32(vmin);
|
|
1549
|
+
const uint32x4_t equals = vceqq_f32(vminv, vminy);
|
|
1550
|
+
imin = vminvq_u32(vorrq_u32(
|
|
1551
|
+
vandq_u32(equals, iminv),
|
|
1552
|
+
vandq_u32(
|
|
1553
|
+
vmvnq_u32(equals),
|
|
1554
|
+
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
|
|
1555
|
+
}
|
|
1556
|
+
for (; i < n; ++i) {
|
|
1557
|
+
c[i] = a[i] + bf * b[i];
|
|
1558
|
+
if (c[i] < vmin) {
|
|
1559
|
+
vmin = c[i];
|
|
1560
|
+
imin = static_cast<uint32_t>(i);
|
|
1561
|
+
}
|
|
1562
|
+
}
|
|
1563
|
+
return static_cast<int>(imin);
|
|
1564
|
+
}
|
|
1565
|
+
|
|
846
1566
|
#else
|
|
847
1567
|
|
|
848
1568
|
int fvec_madd_and_argmin(
|