faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -0,0 +1,1185 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/utils/distances.h>
|
|
9
|
+
|
|
10
|
+
#include <immintrin.h>
|
|
11
|
+
|
|
12
|
+
#define AUTOVEC_LEVEL SIMDLevel::AVX2
|
|
13
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
14
|
+
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
|
|
15
|
+
|
|
16
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
17
|
+
#include <faiss/utils/simd_impl/distances_sse-inl.h>
|
|
18
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
19
|
+
#include <faiss/utils/transpose/transpose-avx2-inl.h>
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
|
|
23
|
+
template <>
|
|
24
|
+
void fvec_madd<SIMDLevel::AVX2>(
|
|
25
|
+
const size_t n,
|
|
26
|
+
const float* __restrict a,
|
|
27
|
+
const float bf,
|
|
28
|
+
const float* __restrict b,
|
|
29
|
+
float* __restrict c) {
|
|
30
|
+
//
|
|
31
|
+
const size_t n8 = n / 8;
|
|
32
|
+
const size_t n_for_masking = n % 8;
|
|
33
|
+
|
|
34
|
+
const __m256 bfmm = _mm256_set1_ps(bf);
|
|
35
|
+
|
|
36
|
+
size_t idx = 0;
|
|
37
|
+
for (idx = 0; idx < n8 * 8; idx += 8) {
|
|
38
|
+
const __m256 ax = _mm256_loadu_ps(a + idx);
|
|
39
|
+
const __m256 bx = _mm256_loadu_ps(b + idx);
|
|
40
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
41
|
+
_mm256_storeu_ps(c + idx, abmul);
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
if (n_for_masking > 0) {
|
|
45
|
+
__m256i mask;
|
|
46
|
+
switch (n_for_masking) {
|
|
47
|
+
case 1:
|
|
48
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1);
|
|
49
|
+
break;
|
|
50
|
+
case 2:
|
|
51
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1);
|
|
52
|
+
break;
|
|
53
|
+
case 3:
|
|
54
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1);
|
|
55
|
+
break;
|
|
56
|
+
case 4:
|
|
57
|
+
mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1);
|
|
58
|
+
break;
|
|
59
|
+
case 5:
|
|
60
|
+
mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
|
|
61
|
+
break;
|
|
62
|
+
case 6:
|
|
63
|
+
mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1);
|
|
64
|
+
break;
|
|
65
|
+
case 7:
|
|
66
|
+
mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1);
|
|
67
|
+
break;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
const __m256 ax = _mm256_maskload_ps(a + idx, mask);
|
|
71
|
+
const __m256 bx = _mm256_maskload_ps(b + idx, mask);
|
|
72
|
+
const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax);
|
|
73
|
+
_mm256_maskstore_ps(c + idx, mask, abmul);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
template <size_t DIM>
|
|
78
|
+
void fvec_L2sqr_ny_y_transposed_D(
|
|
79
|
+
float* distances,
|
|
80
|
+
const float* x,
|
|
81
|
+
const float* y,
|
|
82
|
+
const float* y_sqlen,
|
|
83
|
+
const size_t d_offset,
|
|
84
|
+
size_t ny) {
|
|
85
|
+
// current index being processed
|
|
86
|
+
size_t i = 0;
|
|
87
|
+
|
|
88
|
+
// squared length of x
|
|
89
|
+
float x_sqlen = 0;
|
|
90
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
91
|
+
x_sqlen += x[j] * x[j];
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// process 8 vectors per loop.
|
|
95
|
+
const size_t ny8 = ny / 8;
|
|
96
|
+
|
|
97
|
+
if (ny8 > 0) {
|
|
98
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
99
|
+
__m256 m[DIM];
|
|
100
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
101
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
102
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
__m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
|
|
106
|
+
|
|
107
|
+
for (; i < ny8 * 8; i += 8) {
|
|
108
|
+
// collect dim 0 for 8 D4-vectors.
|
|
109
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
110
|
+
|
|
111
|
+
// compute dot products
|
|
112
|
+
// this is x^2 - 2x[0]*y[0]
|
|
113
|
+
__m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
|
|
114
|
+
|
|
115
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
116
|
+
// collect dim j for 8 D4-vectors.
|
|
117
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
118
|
+
dp = _mm256_fnmadd_ps(m[j], vj, dp);
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// we've got x^2 - (2x, y) at this point
|
|
122
|
+
|
|
123
|
+
// y^2 - (2x, y) + x^2
|
|
124
|
+
__m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
125
|
+
|
|
126
|
+
_mm256_storeu_ps(distances + i, distances_v);
|
|
127
|
+
|
|
128
|
+
// scroll y and y_sqlen forward.
|
|
129
|
+
y += 8;
|
|
130
|
+
y_sqlen += 8;
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if (i < ny) {
|
|
135
|
+
// process leftovers
|
|
136
|
+
for (; i < ny; i++) {
|
|
137
|
+
float dp = 0;
|
|
138
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
139
|
+
dp += x[j] * y[j * d_offset];
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
143
|
+
// lowest distance.
|
|
144
|
+
const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
|
|
145
|
+
distances[i] = distance;
|
|
146
|
+
|
|
147
|
+
y += 1;
|
|
148
|
+
y_sqlen += 1;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
template <>
|
|
154
|
+
void fvec_L2sqr_ny_transposed<SIMDLevel::AVX2>(
|
|
155
|
+
float* dis,
|
|
156
|
+
const float* x,
|
|
157
|
+
const float* y,
|
|
158
|
+
const float* y_sqlen,
|
|
159
|
+
size_t d,
|
|
160
|
+
size_t d_offset,
|
|
161
|
+
size_t ny) {
|
|
162
|
+
// optimized for a few special cases
|
|
163
|
+
#define DISPATCH(dval) \
|
|
164
|
+
case dval: \
|
|
165
|
+
return fvec_L2sqr_ny_y_transposed_D<dval>( \
|
|
166
|
+
dis, x, y, y_sqlen, d_offset, ny);
|
|
167
|
+
|
|
168
|
+
switch (d) {
|
|
169
|
+
DISPATCH(1)
|
|
170
|
+
DISPATCH(2)
|
|
171
|
+
DISPATCH(4)
|
|
172
|
+
DISPATCH(8)
|
|
173
|
+
default:
|
|
174
|
+
return fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
|
|
175
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
|
176
|
+
}
|
|
177
|
+
#undef DISPATCH
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
namespace {
|
|
181
|
+
|
|
182
|
+
struct AVX2ElementOpIP : public ElementOpIP {
|
|
183
|
+
using ElementOpIP::op;
|
|
184
|
+
static __m256 op(__m256 x, __m256 y) {
|
|
185
|
+
return _mm256_mul_ps(x, y);
|
|
186
|
+
}
|
|
187
|
+
};
|
|
188
|
+
|
|
189
|
+
struct AVX2ElementOpL2 : public ElementOpL2 {
|
|
190
|
+
using ElementOpL2::op;
|
|
191
|
+
|
|
192
|
+
static __m256 op(__m256 x, __m256 y) {
|
|
193
|
+
__m256 tmp = _mm256_sub_ps(x, y);
|
|
194
|
+
return _mm256_mul_ps(tmp, tmp);
|
|
195
|
+
}
|
|
196
|
+
};
|
|
197
|
+
|
|
198
|
+
} // namespace
|
|
199
|
+
|
|
200
|
+
/// helper function for AVX2
|
|
201
|
+
inline float horizontal_sum(const __m256 v) {
|
|
202
|
+
// add high and low parts
|
|
203
|
+
const __m128 v0 =
|
|
204
|
+
_mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
|
|
205
|
+
// perform horizontal sum on v0
|
|
206
|
+
return horizontal_sum(v0);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
template <>
|
|
210
|
+
void fvec_op_ny_D2<AVX2ElementOpIP>(
|
|
211
|
+
float* dis,
|
|
212
|
+
const float* x,
|
|
213
|
+
const float* y,
|
|
214
|
+
size_t ny) {
|
|
215
|
+
const size_t ny8 = ny / 8;
|
|
216
|
+
size_t i = 0;
|
|
217
|
+
|
|
218
|
+
if (ny8 > 0) {
|
|
219
|
+
// process 8 D2-vectors per loop.
|
|
220
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
221
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
222
|
+
|
|
223
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
224
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
225
|
+
|
|
226
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
227
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
228
|
+
|
|
229
|
+
// load 8x2 matrix and transpose it in registers.
|
|
230
|
+
// the typical bottleneck is memory access, so
|
|
231
|
+
// let's trade instructions for the bandwidth.
|
|
232
|
+
|
|
233
|
+
__m256 v0;
|
|
234
|
+
__m256 v1;
|
|
235
|
+
|
|
236
|
+
transpose_8x2(
|
|
237
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
238
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
239
|
+
v0,
|
|
240
|
+
v1);
|
|
241
|
+
|
|
242
|
+
// compute distances
|
|
243
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
244
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
245
|
+
|
|
246
|
+
// store
|
|
247
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
248
|
+
|
|
249
|
+
y += 16;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
if (i < ny) {
|
|
254
|
+
// process leftovers
|
|
255
|
+
float x0 = x[0];
|
|
256
|
+
float x1 = x[1];
|
|
257
|
+
|
|
258
|
+
for (; i < ny; i++) {
|
|
259
|
+
float distance = x0 * y[0] + x1 * y[1];
|
|
260
|
+
y += 2;
|
|
261
|
+
dis[i] = distance;
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
template <>
|
|
267
|
+
void fvec_op_ny_D2<AVX2ElementOpL2>(
|
|
268
|
+
float* dis,
|
|
269
|
+
const float* x,
|
|
270
|
+
const float* y,
|
|
271
|
+
size_t ny) {
|
|
272
|
+
const size_t ny8 = ny / 8;
|
|
273
|
+
size_t i = 0;
|
|
274
|
+
|
|
275
|
+
if (ny8 > 0) {
|
|
276
|
+
// process 8 D2-vectors per loop.
|
|
277
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
278
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
279
|
+
|
|
280
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
281
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
282
|
+
|
|
283
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
284
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
285
|
+
|
|
286
|
+
// load 8x2 matrix and transpose it in registers.
|
|
287
|
+
// the typical bottleneck is memory access, so
|
|
288
|
+
// let's trade instructions for the bandwidth.
|
|
289
|
+
|
|
290
|
+
__m256 v0;
|
|
291
|
+
__m256 v1;
|
|
292
|
+
|
|
293
|
+
transpose_8x2(
|
|
294
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
295
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
296
|
+
v0,
|
|
297
|
+
v1);
|
|
298
|
+
|
|
299
|
+
// compute differences
|
|
300
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
301
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
302
|
+
|
|
303
|
+
// compute squares of differences
|
|
304
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
305
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
306
|
+
|
|
307
|
+
// store
|
|
308
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
309
|
+
|
|
310
|
+
y += 16;
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if (i < ny) {
|
|
315
|
+
// process leftovers
|
|
316
|
+
float x0 = x[0];
|
|
317
|
+
float x1 = x[1];
|
|
318
|
+
|
|
319
|
+
for (; i < ny; i++) {
|
|
320
|
+
float sub0 = x0 - y[0];
|
|
321
|
+
float sub1 = x1 - y[1];
|
|
322
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
323
|
+
|
|
324
|
+
y += 2;
|
|
325
|
+
dis[i] = distance;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
template <>
|
|
331
|
+
void fvec_op_ny_D4<AVX2ElementOpIP>(
|
|
332
|
+
float* dis,
|
|
333
|
+
const float* x,
|
|
334
|
+
const float* y,
|
|
335
|
+
size_t ny) {
|
|
336
|
+
const size_t ny8 = ny / 8;
|
|
337
|
+
size_t i = 0;
|
|
338
|
+
|
|
339
|
+
if (ny8 > 0) {
|
|
340
|
+
// process 8 D4-vectors per loop.
|
|
341
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
342
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
343
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
344
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
345
|
+
|
|
346
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
347
|
+
// load 8x4 matrix and transpose it in registers.
|
|
348
|
+
// the typical bottleneck is memory access, so
|
|
349
|
+
// let's trade instructions for the bandwidth.
|
|
350
|
+
|
|
351
|
+
__m256 v0;
|
|
352
|
+
__m256 v1;
|
|
353
|
+
__m256 v2;
|
|
354
|
+
__m256 v3;
|
|
355
|
+
|
|
356
|
+
transpose_8x4(
|
|
357
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
358
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
359
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
360
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
361
|
+
v0,
|
|
362
|
+
v1,
|
|
363
|
+
v2,
|
|
364
|
+
v3);
|
|
365
|
+
|
|
366
|
+
// compute distances
|
|
367
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
368
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
369
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
370
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
371
|
+
|
|
372
|
+
// store
|
|
373
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
374
|
+
|
|
375
|
+
y += 32;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
if (i < ny) {
|
|
380
|
+
// process leftovers
|
|
381
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
382
|
+
|
|
383
|
+
for (; i < ny; i++) {
|
|
384
|
+
__m128 accu = AVX2ElementOpIP::op(x0, _mm_loadu_ps(y));
|
|
385
|
+
y += 4;
|
|
386
|
+
dis[i] = horizontal_sum(accu);
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
template <>
|
|
392
|
+
void fvec_op_ny_D4<AVX2ElementOpL2>(
|
|
393
|
+
float* dis,
|
|
394
|
+
const float* x,
|
|
395
|
+
const float* y,
|
|
396
|
+
size_t ny) {
|
|
397
|
+
const size_t ny8 = ny / 8;
|
|
398
|
+
size_t i = 0;
|
|
399
|
+
|
|
400
|
+
if (ny8 > 0) {
|
|
401
|
+
// process 8 D4-vectors per loop.
|
|
402
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
403
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
404
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
405
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
406
|
+
|
|
407
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
408
|
+
// load 8x4 matrix and transpose it in registers.
|
|
409
|
+
// the typical bottleneck is memory access, so
|
|
410
|
+
// let's trade instructions for the bandwidth.
|
|
411
|
+
|
|
412
|
+
__m256 v0;
|
|
413
|
+
__m256 v1;
|
|
414
|
+
__m256 v2;
|
|
415
|
+
__m256 v3;
|
|
416
|
+
|
|
417
|
+
transpose_8x4(
|
|
418
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
419
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
420
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
421
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
422
|
+
v0,
|
|
423
|
+
v1,
|
|
424
|
+
v2,
|
|
425
|
+
v3);
|
|
426
|
+
|
|
427
|
+
// compute differences
|
|
428
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
429
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
430
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
431
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
432
|
+
|
|
433
|
+
// compute squares of differences
|
|
434
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
435
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
436
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
437
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
438
|
+
|
|
439
|
+
// store
|
|
440
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
441
|
+
|
|
442
|
+
y += 32;
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
if (i < ny) {
|
|
447
|
+
// process leftovers
|
|
448
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
449
|
+
|
|
450
|
+
for (; i < ny; i++) {
|
|
451
|
+
__m128 accu = AVX2ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
452
|
+
y += 4;
|
|
453
|
+
dis[i] = horizontal_sum(accu);
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
template <>
|
|
459
|
+
void fvec_op_ny_D8<AVX2ElementOpIP>(
|
|
460
|
+
float* dis,
|
|
461
|
+
const float* x,
|
|
462
|
+
const float* y,
|
|
463
|
+
size_t ny) {
|
|
464
|
+
const size_t ny8 = ny / 8;
|
|
465
|
+
size_t i = 0;
|
|
466
|
+
|
|
467
|
+
if (ny8 > 0) {
|
|
468
|
+
// process 8 D8-vectors per loop.
|
|
469
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
470
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
471
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
472
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
473
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
474
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
475
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
476
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
477
|
+
|
|
478
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
479
|
+
// load 8x8 matrix and transpose it in registers.
|
|
480
|
+
// the typical bottleneck is memory access, so
|
|
481
|
+
// let's trade instructions for the bandwidth.
|
|
482
|
+
|
|
483
|
+
__m256 v0;
|
|
484
|
+
__m256 v1;
|
|
485
|
+
__m256 v2;
|
|
486
|
+
__m256 v3;
|
|
487
|
+
__m256 v4;
|
|
488
|
+
__m256 v5;
|
|
489
|
+
__m256 v6;
|
|
490
|
+
__m256 v7;
|
|
491
|
+
|
|
492
|
+
transpose_8x8(
|
|
493
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
494
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
495
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
496
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
497
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
498
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
499
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
500
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
501
|
+
v0,
|
|
502
|
+
v1,
|
|
503
|
+
v2,
|
|
504
|
+
v3,
|
|
505
|
+
v4,
|
|
506
|
+
v5,
|
|
507
|
+
v6,
|
|
508
|
+
v7);
|
|
509
|
+
|
|
510
|
+
// compute distances
|
|
511
|
+
__m256 distances = _mm256_mul_ps(m0, v0);
|
|
512
|
+
distances = _mm256_fmadd_ps(m1, v1, distances);
|
|
513
|
+
distances = _mm256_fmadd_ps(m2, v2, distances);
|
|
514
|
+
distances = _mm256_fmadd_ps(m3, v3, distances);
|
|
515
|
+
distances = _mm256_fmadd_ps(m4, v4, distances);
|
|
516
|
+
distances = _mm256_fmadd_ps(m5, v5, distances);
|
|
517
|
+
distances = _mm256_fmadd_ps(m6, v6, distances);
|
|
518
|
+
distances = _mm256_fmadd_ps(m7, v7, distances);
|
|
519
|
+
|
|
520
|
+
// store
|
|
521
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
522
|
+
|
|
523
|
+
y += 64;
|
|
524
|
+
}
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
if (i < ny) {
|
|
528
|
+
// process leftovers
|
|
529
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
530
|
+
|
|
531
|
+
for (; i < ny; i++) {
|
|
532
|
+
__m256 accu = AVX2ElementOpIP::op(x0, _mm256_loadu_ps(y));
|
|
533
|
+
y += 8;
|
|
534
|
+
dis[i] = horizontal_sum(accu);
|
|
535
|
+
}
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
template <>
|
|
540
|
+
void fvec_op_ny_D8<AVX2ElementOpL2>(
|
|
541
|
+
float* dis,
|
|
542
|
+
const float* x,
|
|
543
|
+
const float* y,
|
|
544
|
+
size_t ny) {
|
|
545
|
+
const size_t ny8 = ny / 8;
|
|
546
|
+
size_t i = 0;
|
|
547
|
+
|
|
548
|
+
if (ny8 > 0) {
|
|
549
|
+
// process 8 D8-vectors per loop.
|
|
550
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
551
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
552
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
553
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
554
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
555
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
556
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
557
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
558
|
+
|
|
559
|
+
for (i = 0; i < ny8 * 8; i += 8) {
|
|
560
|
+
// load 8x8 matrix and transpose it in registers.
|
|
561
|
+
// the typical bottleneck is memory access, so
|
|
562
|
+
// let's trade instructions for the bandwidth.
|
|
563
|
+
|
|
564
|
+
__m256 v0;
|
|
565
|
+
__m256 v1;
|
|
566
|
+
__m256 v2;
|
|
567
|
+
__m256 v3;
|
|
568
|
+
__m256 v4;
|
|
569
|
+
__m256 v5;
|
|
570
|
+
__m256 v6;
|
|
571
|
+
__m256 v7;
|
|
572
|
+
|
|
573
|
+
transpose_8x8(
|
|
574
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
575
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
576
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
577
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
578
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
579
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
580
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
581
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
582
|
+
v0,
|
|
583
|
+
v1,
|
|
584
|
+
v2,
|
|
585
|
+
v3,
|
|
586
|
+
v4,
|
|
587
|
+
v5,
|
|
588
|
+
v6,
|
|
589
|
+
v7);
|
|
590
|
+
|
|
591
|
+
// compute differences
|
|
592
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
593
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
594
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
595
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
596
|
+
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
597
|
+
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
598
|
+
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
599
|
+
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
600
|
+
|
|
601
|
+
// compute squares of differences
|
|
602
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
603
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
604
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
605
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
606
|
+
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
607
|
+
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
608
|
+
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
609
|
+
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
610
|
+
|
|
611
|
+
// store
|
|
612
|
+
_mm256_storeu_ps(dis + i, distances);
|
|
613
|
+
|
|
614
|
+
y += 64;
|
|
615
|
+
}
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
if (i < ny) {
|
|
619
|
+
// process leftovers
|
|
620
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
621
|
+
|
|
622
|
+
for (; i < ny; i++) {
|
|
623
|
+
__m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
624
|
+
y += 8;
|
|
625
|
+
dis[i] = horizontal_sum(accu);
|
|
626
|
+
}
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
template <>
|
|
631
|
+
void fvec_inner_products_ny<SIMDLevel::AVX2>(
|
|
632
|
+
float* ip, /* output inner product */
|
|
633
|
+
const float* x,
|
|
634
|
+
const float* y,
|
|
635
|
+
size_t d,
|
|
636
|
+
size_t ny) {
|
|
637
|
+
fvec_inner_products_ny_ref<AVX2ElementOpIP>(ip, x, y, d, ny);
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
template <>
|
|
641
|
+
void fvec_L2sqr_ny<SIMDLevel::AVX2>(
|
|
642
|
+
float* dis,
|
|
643
|
+
const float* x,
|
|
644
|
+
const float* y,
|
|
645
|
+
size_t d,
|
|
646
|
+
size_t ny) {
|
|
647
|
+
fvec_L2sqr_ny_ref<AVX2ElementOpL2>(dis, x, y, d, ny);
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
template <>
|
|
651
|
+
size_t fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX2>(
|
|
652
|
+
float* /*distances_tmp_buffer*/,
|
|
653
|
+
const float* x,
|
|
654
|
+
const float* y,
|
|
655
|
+
size_t ny) {
|
|
656
|
+
// this implementation does not use distances_tmp_buffer.
|
|
657
|
+
// current index being processed
|
|
658
|
+
size_t i = 0;
|
|
659
|
+
|
|
660
|
+
// min distance and the index of the closest vector so far
|
|
661
|
+
float current_min_distance = HUGE_VALF;
|
|
662
|
+
size_t current_min_index = 0;
|
|
663
|
+
|
|
664
|
+
// process 8 D2-vectors per loop.
|
|
665
|
+
const size_t ny8 = ny / 8;
|
|
666
|
+
if (ny8 > 0) {
|
|
667
|
+
_mm_prefetch((const char*)y, _MM_HINT_T0);
|
|
668
|
+
_mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
|
|
669
|
+
|
|
670
|
+
// track min distance and the closest vector independently
|
|
671
|
+
// for each of 8 AVX2 components.
|
|
672
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
673
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
674
|
+
|
|
675
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
676
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
677
|
+
|
|
678
|
+
// 1 value per register
|
|
679
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
680
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
681
|
+
|
|
682
|
+
for (; i < ny8 * 8; i += 8) {
|
|
683
|
+
_mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
|
|
684
|
+
|
|
685
|
+
__m256 v0;
|
|
686
|
+
__m256 v1;
|
|
687
|
+
|
|
688
|
+
transpose_8x2(
|
|
689
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
690
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
691
|
+
v0,
|
|
692
|
+
v1);
|
|
693
|
+
|
|
694
|
+
// compute differences
|
|
695
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
696
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
697
|
+
|
|
698
|
+
// compute squares of differences
|
|
699
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
700
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
701
|
+
|
|
702
|
+
// compare the new distances to the min distances
|
|
703
|
+
// for each of 8 AVX2 components.
|
|
704
|
+
__m256 comparison =
|
|
705
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
706
|
+
|
|
707
|
+
// update min distances and indices with closest vectors if needed.
|
|
708
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
709
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
710
|
+
_mm256_castsi256_ps(current_indices),
|
|
711
|
+
_mm256_castsi256_ps(min_indices),
|
|
712
|
+
comparison));
|
|
713
|
+
|
|
714
|
+
// update current indices values. Basically, +8 to each of the
|
|
715
|
+
// 8 AVX2 components.
|
|
716
|
+
current_indices =
|
|
717
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
718
|
+
|
|
719
|
+
// scroll y forward (8 vectors 2 DIM each).
|
|
720
|
+
y += 16;
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
// dump values and find the minimum distance / minimum index
|
|
724
|
+
float min_distances_scalar[8];
|
|
725
|
+
uint32_t min_indices_scalar[8];
|
|
726
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
727
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
728
|
+
|
|
729
|
+
for (size_t j = 0; j < 8; j++) {
|
|
730
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
731
|
+
current_min_distance = min_distances_scalar[j];
|
|
732
|
+
current_min_index = min_indices_scalar[j];
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
if (i < ny) {
|
|
738
|
+
// process leftovers.
|
|
739
|
+
// the following code is not optimal, but it is rarely invoked.
|
|
740
|
+
float x0 = x[0];
|
|
741
|
+
float x1 = x[1];
|
|
742
|
+
|
|
743
|
+
for (; i < ny; i++) {
|
|
744
|
+
float sub0 = x0 - y[0];
|
|
745
|
+
float sub1 = x1 - y[1];
|
|
746
|
+
float distance = sub0 * sub0 + sub1 * sub1;
|
|
747
|
+
|
|
748
|
+
y += 2;
|
|
749
|
+
|
|
750
|
+
if (current_min_distance > distance) {
|
|
751
|
+
current_min_distance = distance;
|
|
752
|
+
current_min_index = i;
|
|
753
|
+
}
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
return current_min_index;
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
template <>
|
|
761
|
+
size_t fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX2>(
|
|
762
|
+
float* /*distances_tmp_buffer*/,
|
|
763
|
+
const float* x,
|
|
764
|
+
const float* y,
|
|
765
|
+
size_t ny) {
|
|
766
|
+
// this implementation does not use distances_tmp_buffer.
|
|
767
|
+
|
|
768
|
+
// current index being processed
|
|
769
|
+
size_t i = 0;
|
|
770
|
+
|
|
771
|
+
// min distance and the index of the closest vector so far
|
|
772
|
+
float current_min_distance = HUGE_VALF;
|
|
773
|
+
size_t current_min_index = 0;
|
|
774
|
+
|
|
775
|
+
// process 8 D4-vectors per loop.
|
|
776
|
+
const size_t ny8 = ny / 8;
|
|
777
|
+
|
|
778
|
+
if (ny8 > 0) {
|
|
779
|
+
// track min distance and the closest vector independently
|
|
780
|
+
// for each of 8 AVX2 components.
|
|
781
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
782
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
783
|
+
|
|
784
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
785
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
786
|
+
|
|
787
|
+
// 1 value per register
|
|
788
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
789
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
790
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
791
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
792
|
+
|
|
793
|
+
for (; i < ny8 * 8; i += 8) {
|
|
794
|
+
__m256 v0;
|
|
795
|
+
__m256 v1;
|
|
796
|
+
__m256 v2;
|
|
797
|
+
__m256 v3;
|
|
798
|
+
|
|
799
|
+
transpose_8x4(
|
|
800
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
801
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
802
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
803
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
804
|
+
v0,
|
|
805
|
+
v1,
|
|
806
|
+
v2,
|
|
807
|
+
v3);
|
|
808
|
+
|
|
809
|
+
// compute differences
|
|
810
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
811
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
812
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
813
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
814
|
+
|
|
815
|
+
// compute squares of differences
|
|
816
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
817
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
818
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
819
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
820
|
+
|
|
821
|
+
// compare the new distances to the min distances
|
|
822
|
+
// for each of 8 AVX2 components.
|
|
823
|
+
__m256 comparison =
|
|
824
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
825
|
+
|
|
826
|
+
// update min distances and indices with closest vectors if needed.
|
|
827
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
828
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
829
|
+
_mm256_castsi256_ps(current_indices),
|
|
830
|
+
_mm256_castsi256_ps(min_indices),
|
|
831
|
+
comparison));
|
|
832
|
+
|
|
833
|
+
// update current indices values. Basically, +8 to each of the
|
|
834
|
+
// 8 AVX2 components.
|
|
835
|
+
current_indices =
|
|
836
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
837
|
+
|
|
838
|
+
// scroll y forward (8 vectors 4 DIM each).
|
|
839
|
+
y += 32;
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
// dump values and find the minimum distance / minimum index
|
|
843
|
+
float min_distances_scalar[8];
|
|
844
|
+
uint32_t min_indices_scalar[8];
|
|
845
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
846
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
847
|
+
|
|
848
|
+
for (size_t j = 0; j < 8; j++) {
|
|
849
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
850
|
+
current_min_distance = min_distances_scalar[j];
|
|
851
|
+
current_min_index = min_indices_scalar[j];
|
|
852
|
+
}
|
|
853
|
+
}
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
if (i < ny) {
|
|
857
|
+
// process leftovers
|
|
858
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
859
|
+
|
|
860
|
+
for (; i < ny; i++) {
|
|
861
|
+
__m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
|
|
862
|
+
y += 4;
|
|
863
|
+
const float distance = horizontal_sum(accu);
|
|
864
|
+
|
|
865
|
+
if (current_min_distance > distance) {
|
|
866
|
+
current_min_distance = distance;
|
|
867
|
+
current_min_index = i;
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
return current_min_index;
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
template <>
|
|
876
|
+
size_t fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX2>(
|
|
877
|
+
float* /*distances_tmp_buffer*/,
|
|
878
|
+
const float* x,
|
|
879
|
+
const float* y,
|
|
880
|
+
size_t ny) {
|
|
881
|
+
// this implementation does not use distances_tmp_buffer.
|
|
882
|
+
|
|
883
|
+
// current index being processed
|
|
884
|
+
size_t i = 0;
|
|
885
|
+
|
|
886
|
+
// min distance and the index of the closest vector so far
|
|
887
|
+
float current_min_distance = HUGE_VALF;
|
|
888
|
+
size_t current_min_index = 0;
|
|
889
|
+
|
|
890
|
+
// process 8 D8-vectors per loop.
|
|
891
|
+
const size_t ny8 = ny / 8;
|
|
892
|
+
if (ny8 > 0) {
|
|
893
|
+
// track min distance and the closest vector independently
|
|
894
|
+
// for each of 8 AVX2 components.
|
|
895
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
896
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
897
|
+
|
|
898
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
899
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
900
|
+
|
|
901
|
+
// 1 value per register
|
|
902
|
+
const __m256 m0 = _mm256_set1_ps(x[0]);
|
|
903
|
+
const __m256 m1 = _mm256_set1_ps(x[1]);
|
|
904
|
+
const __m256 m2 = _mm256_set1_ps(x[2]);
|
|
905
|
+
const __m256 m3 = _mm256_set1_ps(x[3]);
|
|
906
|
+
|
|
907
|
+
const __m256 m4 = _mm256_set1_ps(x[4]);
|
|
908
|
+
const __m256 m5 = _mm256_set1_ps(x[5]);
|
|
909
|
+
const __m256 m6 = _mm256_set1_ps(x[6]);
|
|
910
|
+
const __m256 m7 = _mm256_set1_ps(x[7]);
|
|
911
|
+
|
|
912
|
+
for (; i < ny8 * 8; i += 8) {
|
|
913
|
+
__m256 v0;
|
|
914
|
+
__m256 v1;
|
|
915
|
+
__m256 v2;
|
|
916
|
+
__m256 v3;
|
|
917
|
+
__m256 v4;
|
|
918
|
+
__m256 v5;
|
|
919
|
+
__m256 v6;
|
|
920
|
+
__m256 v7;
|
|
921
|
+
|
|
922
|
+
transpose_8x8(
|
|
923
|
+
_mm256_loadu_ps(y + 0 * 8),
|
|
924
|
+
_mm256_loadu_ps(y + 1 * 8),
|
|
925
|
+
_mm256_loadu_ps(y + 2 * 8),
|
|
926
|
+
_mm256_loadu_ps(y + 3 * 8),
|
|
927
|
+
_mm256_loadu_ps(y + 4 * 8),
|
|
928
|
+
_mm256_loadu_ps(y + 5 * 8),
|
|
929
|
+
_mm256_loadu_ps(y + 6 * 8),
|
|
930
|
+
_mm256_loadu_ps(y + 7 * 8),
|
|
931
|
+
v0,
|
|
932
|
+
v1,
|
|
933
|
+
v2,
|
|
934
|
+
v3,
|
|
935
|
+
v4,
|
|
936
|
+
v5,
|
|
937
|
+
v6,
|
|
938
|
+
v7);
|
|
939
|
+
|
|
940
|
+
// compute differences
|
|
941
|
+
const __m256 d0 = _mm256_sub_ps(m0, v0);
|
|
942
|
+
const __m256 d1 = _mm256_sub_ps(m1, v1);
|
|
943
|
+
const __m256 d2 = _mm256_sub_ps(m2, v2);
|
|
944
|
+
const __m256 d3 = _mm256_sub_ps(m3, v3);
|
|
945
|
+
const __m256 d4 = _mm256_sub_ps(m4, v4);
|
|
946
|
+
const __m256 d5 = _mm256_sub_ps(m5, v5);
|
|
947
|
+
const __m256 d6 = _mm256_sub_ps(m6, v6);
|
|
948
|
+
const __m256 d7 = _mm256_sub_ps(m7, v7);
|
|
949
|
+
|
|
950
|
+
// compute squares of differences
|
|
951
|
+
__m256 distances = _mm256_mul_ps(d0, d0);
|
|
952
|
+
distances = _mm256_fmadd_ps(d1, d1, distances);
|
|
953
|
+
distances = _mm256_fmadd_ps(d2, d2, distances);
|
|
954
|
+
distances = _mm256_fmadd_ps(d3, d3, distances);
|
|
955
|
+
distances = _mm256_fmadd_ps(d4, d4, distances);
|
|
956
|
+
distances = _mm256_fmadd_ps(d5, d5, distances);
|
|
957
|
+
distances = _mm256_fmadd_ps(d6, d6, distances);
|
|
958
|
+
distances = _mm256_fmadd_ps(d7, d7, distances);
|
|
959
|
+
|
|
960
|
+
// compare the new distances to the min distances
|
|
961
|
+
// for each of 8 AVX2 components.
|
|
962
|
+
__m256 comparison =
|
|
963
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
964
|
+
|
|
965
|
+
// update min distances and indices with closest vectors if needed.
|
|
966
|
+
min_distances = _mm256_min_ps(distances, min_distances);
|
|
967
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
968
|
+
_mm256_castsi256_ps(current_indices),
|
|
969
|
+
_mm256_castsi256_ps(min_indices),
|
|
970
|
+
comparison));
|
|
971
|
+
|
|
972
|
+
// update current indices values. Basically, +8 to each of the
|
|
973
|
+
// 8 AVX2 components.
|
|
974
|
+
current_indices =
|
|
975
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
976
|
+
|
|
977
|
+
// scroll y forward (8 vectors 8 DIM each).
|
|
978
|
+
y += 64;
|
|
979
|
+
}
|
|
980
|
+
|
|
981
|
+
// dump values and find the minimum distance / minimum index
|
|
982
|
+
float min_distances_scalar[8];
|
|
983
|
+
uint32_t min_indices_scalar[8];
|
|
984
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
985
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
986
|
+
|
|
987
|
+
for (size_t j = 0; j < 8; j++) {
|
|
988
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
989
|
+
current_min_distance = min_distances_scalar[j];
|
|
990
|
+
current_min_index = min_indices_scalar[j];
|
|
991
|
+
}
|
|
992
|
+
}
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
if (i < ny) {
|
|
996
|
+
// process leftovers
|
|
997
|
+
__m256 x0 = _mm256_loadu_ps(x);
|
|
998
|
+
|
|
999
|
+
for (; i < ny; i++) {
|
|
1000
|
+
__m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y));
|
|
1001
|
+
y += 8;
|
|
1002
|
+
const float distance = horizontal_sum(accu);
|
|
1003
|
+
|
|
1004
|
+
if (current_min_distance > distance) {
|
|
1005
|
+
current_min_distance = distance;
|
|
1006
|
+
current_min_index = i;
|
|
1007
|
+
}
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
return current_min_index;
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
template <>
|
|
1015
|
+
size_t fvec_L2sqr_ny_nearest<SIMDLevel::AVX2>(
|
|
1016
|
+
float* distances_tmp_buffer,
|
|
1017
|
+
const float* x,
|
|
1018
|
+
const float* y,
|
|
1019
|
+
size_t d,
|
|
1020
|
+
size_t ny) {
|
|
1021
|
+
return fvec_L2sqr_ny_nearest_x86<SIMDLevel::AVX2>(
|
|
1022
|
+
distances_tmp_buffer,
|
|
1023
|
+
x,
|
|
1024
|
+
y,
|
|
1025
|
+
d,
|
|
1026
|
+
ny,
|
|
1027
|
+
&fvec_L2sqr_ny_nearest_D2<SIMDLevel::AVX2>,
|
|
1028
|
+
&fvec_L2sqr_ny_nearest_D4<SIMDLevel::AVX2>,
|
|
1029
|
+
&fvec_L2sqr_ny_nearest_D8<SIMDLevel::AVX2>);
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
template <size_t DIM>
|
|
1033
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed_D(
|
|
1034
|
+
float* /*distances_tmp_buffer*/,
|
|
1035
|
+
const float* x,
|
|
1036
|
+
const float* y,
|
|
1037
|
+
const float* y_sqlen,
|
|
1038
|
+
const size_t d_offset,
|
|
1039
|
+
size_t ny) {
|
|
1040
|
+
// this implementation does not use distances_tmp_buffer.
|
|
1041
|
+
|
|
1042
|
+
// current index being processed
|
|
1043
|
+
size_t i = 0;
|
|
1044
|
+
|
|
1045
|
+
// min distance and the index of the closest vector so far
|
|
1046
|
+
float current_min_distance = HUGE_VALF;
|
|
1047
|
+
size_t current_min_index = 0;
|
|
1048
|
+
|
|
1049
|
+
// process 8 vectors per loop.
|
|
1050
|
+
const size_t ny8 = ny / 8;
|
|
1051
|
+
|
|
1052
|
+
if (ny8 > 0) {
|
|
1053
|
+
// track min distance and the closest vector independently
|
|
1054
|
+
// for each of 8 AVX2 components.
|
|
1055
|
+
__m256 min_distances = _mm256_set1_ps(HUGE_VALF);
|
|
1056
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
1057
|
+
|
|
1058
|
+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
1059
|
+
const __m256i indices_increment = _mm256_set1_epi32(8);
|
|
1060
|
+
|
|
1061
|
+
// m[i] = (2 * x[i], ... 2 * x[i])
|
|
1062
|
+
__m256 m[DIM];
|
|
1063
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1064
|
+
m[j] = _mm256_set1_ps(x[j]);
|
|
1065
|
+
m[j] = _mm256_add_ps(m[j], m[j]);
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
for (; i < ny8 * 8; i += 8) {
|
|
1069
|
+
// collect dim 0 for 8 D4-vectors.
|
|
1070
|
+
const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
|
|
1071
|
+
// compute dot products
|
|
1072
|
+
__m256 dp = _mm256_mul_ps(m[0], v0);
|
|
1073
|
+
|
|
1074
|
+
for (size_t j = 1; j < DIM; j++) {
|
|
1075
|
+
// collect dim j for 8 D4-vectors.
|
|
1076
|
+
const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
|
|
1077
|
+
dp = _mm256_fmadd_ps(m[j], vj, dp);
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
// compute y^2 - (2 * x, y), which is sufficient for looking for the
|
|
1081
|
+
// lowest distance.
|
|
1082
|
+
// x^2 is the constant that can be avoided.
|
|
1083
|
+
const __m256 distances =
|
|
1084
|
+
_mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp);
|
|
1085
|
+
|
|
1086
|
+
// compare the new distances to the min distances
|
|
1087
|
+
// for each of 8 AVX2 components.
|
|
1088
|
+
const __m256 comparison =
|
|
1089
|
+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS);
|
|
1090
|
+
|
|
1091
|
+
// update min distances and indices with closest vectors if needed.
|
|
1092
|
+
min_distances =
|
|
1093
|
+
_mm256_blendv_ps(distances, min_distances, comparison);
|
|
1094
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
1095
|
+
_mm256_castsi256_ps(current_indices),
|
|
1096
|
+
_mm256_castsi256_ps(min_indices),
|
|
1097
|
+
comparison));
|
|
1098
|
+
|
|
1099
|
+
// update current indices values. Basically, +8 to each of the
|
|
1100
|
+
// 8 AVX2 components.
|
|
1101
|
+
current_indices =
|
|
1102
|
+
_mm256_add_epi32(current_indices, indices_increment);
|
|
1103
|
+
|
|
1104
|
+
// scroll y and y_sqlen forward.
|
|
1105
|
+
y += 8;
|
|
1106
|
+
y_sqlen += 8;
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
// dump values and find the minimum distance / minimum index
|
|
1110
|
+
float min_distances_scalar[8];
|
|
1111
|
+
uint32_t min_indices_scalar[8];
|
|
1112
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
1113
|
+
_mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices);
|
|
1114
|
+
|
|
1115
|
+
for (size_t j = 0; j < 8; j++) {
|
|
1116
|
+
if (current_min_distance > min_distances_scalar[j]) {
|
|
1117
|
+
current_min_distance = min_distances_scalar[j];
|
|
1118
|
+
current_min_index = min_indices_scalar[j];
|
|
1119
|
+
}
|
|
1120
|
+
}
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
if (i < ny) {
|
|
1124
|
+
// process leftovers
|
|
1125
|
+
for (; i < ny; i++) {
|
|
1126
|
+
float dp = 0;
|
|
1127
|
+
for (size_t j = 0; j < DIM; j++) {
|
|
1128
|
+
dp += x[j] * y[j * d_offset];
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
// compute y^2 - 2 * (x, y), which is sufficient for looking for the
|
|
1132
|
+
// lowest distance.
|
|
1133
|
+
const float distance = y_sqlen[0] - 2 * dp;
|
|
1134
|
+
|
|
1135
|
+
if (current_min_distance > distance) {
|
|
1136
|
+
current_min_distance = distance;
|
|
1137
|
+
current_min_index = i;
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
y += 1;
|
|
1141
|
+
y_sqlen += 1;
|
|
1142
|
+
}
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
return current_min_index;
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
template <>
|
|
1149
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::AVX2>(
|
|
1150
|
+
float* distances_tmp_buffer,
|
|
1151
|
+
const float* x,
|
|
1152
|
+
const float* y,
|
|
1153
|
+
const float* y_sqlen,
|
|
1154
|
+
size_t d,
|
|
1155
|
+
size_t d_offset,
|
|
1156
|
+
size_t ny) {
|
|
1157
|
+
// optimized for a few special cases
|
|
1158
|
+
#define DISPATCH(dval) \
|
|
1159
|
+
case dval: \
|
|
1160
|
+
return fvec_L2sqr_ny_nearest_y_transposed_D<dval>( \
|
|
1161
|
+
distances_tmp_buffer, x, y, y_sqlen, d_offset, ny);
|
|
1162
|
+
|
|
1163
|
+
switch (d) {
|
|
1164
|
+
DISPATCH(1)
|
|
1165
|
+
DISPATCH(2)
|
|
1166
|
+
DISPATCH(4)
|
|
1167
|
+
DISPATCH(8)
|
|
1168
|
+
default:
|
|
1169
|
+
return fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::NONE>(
|
|
1170
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
1171
|
+
}
|
|
1172
|
+
#undef DISPATCH
|
|
1173
|
+
}
|
|
1174
|
+
|
|
1175
|
+
template <>
|
|
1176
|
+
int fvec_madd_and_argmin<SIMDLevel::AVX2>(
|
|
1177
|
+
size_t n,
|
|
1178
|
+
const float* a,
|
|
1179
|
+
float bf,
|
|
1180
|
+
const float* b,
|
|
1181
|
+
float* c) {
|
|
1182
|
+
return fvec_madd_and_argmin_sse(n, a, bf, b, c);
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
} // namespace faiss
|