faiss 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
@@ -5,13 +5,12 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#include <faiss/utils/distances.h>
|
11
9
|
|
12
10
|
#include <algorithm>
|
13
11
|
#include <cassert>
|
14
12
|
#include <cmath>
|
13
|
+
#include <cstddef>
|
15
14
|
#include <cstdio>
|
16
15
|
#include <cstring>
|
17
16
|
|
@@ -64,7 +63,7 @@ void fvec_norms_L2(
|
|
64
63
|
const float* __restrict x,
|
65
64
|
size_t d,
|
66
65
|
size_t nx) {
|
67
|
-
#pragma omp parallel for
|
66
|
+
#pragma omp parallel for if (nx > 10000)
|
68
67
|
for (int64_t i = 0; i < nx; i++) {
|
69
68
|
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
|
70
69
|
}
|
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
|
|
75
74
|
const float* __restrict x,
|
76
75
|
size_t d,
|
77
76
|
size_t nx) {
|
78
|
-
#pragma omp parallel for
|
77
|
+
#pragma omp parallel for if (nx > 10000)
|
79
78
|
for (int64_t i = 0; i < nx; i++)
|
80
79
|
nr[i] = fvec_norm_L2sqr(x + i * d, d);
|
81
80
|
}
|
82
81
|
|
83
|
-
|
84
|
-
|
82
|
+
// The following is a workaround to a problem
|
83
|
+
// in OpenMP in fbcode. The crash occurs
|
84
|
+
// inside OMP when IndexIVFSpectralHash::set_query()
|
85
|
+
// calls fvec_renorm_L2. set_query() is always
|
86
|
+
// calling this function with nx == 1, so even
|
87
|
+
// the omp version should run single threaded,
|
88
|
+
// as per the if condition of the omp pragma.
|
89
|
+
// Instead, the omp version crashes inside OMP.
|
90
|
+
// The workaround below is explicitly branching
|
91
|
+
// off to a codepath without omp.
|
92
|
+
|
93
|
+
#define FVEC_RENORM_L2_IMPL \
|
94
|
+
float* __restrict xi = x + i * d; \
|
95
|
+
\
|
96
|
+
float nr = fvec_norm_L2sqr(xi, d); \
|
97
|
+
\
|
98
|
+
if (nr > 0) { \
|
99
|
+
size_t j; \
|
100
|
+
const float inv_nr = 1.0 / sqrtf(nr); \
|
101
|
+
for (j = 0; j < d; j++) \
|
102
|
+
xi[j] *= inv_nr; \
|
103
|
+
}
|
104
|
+
|
105
|
+
void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
|
85
106
|
for (int64_t i = 0; i < nx; i++) {
|
86
|
-
|
107
|
+
FVEC_RENORM_L2_IMPL
|
108
|
+
}
|
109
|
+
}
|
87
110
|
|
88
|
-
|
111
|
+
void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
|
112
|
+
#pragma omp parallel for if (nx > 10000)
|
113
|
+
for (int64_t i = 0; i < nx; i++) {
|
114
|
+
FVEC_RENORM_L2_IMPL
|
115
|
+
}
|
116
|
+
}
|
89
117
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
}
|
118
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
119
|
+
if (nx <= 10000) {
|
120
|
+
fvec_renorm_L2_noomp(d, nx, x);
|
121
|
+
} else {
|
122
|
+
fvec_renorm_L2_omp(d, nx, x);
|
96
123
|
}
|
97
124
|
}
|
98
125
|
|
@@ -103,16 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
103
130
|
namespace {
|
104
131
|
|
105
132
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
106
|
-
template <class
|
133
|
+
template <class BlockResultHandler, bool use_sel = false>
|
107
134
|
void exhaustive_inner_product_seq(
|
108
135
|
const float* x,
|
109
136
|
const float* y,
|
110
137
|
size_t d,
|
111
138
|
size_t nx,
|
112
139
|
size_t ny,
|
113
|
-
|
140
|
+
BlockResultHandler& res,
|
114
141
|
const IDSelector* sel = nullptr) {
|
115
|
-
using SingleResultHandler =
|
142
|
+
using SingleResultHandler =
|
143
|
+
typename BlockResultHandler::SingleResultHandler;
|
116
144
|
int nt = std::min(int(nx), omp_get_max_threads());
|
117
145
|
|
118
146
|
FAISS_ASSERT(use_sel == (sel != nullptr));
|
@@ -139,16 +167,17 @@ void exhaustive_inner_product_seq(
|
|
139
167
|
}
|
140
168
|
}
|
141
169
|
|
142
|
-
template <class
|
170
|
+
template <class BlockResultHandler, bool use_sel = false>
|
143
171
|
void exhaustive_L2sqr_seq(
|
144
172
|
const float* x,
|
145
173
|
const float* y,
|
146
174
|
size_t d,
|
147
175
|
size_t nx,
|
148
176
|
size_t ny,
|
149
|
-
|
177
|
+
BlockResultHandler& res,
|
150
178
|
const IDSelector* sel = nullptr) {
|
151
|
-
using SingleResultHandler =
|
179
|
+
using SingleResultHandler =
|
180
|
+
typename BlockResultHandler::SingleResultHandler;
|
152
181
|
int nt = std::min(int(nx), omp_get_max_threads());
|
153
182
|
|
154
183
|
FAISS_ASSERT(use_sel == (sel != nullptr));
|
@@ -174,14 +203,14 @@ void exhaustive_L2sqr_seq(
|
|
174
203
|
}
|
175
204
|
|
176
205
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
177
|
-
template <class
|
206
|
+
template <class BlockResultHandler>
|
178
207
|
void exhaustive_inner_product_blas(
|
179
208
|
const float* x,
|
180
209
|
const float* y,
|
181
210
|
size_t d,
|
182
211
|
size_t nx,
|
183
212
|
size_t ny,
|
184
|
-
|
213
|
+
BlockResultHandler& res) {
|
185
214
|
// BLAS does not like empty matrices
|
186
215
|
if (nx == 0 || ny == 0)
|
187
216
|
return;
|
@@ -230,14 +259,14 @@ void exhaustive_inner_product_blas(
|
|
230
259
|
|
231
260
|
// distance correction is an operator that can be applied to transform
|
232
261
|
// the distances
|
233
|
-
template <class
|
262
|
+
template <class BlockResultHandler>
|
234
263
|
void exhaustive_L2sqr_blas_default_impl(
|
235
264
|
const float* x,
|
236
265
|
const float* y,
|
237
266
|
size_t d,
|
238
267
|
size_t nx,
|
239
268
|
size_t ny,
|
240
|
-
|
269
|
+
BlockResultHandler& res,
|
241
270
|
const float* y_norms = nullptr) {
|
242
271
|
// BLAS does not like empty matrices
|
243
272
|
if (nx == 0 || ny == 0)
|
@@ -313,14 +342,14 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
313
342
|
}
|
314
343
|
}
|
315
344
|
|
316
|
-
template <class
|
345
|
+
template <class BlockResultHandler>
|
317
346
|
void exhaustive_L2sqr_blas(
|
318
347
|
const float* x,
|
319
348
|
const float* y,
|
320
349
|
size_t d,
|
321
350
|
size_t nx,
|
322
351
|
size_t ny,
|
323
|
-
|
352
|
+
BlockResultHandler& res,
|
324
353
|
const float* y_norms = nullptr) {
|
325
354
|
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
|
326
355
|
}
|
@@ -332,7 +361,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
332
361
|
size_t d,
|
333
362
|
size_t nx,
|
334
363
|
size_t ny,
|
335
|
-
|
364
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
336
365
|
const float* y_norms) {
|
337
366
|
// BLAS does not like empty matrices
|
338
367
|
if (nx == 0 || ny == 0)
|
@@ -388,8 +417,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
388
417
|
for (int64_t i = i0; i < i1; i++) {
|
389
418
|
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
390
419
|
|
391
|
-
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
392
|
-
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
420
|
+
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
|
421
|
+
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
|
393
422
|
|
394
423
|
// constant
|
395
424
|
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
@@ -416,8 +445,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
416
445
|
|
417
446
|
// process 16 elements per loop
|
418
447
|
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
419
|
-
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
420
|
-
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
448
|
+
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
|
449
|
+
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
|
421
450
|
|
422
451
|
// load values for norms
|
423
452
|
const __m256 y_norm_0 =
|
@@ -535,13 +564,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
535
564
|
|
536
565
|
// an override if only a single closest point is needed
|
537
566
|
template <>
|
538
|
-
void exhaustive_L2sqr_blas<
|
567
|
+
void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
539
568
|
const float* x,
|
540
569
|
const float* y,
|
541
570
|
size_t d,
|
542
571
|
size_t nx,
|
543
572
|
size_t ny,
|
544
|
-
|
573
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
545
574
|
const float* y_norms) {
|
546
575
|
#if defined(__AVX2__)
|
547
576
|
// use a faster fused kernel if available
|
@@ -562,28 +591,29 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
|
562
591
|
|
563
592
|
// run the default implementation
|
564
593
|
exhaustive_L2sqr_blas_default_impl<
|
565
|
-
|
594
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
566
595
|
x, y, d, nx, ny, res, y_norms);
|
567
596
|
#else
|
568
597
|
// run the default implementation
|
569
598
|
exhaustive_L2sqr_blas_default_impl<
|
570
|
-
|
599
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
571
600
|
x, y, d, nx, ny, res, y_norms);
|
572
601
|
#endif
|
573
602
|
}
|
574
603
|
|
575
|
-
template <class
|
604
|
+
template <class BlockResultHandler>
|
576
605
|
void knn_L2sqr_select(
|
577
606
|
const float* x,
|
578
607
|
const float* y,
|
579
608
|
size_t d,
|
580
609
|
size_t nx,
|
581
610
|
size_t ny,
|
582
|
-
|
611
|
+
BlockResultHandler& res,
|
583
612
|
const float* y_norm2,
|
584
613
|
const IDSelector* sel) {
|
585
614
|
if (sel) {
|
586
|
-
exhaustive_L2sqr_seq<
|
615
|
+
exhaustive_L2sqr_seq<BlockResultHandler, true>(
|
616
|
+
x, y, d, nx, ny, res, sel);
|
587
617
|
} else if (nx < distance_compute_blas_threshold) {
|
588
618
|
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
589
619
|
} else {
|
@@ -591,6 +621,25 @@ void knn_L2sqr_select(
|
|
591
621
|
}
|
592
622
|
}
|
593
623
|
|
624
|
+
template <class BlockResultHandler>
|
625
|
+
void knn_inner_product_select(
|
626
|
+
const float* x,
|
627
|
+
const float* y,
|
628
|
+
size_t d,
|
629
|
+
size_t nx,
|
630
|
+
size_t ny,
|
631
|
+
BlockResultHandler& res,
|
632
|
+
const IDSelector* sel) {
|
633
|
+
if (sel) {
|
634
|
+
exhaustive_inner_product_seq<BlockResultHandler, true>(
|
635
|
+
x, y, d, nx, ny, res, sel);
|
636
|
+
} else if (nx < distance_compute_blas_threshold) {
|
637
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
638
|
+
} else {
|
639
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
640
|
+
}
|
641
|
+
}
|
642
|
+
|
594
643
|
} // anonymous namespace
|
595
644
|
|
596
645
|
/*******************************************************
|
@@ -609,7 +658,7 @@ void knn_inner_product(
|
|
609
658
|
size_t nx,
|
610
659
|
size_t ny,
|
611
660
|
size_t k,
|
612
|
-
float*
|
661
|
+
float* vals,
|
613
662
|
int64_t* ids,
|
614
663
|
const IDSelector* sel) {
|
615
664
|
int64_t imin = 0;
|
@@ -622,30 +671,21 @@ void knn_inner_product(
|
|
622
671
|
}
|
623
672
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
624
673
|
knn_inner_products_by_idx(
|
625
|
-
x, y, sela->ids, d, nx, sela->n, k,
|
674
|
+
x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
626
675
|
return;
|
627
676
|
}
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
} else {
|
636
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
637
|
-
}
|
677
|
+
|
678
|
+
if (k == 1) {
|
679
|
+
Top1BlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids);
|
680
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
681
|
+
} else if (k < distance_compute_min_k_reservoir) {
|
682
|
+
HeapBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
683
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
638
684
|
} else {
|
639
|
-
|
640
|
-
|
641
|
-
if (sel) {
|
642
|
-
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
643
|
-
} else if (nx < distance_compute_blas_threshold) {
|
644
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
645
|
-
} else {
|
646
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
647
|
-
}
|
685
|
+
ReservoirBlockResultHandler<CMin<float, int64_t>> res(nx, vals, ids, k);
|
686
|
+
knn_inner_product_select(x, y, d, nx, ny, res, sel);
|
648
687
|
}
|
688
|
+
|
649
689
|
if (imin != 0) {
|
650
690
|
for (size_t i = 0; i < nx * k; i++) {
|
651
691
|
if (ids[i] >= 0) {
|
@@ -687,17 +727,17 @@ void knn_L2sqr(
|
|
687
727
|
sel = nullptr;
|
688
728
|
}
|
689
729
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
690
|
-
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
730
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
691
731
|
return;
|
692
732
|
}
|
693
733
|
if (k == 1) {
|
694
|
-
|
734
|
+
Top1BlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
695
735
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
696
736
|
} else if (k < distance_compute_min_k_reservoir) {
|
697
|
-
|
737
|
+
HeapBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
698
738
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
699
739
|
} else {
|
700
|
-
|
740
|
+
ReservoirBlockResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
701
741
|
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
702
742
|
}
|
703
743
|
if (imin != 0) {
|
@@ -735,7 +775,7 @@ void range_search_L2sqr(
|
|
735
775
|
float radius,
|
736
776
|
RangeSearchResult* res,
|
737
777
|
const IDSelector* sel) {
|
738
|
-
using RH =
|
778
|
+
using RH = RangeSearchBlockResultHandler<CMax<float, int64_t>>;
|
739
779
|
RH resh(res, radius);
|
740
780
|
if (sel) {
|
741
781
|
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
@@ -755,7 +795,7 @@ void range_search_inner_product(
|
|
755
795
|
float radius,
|
756
796
|
RangeSearchResult* res,
|
757
797
|
const IDSelector* sel) {
|
758
|
-
using RH =
|
798
|
+
using RH = RangeSearchBlockResultHandler<CMin<float, int64_t>>;
|
759
799
|
RH resh(res, radius);
|
760
800
|
if (sel) {
|
761
801
|
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
@@ -786,9 +826,11 @@ void fvec_inner_products_by_idx(
|
|
786
826
|
const float* xj = x + j * d;
|
787
827
|
float* __restrict ipj = ip + j * ny;
|
788
828
|
for (size_t i = 0; i < ny; i++) {
|
789
|
-
if (idsj[i] < 0)
|
790
|
-
|
791
|
-
|
829
|
+
if (idsj[i] < 0) {
|
830
|
+
ipj[i] = -INFINITY;
|
831
|
+
} else {
|
832
|
+
ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
|
833
|
+
}
|
792
834
|
}
|
793
835
|
}
|
794
836
|
}
|
@@ -809,9 +851,11 @@ void fvec_L2sqr_by_idx(
|
|
809
851
|
const float* xj = x + j * d;
|
810
852
|
float* __restrict disj = dis + j * ny;
|
811
853
|
for (size_t i = 0; i < ny; i++) {
|
812
|
-
if (idsj[i] < 0)
|
813
|
-
|
814
|
-
|
854
|
+
if (idsj[i] < 0) {
|
855
|
+
disj[i] = INFINITY;
|
856
|
+
} else {
|
857
|
+
disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
|
858
|
+
}
|
815
859
|
}
|
816
860
|
}
|
817
861
|
}
|
@@ -828,6 +872,8 @@ void pairwise_indexed_L2sqr(
|
|
828
872
|
for (int64_t j = 0; j < n; j++) {
|
829
873
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
830
874
|
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
|
875
|
+
} else {
|
876
|
+
dis[j] = INFINITY;
|
831
877
|
}
|
832
878
|
}
|
833
879
|
}
|
@@ -844,6 +890,8 @@ void pairwise_indexed_inner_product(
|
|
844
890
|
for (int64_t j = 0; j < n; j++) {
|
845
891
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
846
892
|
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
|
893
|
+
} else {
|
894
|
+
dis[j] = -INFINITY;
|
847
895
|
}
|
848
896
|
}
|
849
897
|
}
|
@@ -857,6 +905,7 @@ void knn_inner_products_by_idx(
|
|
857
905
|
size_t d,
|
858
906
|
size_t nx,
|
859
907
|
size_t ny,
|
908
|
+
size_t nsubset,
|
860
909
|
size_t k,
|
861
910
|
float* res_vals,
|
862
911
|
int64_t* res_ids,
|
@@ -874,9 +923,10 @@ void knn_inner_products_by_idx(
|
|
874
923
|
int64_t* __restrict idxi = res_ids + i * k;
|
875
924
|
minheap_heapify(k, simi, idxi);
|
876
925
|
|
877
|
-
for (j = 0; j <
|
878
|
-
if (idsi[j] < 0)
|
926
|
+
for (j = 0; j < nsubset; j++) {
|
927
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
879
928
|
break;
|
929
|
+
}
|
880
930
|
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
|
881
931
|
|
882
932
|
if (ip > simi[0]) {
|
@@ -894,6 +944,7 @@ void knn_L2sqr_by_idx(
|
|
894
944
|
size_t d,
|
895
945
|
size_t nx,
|
896
946
|
size_t ny,
|
947
|
+
size_t nsubset,
|
897
948
|
size_t k,
|
898
949
|
float* res_vals,
|
899
950
|
int64_t* res_ids,
|
@@ -908,7 +959,10 @@ void knn_L2sqr_by_idx(
|
|
908
959
|
float* __restrict simi = res_vals + i * k;
|
909
960
|
int64_t* __restrict idxi = res_ids + i * k;
|
910
961
|
maxheap_heapify(k, simi, idxi);
|
911
|
-
for (size_t j = 0; j <
|
962
|
+
for (size_t j = 0; j < nsubset; j++) {
|
963
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
964
|
+
break;
|
965
|
+
}
|
912
966
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
913
967
|
|
914
968
|
if (disij < simi[0]) {
|
@@ -36,6 +36,34 @@ float fvec_L1(const float* x, const float* y, size_t d);
|
|
36
36
|
/// infinity distance
|
37
37
|
float fvec_Linf(const float* x, const float* y, size_t d);
|
38
38
|
|
39
|
+
/// Special version of inner product that computes 4 distances
|
40
|
+
/// between x and yi, which is performance oriented.
|
41
|
+
void fvec_inner_product_batch_4(
|
42
|
+
const float* x,
|
43
|
+
const float* y0,
|
44
|
+
const float* y1,
|
45
|
+
const float* y2,
|
46
|
+
const float* y3,
|
47
|
+
const size_t d,
|
48
|
+
float& dis0,
|
49
|
+
float& dis1,
|
50
|
+
float& dis2,
|
51
|
+
float& dis3);
|
52
|
+
|
53
|
+
/// Special version of L2sqr that computes 4 distances
|
54
|
+
/// between x and yi, which is performance oriented.
|
55
|
+
void fvec_L2sqr_batch_4(
|
56
|
+
const float* x,
|
57
|
+
const float* y0,
|
58
|
+
const float* y1,
|
59
|
+
const float* y2,
|
60
|
+
const float* y3,
|
61
|
+
const size_t d,
|
62
|
+
float& dis0,
|
63
|
+
float& dis1,
|
64
|
+
float& dis2,
|
65
|
+
float& dis3);
|
66
|
+
|
39
67
|
/** Compute pairwise distances between sets of vectors
|
40
68
|
*
|
41
69
|
* @param d dimension of the vectors
|
@@ -170,8 +198,16 @@ void fvec_sub(size_t d, const float* a, const float* b, float* c);
|
|
170
198
|
* Compute a subset of distances
|
171
199
|
***************************************************************************/
|
172
200
|
|
173
|
-
|
174
|
-
|
201
|
+
/** compute the inner product between x and a subset y of ny vectors defined by
|
202
|
+
* ids
|
203
|
+
*
|
204
|
+
* ip(i, j) = inner_product(x(i, :), y(ids(i, j), :))
|
205
|
+
*
|
206
|
+
* @param ip output array, size nx * ny
|
207
|
+
* @param x first-term vector, size nx * d
|
208
|
+
* @param y second-term vector, size (max(ids) + 1) * d
|
209
|
+
* @param ids ids to sample from y, size nx * ny
|
210
|
+
*/
|
175
211
|
void fvec_inner_products_by_idx(
|
176
212
|
float* ip,
|
177
213
|
const float* x,
|
@@ -181,7 +217,16 @@ void fvec_inner_products_by_idx(
|
|
181
217
|
size_t nx,
|
182
218
|
size_t ny);
|
183
219
|
|
184
|
-
|
220
|
+
/** compute the squared L2 distances between x and a subset y of ny vectors
|
221
|
+
* defined by ids
|
222
|
+
*
|
223
|
+
* dis(i, j) = inner_product(x(i, :), y(ids(i, j), :))
|
224
|
+
*
|
225
|
+
* @param dis output array, size nx * ny
|
226
|
+
* @param x first-term vector, size nx * d
|
227
|
+
* @param y second-term vector, size (max(ids) + 1) * d
|
228
|
+
* @param ids ids to sample from y, size nx * ny
|
229
|
+
*/
|
185
230
|
void fvec_L2sqr_by_idx(
|
186
231
|
float* dis,
|
187
232
|
const float* x,
|
@@ -208,7 +253,14 @@ void pairwise_indexed_L2sqr(
|
|
208
253
|
const int64_t* iy,
|
209
254
|
float* dis);
|
210
255
|
|
211
|
-
|
256
|
+
/** compute dis[j] = inner_product(x[ix[j]], y[iy[j]]) forall j=0..n-1
|
257
|
+
*
|
258
|
+
* @param x size (max(ix) + 1, d)
|
259
|
+
* @param y size (max(iy) + 1, d)
|
260
|
+
* @param ix size n
|
261
|
+
* @param iy size n
|
262
|
+
* @param dis size n
|
263
|
+
*/
|
212
264
|
void pairwise_indexed_inner_product(
|
213
265
|
size_t d,
|
214
266
|
size_t n,
|
@@ -324,6 +376,7 @@ void knn_inner_products_by_idx(
|
|
324
376
|
const int64_t* subset,
|
325
377
|
size_t d,
|
326
378
|
size_t nx,
|
379
|
+
size_t ny,
|
327
380
|
size_t nsubset,
|
328
381
|
size_t k,
|
329
382
|
float* vals,
|
@@ -346,6 +399,7 @@ void knn_L2sqr_by_idx(
|
|
346
399
|
const int64_t* subset,
|
347
400
|
size_t d,
|
348
401
|
size_t nx,
|
402
|
+
size_t ny,
|
349
403
|
size_t nsubset,
|
350
404
|
size_t k,
|
351
405
|
float* vals,
|
@@ -406,4 +460,27 @@ void compute_PQ_dis_tables_dsub2(
|
|
406
460
|
* Templatized versions of distance functions
|
407
461
|
***************************************************************************/
|
408
462
|
|
463
|
+
/***************************************************************************
|
464
|
+
* Misc matrix and vector manipulation functions
|
465
|
+
***************************************************************************/
|
466
|
+
|
467
|
+
/** compute c := a + bf * b for a, b and c tables
|
468
|
+
*
|
469
|
+
* @param n size of the tables
|
470
|
+
* @param a size n
|
471
|
+
* @param b size n
|
472
|
+
* @param c restult table, size n
|
473
|
+
*/
|
474
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
|
475
|
+
|
476
|
+
/** same as fvec_madd, also return index of the min of the result table
|
477
|
+
* @return index of the min of table c
|
478
|
+
*/
|
479
|
+
int fvec_madd_and_argmin(
|
480
|
+
size_t n,
|
481
|
+
const float* a,
|
482
|
+
float bf,
|
483
|
+
const float* b,
|
484
|
+
float* c);
|
485
|
+
|
409
486
|
} // namespace faiss
|
@@ -9,7 +9,7 @@
|
|
9
9
|
|
10
10
|
#include <faiss/utils/distances_fused/avx512.h>
|
11
11
|
|
12
|
-
#ifdef
|
12
|
+
#ifdef __AVX512F__
|
13
13
|
|
14
14
|
#include <immintrin.h>
|
15
15
|
|
@@ -68,7 +68,7 @@ void kernel(
|
|
68
68
|
const float* const __restrict y,
|
69
69
|
const float* const __restrict y_transposed,
|
70
70
|
size_t ny,
|
71
|
-
|
71
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
72
72
|
const float* __restrict y_norms,
|
73
73
|
size_t i) {
|
74
74
|
const size_t ny_p =
|
@@ -231,7 +231,7 @@ void exhaustive_L2sqr_fused_cmax(
|
|
231
231
|
const float* const __restrict y,
|
232
232
|
size_t nx,
|
233
233
|
size_t ny,
|
234
|
-
|
234
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
235
235
|
const float* __restrict y_norms) {
|
236
236
|
// BLAS does not like empty matrices
|
237
237
|
if (nx == 0 || ny == 0) {
|
@@ -275,7 +275,7 @@ void exhaustive_L2sqr_fused_cmax(
|
|
275
275
|
x, y, y_transposed.data(), ny, res, y_norms, i);
|
276
276
|
}
|
277
277
|
|
278
|
-
// Does nothing for
|
278
|
+
// Does nothing for Top1BlockResultHandler, but
|
279
279
|
// keeping the call for the consistency.
|
280
280
|
res.end_multiple();
|
281
281
|
InterruptCallback::check();
|
@@ -289,7 +289,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
|
|
289
289
|
size_t d,
|
290
290
|
size_t nx,
|
291
291
|
size_t ny,
|
292
|
-
|
292
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
293
293
|
const float* y_norms) {
|
294
294
|
// process only cases with certain dimensionalities
|
295
295
|
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
#include <faiss/utils/Heap.h>
|
18
18
|
|
19
|
-
#ifdef
|
19
|
+
#ifdef __AVX512F__
|
20
20
|
|
21
21
|
namespace faiss {
|
22
22
|
|
@@ -28,7 +28,7 @@ bool exhaustive_L2sqr_fused_cmax_AVX512(
|
|
28
28
|
size_t d,
|
29
29
|
size_t nx,
|
30
30
|
size_t ny,
|
31
|
-
|
31
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
32
32
|
const float* y_norms);
|
33
33
|
|
34
34
|
} // namespace faiss
|
@@ -20,14 +20,14 @@ bool exhaustive_L2sqr_fused_cmax(
|
|
20
20
|
size_t d,
|
21
21
|
size_t nx,
|
22
22
|
size_t ny,
|
23
|
-
|
23
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
24
24
|
const float* y_norms) {
|
25
25
|
if (nx == 0 || ny == 0) {
|
26
26
|
// nothing to do
|
27
27
|
return true;
|
28
28
|
}
|
29
29
|
|
30
|
-
#ifdef
|
30
|
+
#ifdef __AVX512F__
|
31
31
|
// avx512 kernel
|
32
32
|
return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
|
33
33
|
#elif defined(__AVX2__) || defined(__aarch64__)
|