faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -5,13 +5,12 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/utils/distances.h>
|
|
11
9
|
|
|
12
10
|
#include <algorithm>
|
|
13
11
|
#include <cassert>
|
|
14
12
|
#include <cmath>
|
|
13
|
+
#include <cstddef>
|
|
15
14
|
#include <cstdio>
|
|
16
15
|
#include <cstring>
|
|
17
16
|
|
|
@@ -64,7 +63,7 @@ void fvec_norms_L2(
|
|
|
64
63
|
const float* __restrict x,
|
|
65
64
|
size_t d,
|
|
66
65
|
size_t nx) {
|
|
67
|
-
#pragma omp parallel for
|
|
66
|
+
#pragma omp parallel for if (nx > 10000)
|
|
68
67
|
for (int64_t i = 0; i < nx; i++) {
|
|
69
68
|
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
|
|
70
69
|
}
|
|
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
|
|
|
75
74
|
const float* __restrict x,
|
|
76
75
|
size_t d,
|
|
77
76
|
size_t nx) {
|
|
78
|
-
#pragma omp parallel for
|
|
77
|
+
#pragma omp parallel for if (nx > 10000)
|
|
79
78
|
for (int64_t i = 0; i < nx; i++)
|
|
80
79
|
nr[i] = fvec_norm_L2sqr(x + i * d, d);
|
|
81
80
|
}
|
|
82
81
|
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
// The following is a workaround to a problem
|
|
83
|
+
// in OpenMP in fbcode. The crash occurs
|
|
84
|
+
// inside OMP when IndexIVFSpectralHash::set_query()
|
|
85
|
+
// calls fvec_renorm_L2. set_query() is always
|
|
86
|
+
// calling this function with nx == 1, so even
|
|
87
|
+
// the omp version should run single threaded,
|
|
88
|
+
// as per the if condition of the omp pragma.
|
|
89
|
+
// Instead, the omp version crashes inside OMP.
|
|
90
|
+
// The workaround below is explicitly branching
|
|
91
|
+
// off to a codepath without omp.
|
|
92
|
+
|
|
93
|
+
#define FVEC_RENORM_L2_IMPL \
|
|
94
|
+
float* __restrict xi = x + i * d; \
|
|
95
|
+
\
|
|
96
|
+
float nr = fvec_norm_L2sqr(xi, d); \
|
|
97
|
+
\
|
|
98
|
+
if (nr > 0) { \
|
|
99
|
+
size_t j; \
|
|
100
|
+
const float inv_nr = 1.0 / sqrtf(nr); \
|
|
101
|
+
for (j = 0; j < d; j++) \
|
|
102
|
+
xi[j] *= inv_nr; \
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
|
|
85
106
|
for (int64_t i = 0; i < nx; i++) {
|
|
86
|
-
|
|
107
|
+
FVEC_RENORM_L2_IMPL
|
|
108
|
+
}
|
|
109
|
+
}
|
|
87
110
|
|
|
88
|
-
|
|
111
|
+
void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
|
|
112
|
+
#pragma omp parallel for if (nx > 10000)
|
|
113
|
+
for (int64_t i = 0; i < nx; i++) {
|
|
114
|
+
FVEC_RENORM_L2_IMPL
|
|
115
|
+
}
|
|
116
|
+
}
|
|
89
117
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
}
|
|
118
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
119
|
+
if (nx <= 10000) {
|
|
120
|
+
fvec_renorm_L2_noomp(d, nx, x);
|
|
121
|
+
} else {
|
|
122
|
+
fvec_renorm_L2_omp(d, nx, x);
|
|
96
123
|
}
|
|
97
124
|
}
|
|
98
125
|
|
|
@@ -103,16 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
|
103
130
|
namespace {
|
|
104
131
|
|
|
105
132
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
106
|
-
template <class
|
|
133
|
+
template <class BlockResultHandler, bool use_sel = false>
|
|
107
134
|
void exhaustive_inner_product_seq(
|
|
108
135
|
const float* x,
|
|
109
136
|
const float* y,
|
|
110
137
|
size_t d,
|
|
111
138
|
size_t nx,
|
|
112
139
|
size_t ny,
|
|
113
|
-
|
|
140
|
+
BlockResultHandler& res,
|
|
114
141
|
const IDSelector* sel = nullptr) {
|
|
115
|
-
using SingleResultHandler =
|
|
142
|
+
using SingleResultHandler =
|
|
143
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
116
144
|
int nt = std::min(int(nx), omp_get_max_threads());
|
|
117
145
|
|
|
118
146
|
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
@@ -139,16 +167,17 @@ void exhaustive_inner_product_seq(
|
|
|
139
167
|
}
|
|
140
168
|
}
|
|
141
169
|
|
|
142
|
-
template <class
|
|
170
|
+
template <class BlockResultHandler, bool use_sel = false>
|
|
143
171
|
void exhaustive_L2sqr_seq(
|
|
144
172
|
const float* x,
|
|
145
173
|
const float* y,
|
|
146
174
|
size_t d,
|
|
147
175
|
size_t nx,
|
|
148
176
|
size_t ny,
|
|
149
|
-
|
|
177
|
+
BlockResultHandler& res,
|
|
150
178
|
const IDSelector* sel = nullptr) {
|
|
151
|
-
using SingleResultHandler =
|
|
179
|
+
using SingleResultHandler =
|
|
180
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
152
181
|
int nt = std::min(int(nx), omp_get_max_threads());
|
|
153
182
|
|
|
154
183
|
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
@@ -174,14 +203,14 @@ void exhaustive_L2sqr_seq(
|
|
|
174
203
|
}
|
|
175
204
|
|
|
176
205
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
177
|
-
template <class
|
|
206
|
+
template <class BlockResultHandler>
|
|
178
207
|
void exhaustive_inner_product_blas(
|
|
179
208
|
const float* x,
|
|
180
209
|
const float* y,
|
|
181
210
|
size_t d,
|
|
182
211
|
size_t nx,
|
|
183
212
|
size_t ny,
|
|
184
|
-
|
|
213
|
+
BlockResultHandler& res) {
|
|
185
214
|
// BLAS does not like empty matrices
|
|
186
215
|
if (nx == 0 || ny == 0)
|
|
187
216
|
return;
|
|
@@ -230,14 +259,14 @@ void exhaustive_inner_product_blas(
|
|
|
230
259
|
|
|
231
260
|
// distance correction is an operator that can be applied to transform
|
|
232
261
|
// the distances
|
|
233
|
-
template <class
|
|
262
|
+
template <class BlockResultHandler>
|
|
234
263
|
void exhaustive_L2sqr_blas_default_impl(
|
|
235
264
|
const float* x,
|
|
236
265
|
const float* y,
|
|
237
266
|
size_t d,
|
|
238
267
|
size_t nx,
|
|
239
268
|
size_t ny,
|
|
240
|
-
|
|
269
|
+
BlockResultHandler& res,
|
|
241
270
|
const float* y_norms = nullptr) {
|
|
242
271
|
// BLAS does not like empty matrices
|
|
243
272
|
if (nx == 0 || ny == 0)
|
|
@@ -313,14 +342,14 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
|
313
342
|
}
|
|
314
343
|
}
|
|
315
344
|
|
|
316
|
-
template <class
|
|
345
|
+
template <class BlockResultHandler>
|
|
317
346
|
void exhaustive_L2sqr_blas(
|
|
318
347
|
const float* x,
|
|
319
348
|
const float* y,
|
|
320
349
|
size_t d,
|
|
321
350
|
size_t nx,
|
|
322
351
|
size_t ny,
|
|
323
|
-
|
|
352
|
+
BlockResultHandler& res,
|
|
324
353
|
const float* y_norms = nullptr) {
|
|
325
354
|
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
|
|
326
355
|
}
|
|
@@ -332,7 +361,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
332
361
|
size_t d,
|
|
333
362
|
size_t nx,
|
|
334
363
|
size_t ny,
|
|
335
|
-
|
|
364
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
336
365
|
const float* y_norms) {
|
|
337
366
|
// BLAS does not like empty matrices
|
|
338
367
|
if (nx == 0 || ny == 0)
|
|
@@ -388,8 +417,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
388
417
|
for (int64_t i = i0; i < i1; i++) {
|
|
389
418
|
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
390
419
|
|
|
391
|
-
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
|
392
|
-
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
|
420
|
+
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
|
|
421
|
+
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
|
|
393
422
|
|
|
394
423
|
// constant
|
|
395
424
|
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
@@ -416,8 +445,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
416
445
|
|
|
417
446
|
// process 16 elements per loop
|
|
418
447
|
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
419
|
-
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
|
420
|
-
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
|
448
|
+
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
|
|
449
|
+
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
|
|
421
450
|
|
|
422
451
|
// load values for norms
|
|
423
452
|
const __m256 y_norm_0 =
|
|
@@ -535,13 +564,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
535
564
|
|
|
536
565
|
// an override if only a single closest point is needed
|
|
537
566
|
template <>
|
|
538
|
-
void exhaustive_L2sqr_blas<
|
|
567
|
+
void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
539
568
|
const float* x,
|
|
540
569
|
const float* y,
|
|
541
570
|
size_t d,
|
|
542
571
|
size_t nx,
|
|
543
572
|
size_t ny,
|
|
544
|
-
|
|
573
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
545
574
|
const float* y_norms) {
|
|
546
575
|
#if defined(__AVX2__)
|
|
547
576
|
// use a faster fused kernel if available
|
|
@@ -562,28 +591,29 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
|
|
562
591
|
|
|
563
592
|
// run the default implementation
|
|
564
593
|
exhaustive_L2sqr_blas_default_impl<
|
|
565
|
-
|
|
594
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
566
595
|
x, y, d, nx, ny, res, y_norms);
|
|
567
596
|
#else
|
|
568
597
|
// run the default implementation
|
|
569
598
|
exhaustive_L2sqr_blas_default_impl<
|
|
570
|
-
|
|
599
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
571
600
|
x, y, d, nx, ny, res, y_norms);
|
|
572
601
|
#endif
|
|
573
602
|
}
|
|
574
603
|
|
|
575
|
-
template <class
|
|
604
|
+
template <class BlockResultHandler>
|
|
576
605
|
void knn_L2sqr_select(
|
|
577
606
|
const float* x,
|
|
578
607
|
const float* y,
|
|
579
608
|
size_t d,
|
|
580
609
|
size_t nx,
|
|
581
610
|
size_t ny,
|
|
582
|
-
|
|
611
|
+
BlockResultHandler& res,
|
|
583
612
|
const float* y_norm2,
|
|
584
613
|
const IDSelector* sel) {
|
|
585
614
|
if (sel) {
|
|
586
|
-
exhaustive_L2sqr_seq<
|
|
615
|
+
exhaustive_L2sqr_seq<BlockResultHandler, true>(
|
|
616
|
+
x, y, d, nx, ny, res, sel);
|
|
587
617
|
} else if (nx < distance_compute_blas_threshold) {
|
|
588
618
|
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
589
619
|
} else {
|
|
@@ -591,6 +621,25 @@ void knn_L2sqr_select(
|
|
|
591
621
|
}
|
|
592
622
|
}
|
|
593
623
|
|
|
624
|
+
template <class BlockResultHandler>
|
|
625
|
+
void knn_inner_product_select(
|
|
626
|
+
const float* x,
|
|
627
|
+
const float* y,
|
|
628
|
+
size_t d,
|
|
629
|
+
size_t nx,
|
|
630
|
+
size_t ny,
|
|
631
|
+
BlockResultHandler& res,
|
|
632
|
+
const IDSelector* sel) {
|
|
633
|
+
if (sel) {
|
|
634
|
+
exhaustive_inner_product_seq<BlockResultHandler, true>(
|
|
635
|
+
x, y, d, nx, ny, res, sel);
|
|
636
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
637
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
638
|
+
} else {
|
|
639
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
|
|
594
643
|
} // anonymous namespace
|
|
595
644
|
|
|
596
645
|
/*******************************************************
|
|
@@ -609,7 +658,7 @@ void knn_inner_product(
|
|
|
609
658
|
size_t nx,
|
|
610
659
|
size_t ny,
|
|
611
660
|
size_t k,
|
|
612
|
-
float*
|
|
661
|
+
float* vals,
|
|
613
662
|
int64_t* ids,
|
|
614
663
|
const IDSelector* sel) {
|
|
615
664
|
int64_t imin = 0;
|
|
@@ -622,30 +671,21 @@ void knn_inner_product(
|
|
|
622
671
|
}
|
|
623
672
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
624
673
|
knn_inner_products_by_idx(
|
|
625
|
-
x, y, sela->ids, d, nx, sela->n, k,
|
|
674
|
+
x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
|
626
675
|
return;
|
|
627
676
|
}
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
} else {
|
|
636
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
637
|
-
}
|
|
677
|
+
|
|
678
|
+
if (k == 1) {
|
|
679
|
+
Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
|
|
680
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
|
681
|
+
} else if (k < distance_compute_min_k_reservoir) {
|
|
682
|
+
HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
|
683
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
|
638
684
|
} else {
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
if (sel) {
|
|
642
|
-
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
643
|
-
} else if (nx < distance_compute_blas_threshold) {
|
|
644
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
|
645
|
-
} else {
|
|
646
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
647
|
-
}
|
|
685
|
+
ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
|
686
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
|
648
687
|
}
|
|
688
|
+
|
|
649
689
|
if (imin != 0) {
|
|
650
690
|
for (size_t i = 0; i < nx * k; i++) {
|
|
651
691
|
if (ids[i] >= 0) {
|
|
@@ -687,17 +727,17 @@ void knn_L2sqr(
|
|
|
687
727
|
sel = nullptr;
|
|
688
728
|
}
|
|
689
729
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
690
|
-
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
|
730
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
|
691
731
|
return;
|
|
692
732
|
}
|
|
693
733
|
if (k == 1) {
|
|
694
|
-
|
|
734
|
+
Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
|
695
735
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
696
736
|
} else if (k < distance_compute_min_k_reservoir) {
|
|
697
|
-
|
|
737
|
+
HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
698
738
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
699
739
|
} else {
|
|
700
|
-
|
|
740
|
+
ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
701
741
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
702
742
|
}
|
|
703
743
|
if (imin != 0) {
|
|
@@ -735,7 +775,7 @@ void range_search_L2sqr(
|
|
|
735
775
|
float radius,
|
|
736
776
|
RangeSearchResult* res,
|
|
737
777
|
const IDSelector* sel) {
|
|
738
|
-
using RH =
|
|
778
|
+
using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
|
|
739
779
|
RH resh(res, radius);
|
|
740
780
|
if (sel) {
|
|
741
781
|
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
@@ -755,7 +795,7 @@ void range_search_inner_product(
|
|
|
755
795
|
float radius,
|
|
756
796
|
RangeSearchResult* res,
|
|
757
797
|
const IDSelector* sel) {
|
|
758
|
-
using RH =
|
|
798
|
+
using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
|
|
759
799
|
RH resh(res, radius);
|
|
760
800
|
if (sel) {
|
|
761
801
|
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
@@ -786,9 +826,11 @@ void fvec_inner_products_by_idx(
|
|
|
786
826
|
const float* xj = x + j * d;
|
|
787
827
|
float* __restrict ipj = ip + j * ny;
|
|
788
828
|
for (size_t i = 0; i < ny; i++) {
|
|
789
|
-
if (idsj[i] < 0)
|
|
790
|
-
|
|
791
|
-
|
|
829
|
+
if (idsj[i] < 0) {
|
|
830
|
+
ipj[i] = -INFINITY;
|
|
831
|
+
} else {
|
|
832
|
+
ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
|
|
833
|
+
}
|
|
792
834
|
}
|
|
793
835
|
}
|
|
794
836
|
}
|
|
@@ -809,9 +851,11 @@ void fvec_L2sqr_by_idx(
|
|
|
809
851
|
const float* xj = x + j * d;
|
|
810
852
|
float* __restrict disj = dis + j * ny;
|
|
811
853
|
for (size_t i = 0; i < ny; i++) {
|
|
812
|
-
if (idsj[i] < 0)
|
|
813
|
-
|
|
814
|
-
|
|
854
|
+
if (idsj[i] < 0) {
|
|
855
|
+
disj[i] = INFINITY;
|
|
856
|
+
} else {
|
|
857
|
+
disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
|
|
858
|
+
}
|
|
815
859
|
}
|
|
816
860
|
}
|
|
817
861
|
}
|
|
@@ -828,6 +872,8 @@ void pairwise_indexed_L2sqr(
|
|
|
828
872
|
for (int64_t j = 0; j < n; j++) {
|
|
829
873
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
830
874
|
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
|
|
875
|
+
} else {
|
|
876
|
+
dis[j] = INFINITY;
|
|
831
877
|
}
|
|
832
878
|
}
|
|
833
879
|
}
|
|
@@ -844,6 +890,8 @@ void pairwise_indexed_inner_product(
|
|
|
844
890
|
for (int64_t j = 0; j < n; j++) {
|
|
845
891
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
846
892
|
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
|
|
893
|
+
} else {
|
|
894
|
+
dis[j] = -INFINITY;
|
|
847
895
|
}
|
|
848
896
|
}
|
|
849
897
|
}
|
|
@@ -857,6 +905,7 @@ void knn_inner_products_by_idx(
|
|
|
857
905
|
size_t d,
|
|
858
906
|
size_t nx,
|
|
859
907
|
size_t ny,
|
|
908
|
+
size_t nsubset,
|
|
860
909
|
size_t k,
|
|
861
910
|
float* res_vals,
|
|
862
911
|
int64_t* res_ids,
|
|
@@ -874,9 +923,10 @@ void knn_inner_products_by_idx(
|
|
|
874
923
|
int64_t* __restrict idxi = res_ids + i * k;
|
|
875
924
|
minheap_heapify(k, simi, idxi);
|
|
876
925
|
|
|
877
|
-
for (j = 0; j <
|
|
878
|
-
if (idsi[j] < 0)
|
|
926
|
+
for (j = 0; j < nsubset; j++) {
|
|
927
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
|
879
928
|
break;
|
|
929
|
+
}
|
|
880
930
|
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
|
|
881
931
|
|
|
882
932
|
if (ip > simi[0]) {
|
|
@@ -894,6 +944,7 @@ void knn_L2sqr_by_idx(
|
|
|
894
944
|
size_t d,
|
|
895
945
|
size_t nx,
|
|
896
946
|
size_t ny,
|
|
947
|
+
size_t nsubset,
|
|
897
948
|
size_t k,
|
|
898
949
|
float* res_vals,
|
|
899
950
|
int64_t* res_ids,
|
|
@@ -908,7 +959,10 @@ void knn_L2sqr_by_idx(
|
|
|
908
959
|
float* __restrict simi = res_vals + i * k;
|
|
909
960
|
int64_t* __restrict idxi = res_ids + i * k;
|
|
910
961
|
maxheap_heapify(k, simi, idxi);
|
|
911
|
-
for (size_t j = 0; j <
|
|
962
|
+
for (size_t j = 0; j < nsubset; j++) {
|
|
963
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
|
964
|
+
break;
|
|
965
|
+
}
|
|
912
966
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
|
913
967
|
|
|
914
968
|
if (disij < simi[0]) {
|
|
@@ -36,6 +36,34 @@ float fvec_L1(const float* x, const float* y, size_t d);
|
|
|
36
36
|
/// infinity distance
|
|
37
37
|
float fvec_Linf(const float* x, const float* y, size_t d);
|
|
38
38
|
|
|
39
|
+
/// Special version of inner product that computes 4 distances
|
|
40
|
+
/// between x and yi, which is performance oriented.
|
|
41
|
+
void fvec_inner_product_batch_4(
|
|
42
|
+
const float* x,
|
|
43
|
+
const float* y0,
|
|
44
|
+
const float* y1,
|
|
45
|
+
const float* y2,
|
|
46
|
+
const float* y3,
|
|
47
|
+
const size_t d,
|
|
48
|
+
float& dis0,
|
|
49
|
+
float& dis1,
|
|
50
|
+
float& dis2,
|
|
51
|
+
float& dis3);
|
|
52
|
+
|
|
53
|
+
/// Special version of L2sqr that computes 4 distances
|
|
54
|
+
/// between x and yi, which is performance oriented.
|
|
55
|
+
void fvec_L2sqr_batch_4(
|
|
56
|
+
const float* x,
|
|
57
|
+
const float* y0,
|
|
58
|
+
const float* y1,
|
|
59
|
+
const float* y2,
|
|
60
|
+
const float* y3,
|
|
61
|
+
const size_t d,
|
|
62
|
+
float& dis0,
|
|
63
|
+
float& dis1,
|
|
64
|
+
float& dis2,
|
|
65
|
+
float& dis3);
|
|
66
|
+
|
|
39
67
|
/** Compute pairwise distances between sets of vectors
|
|
40
68
|
*
|
|
41
69
|
* @param d dimension of the vectors
|
|
@@ -170,8 +198,16 @@ void fvec_sub(size_t d, const float* a, const float* b, float* c);
|
|
|
170
198
|
* Compute a subset of distances
|
|
171
199
|
***************************************************************************/
|
|
172
200
|
|
|
173
|
-
|
|
174
|
-
|
|
201
|
+
/** compute the inner product between x and a subset y of ny vectors defined by
|
|
202
|
+
* ids
|
|
203
|
+
*
|
|
204
|
+
* ip(i, j) = inner_product(x(i, :), y(ids(i, j), :))
|
|
205
|
+
*
|
|
206
|
+
* @param ip output array, size nx * ny
|
|
207
|
+
* @param x first-term vector, size nx * d
|
|
208
|
+
* @param y second-term vector, size (max(ids) + 1) * d
|
|
209
|
+
* @param ids ids to sample from y, size nx * ny
|
|
210
|
+
*/
|
|
175
211
|
void fvec_inner_products_by_idx(
|
|
176
212
|
float* ip,
|
|
177
213
|
const float* x,
|
|
@@ -181,7 +217,16 @@ void fvec_inner_products_by_idx(
|
|
|
181
217
|
size_t nx,
|
|
182
218
|
size_t ny);
|
|
183
219
|
|
|
184
|
-
|
|
220
|
+
/** compute the squared L2 distances between x and a subset y of ny vectors
|
|
221
|
+
* defined by ids
|
|
222
|
+
*
|
|
223
|
+
* dis(i, j) = inner_product(x(i, :), y(ids(i, j), :))
|
|
224
|
+
*
|
|
225
|
+
* @param dis output array, size nx * ny
|
|
226
|
+
* @param x first-term vector, size nx * d
|
|
227
|
+
* @param y second-term vector, size (max(ids) + 1) * d
|
|
228
|
+
* @param ids ids to sample from y, size nx * ny
|
|
229
|
+
*/
|
|
185
230
|
void fvec_L2sqr_by_idx(
|
|
186
231
|
float* dis,
|
|
187
232
|
const float* x,
|
|
@@ -208,7 +253,14 @@ void pairwise_indexed_L2sqr(
|
|
|
208
253
|
const int64_t* iy,
|
|
209
254
|
float* dis);
|
|
210
255
|
|
|
211
|
-
|
|
256
|
+
/** compute dis[j] = inner_product(x[ix[j]], y[iy[j]]) forall j=0..n-1
|
|
257
|
+
*
|
|
258
|
+
* @param x size (max(ix) + 1, d)
|
|
259
|
+
* @param y size (max(iy) + 1, d)
|
|
260
|
+
* @param ix size n
|
|
261
|
+
* @param iy size n
|
|
262
|
+
* @param dis size n
|
|
263
|
+
*/
|
|
212
264
|
void pairwise_indexed_inner_product(
|
|
213
265
|
size_t d,
|
|
214
266
|
size_t n,
|
|
@@ -324,6 +376,7 @@ void knn_inner_products_by_idx(
|
|
|
324
376
|
const int64_t* subset,
|
|
325
377
|
size_t d,
|
|
326
378
|
size_t nx,
|
|
379
|
+
size_t ny,
|
|
327
380
|
size_t nsubset,
|
|
328
381
|
size_t k,
|
|
329
382
|
float* vals,
|
|
@@ -346,6 +399,7 @@ void knn_L2sqr_by_idx(
|
|
|
346
399
|
const int64_t* subset,
|
|
347
400
|
size_t d,
|
|
348
401
|
size_t nx,
|
|
402
|
+
size_t ny,
|
|
349
403
|
size_t nsubset,
|
|
350
404
|
size_t k,
|
|
351
405
|
float* vals,
|
|
@@ -406,4 +460,27 @@ void compute_PQ_dis_tables_dsub2(
|
|
|
406
460
|
* Templatized versions of distance functions
|
|
407
461
|
***************************************************************************/
|
|
408
462
|
|
|
463
|
+
/***************************************************************************
|
|
464
|
+
* Misc matrix and vector manipulation functions
|
|
465
|
+
***************************************************************************/
|
|
466
|
+
|
|
467
|
+
/** compute c := a + bf * b for a, b and c tables
|
|
468
|
+
*
|
|
469
|
+
* @param n size of the tables
|
|
470
|
+
* @param a size n
|
|
471
|
+
* @param b size n
|
|
472
|
+
* @param c restult table, size n
|
|
473
|
+
*/
|
|
474
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
|
|
475
|
+
|
|
476
|
+
/** same as fvec_madd, also return index of the min of the result table
|
|
477
|
+
* @return index of the min of table c
|
|
478
|
+
*/
|
|
479
|
+
int fvec_madd_and_argmin(
|
|
480
|
+
size_t n,
|
|
481
|
+
const float* a,
|
|
482
|
+
float bf,
|
|
483
|
+
const float* b,
|
|
484
|
+
float* c);
|
|
485
|
+
|
|
409
486
|
} // namespace faiss
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/utils/distances_fused/avx512.h>
|
|
11
11
|
|
|
12
|
-
#ifdef
|
|
12
|
+
#ifdef __AVX512F__
|
|
13
13
|
|
|
14
14
|
#include <immintrin.h>
|
|
15
15
|
|
|
@@ -68,7 +68,7 @@ void kernel(
|
|
|
68
68
|
const float* const __restrict y,
|
|
69
69
|
const float* const __restrict y_transposed,
|
|
70
70
|
size_t ny,
|
|
71
|
-
|
|
71
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
72
72
|
const float* __restrict y_norms,
|
|
73
73
|
size_t i) {
|
|
74
74
|
const size_t ny_p =
|
|
@@ -231,7 +231,7 @@ void exhaustive_L2sqr_fused_cmax(
|
|
|
231
231
|
const float* const __restrict y,
|
|
232
232
|
size_t nx,
|
|
233
233
|
size_t ny,
|
|
234
|
-
|
|
234
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
235
235
|
const float* __restrict y_norms) {
|
|
236
236
|
// BLAS does not like empty matrices
|
|
237
237
|
if (nx == 0 || ny == 0) {
|
|
@@ -275,7 +275,7 @@ void exhaustive_L2sqr_fused_cmax(
|
|
|
275
275
|
x, y, y_transposed.data(), ny, res, y_norms, i);
|
|
276
276
|
}
|
|
277
277
|
|
|
278
|
-
// Does nothing for
|
|
278
|
+
// Does nothing for Top1BlockResultHandler, but
|
|
279
279
|
// keeping the call for the consistency.
|
|
280
280
|
res.end_multiple();
|
|
281
281
|
InterruptCallback::check();
|
|
@@ -289,7 +289,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
|
|
|
289
289
|
size_t d,
|
|
290
290
|
size_t nx,
|
|
291
291
|
size_t ny,
|
|
292
|
-
|
|
292
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
293
293
|
const float* y_norms) {
|
|
294
294
|
// process only cases with certain dimensionalities
|
|
295
295
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
#include <faiss/utils/Heap.h>
|
|
18
18
|
|
|
19
|
-
#ifdef
|
|
19
|
+
#ifdef __AVX512F__
|
|
20
20
|
|
|
21
21
|
namespace faiss {
|
|
22
22
|
|
|
@@ -28,7 +28,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
|
|
|
28
28
|
size_t d,
|
|
29
29
|
size_t nx,
|
|
30
30
|
size_t ny,
|
|
31
|
-
|
|
31
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
32
32
|
const float* y_norms);
|
|
33
33
|
|
|
34
34
|
} // namespace faiss
|
|
@@ -20,14 +20,14 @@ bool exhaustive_L2sqr_fused_cmax(
|
|
|
20
20
|
size_t d,
|
|
21
21
|
size_t nx,
|
|
22
22
|
size_t ny,
|
|
23
|
-
|
|
23
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
24
24
|
const float* y_norms) {
|
|
25
25
|
if (nx == 0 || ny == 0) {
|
|
26
26
|
// nothing to do
|
|
27
27
|
return true;
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
-
#ifdef
|
|
30
|
+
#ifdef __AVX512F__
|
|
31
31
|
// avx512 kernel
|
|
32
32
|
return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
|
|
33
33
|
#elif defined(__AVX2__) || defined(__aarch64__)
|