faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -23,7 +23,9 @@
|
|
|
23
23
|
#include <immintrin.h>
|
|
24
24
|
#endif
|
|
25
25
|
|
|
26
|
-
#
|
|
26
|
+
#if defined(__AVX512F__)
|
|
27
|
+
#include <faiss/utils/transpose/transpose-avx512-inl.h>
|
|
28
|
+
#elif defined(__AVX2__)
|
|
27
29
|
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
|
28
30
|
#endif
|
|
29
31
|
|
|
@@ -223,6 +225,76 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
|
223
225
|
}
|
|
224
226
|
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
225
227
|
|
|
228
|
+
/// Special version of inner product that computes 4 distances
|
|
229
|
+
/// between x and yi
|
|
230
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
231
|
+
void fvec_inner_product_batch_4(
|
|
232
|
+
const float* __restrict x,
|
|
233
|
+
const float* __restrict y0,
|
|
234
|
+
const float* __restrict y1,
|
|
235
|
+
const float* __restrict y2,
|
|
236
|
+
const float* __restrict y3,
|
|
237
|
+
const size_t d,
|
|
238
|
+
float& dis0,
|
|
239
|
+
float& dis1,
|
|
240
|
+
float& dis2,
|
|
241
|
+
float& dis3) {
|
|
242
|
+
float d0 = 0;
|
|
243
|
+
float d1 = 0;
|
|
244
|
+
float d2 = 0;
|
|
245
|
+
float d3 = 0;
|
|
246
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
247
|
+
for (size_t i = 0; i < d; ++i) {
|
|
248
|
+
d0 += x[i] * y0[i];
|
|
249
|
+
d1 += x[i] * y1[i];
|
|
250
|
+
d2 += x[i] * y2[i];
|
|
251
|
+
d3 += x[i] * y3[i];
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
dis0 = d0;
|
|
255
|
+
dis1 = d1;
|
|
256
|
+
dis2 = d2;
|
|
257
|
+
dis3 = d3;
|
|
258
|
+
}
|
|
259
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
260
|
+
|
|
261
|
+
/// Special version of L2sqr that computes 4 distances
|
|
262
|
+
/// between x and yi, which is performance oriented.
|
|
263
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
264
|
+
void fvec_L2sqr_batch_4(
|
|
265
|
+
const float* x,
|
|
266
|
+
const float* y0,
|
|
267
|
+
const float* y1,
|
|
268
|
+
const float* y2,
|
|
269
|
+
const float* y3,
|
|
270
|
+
const size_t d,
|
|
271
|
+
float& dis0,
|
|
272
|
+
float& dis1,
|
|
273
|
+
float& dis2,
|
|
274
|
+
float& dis3) {
|
|
275
|
+
float d0 = 0;
|
|
276
|
+
float d1 = 0;
|
|
277
|
+
float d2 = 0;
|
|
278
|
+
float d3 = 0;
|
|
279
|
+
FAISS_PRAGMA_IMPRECISE_LOOP
|
|
280
|
+
for (size_t i = 0; i < d; ++i) {
|
|
281
|
+
const float q0 = x[i] - y0[i];
|
|
282
|
+
const float q1 = x[i] - y1[i];
|
|
283
|
+
const float q2 = x[i] - y2[i];
|
|
284
|
+
const float q3 = x[i] - y3[i];
|
|
285
|
+
d0 += q0 * q0;
|
|
286
|
+
d1 += q1 * q1;
|
|
287
|
+
d2 += q2 * q2;
|
|
288
|
+
d3 += q3 * q3;
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
dis0 = d0;
|
|
292
|
+
dis1 = d1;
|
|
293
|
+
dis2 = d2;
|
|
294
|
+
dis3 = d3;
|
|
295
|
+
}
|
|
296
|
+
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
297
|
+
|
|
226
298
|
/*********************************************************
|
|
227
299
|
* SSE and AVX implementations
|
|
228
300
|
*/
|
|
@@ -236,8 +308,10 @@ static inline __m128 masked_read(int d, const float* x) {
|
|
|
236
308
|
switch (d) {
|
|
237
309
|
case 3:
|
|
238
310
|
buf[2] = x[2];
|
|
311
|
+
[[fallthrough]];
|
|
239
312
|
case 2:
|
|
240
313
|
buf[1] = x[1];
|
|
314
|
+
[[fallthrough]];
|
|
241
315
|
case 1:
|
|
242
316
|
buf[0] = x[0];
|
|
243
317
|
}
|
|
@@ -247,6 +321,41 @@ static inline __m128 masked_read(int d, const float* x) {
|
|
|
247
321
|
|
|
248
322
|
namespace {
|
|
249
323
|
|
|
324
|
+
/// helper function
|
|
325
|
+
inline float horizontal_sum(const __m128 v) {
|
|
326
|
+
// say, v is [x0, x1, x2, x3]
|
|
327
|
+
|
|
328
|
+
// v0 is [x2, x3, ..., ...]
|
|
329
|
+
const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
|
|
330
|
+
// v1 is [x0 + x2, x1 + x3, ..., ...]
|
|
331
|
+
const __m128 v1 = _mm_add_ps(v, v0);
|
|
332
|
+
// v2 is [x1 + x3, ..., .... ,...]
|
|
333
|
+
__m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
|
|
334
|
+
// v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
|
|
335
|
+
const __m128 v3 = _mm_add_ps(v1, v2);
|
|
336
|
+
// return v3[0]
|
|
337
|
+
return _mm_cvtss_f32(v3);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
#ifdef __AVX2__
|
|
341
|
+
/// helper function for AVX2
|
|
342
|
+
inline float horizontal_sum(const __m256 v) {
|
|
343
|
+
// add high and low parts
|
|
344
|
+
const __m128 v0 =
|
|
345
|
+
_mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
|
|
346
|
+
// perform horizontal sum on v0
|
|
347
|
+
return horizontal_sum(v0);
|
|
348
|
+
}
|
|
349
|
+
#endif
|
|
350
|
+
|
|
351
|
+
#ifdef __AVX512F__
|
|
352
|
+
/// helper function for AVX512
|
|
353
|
+
inline float horizontal_sum(const __m512 v) {
|
|
354
|
+
// performs better than adding the high and low parts
|
|
355
|
+
return _mm512_reduce_add_ps(v);
|
|
356
|
+
}
|
|
357
|
+
#endif
|
|
358
|
+
|
|
250
359
|
/// Function that does a component-wise operation between x and y
|
|
251
360
|
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
|
252
361
|
/// functions below
|
|
@@ -260,6 +369,20 @@ struct ElementOpL2 {
|
|
|
260
369
|
__m128 tmp = _mm_sub_ps(x, y);
|
|
261
370
|
return _mm_mul_ps(tmp, tmp);
|
|
262
371
|
}
|
|
372
|
+
|
|
373
|
+
#ifdef __AVX2__
|
|
374
|
+
static __m256 op(__m256 x, __m256 y) {
|
|
375
|
+
__m256 tmp = _mm256_sub_ps(x, y);
|
|
376
|
+
return _mm256_mul_ps(tmp, tmp);
|
|
377
|
+
}
|
|
378
|
+
#endif
|
|
379
|
+
|
|
380
|
+
#ifdef __AVX512F__
|
|
381
|
+
static __m512 op(__m512 x, __m512 y) {
|
|
382
|
+
__m512 tmp = _mm512_sub_ps(x, y);
|
|
383
|
+
return _mm512_mul_ps(tmp, tmp);
|
|
384
|
+
}
|
|
385
|
+
#endif
|
|
263
386
|
};
|
|
264
387
|
|
|
265
388
|
/// Function that does a component-wise operation between x and y
|
|
@@ -272,6 +395,18 @@ struct ElementOpIP {
|
|
|
272
395
|
static __m128 op(__m128 x, __m128 y) {
|
|
273
396
|
return _mm_mul_ps(x, y);
|
|
274
397
|
}
|
|
398
|
+
|
|
399
|
+
#ifdef __AVX2__
|
|
400
|
+
static __m256 op(__m256 x, __m256 y) {
|
|
401
|
+
return _mm256_mul_ps(x, y);
|
|
402
|
+
}
|
|
403
|
+
#endif
|
|
404
|
+
|
|
405
|
+
#ifdef __AVX512F__
|
|
406
|
+
static __m512 op(__m512 x, __m512 y) {
|
|
407
|
+
return _mm512_mul_ps(x, y);
|
|
408
|
+
}
|
|
409
|
+
#endif
|
|
275
410
|
};
|
|
276
411
|
|
|
277
412
|
template <class ElementOp>
|
|
@@ -314,26 +449,133 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
|
314
449
|
}
|
|
315
450
|
}
|
|
316
451
|
|
|
317
|
-
|
|
318
|
-
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
319
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
452
|
+
#if defined(__AVX512F__)
|
|
320
453
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
454
|
+
template <>
|
|
455
|
+
void fvec_op_ny_D2<ElementOpIP>(
|
|
456
|
+
float* dis,
|
|
457
|
+
const float* x,
|
|
458
|
+
const float* y,
|
|
459
|
+
size_t ny) {
|
|
460
|
+
const size_t ny16 = ny / 16;
|
|
461
|
+
size_t i = 0;
|
|
462
|
+
|
|
463
|
+
if (ny16 > 0) {
|
|
464
|
+
// process 16 D2-vectors per loop.
|
|
465
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
466
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
467
|
+
|
|
468
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
469
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
470
|
+
|
|
471
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
472
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
473
|
+
|
|
474
|
+
// load 16x2 matrix and transpose it in registers.
|
|
475
|
+
// the typical bottleneck is memory access, so
|
|
476
|
+
// let's trade instructions for the bandwidth.
|
|
477
|
+
|
|
478
|
+
__m512 v0;
|
|
479
|
+
__m512 v1;
|
|
480
|
+
|
|
481
|
+
transpose_16x2(
|
|
482
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
483
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
484
|
+
v0,
|
|
485
|
+
v1);
|
|
486
|
+
|
|
487
|
+
// compute distances (dot product)
|
|
488
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
489
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
490
|
+
|
|
491
|
+
// store
|
|
492
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
493
|
+
|
|
494
|
+
y += 32; // move to the next set of 16x2 elements
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
if (i < ny) {
|
|
499
|
+
// process leftovers
|
|
500
|
+
float x0 = x[0];
|
|
501
|
+
float x1 = x[1];
|
|
502
|
+
|
|
503
|
+
for (; i < ny; i++) {
|
|
504
|
+
float distance = x0 * y[0] + x1 * y[1];
|
|
505
|
+
y += 2;
|
|
506
|
+
dis[i] = distance;
|
|
507
|
+
}
|
|
327
508
|
}
|
|
328
509
|
}
|
|
329
510
|
|
|
330
|
-
|
|
511
|
+
template <>
|
|
512
|
+
void fvec_op_ny_D2<ElementOpL2>(
|
|
513
|
+
float* dis,
|
|
514
|
+
const float* x,
|
|
515
|
+
const float* y,
|
|
516
|
+
size_t ny) {
|
|
517
|
+
const size_t ny16 = ny / 16;
|
|
518
|
+
size_t i = 0;
|
|
519
|
+
|
|
520
|
+
if (ny16 > 0) {
|
|
521
|
+
// process 16 D2-vectors per loop.
|
|
522
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
523
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
524
|
+
|
|
525
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
526
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
527
|
+
|
|
528
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
529
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
530
|
+
|
|
531
|
+
// load 16x2 matrix and transpose it in registers.
|
|
532
|
+
// the typical bottleneck is memory access, so
|
|
533
|
+
// let's trade instructions for the bandwidth.
|
|
534
|
+
|
|
535
|
+
__m512 v0;
|
|
536
|
+
__m512 v1;
|
|
537
|
+
|
|
538
|
+
transpose_16x2(
|
|
539
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
540
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
541
|
+
v0,
|
|
542
|
+
v1);
|
|
543
|
+
|
|
544
|
+
// compute differences
|
|
545
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
546
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
547
|
+
|
|
548
|
+
// compute squares of differences
|
|
549
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
550
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
551
|
+
|
|
552
|
+
// store
|
|
553
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
554
|
+
|
|
555
|
+
y += 32; // move to the next set of 16x2 elements
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
if (i < ny) {
|
|
560
|
+
// process leftovers
|
|
561
|
+
float x0 = x[0];
|
|
562
|
+
float x1 = x[1];
|
|
563
|
+
|
|
564
|
+
for (; i < ny; i++) {
|
|
565
|
+
float sub0 = x0 - y[0];
|
|
566
|
+
float sub1 = x1 - y[1];
|
|
567
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
568
|
+
|
|
569
|
+
y += 2;
|
|
570
|
+
dis[i] = distance;
|
|
571
|
+
}
|
|
572
|
+
}
|
|
573
|
+
}
|
|
331
574
|
|
|
332
|
-
|
|
333
|
-
// Todo: implement fvec_op_ny_Dxxx in the same way.
|
|
575
|
+
#elif defined(__AVX2__)
|
|
334
576
|
|
|
335
577
|
template <>
|
|
336
|
-
void
|
|
578
|
+
void fvec_op_ny_D2<ElementOpIP>(
|
|
337
579
|
float* dis,
|
|
338
580
|
const float* x,
|
|
339
581
|
const float* y,
|
|
@@ -342,68 +584,55 @@ void fvec_op_ny_D4<ElementOpIP>(
|
|
|
342
584
|
size_t i = 0;
|
|
343
585
|
|
|
344
586
|
if (ny8 > 0) {
|
|
345
|
-
// process 8
|
|
346
|
-
_mm_prefetch(y,
|
|
347
|
-
_mm_prefetch(y + 16,
|
|
587
|
+
// process 8 D2-vectors per loop.
|
|
588
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
589
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
348
590
|
|
|
349
|
-
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
350
591
|
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
351
|
-
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
352
592
|
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
353
|
-
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
354
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
355
|
-
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
356
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
357
593
|
|
|
358
594
|
for (i = 0; i < ny8 * 8; i += 8) {
|
|
359
|
-
|
|
595
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
596
|
+
|
|
597
|
+
// load 8x2 matrix and transpose it in registers.
|
|
360
598
|
// the typical bottleneck is memory access, so
|
|
361
599
|
// let's trade instructions for the bandwidth.
|
|
362
600
|
|
|
363
601
|
__m256 v0;
|
|
364
602
|
__m256 v1;
|
|
365
|
-
__m256 v2;
|
|
366
|
-
__m256 v3;
|
|
367
603
|
|
|
368
|
-
|
|
604
|
+
transpose_8x2(
|
|
369
605
|
_mm256_loadu_ps(y + 0 * 8),
|
|
370
606
|
_mm256_loadu_ps(y + 1 * 8),
|
|
371
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
372
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
373
607
|
v0,
|
|
374
|
-
v1
|
|
375
|
-
v2,
|
|
376
|
-
v3);
|
|
608
|
+
v1);
|
|
377
609
|
|
|
378
610
|
// compute distances
|
|
379
611
|
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
380
612
|
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
381
|
-
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
382
|
-
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
383
613
|
|
|
384
614
|
// store
|
|
385
615
|
_mm256_storeu_ps(dis + i, distances);
|
|
386
616
|
|
|
387
|
-
y +=
|
|
617
|
+
y += 16;
|
|
388
618
|
}
|
|
389
619
|
}
|
|
390
620
|
|
|
391
621
|
if (i < ny) {
|
|
392
622
|
// process leftovers
|
|
393
|
-
|
|
623
|
+
float x0 = x[0];
|
|
624
|
+
float x1 = x[1];
|
|
394
625
|
|
|
395
626
|
for (; i < ny; i++) {
|
|
396
|
-
|
|
397
|
-
y +=
|
|
398
|
-
|
|
399
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
400
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
627
|
+
float distance = x0 * y[0] + x1 * y[1];
|
|
628
|
+
y += 2;
|
|
629
|
+
dis[i] = distance;
|
|
401
630
|
}
|
|
402
631
|
}
|
|
403
632
|
}
|
|
404
633
|
|
|
405
634
|
template <>
|
|
406
|
-
void
|
|
635
|
+
void fvec_op_ny_D2<ElementOpL2>(
|
|
407
636
|
float* dis,
|
|
408
637
|
const float* x,
|
|
409
638
|
const float* y,
|
|
@@ -412,68 +641,56 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
|
412
641
|
size_t i = 0;
|
|
413
642
|
|
|
414
643
|
if (ny8 > 0) {
|
|
415
|
-
// process 8
|
|
416
|
-
_mm_prefetch(y,
|
|
417
|
-
_mm_prefetch(y + 16,
|
|
644
|
+
// process 8 D2-vectors per loop.
|
|
645
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
646
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
418
647
|
|
|
419
|
-
// m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
|
|
420
648
|
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
421
|
-
// m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
|
|
422
649
|
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
423
|
-
// m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
|
|
424
|
-
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
425
|
-
// m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
|
|
426
|
-
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
427
650
|
|
|
428
651
|
for (i = 0; i < ny8 * 8; i += 8) {
|
|
429
|
-
|
|
652
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
653
|
+
|
|
654
|
+
// load 8x2 matrix and transpose it in registers.
|
|
430
655
|
// the typical bottleneck is memory access, so
|
|
431
656
|
// let's trade instructions for the bandwidth.
|
|
432
657
|
|
|
433
658
|
__m256 v0;
|
|
434
659
|
__m256 v1;
|
|
435
|
-
__m256 v2;
|
|
436
|
-
__m256 v3;
|
|
437
660
|
|
|
438
|
-
|
|
661
|
+
transpose_8x2(
|
|
439
662
|
_mm256_loadu_ps(y + 0 * 8),
|
|
440
663
|
_mm256_loadu_ps(y + 1 * 8),
|
|
441
|
-
_mm256_loadu_ps(y + 2 * 8),
|
|
442
|
-
_mm256_loadu_ps(y + 3 * 8),
|
|
443
664
|
v0,
|
|
444
|
-
v1
|
|
445
|
-
v2,
|
|
446
|
-
v3);
|
|
665
|
+
v1);
|
|
447
666
|
|
|
448
667
|
// compute differences
|
|
449
668
|
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
450
669
|
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
451
|
-
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
452
|
-
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
453
670
|
|
|
454
671
|
// compute squares of differences
|
|
455
672
|
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
456
673
|
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
457
|
-
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
458
|
-
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
459
674
|
|
|
460
675
|
// store
|
|
461
676
|
_mm256_storeu_ps(dis + i, distances);
|
|
462
677
|
|
|
463
|
-
y +=
|
|
678
|
+
y += 16;
|
|
464
679
|
}
|
|
465
680
|
}
|
|
466
681
|
|
|
467
682
|
if (i < ny) {
|
|
468
683
|
// process leftovers
|
|
469
|
-
|
|
684
|
+
float x0 = x[0];
|
|
685
|
+
float x1 = x[1];
|
|
470
686
|
|
|
471
687
|
for (; i < ny; i++) {
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
688
|
+
float sub0 = x0 - y[0];
|
|
689
|
+
float sub1 = x1 - y[1];
|
|
690
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
691
|
+
|
|
692
|
+
y += 2;
|
|
693
|
+
dis[i] = distance;
|
|
477
694
|
}
|
|
478
695
|
}
|
|
479
696
|
}
|
|
@@ -481,77 +698,698 @@ void fvec_op_ny_D4<ElementOpL2>(
|
|
|
481
698
|
#endif
|
|
482
699
|
|
|
483
700
|
template <class ElementOp>
|
|
484
|
-
void
|
|
485
|
-
__m128 x0 = _mm_loadu_ps(x);
|
|
486
|
-
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
487
|
-
|
|
488
|
-
for (size_t i = 0; i < ny; i++) {
|
|
489
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
490
|
-
y += 4;
|
|
491
|
-
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
492
|
-
y += 4;
|
|
493
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
494
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
495
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
496
|
-
}
|
|
497
|
-
}
|
|
498
|
-
|
|
499
|
-
template <class ElementOp>
|
|
500
|
-
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
|
|
701
|
+
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
501
702
|
__m128 x0 = _mm_loadu_ps(x);
|
|
502
|
-
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
503
|
-
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
504
703
|
|
|
505
704
|
for (size_t i = 0; i < ny; i++) {
|
|
506
705
|
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
507
706
|
y += 4;
|
|
508
|
-
|
|
509
|
-
y += 4;
|
|
510
|
-
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
|
|
511
|
-
y += 4;
|
|
512
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
513
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
514
|
-
dis[i] = _mm_cvtss_f32(accu);
|
|
707
|
+
dis[i] = horizontal_sum(accu);
|
|
515
708
|
}
|
|
516
709
|
}
|
|
517
710
|
|
|
518
|
-
|
|
711
|
+
#if defined(__AVX512F__)
|
|
519
712
|
|
|
520
|
-
|
|
713
|
+
template <>
|
|
714
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
|
521
715
|
float* dis,
|
|
522
716
|
const float* x,
|
|
523
717
|
const float* y,
|
|
524
|
-
size_t d,
|
|
525
718
|
size_t ny) {
|
|
526
|
-
|
|
719
|
+
const size_t ny16 = ny / 16;
|
|
720
|
+
size_t i = 0;
|
|
527
721
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
722
|
+
if (ny16 > 0) {
|
|
723
|
+
// process 16 D4-vectors per loop.
|
|
724
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
725
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
726
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
727
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
532
728
|
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
729
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
730
|
+
// load 16x4 matrix and transpose it in registers.
|
|
731
|
+
// the typical bottleneck is memory access, so
|
|
732
|
+
// let's trade instructions for the bandwidth.
|
|
733
|
+
|
|
734
|
+
__m512 v0;
|
|
735
|
+
__m512 v1;
|
|
736
|
+
__m512 v2;
|
|
737
|
+
__m512 v3;
|
|
738
|
+
|
|
739
|
+
transpose_16x4(
|
|
740
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
741
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
742
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
743
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
744
|
+
v0,
|
|
745
|
+
v1,
|
|
746
|
+
v2,
|
|
747
|
+
v3);
|
|
748
|
+
|
|
749
|
+
// compute distances
|
|
750
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
751
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
752
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
|
753
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
|
754
|
+
|
|
755
|
+
// store
|
|
756
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
757
|
+
|
|
758
|
+
y += 64; // move to the next set of 16x4 elements
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
if (i < ny) {
|
|
763
|
+
// process leftovers
|
|
764
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
765
|
+
|
|
766
|
+
for (; i < ny; i++) {
|
|
767
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
768
|
+
y += 4;
|
|
769
|
+
dis[i] = horizontal_sum(accu);
|
|
770
|
+
}
|
|
542
771
|
}
|
|
543
|
-
#undef DISPATCH
|
|
544
772
|
}
|
|
545
773
|
|
|
546
|
-
|
|
774
|
+
template <>
|
|
775
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
|
547
776
|
float* dis,
|
|
548
777
|
const float* x,
|
|
549
778
|
const float* y,
|
|
550
|
-
size_t d,
|
|
551
779
|
size_t ny) {
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
780
|
+
const size_t ny16 = ny / 16;
|
|
781
|
+
size_t i = 0;
|
|
782
|
+
|
|
783
|
+
if (ny16 > 0) {
|
|
784
|
+
// process 16 D4-vectors per loop.
|
|
785
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
786
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
787
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
788
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
789
|
+
|
|
790
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
791
|
+
// load 16x4 matrix and transpose it in registers.
|
|
792
|
+
// the typical bottleneck is memory access, so
|
|
793
|
+
// let's trade instructions for the bandwidth.
|
|
794
|
+
|
|
795
|
+
__m512 v0;
|
|
796
|
+
__m512 v1;
|
|
797
|
+
__m512 v2;
|
|
798
|
+
__m512 v3;
|
|
799
|
+
|
|
800
|
+
transpose_16x4(
|
|
801
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
802
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
803
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
804
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
805
|
+
v0,
|
|
806
|
+
v1,
|
|
807
|
+
v2,
|
|
808
|
+
v3);
|
|
809
|
+
|
|
810
|
+
// compute differences
|
|
811
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
812
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
813
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
814
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
815
|
+
|
|
816
|
+
// compute squares of differences
|
|
817
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
818
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
819
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
820
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
821
|
+
|
|
822
|
+
// store
|
|
823
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
824
|
+
|
|
825
|
+
y += 64; // move to the next set of 16x4 elements
|
|
826
|
+
}
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
if (i < ny) {
|
|
830
|
+
// process leftovers
|
|
831
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
832
|
+
|
|
833
|
+
for (; i < ny; i++) {
|
|
834
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
835
|
+
y += 4;
|
|
836
|
+
dis[i] = horizontal_sum(accu);
|
|
837
|
+
}
|
|
838
|
+
}
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
#elif defined(__AVX2__)
|
|
842
|
+
|
|
843
|
+
template <>
|
|
844
|
+
void fvec_op_ny_D4<ElementOpIP>(
|
|
845
|
+
float* dis,
|
|
846
|
+
const float* x,
|
|
847
|
+
const float* y,
|
|
848
|
+
size_t ny) {
|
|
849
|
+
const size_t ny8 = ny / 8;
|
|
850
|
+
size_t i = 0;
|
|
851
|
+
|
|
852
|
+
if (ny8 > 0) {
|
|
853
|
+
// process 8 D4-vectors per loop.
|
|
854
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
855
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
856
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
857
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
858
|
+
|
|
859
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
860
|
+
// load 8x4 matrix and transpose it in registers.
|
|
861
|
+
// the typical bottleneck is memory access, so
|
|
862
|
+
// let's trade instructions for the bandwidth.
|
|
863
|
+
|
|
864
|
+
__m256 v0;
|
|
865
|
+
__m256 v1;
|
|
866
|
+
__m256 v2;
|
|
867
|
+
__m256 v3;
|
|
868
|
+
|
|
869
|
+
transpose_8x4(
|
|
870
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
871
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
872
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
873
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
874
|
+
v0,
|
|
875
|
+
v1,
|
|
876
|
+
v2,
|
|
877
|
+
v3);
|
|
878
|
+
|
|
879
|
+
// compute distances
|
|
880
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
881
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
882
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
883
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
884
|
+
|
|
885
|
+
// store
|
|
886
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
887
|
+
|
|
888
|
+
y += 32;
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
if (i < ny) {
|
|
893
|
+
// process leftovers
|
|
894
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
895
|
+
|
|
896
|
+
for (; i < ny; i++) {
|
|
897
|
+
__m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
898
|
+
y += 4;
|
|
899
|
+
dis[i] = horizontal_sum(accu);
|
|
900
|
+
}
|
|
901
|
+
}
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
template <>
|
|
905
|
+
void fvec_op_ny_D4<ElementOpL2>(
|
|
906
|
+
float* dis,
|
|
907
|
+
const float* x,
|
|
908
|
+
const float* y,
|
|
909
|
+
size_t ny) {
|
|
910
|
+
const size_t ny8 = ny / 8;
|
|
911
|
+
size_t i = 0;
|
|
912
|
+
|
|
913
|
+
if (ny8 > 0) {
|
|
914
|
+
// process 8 D4-vectors per loop.
|
|
915
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
916
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
917
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
918
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
919
|
+
|
|
920
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
921
|
+
// load 8x4 matrix and transpose it in registers.
|
|
922
|
+
// the typical bottleneck is memory access, so
|
|
923
|
+
// let's trade instructions for the bandwidth.
|
|
924
|
+
|
|
925
|
+
__m256 v0;
|
|
926
|
+
__m256 v1;
|
|
927
|
+
__m256 v2;
|
|
928
|
+
__m256 v3;
|
|
929
|
+
|
|
930
|
+
transpose_8x4(
|
|
931
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
932
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
933
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
934
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
935
|
+
v0,
|
|
936
|
+
v1,
|
|
937
|
+
v2,
|
|
938
|
+
v3);
|
|
939
|
+
|
|
940
|
+
// compute differences
|
|
941
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
942
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
943
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
944
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
945
|
+
|
|
946
|
+
// compute squares of differences
|
|
947
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
948
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
949
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
950
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
951
|
+
|
|
952
|
+
// store
|
|
953
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
954
|
+
|
|
955
|
+
y += 32;
|
|
956
|
+
}
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
if (i < ny) {
|
|
960
|
+
// process leftovers
|
|
961
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
962
|
+
|
|
963
|
+
for (; i < ny; i++) {
|
|
964
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
965
|
+
y += 4;
|
|
966
|
+
dis[i] = horizontal_sum(accu);
|
|
967
|
+
}
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
#endif
|
|
972
|
+
|
|
973
|
+
template <class ElementOp>
|
|
974
|
+
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
975
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
976
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
977
|
+
|
|
978
|
+
for (size_t i = 0; i < ny; i++) {
|
|
979
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
980
|
+
y += 4;
|
|
981
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
982
|
+
y += 4;
|
|
983
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
984
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
985
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
986
|
+
}
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
#if defined(__AVX512F__)
|
|
990
|
+
|
|
991
|
+
template <>
|
|
992
|
+
void fvec_op_ny_D8<ElementOpIP>(
|
|
993
|
+
float* dis,
|
|
994
|
+
const float* x,
|
|
995
|
+
const float* y,
|
|
996
|
+
size_t ny) {
|
|
997
|
+
const size_t ny16 = ny / 16;
|
|
998
|
+
size_t i = 0;
|
|
999
|
+
|
|
1000
|
+
if (ny16 > 0) {
|
|
1001
|
+
// process 16 D16-vectors per loop.
|
|
1002
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1003
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1004
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1005
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1006
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1007
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1008
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1009
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1010
|
+
|
|
1011
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
1012
|
+
// load 16x8 matrix and transpose it in registers.
|
|
1013
|
+
// the typical bottleneck is memory access, so
|
|
1014
|
+
// let's trade instructions for the bandwidth.
|
|
1015
|
+
|
|
1016
|
+
__m512 v0;
|
|
1017
|
+
__m512 v1;
|
|
1018
|
+
__m512 v2;
|
|
1019
|
+
__m512 v3;
|
|
1020
|
+
__m512 v4;
|
|
1021
|
+
__m512 v5;
|
|
1022
|
+
__m512 v6;
|
|
1023
|
+
__m512 v7;
|
|
1024
|
+
|
|
1025
|
+
transpose_16x8(
|
|
1026
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
1027
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
1028
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
1029
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
1030
|
+
_mm512_loadu_ps(y + 4 * 16),
|
|
1031
|
+
_mm512_loadu_ps(y + 5 * 16),
|
|
1032
|
+
_mm512_loadu_ps(y + 6 * 16),
|
|
1033
|
+
_mm512_loadu_ps(y + 7 * 16),
|
|
1034
|
+
v0,
|
|
1035
|
+
v1,
|
|
1036
|
+
v2,
|
|
1037
|
+
v3,
|
|
1038
|
+
v4,
|
|
1039
|
+
v5,
|
|
1040
|
+
v6,
|
|
1041
|
+
v7);
|
|
1042
|
+
|
|
1043
|
+
// compute distances
|
|
1044
|
+
__m512 distances = _mm512_mul_ps(m0, v0);
|
|
1045
|
+
distances = _mm512_fmadd_ps(m1, v1, distances);
|
|
1046
|
+
distances = _mm512_fmadd_ps(m2, v2, distances);
|
|
1047
|
+
distances = _mm512_fmadd_ps(m3, v3, distances);
|
|
1048
|
+
distances = _mm512_fmadd_ps(m4, v4, distances);
|
|
1049
|
+
distances = _mm512_fmadd_ps(m5, v5, distances);
|
|
1050
|
+
distances = _mm512_fmadd_ps(m6, v6, distances);
|
|
1051
|
+
distances = _mm512_fmadd_ps(m7, v7, distances);
|
|
1052
|
+
|
|
1053
|
+
// store
|
|
1054
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
1055
|
+
|
|
1056
|
+
y += 128; // 16 floats * 8 rows
|
|
1057
|
+
}
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
if (i < ny) {
|
|
1061
|
+
// process leftovers
|
|
1062
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1063
|
+
|
|
1064
|
+
for (; i < ny; i++) {
|
|
1065
|
+
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
|
1066
|
+
y += 8;
|
|
1067
|
+
dis[i] = horizontal_sum(accu);
|
|
1068
|
+
}
|
|
1069
|
+
}
|
|
1070
|
+
}
|
|
1071
|
+
|
|
1072
|
+
template <>
|
|
1073
|
+
void fvec_op_ny_D8<ElementOpL2>(
|
|
1074
|
+
float* dis,
|
|
1075
|
+
const float* x,
|
|
1076
|
+
const float* y,
|
|
1077
|
+
size_t ny) {
|
|
1078
|
+
const size_t ny16 = ny / 16;
|
|
1079
|
+
size_t i = 0;
|
|
1080
|
+
|
|
1081
|
+
if (ny16 > 0) {
|
|
1082
|
+
// process 16 D16-vectors per loop.
|
|
1083
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1084
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1085
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1086
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1087
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1088
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1089
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1090
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1091
|
+
|
|
1092
|
+
for (i = 0; i < ny16 * 16; i += 16) {
|
|
1093
|
+
// load 16x8 matrix and transpose it in registers.
|
|
1094
|
+
// the typical bottleneck is memory access, so
|
|
1095
|
+
// let's trade instructions for the bandwidth.
|
|
1096
|
+
|
|
1097
|
+
__m512 v0;
|
|
1098
|
+
__m512 v1;
|
|
1099
|
+
__m512 v2;
|
|
1100
|
+
__m512 v3;
|
|
1101
|
+
__m512 v4;
|
|
1102
|
+
__m512 v5;
|
|
1103
|
+
__m512 v6;
|
|
1104
|
+
__m512 v7;
|
|
1105
|
+
|
|
1106
|
+
transpose_16x8(
|
|
1107
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
1108
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
1109
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
1110
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
1111
|
+
_mm512_loadu_ps(y + 4 * 16),
|
|
1112
|
+
_mm512_loadu_ps(y + 5 * 16),
|
|
1113
|
+
_mm512_loadu_ps(y + 6 * 16),
|
|
1114
|
+
_mm512_loadu_ps(y + 7 * 16),
|
|
1115
|
+
v0,
|
|
1116
|
+
v1,
|
|
1117
|
+
v2,
|
|
1118
|
+
v3,
|
|
1119
|
+
v4,
|
|
1120
|
+
v5,
|
|
1121
|
+
v6,
|
|
1122
|
+
v7);
|
|
1123
|
+
|
|
1124
|
+
// compute differences
|
|
1125
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1126
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1127
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1128
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
1129
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
|
1130
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
|
1131
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
|
1132
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
|
1133
|
+
|
|
1134
|
+
// compute squares of differences
|
|
1135
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1136
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1137
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1138
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
1139
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
|
1140
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
|
1141
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
|
1142
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
|
1143
|
+
|
|
1144
|
+
// store
|
|
1145
|
+
_mm512_storeu_ps(dis + i, distances);
|
|
1146
|
+
|
|
1147
|
+
y += 128; // 16 floats * 8 rows
|
|
1148
|
+
}
|
|
1149
|
+
}
|
|
1150
|
+
|
|
1151
|
+
if (i < ny) {
|
|
1152
|
+
// process leftovers
|
|
1153
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1154
|
+
|
|
1155
|
+
for (; i < ny; i++) {
|
|
1156
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1157
|
+
y += 8;
|
|
1158
|
+
dis[i] = horizontal_sum(accu);
|
|
1159
|
+
}
|
|
1160
|
+
}
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
#elif defined(__AVX2__)
|
|
1164
|
+
|
|
1165
|
+
template <>
|
|
1166
|
+
void fvec_op_ny_D8<ElementOpIP>(
|
|
1167
|
+
float* dis,
|
|
1168
|
+
const float* x,
|
|
1169
|
+
const float* y,
|
|
1170
|
+
size_t ny) {
|
|
1171
|
+
const size_t ny8 = ny / 8;
|
|
1172
|
+
size_t i = 0;
|
|
1173
|
+
|
|
1174
|
+
if (ny8 > 0) {
|
|
1175
|
+
// process 8 D8-vectors per loop.
|
|
1176
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
1177
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
1178
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
1179
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
1180
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
1181
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
1182
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
1183
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
1184
|
+
|
|
1185
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
1186
|
+
// load 8x8 matrix and transpose it in registers.
|
|
1187
|
+
// the typical bottleneck is memory access, so
|
|
1188
|
+
// let's trade instructions for the bandwidth.
|
|
1189
|
+
|
|
1190
|
+
__m256 v0;
|
|
1191
|
+
__m256 v1;
|
|
1192
|
+
__m256 v2;
|
|
1193
|
+
__m256 v3;
|
|
1194
|
+
__m256 v4;
|
|
1195
|
+
__m256 v5;
|
|
1196
|
+
__m256 v6;
|
|
1197
|
+
__m256 v7;
|
|
1198
|
+
|
|
1199
|
+
transpose_8x8(
|
|
1200
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
1201
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
1202
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
1203
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
1204
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
1205
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
1206
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
1207
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
1208
|
+
v0,
|
|
1209
|
+
v1,
|
|
1210
|
+
v2,
|
|
1211
|
+
v3,
|
|
1212
|
+
v4,
|
|
1213
|
+
v5,
|
|
1214
|
+
v6,
|
|
1215
|
+
v7);
|
|
1216
|
+
|
|
1217
|
+
// compute distances
|
|
1218
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
1219
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
1220
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
1221
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
1222
|
+
distances = _mm256_fmadd_ps(m4, v4, distances);
|
|
1223
|
+
distances = _mm256_fmadd_ps(m5, v5, distances);
|
|
1224
|
+
distances = _mm256_fmadd_ps(m6, v6, distances);
|
|
1225
|
+
distances = _mm256_fmadd_ps(m7, v7, distances);
|
|
1226
|
+
|
|
1227
|
+
// store
|
|
1228
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
1229
|
+
|
|
1230
|
+
y += 64;
|
|
1231
|
+
}
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
if (i < ny) {
|
|
1235
|
+
// process leftovers
|
|
1236
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1237
|
+
|
|
1238
|
+
for (; i < ny; i++) {
|
|
1239
|
+
__m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
|
1240
|
+
y += 8;
|
|
1241
|
+
dis[i] = horizontal_sum(accu);
|
|
1242
|
+
}
|
|
1243
|
+
}
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
template <>
|
|
1247
|
+
void fvec_op_ny_D8<ElementOpL2>(
|
|
1248
|
+
float* dis,
|
|
1249
|
+
const float* x,
|
|
1250
|
+
const float* y,
|
|
1251
|
+
size_t ny) {
|
|
1252
|
+
const size_t ny8 = ny / 8;
|
|
1253
|
+
size_t i = 0;
|
|
1254
|
+
|
|
1255
|
+
if (ny8 > 0) {
|
|
1256
|
+
// process 8 D8-vectors per loop.
|
|
1257
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
1258
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
1259
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
1260
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
1261
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
1262
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
1263
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
1264
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
1265
|
+
|
|
1266
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
1267
|
+
// load 8x8 matrix and transpose it in registers.
|
|
1268
|
+
// the typical bottleneck is memory access, so
|
|
1269
|
+
// let's trade instructions for the bandwidth.
|
|
1270
|
+
|
|
1271
|
+
__m256 v0;
|
|
1272
|
+
__m256 v1;
|
|
1273
|
+
__m256 v2;
|
|
1274
|
+
__m256 v3;
|
|
1275
|
+
__m256 v4;
|
|
1276
|
+
__m256 v5;
|
|
1277
|
+
__m256 v6;
|
|
1278
|
+
__m256 v7;
|
|
1279
|
+
|
|
1280
|
+
transpose_8x8(
|
|
1281
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
1282
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
1283
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
1284
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
1285
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
1286
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
1287
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
1288
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
1289
|
+
v0,
|
|
1290
|
+
v1,
|
|
1291
|
+
v2,
|
|
1292
|
+
v3,
|
|
1293
|
+
v4,
|
|
1294
|
+
v5,
|
|
1295
|
+
v6,
|
|
1296
|
+
v7);
|
|
1297
|
+
|
|
1298
|
+
// compute differences
|
|
1299
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
1300
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
1301
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
1302
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
1303
|
+
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
1304
|
+
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
1305
|
+
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
1306
|
+
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
1307
|
+
|
|
1308
|
+
// compute squares of differences
|
|
1309
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
1310
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
1311
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
1312
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
1313
|
+
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
1314
|
+
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
1315
|
+
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
1316
|
+
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
1317
|
+
|
|
1318
|
+
// store
|
|
1319
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
1320
|
+
|
|
1321
|
+
y += 64;
|
|
1322
|
+
}
|
|
1323
|
+
}
|
|
1324
|
+
|
|
1325
|
+
if (i < ny) {
|
|
1326
|
+
// process leftovers
|
|
1327
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1328
|
+
|
|
1329
|
+
for (; i < ny; i++) {
|
|
1330
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1331
|
+
y += 8;
|
|
1332
|
+
dis[i] = horizontal_sum(accu);
|
|
1333
|
+
}
|
|
1334
|
+
}
|
|
1335
|
+
}
|
|
1336
|
+
|
|
1337
|
+
#endif
|
|
1338
|
+
|
|
1339
|
+
template <class ElementOp>
|
|
1340
|
+
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
|
|
1341
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
1342
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
1343
|
+
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
1344
|
+
|
|
1345
|
+
for (size_t i = 0; i < ny; i++) {
|
|
1346
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
1347
|
+
y += 4;
|
|
1348
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
1349
|
+
y += 4;
|
|
1350
|
+
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
|
|
1351
|
+
y += 4;
|
|
1352
|
+
dis[i] = horizontal_sum(accu);
|
|
1353
|
+
}
|
|
1354
|
+
}
|
|
1355
|
+
|
|
1356
|
+
} // anonymous namespace
|
|
1357
|
+
|
|
1358
|
+
void fvec_L2sqr_ny(
|
|
1359
|
+
float* dis,
|
|
1360
|
+
const float* x,
|
|
1361
|
+
const float* y,
|
|
1362
|
+
size_t d,
|
|
1363
|
+
size_t ny) {
|
|
1364
|
+
// optimized for a few special cases
|
|
1365
|
+
|
|
1366
|
+
#define DISPATCH(dval) \
|
|
1367
|
+
case dval: \
|
|
1368
|
+
fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
|
|
1369
|
+
return;
|
|
1370
|
+
|
|
1371
|
+
switch (d) {
|
|
1372
|
+
DISPATCH(1)
|
|
1373
|
+
DISPATCH(2)
|
|
1374
|
+
DISPATCH(4)
|
|
1375
|
+
DISPATCH(8)
|
|
1376
|
+
DISPATCH(12)
|
|
1377
|
+
default:
|
|
1378
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
1379
|
+
return;
|
|
1380
|
+
}
|
|
1381
|
+
#undef DISPATCH
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
void fvec_inner_products_ny(
|
|
1385
|
+
float* dis,
|
|
1386
|
+
const float* x,
|
|
1387
|
+
const float* y,
|
|
1388
|
+
size_t d,
|
|
1389
|
+
size_t ny) {
|
|
1390
|
+
#define DISPATCH(dval) \
|
|
1391
|
+
case dval: \
|
|
1392
|
+
fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
|
|
555
1393
|
return;
|
|
556
1394
|
|
|
557
1395
|
switch (d) {
|
|
@@ -564,121 +1402,506 @@ void fvec_inner_products_ny(
|
|
|
564
1402
|
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
565
1403
|
return;
|
|
566
1404
|
}
|
|
567
|
-
#undef DISPATCH
|
|
1405
|
+
#undef DISPATCH
|
|
1406
|
+
}
|
|
1407
|
+
|
|
1408
|
+
#if defined(__AVX512F__)
|
|
1409
|
+
|
|
1410
|
+
template <size_t DIM>
|
|
1411
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
|
1412
|
+
float* distances,
|
|
1413
|
+
const float* x,
|
|
1414
|
+
const float* y,
|
|
1415
|
+
const float* y_sqlen,
|
|
1416
|
+
const size_t d_offset,
|
|
1417
|
+
size_t ny) {
|
|
1418
|
+
// current index being processed
|
|
1419
|
+
size_t i = 0;
|
|
1420
|
+
|
|
1421
|
+
// squared length of x
|
|
1422
|
+
float x_sqlen = 0;
|
|
1423
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1424
|
+
x_sqlen += x[j] * x[j];
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
// process 16 vectors per loop
|
|
1428
|
+
const size_t ny16 = ny / 16;
|
|
1429
|
+
|
|
1430
|
+
if (ny16 > 0) {
|
|
1431
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
1432
|
+
__m512 m[DIM];
|
|
1433
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1434
|
+
m[j] = _mm512_set1_ps(x[j]);
|
|
1435
|
+
m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
|
|
1436
|
+
}
|
|
1437
|
+
|
|
1438
|
+
__m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
|
|
1439
|
+
|
|
1440
|
+
for (; i < ny16 * 16; i += 16) {
|
|
1441
|
+
// Load vectors for 16 dimensions
|
|
1442
|
+
__m512 v[DIM];
|
|
1443
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1444
|
+
v[j] = _mm512_loadu_ps(y + j * d_offset);
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
// Compute dot products
|
|
1448
|
+
__m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
|
|
1449
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
1450
|
+
dp = _mm512_fnmadd_ps(m[j], v[j], dp);
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
// Compute y^2 - (2 * x, y) + x^2
|
|
1454
|
+
__m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
|
|
1455
|
+
|
|
1456
|
+
_mm512_storeu_ps(distances + i, distances_v);
|
|
1457
|
+
|
|
1458
|
+
// Scroll y and y_sqlen forward
|
|
1459
|
+
y += 16;
|
|
1460
|
+
y_sqlen += 16;
|
|
1461
|
+
}
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
if (i < ny) {
|
|
1465
|
+
// Process leftovers
|
|
1466
|
+
for (; i < ny; i++) {
|
|
1467
|
+
float dp = 0;
|
|
1468
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1469
|
+
dp += x[j] * y[j * d_offset];
|
|
1470
|
+
}
|
|
1471
|
+
|
|
1472
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
1473
|
+
// lowest distance.
|
|
1474
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
1475
|
+
distances[i] = distance;
|
|
1476
|
+
|
|
1477
|
+
y += 1;
|
|
1478
|
+
y_sqlen += 1;
|
|
1479
|
+
}
|
|
1480
|
+
}
|
|
1481
|
+
}
|
|
1482
|
+
|
|
1483
|
+
#elif defined(__AVX2__)
|
|
1484
|
+
|
|
1485
|
+
template <size_t DIM>
|
|
1486
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
|
1487
|
+
float* distances,
|
|
1488
|
+
const float* x,
|
|
1489
|
+
const float* y,
|
|
1490
|
+
const float* y_sqlen,
|
|
1491
|
+
const size_t d_offset,
|
|
1492
|
+
size_t ny) {
|
|
1493
|
+
// current index being processed
|
|
1494
|
+
size_t i = 0;
|
|
1495
|
+
|
|
1496
|
+
// squared length of x
|
|
1497
|
+
float x_sqlen = 0;
|
|
1498
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1499
|
+
x_sqlen += x[j] * x[j];
|
|
1500
|
+
}
|
|
1501
|
+
|
|
1502
|
+
// process 8 vectors per loop.
|
|
1503
|
+
const size_t ny8 = ny / 8;
|
|
1504
|
+
|
|
1505
|
+
if (ny8 > 0) {
|
|
1506
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
1507
|
+
__m256 m[DIM];
|
|
1508
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1509
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
1510
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
1511
|
+
}
|
|
1512
|
+
|
|
1513
|
+
__m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
|
|
1514
|
+
|
|
1515
|
+
for (; i < ny8 * 8; i += 8) {
|
|
1516
|
+
// collect dim 0 for 8 D4-vectors.
|
|
1517
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
1518
|
+
|
|
1519
|
+
// compute dot products
|
|
1520
|
+
// this is x^2 - 2x[0]*y[0]
|
|
1521
|
+
__m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
|
|
1522
|
+
|
|
1523
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
1524
|
+
// collect dim j for 8 D4-vectors.
|
|
1525
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
1526
|
+
dp = _mm256_fnmadd_ps(m[j], vj, dp);
|
|
1527
|
+
}
|
|
1528
|
+
|
|
1529
|
+
// we've got x^2 - (2x, y) at this point
|
|
1530
|
+
|
|
1531
|
+
// y^2 - (2x, y) + x^2
|
|
1532
|
+
__m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
1533
|
+
|
|
1534
|
+
_mm256_storeu_ps(distances + i, distances_v);
|
|
1535
|
+
|
|
1536
|
+
// scroll y and y_sqlen forward.
|
|
1537
|
+
y += 8;
|
|
1538
|
+
y_sqlen += 8;
|
|
1539
|
+
}
|
|
1540
|
+
}
|
|
1541
|
+
|
|
1542
|
+
if (i < ny) {
|
|
1543
|
+
// process leftovers
|
|
1544
|
+
for (; i < ny; i++) {
|
|
1545
|
+
float dp = 0;
|
|
1546
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1547
|
+
dp += x[j] * y[j * d_offset];
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1550
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
1551
|
+
// lowest distance.
|
|
1552
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
1553
|
+
distances[i] = distance;
|
|
1554
|
+
|
|
1555
|
+
y += 1;
|
|
1556
|
+
y_sqlen += 1;
|
|
1557
|
+
}
|
|
1558
|
+
}
|
|
1559
|
+
}
|
|
1560
|
+
|
|
1561
|
+
#endif
|
|
1562
|
+
|
|
1563
|
+
void fvec_L2sqr_ny_transposed(
|
|
1564
|
+
float* dis,
|
|
1565
|
+
const float* x,
|
|
1566
|
+
const float* y,
|
|
1567
|
+
const float* y_sqlen,
|
|
1568
|
+
size_t d,
|
|
1569
|
+
size_t d_offset,
|
|
1570
|
+
size_t ny) {
|
|
1571
|
+
// optimized for a few special cases
|
|
1572
|
+
|
|
1573
|
+
#ifdef __AVX2__
|
|
1574
|
+
#define DISPATCH(dval) \
|
|
1575
|
+
case dval: \
|
|
1576
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
|
1577
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
|
1578
|
+
|
|
1579
|
+
switch (d) {
|
|
1580
|
+
DISPATCH(1)
|
|
1581
|
+
DISPATCH(2)
|
|
1582
|
+
DISPATCH(4)
|
|
1583
|
+
DISPATCH(8)
|
|
1584
|
+
default:
|
|
1585
|
+
return fvec_L2sqr_ny_y_transposed_ref(
|
|
1586
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1587
|
+
}
|
|
1588
|
+
#undef DISPATCH
|
|
1589
|
+
#else
|
|
1590
|
+
// non-AVX2 case
|
|
1591
|
+
return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
|
|
1592
|
+
#endif
|
|
1593
|
+
}
|
|
1594
|
+
|
|
1595
|
+
#if defined(__AVX512F__)
|
|
1596
|
+
|
|
1597
|
+
size_t fvec_L2sqr_ny_nearest_D2(
|
|
1598
|
+
float* distances_tmp_buffer,
|
|
1599
|
+
const float* x,
|
|
1600
|
+
const float* y,
|
|
1601
|
+
size_t ny) {
|
|
1602
|
+
// this implementation does not use distances_tmp_buffer.
|
|
1603
|
+
|
|
1604
|
+
size_t i = 0;
|
|
1605
|
+
float current_min_distance = HUGE_VALF;
|
|
1606
|
+
size_t current_min_index = 0;
|
|
1607
|
+
|
|
1608
|
+
const size_t ny16 = ny / 16;
|
|
1609
|
+
if (ny16 > 0) {
|
|
1610
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
1611
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
1612
|
+
|
|
1613
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1614
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
|
1615
|
+
|
|
1616
|
+
__m512i current_indices = _mm512_setr_epi32(
|
|
1617
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1618
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
1619
|
+
|
|
1620
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1621
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1622
|
+
|
|
1623
|
+
for (; i < ny16 * 16; i += 16) {
|
|
1624
|
+
_mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
|
|
1625
|
+
|
|
1626
|
+
__m512 v0;
|
|
1627
|
+
__m512 v1;
|
|
1628
|
+
|
|
1629
|
+
transpose_16x2(
|
|
1630
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
1631
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
1632
|
+
v0,
|
|
1633
|
+
v1);
|
|
1634
|
+
|
|
1635
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1636
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1637
|
+
|
|
1638
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1639
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1640
|
+
|
|
1641
|
+
__mmask16 comparison =
|
|
1642
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
1643
|
+
|
|
1644
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1645
|
+
min_indices = _mm512_mask_blend_epi32(
|
|
1646
|
+
comparison, min_indices, current_indices);
|
|
1647
|
+
|
|
1648
|
+
current_indices =
|
|
1649
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
|
1650
|
+
|
|
1651
|
+
y += 32;
|
|
1652
|
+
}
|
|
1653
|
+
|
|
1654
|
+
alignas(64) float min_distances_scalar[16];
|
|
1655
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
|
1656
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1657
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1658
|
+
|
|
1659
|
+
for (size_t j = 0; j < 16; j++) {
|
|
1660
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
1661
|
+
current_min_distance = min_distances_scalar[j];
|
|
1662
|
+
current_min_index = min_indices_scalar[j];
|
|
1663
|
+
}
|
|
1664
|
+
}
|
|
1665
|
+
}
|
|
1666
|
+
|
|
1667
|
+
if (i < ny) {
|
|
1668
|
+
float x0 = x[0];
|
|
1669
|
+
float x1 = x[1];
|
|
1670
|
+
|
|
1671
|
+
for (; i < ny; i++) {
|
|
1672
|
+
float sub0 = x0 - y[0];
|
|
1673
|
+
float sub1 = x1 - y[1];
|
|
1674
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
1675
|
+
|
|
1676
|
+
y += 2;
|
|
1677
|
+
|
|
1678
|
+
if (current_min_distance > distance) {
|
|
1679
|
+
current_min_distance = distance;
|
|
1680
|
+
current_min_index = i;
|
|
1681
|
+
}
|
|
1682
|
+
}
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
return current_min_index;
|
|
568
1686
|
}
|
|
569
1687
|
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
void fvec_L2sqr_ny_y_transposed_D(
|
|
573
|
-
float* distances,
|
|
1688
|
+
size_t fvec_L2sqr_ny_nearest_D4(
|
|
1689
|
+
float* distances_tmp_buffer,
|
|
574
1690
|
const float* x,
|
|
575
1691
|
const float* y,
|
|
576
|
-
const float* y_sqlen,
|
|
577
|
-
const size_t d_offset,
|
|
578
1692
|
size_t ny) {
|
|
579
|
-
//
|
|
1693
|
+
// this implementation does not use distances_tmp_buffer.
|
|
1694
|
+
|
|
580
1695
|
size_t i = 0;
|
|
1696
|
+
float current_min_distance = HUGE_VALF;
|
|
1697
|
+
size_t current_min_index = 0;
|
|
581
1698
|
|
|
582
|
-
|
|
583
|
-
float x_sqlen = 0;
|
|
584
|
-
;
|
|
585
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
586
|
-
x_sqlen += x[j] * x[j];
|
|
587
|
-
}
|
|
1699
|
+
const size_t ny16 = ny / 16;
|
|
588
1700
|
|
|
589
|
-
|
|
590
|
-
|
|
1701
|
+
if (ny16 > 0) {
|
|
1702
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1703
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
|
591
1704
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
596
|
-
m[j] = _mm256_set1_ps(x[j]);
|
|
597
|
-
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
598
|
-
}
|
|
1705
|
+
__m512i current_indices = _mm512_setr_epi32(
|
|
1706
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1707
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
599
1708
|
|
|
600
|
-
|
|
1709
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1710
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1711
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1712
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
601
1713
|
|
|
602
|
-
for (; i <
|
|
603
|
-
|
|
604
|
-
|
|
1714
|
+
for (; i < ny16 * 16; i += 16) {
|
|
1715
|
+
__m512 v0;
|
|
1716
|
+
__m512 v1;
|
|
1717
|
+
__m512 v2;
|
|
1718
|
+
__m512 v3;
|
|
605
1719
|
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
1720
|
+
transpose_16x4(
|
|
1721
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
1722
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
1723
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
1724
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
1725
|
+
v0,
|
|
1726
|
+
v1,
|
|
1727
|
+
v2,
|
|
1728
|
+
v3);
|
|
609
1729
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
}
|
|
1730
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1731
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1732
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1733
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
615
1734
|
|
|
616
|
-
|
|
1735
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1736
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1737
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1738
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
617
1739
|
|
|
618
|
-
|
|
619
|
-
|
|
1740
|
+
__mmask16 comparison =
|
|
1741
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
620
1742
|
|
|
621
|
-
|
|
1743
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1744
|
+
min_indices = _mm512_mask_blend_epi32(
|
|
1745
|
+
comparison, min_indices, current_indices);
|
|
622
1746
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
1747
|
+
current_indices =
|
|
1748
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
|
1749
|
+
|
|
1750
|
+
y += 64;
|
|
1751
|
+
}
|
|
1752
|
+
|
|
1753
|
+
alignas(64) float min_distances_scalar[16];
|
|
1754
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
|
1755
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1756
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1757
|
+
|
|
1758
|
+
for (size_t j = 0; j < 16; j++) {
|
|
1759
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
1760
|
+
current_min_distance = min_distances_scalar[j];
|
|
1761
|
+
current_min_index = min_indices_scalar[j];
|
|
1762
|
+
}
|
|
626
1763
|
}
|
|
627
1764
|
}
|
|
628
1765
|
|
|
629
1766
|
if (i < ny) {
|
|
630
|
-
|
|
631
|
-
for (; i < ny; i++) {
|
|
632
|
-
float dp = 0;
|
|
633
|
-
for (size_t j = 0; j < DIM; j++) {
|
|
634
|
-
dp += x[j] * y[j * d_offset];
|
|
635
|
-
}
|
|
1767
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
636
1768
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
1769
|
+
for (; i < ny; i++) {
|
|
1770
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
1771
|
+
y += 4;
|
|
1772
|
+
const float distance = horizontal_sum(accu);
|
|
641
1773
|
|
|
642
|
-
|
|
643
|
-
|
|
1774
|
+
if (current_min_distance > distance) {
|
|
1775
|
+
current_min_distance = distance;
|
|
1776
|
+
current_min_index = i;
|
|
1777
|
+
}
|
|
644
1778
|
}
|
|
645
1779
|
}
|
|
1780
|
+
|
|
1781
|
+
return current_min_index;
|
|
646
1782
|
}
|
|
647
|
-
#endif
|
|
648
1783
|
|
|
649
|
-
|
|
650
|
-
float*
|
|
1784
|
+
size_t fvec_L2sqr_ny_nearest_D8(
|
|
1785
|
+
float* distances_tmp_buffer,
|
|
651
1786
|
const float* x,
|
|
652
1787
|
const float* y,
|
|
653
|
-
const float* y_sqlen,
|
|
654
|
-
size_t d,
|
|
655
|
-
size_t d_offset,
|
|
656
1788
|
size_t ny) {
|
|
657
|
-
//
|
|
1789
|
+
// this implementation does not use distances_tmp_buffer.
|
|
658
1790
|
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
|
663
|
-
dis, x, y, y_sqlen, d_offset, ny);
|
|
1791
|
+
size_t i = 0;
|
|
1792
|
+
float current_min_distance = HUGE_VALF;
|
|
1793
|
+
size_t current_min_index = 0;
|
|
664
1794
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
1795
|
+
const size_t ny16 = ny / 16;
|
|
1796
|
+
if (ny16 > 0) {
|
|
1797
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
1798
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
|
1799
|
+
|
|
1800
|
+
__m512i current_indices = _mm512_setr_epi32(
|
|
1801
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
1802
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
1803
|
+
|
|
1804
|
+
const __m512 m0 = _mm512_set1_ps(x[0]);
|
|
1805
|
+
const __m512 m1 = _mm512_set1_ps(x[1]);
|
|
1806
|
+
const __m512 m2 = _mm512_set1_ps(x[2]);
|
|
1807
|
+
const __m512 m3 = _mm512_set1_ps(x[3]);
|
|
1808
|
+
|
|
1809
|
+
const __m512 m4 = _mm512_set1_ps(x[4]);
|
|
1810
|
+
const __m512 m5 = _mm512_set1_ps(x[5]);
|
|
1811
|
+
const __m512 m6 = _mm512_set1_ps(x[6]);
|
|
1812
|
+
const __m512 m7 = _mm512_set1_ps(x[7]);
|
|
1813
|
+
|
|
1814
|
+
for (; i < ny16 * 16; i += 16) {
|
|
1815
|
+
__m512 v0;
|
|
1816
|
+
__m512 v1;
|
|
1817
|
+
__m512 v2;
|
|
1818
|
+
__m512 v3;
|
|
1819
|
+
__m512 v4;
|
|
1820
|
+
__m512 v5;
|
|
1821
|
+
__m512 v6;
|
|
1822
|
+
__m512 v7;
|
|
1823
|
+
|
|
1824
|
+
transpose_16x8(
|
|
1825
|
+
_mm512_loadu_ps(y + 0 * 16),
|
|
1826
|
+
_mm512_loadu_ps(y + 1 * 16),
|
|
1827
|
+
_mm512_loadu_ps(y + 2 * 16),
|
|
1828
|
+
_mm512_loadu_ps(y + 3 * 16),
|
|
1829
|
+
_mm512_loadu_ps(y + 4 * 16),
|
|
1830
|
+
_mm512_loadu_ps(y + 5 * 16),
|
|
1831
|
+
_mm512_loadu_ps(y + 6 * 16),
|
|
1832
|
+
_mm512_loadu_ps(y + 7 * 16),
|
|
1833
|
+
v0,
|
|
1834
|
+
v1,
|
|
1835
|
+
v2,
|
|
1836
|
+
v3,
|
|
1837
|
+
v4,
|
|
1838
|
+
v5,
|
|
1839
|
+
v6,
|
|
1840
|
+
v7);
|
|
1841
|
+
|
|
1842
|
+
const __m512 d0 = _mm512_sub_ps(m0, v0);
|
|
1843
|
+
const __m512 d1 = _mm512_sub_ps(m1, v1);
|
|
1844
|
+
const __m512 d2 = _mm512_sub_ps(m2, v2);
|
|
1845
|
+
const __m512 d3 = _mm512_sub_ps(m3, v3);
|
|
1846
|
+
const __m512 d4 = _mm512_sub_ps(m4, v4);
|
|
1847
|
+
const __m512 d5 = _mm512_sub_ps(m5, v5);
|
|
1848
|
+
const __m512 d6 = _mm512_sub_ps(m6, v6);
|
|
1849
|
+
const __m512 d7 = _mm512_sub_ps(m7, v7);
|
|
1850
|
+
|
|
1851
|
+
__m512 distances = _mm512_mul_ps(d0, d0);
|
|
1852
|
+
distances = _mm512_fmadd_ps(d1, d1, distances);
|
|
1853
|
+
distances = _mm512_fmadd_ps(d2, d2, distances);
|
|
1854
|
+
distances = _mm512_fmadd_ps(d3, d3, distances);
|
|
1855
|
+
distances = _mm512_fmadd_ps(d4, d4, distances);
|
|
1856
|
+
distances = _mm512_fmadd_ps(d5, d5, distances);
|
|
1857
|
+
distances = _mm512_fmadd_ps(d6, d6, distances);
|
|
1858
|
+
distances = _mm512_fmadd_ps(d7, d7, distances);
|
|
1859
|
+
|
|
1860
|
+
__mmask16 comparison =
|
|
1861
|
+
_mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
|
|
1862
|
+
|
|
1863
|
+
min_distances = _mm512_min_ps(distances, min_distances);
|
|
1864
|
+
min_indices = _mm512_mask_blend_epi32(
|
|
1865
|
+
comparison, min_indices, current_indices);
|
|
1866
|
+
|
|
1867
|
+
current_indices =
|
|
1868
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
|
1869
|
+
|
|
1870
|
+
y += 128;
|
|
1871
|
+
}
|
|
1872
|
+
|
|
1873
|
+
alignas(64) float min_distances_scalar[16];
|
|
1874
|
+
alignas(64) uint32_t min_indices_scalar[16];
|
|
1875
|
+
_mm512_store_ps(min_distances_scalar, min_distances);
|
|
1876
|
+
_mm512_store_epi32(min_indices_scalar, min_indices);
|
|
1877
|
+
|
|
1878
|
+
for (size_t j = 0; j < 16; j++) {
|
|
1879
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
1880
|
+
current_min_distance = min_distances_scalar[j];
|
|
1881
|
+
current_min_index = min_indices_scalar[j];
|
|
1882
|
+
}
|
|
1883
|
+
}
|
|
673
1884
|
}
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
1885
|
+
|
|
1886
|
+
if (i < ny) {
|
|
1887
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
1888
|
+
|
|
1889
|
+
for (; i < ny; i++) {
|
|
1890
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1891
|
+
y += 8;
|
|
1892
|
+
const float distance = horizontal_sum(accu);
|
|
1893
|
+
|
|
1894
|
+
if (current_min_distance > distance) {
|
|
1895
|
+
current_min_distance = distance;
|
|
1896
|
+
current_min_index = i;
|
|
1897
|
+
}
|
|
1898
|
+
}
|
|
1899
|
+
}
|
|
1900
|
+
|
|
1901
|
+
return current_min_index;
|
|
679
1902
|
}
|
|
680
1903
|
|
|
681
|
-
#
|
|
1904
|
+
#elif defined(__AVX2__)
|
|
682
1905
|
|
|
683
1906
|
size_t fvec_L2sqr_ny_nearest_D2(
|
|
684
1907
|
float* distances_tmp_buffer,
|
|
@@ -697,8 +1920,8 @@ size_t fvec_L2sqr_ny_nearest_D2(
|
|
|
697
1920
|
// process 8 D2-vectors per loop.
|
|
698
1921
|
const size_t ny8 = ny / 8;
|
|
699
1922
|
if (ny8 > 0) {
|
|
700
|
-
_mm_prefetch(y, _MM_HINT_T0);
|
|
701
|
-
_mm_prefetch(y + 16, _MM_HINT_T0);
|
|
1923
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
1924
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
702
1925
|
|
|
703
1926
|
// track min distance and the closest vector independently
|
|
704
1927
|
// for each of 8 AVX2 components.
|
|
@@ -713,7 +1936,7 @@ size_t fvec_L2sqr_ny_nearest_D2(
|
|
|
713
1936
|
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
714
1937
|
|
|
715
1938
|
for (; i < ny8 * 8; i += 8) {
|
|
716
|
-
_mm_prefetch(y + 32, _MM_HINT_T0);
|
|
1939
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
717
1940
|
|
|
718
1941
|
__m256 v0;
|
|
719
1942
|
__m256 v1;
|
|
@@ -892,10 +2115,7 @@ size_t fvec_L2sqr_ny_nearest_D4(
|
|
|
892
2115
|
for (; i < ny; i++) {
|
|
893
2116
|
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
894
2117
|
y += 4;
|
|
895
|
-
|
|
896
|
-
accu = _mm_hadd_ps(accu, accu);
|
|
897
|
-
|
|
898
|
-
const auto distance = _mm_cvtss_f32(accu);
|
|
2118
|
+
const float distance = horizontal_sum(accu);
|
|
899
2119
|
|
|
900
2120
|
if (current_min_distance > distance) {
|
|
901
2121
|
current_min_distance = distance;
|
|
@@ -1031,23 +2251,9 @@ size_t fvec_L2sqr_ny_nearest_D8(
|
|
|
1031
2251
|
__m256 x0 = _mm256_loadu_ps(x);
|
|
1032
2252
|
|
|
1033
2253
|
for (; i < ny; i++) {
|
|
1034
|
-
__m256
|
|
1035
|
-
__m256 accu = _mm256_mul_ps(sub, sub);
|
|
2254
|
+
__m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1036
2255
|
y += 8;
|
|
1037
|
-
|
|
1038
|
-
// horitontal sum
|
|
1039
|
-
const __m256 h0 = _mm256_hadd_ps(accu, accu);
|
|
1040
|
-
const __m256 h1 = _mm256_hadd_ps(h0, h0);
|
|
1041
|
-
|
|
1042
|
-
// extract high and low __m128 regs from __m256
|
|
1043
|
-
const __m128 h2 = _mm256_extractf128_ps(h1, 1);
|
|
1044
|
-
const __m128 h3 = _mm256_castps256_ps128(h1);
|
|
1045
|
-
|
|
1046
|
-
// get a final hsum into all 4 regs
|
|
1047
|
-
const __m128 h4 = _mm_add_ss(h2, h3);
|
|
1048
|
-
|
|
1049
|
-
// extract f[0] from __m128
|
|
1050
|
-
const float distance = _mm_cvtss_f32(h4);
|
|
2256
|
+
const float distance = horizontal_sum(accu);
|
|
1051
2257
|
|
|
1052
2258
|
if (current_min_distance > distance) {
|
|
1053
2259
|
current_min_distance = distance;
|
|
@@ -1106,7 +2312,123 @@ size_t fvec_L2sqr_ny_nearest(
|
|
|
1106
2312
|
#undef DISPATCH
|
|
1107
2313
|
}
|
|
1108
2314
|
|
|
1109
|
-
#
|
|
2315
|
+
#if defined(__AVX512F__)
|
|
2316
|
+
|
|
2317
|
+
template <size_t DIM>
|
|
2318
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
2319
|
+
float* distances_tmp_buffer,
|
|
2320
|
+
const float* x,
|
|
2321
|
+
const float* y,
|
|
2322
|
+
const float* y_sqlen,
|
|
2323
|
+
const size_t d_offset,
|
|
2324
|
+
size_t ny) {
|
|
2325
|
+
// This implementation does not use distances_tmp_buffer.
|
|
2326
|
+
|
|
2327
|
+
// Current index being processed
|
|
2328
|
+
size_t i = 0;
|
|
2329
|
+
|
|
2330
|
+
// Min distance and the index of the closest vector so far
|
|
2331
|
+
float current_min_distance = HUGE_VALF;
|
|
2332
|
+
size_t current_min_index = 0;
|
|
2333
|
+
|
|
2334
|
+
// Process 16 vectors per loop
|
|
2335
|
+
const size_t ny16 = ny / 16;
|
|
2336
|
+
|
|
2337
|
+
if (ny16 > 0) {
|
|
2338
|
+
// Track min distance and the closest vector independently
|
|
2339
|
+
// for each of 16 AVX-512 components.
|
|
2340
|
+
__m512 min_distances = _mm512_set1_ps(HUGE_VALF);
|
|
2341
|
+
__m512i min_indices = _mm512_set1_epi32(0);
|
|
2342
|
+
|
|
2343
|
+
__m512i current_indices = _mm512_setr_epi32(
|
|
2344
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
|
|
2345
|
+
const __m512i indices_increment = _mm512_set1_epi32(16);
|
|
2346
|
+
|
|
2347
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
2348
|
+
__m512 m[DIM];
|
|
2349
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
2350
|
+
m[j] = _mm512_set1_ps(x[j]);
|
|
2351
|
+
m[j] = _mm512_add_ps(m[j], m[j]);
|
|
2352
|
+
}
|
|
2353
|
+
|
|
2354
|
+
for (; i < ny16 * 16; i += 16) {
|
|
2355
|
+
// Compute dot products
|
|
2356
|
+
const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
|
|
2357
|
+
__m512 dp = _mm512_mul_ps(m[0], v0);
|
|
2358
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
2359
|
+
const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
|
|
2360
|
+
dp = _mm512_fmadd_ps(m[j], vj, dp);
|
|
2361
|
+
}
|
|
2362
|
+
|
|
2363
|
+
// Compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
2364
|
+
// lowest distance.
|
|
2365
|
+
// x^2 is the constant that can be avoided.
|
|
2366
|
+
const __m512 distances =
|
|
2367
|
+
_mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
|
|
2368
|
+
|
|
2369
|
+
// Compare the new distances to the min distances
|
|
2370
|
+
__mmask16 comparison =
|
|
2371
|
+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
|
|
2372
|
+
|
|
2373
|
+
// Update min distances and indices with closest vectors if needed
|
|
2374
|
+
min_distances =
|
|
2375
|
+
_mm512_mask_blend_ps(comparison, distances, min_distances);
|
|
2376
|
+
min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
|
|
2377
|
+
comparison,
|
|
2378
|
+
_mm512_castsi512_ps(current_indices),
|
|
2379
|
+
_mm512_castsi512_ps(min_indices)));
|
|
2380
|
+
|
|
2381
|
+
// Update current indices values. Basically, +16 to each of the 16
|
|
2382
|
+
// AVX-512 components.
|
|
2383
|
+
current_indices =
|
|
2384
|
+
_mm512_add_epi32(current_indices, indices_increment);
|
|
2385
|
+
|
|
2386
|
+
// Scroll y and y_sqlen forward.
|
|
2387
|
+
y += 16;
|
|
2388
|
+
y_sqlen += 16;
|
|
2389
|
+
}
|
|
2390
|
+
|
|
2391
|
+
// Dump values and find the minimum distance / minimum index
|
|
2392
|
+
float min_distances_scalar[16];
|
|
2393
|
+
uint32_t min_indices_scalar[16];
|
|
2394
|
+
_mm512_storeu_ps(min_distances_scalar, min_distances);
|
|
2395
|
+
_mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
|
|
2396
|
+
|
|
2397
|
+
for (size_t j = 0; j < 16; j++) {
|
|
2398
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
2399
|
+
current_min_distance = min_distances_scalar[j];
|
|
2400
|
+
current_min_index = min_indices_scalar[j];
|
|
2401
|
+
}
|
|
2402
|
+
}
|
|
2403
|
+
}
|
|
2404
|
+
|
|
2405
|
+
if (i < ny) {
|
|
2406
|
+
// Process leftovers
|
|
2407
|
+
for (; i < ny; i++) {
|
|
2408
|
+
float dp = 0;
|
|
2409
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
2410
|
+
dp += x[j] * y[j * d_offset];
|
|
2411
|
+
}
|
|
2412
|
+
|
|
2413
|
+
// Compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
2414
|
+
// lowest distance.
|
|
2415
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
|
2416
|
+
|
|
2417
|
+
if (current_min_distance > distance) {
|
|
2418
|
+
current_min_distance = distance;
|
|
2419
|
+
current_min_index = i;
|
|
2420
|
+
}
|
|
2421
|
+
|
|
2422
|
+
y += 1;
|
|
2423
|
+
y_sqlen += 1;
|
|
2424
|
+
}
|
|
2425
|
+
}
|
|
2426
|
+
|
|
2427
|
+
return current_min_index;
|
|
2428
|
+
}
|
|
2429
|
+
|
|
2430
|
+
#elif defined(__AVX2__)
|
|
2431
|
+
|
|
1110
2432
|
template <size_t DIM>
|
|
1111
2433
|
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
1112
2434
|
float* distances_tmp_buffer,
|
|
@@ -1222,6 +2544,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
|
1222
2544
|
|
|
1223
2545
|
return current_min_index;
|
|
1224
2546
|
}
|
|
2547
|
+
|
|
1225
2548
|
#endif
|
|
1226
2549
|
|
|
1227
2550
|
size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
@@ -1260,21 +2583,6 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
|
|
|
1260
2583
|
|
|
1261
2584
|
#ifdef USE_AVX
|
|
1262
2585
|
|
|
1263
|
-
// reads 0 <= d < 8 floats as __m256
|
|
1264
|
-
static inline __m256 masked_read_8(int d, const float* x) {
|
|
1265
|
-
assert(0 <= d && d < 8);
|
|
1266
|
-
if (d < 4) {
|
|
1267
|
-
__m256 res = _mm256_setzero_ps();
|
|
1268
|
-
res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
|
|
1269
|
-
return res;
|
|
1270
|
-
} else {
|
|
1271
|
-
__m256 res = _mm256_setzero_ps();
|
|
1272
|
-
res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
|
|
1273
|
-
res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
|
|
1274
|
-
return res;
|
|
1275
|
-
}
|
|
1276
|
-
}
|
|
1277
|
-
|
|
1278
2586
|
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
1279
2587
|
__m256 msum1 = _mm256_setzero_ps();
|
|
1280
2588
|
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
@@ -1493,7 +2801,7 @@ void fvec_inner_products_ny(
|
|
|
1493
2801
|
* heavily optimized table computations
|
|
1494
2802
|
***************************************************************************/
|
|
1495
2803
|
|
|
1496
|
-
static inline void fvec_madd_ref(
|
|
2804
|
+
[[maybe_unused]] static inline void fvec_madd_ref(
|
|
1497
2805
|
size_t n,
|
|
1498
2806
|
const float* a,
|
|
1499
2807
|
float bf,
|
|
@@ -1503,7 +2811,39 @@ static inline void fvec_madd_ref(
|
|
|
1503
2811
|
c[i] = a[i] + bf * b[i];
|
|
1504
2812
|
}
|
|
1505
2813
|
|
|
1506
|
-
#
|
|
2814
|
+
#if defined(__AVX512F__)
|
|
2815
|
+
|
|
2816
|
+
static inline void fvec_madd_avx512(
|
|
2817
|
+
const size_t n,
|
|
2818
|
+
const float* __restrict a,
|
|
2819
|
+
const float bf,
|
|
2820
|
+
const float* __restrict b,
|
|
2821
|
+
float* __restrict c) {
|
|
2822
|
+
const size_t n16 = n / 16;
|
|
2823
|
+
const size_t n_for_masking = n % 16;
|
|
2824
|
+
|
|
2825
|
+
const __m512 bfmm = _mm512_set1_ps(bf);
|
|
2826
|
+
|
|
2827
|
+
size_t idx = 0;
|
|
2828
|
+
for (idx = 0; idx < n16 * 16; idx += 16) {
|
|
2829
|
+
const __m512 ax = _mm512_loadu_ps(a + idx);
|
|
2830
|
+
const __m512 bx = _mm512_loadu_ps(b + idx);
|
|
2831
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
|
2832
|
+
_mm512_storeu_ps(c + idx, abmul);
|
|
2833
|
+
}
|
|
2834
|
+
|
|
2835
|
+
if (n_for_masking > 0) {
|
|
2836
|
+
const __mmask16 mask = (1 << n_for_masking) - 1;
|
|
2837
|
+
|
|
2838
|
+
const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
|
|
2839
|
+
const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
|
|
2840
|
+
const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
|
|
2841
|
+
_mm512_mask_storeu_ps(c + idx, mask, abmul);
|
|
2842
|
+
}
|
|
2843
|
+
}
|
|
2844
|
+
|
|
2845
|
+
#elif defined(__AVX2__)
|
|
2846
|
+
|
|
1507
2847
|
static inline void fvec_madd_avx2(
|
|
1508
2848
|
const size_t n,
|
|
1509
2849
|
const float* __restrict a,
|
|
@@ -1556,11 +2896,12 @@ static inline void fvec_madd_avx2(
|
|
|
1556
2896
|
_mm256_maskstore_ps(c + idx, mask, abmul);
|
|
1557
2897
|
}
|
|
1558
2898
|
}
|
|
2899
|
+
|
|
1559
2900
|
#endif
|
|
1560
2901
|
|
|
1561
2902
|
#ifdef __SSE3__
|
|
1562
2903
|
|
|
1563
|
-
static inline void fvec_madd_sse(
|
|
2904
|
+
[[maybe_unused]] static inline void fvec_madd_sse(
|
|
1564
2905
|
size_t n,
|
|
1565
2906
|
const float* a,
|
|
1566
2907
|
float bf,
|
|
@@ -1581,7 +2922,9 @@ static inline void fvec_madd_sse(
|
|
|
1581
2922
|
}
|
|
1582
2923
|
|
|
1583
2924
|
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
1584
|
-
#ifdef
|
|
2925
|
+
#ifdef __AVX512F__
|
|
2926
|
+
fvec_madd_avx512(n, a, bf, b, c);
|
|
2927
|
+
#elif __AVX2__
|
|
1585
2928
|
fvec_madd_avx2(n, a, bf, b, c);
|
|
1586
2929
|
#else
|
|
1587
2930
|
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
@@ -1807,10 +3150,13 @@ void pq2_8cents_table(
|
|
|
1807
3150
|
switch (nout) {
|
|
1808
3151
|
case 4:
|
|
1809
3152
|
ip3.storeu(out + 3 * ldo);
|
|
3153
|
+
[[fallthrough]];
|
|
1810
3154
|
case 3:
|
|
1811
3155
|
ip2.storeu(out + 2 * ldo);
|
|
3156
|
+
[[fallthrough]];
|
|
1812
3157
|
case 2:
|
|
1813
3158
|
ip1.storeu(out + 1 * ldo);
|
|
3159
|
+
[[fallthrough]];
|
|
1814
3160
|
case 1:
|
|
1815
3161
|
ip0.storeu(out);
|
|
1816
3162
|
}
|