faiss 0.2.4 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
@@ -17,8 +17,13 @@
|
|
17
17
|
|
18
18
|
#include <omp.h>
|
19
19
|
|
20
|
+
#ifdef __AVX2__
|
21
|
+
#include <immintrin.h>
|
22
|
+
#endif
|
23
|
+
|
20
24
|
#include <faiss/impl/AuxIndexStructures.h>
|
21
25
|
#include <faiss/impl/FaissAssert.h>
|
26
|
+
#include <faiss/impl/IDSelector.h>
|
22
27
|
#include <faiss/impl/ResultHandler.h>
|
23
28
|
|
24
29
|
#ifndef FINTEGER
|
@@ -96,17 +101,20 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
96
101
|
namespace {
|
97
102
|
|
98
103
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
99
|
-
template <class ResultHandler>
|
104
|
+
template <class ResultHandler, bool use_sel = false>
|
100
105
|
void exhaustive_inner_product_seq(
|
101
106
|
const float* x,
|
102
107
|
const float* y,
|
103
108
|
size_t d,
|
104
109
|
size_t nx,
|
105
110
|
size_t ny,
|
106
|
-
ResultHandler& res
|
111
|
+
ResultHandler& res,
|
112
|
+
const IDSelector* sel = nullptr) {
|
107
113
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
108
114
|
int nt = std::min(int(nx), omp_get_max_threads());
|
109
115
|
|
116
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
117
|
+
|
110
118
|
#pragma omp parallel num_threads(nt)
|
111
119
|
{
|
112
120
|
SingleResultHandler resi(res);
|
@@ -117,27 +125,32 @@ void exhaustive_inner_product_seq(
|
|
117
125
|
|
118
126
|
resi.begin(i);
|
119
127
|
|
120
|
-
for (size_t j = 0; j < ny; j
|
128
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
129
|
+
if (use_sel && !sel->is_member(j)) {
|
130
|
+
continue;
|
131
|
+
}
|
121
132
|
float ip = fvec_inner_product(x_i, y_j, d);
|
122
133
|
resi.add_result(ip, j);
|
123
|
-
y_j += d;
|
124
134
|
}
|
125
135
|
resi.end();
|
126
136
|
}
|
127
137
|
}
|
128
138
|
}
|
129
139
|
|
130
|
-
template <class ResultHandler>
|
140
|
+
template <class ResultHandler, bool use_sel = false>
|
131
141
|
void exhaustive_L2sqr_seq(
|
132
142
|
const float* x,
|
133
143
|
const float* y,
|
134
144
|
size_t d,
|
135
145
|
size_t nx,
|
136
146
|
size_t ny,
|
137
|
-
ResultHandler& res
|
147
|
+
ResultHandler& res,
|
148
|
+
const IDSelector* sel = nullptr) {
|
138
149
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
139
150
|
int nt = std::min(int(nx), omp_get_max_threads());
|
140
151
|
|
152
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
153
|
+
|
141
154
|
#pragma omp parallel num_threads(nt)
|
142
155
|
{
|
143
156
|
SingleResultHandler resi(res);
|
@@ -146,10 +159,12 @@ void exhaustive_L2sqr_seq(
|
|
146
159
|
const float* x_i = x + i * d;
|
147
160
|
const float* y_j = y;
|
148
161
|
resi.begin(i);
|
149
|
-
for (size_t j = 0; j < ny; j
|
162
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
163
|
+
if (use_sel && !sel->is_member(j)) {
|
164
|
+
continue;
|
165
|
+
}
|
150
166
|
float disij = fvec_L2sqr(x_i, y_j, d);
|
151
167
|
resi.add_result(disij, j);
|
152
|
-
y_j += d;
|
153
168
|
}
|
154
169
|
resi.end();
|
155
170
|
}
|
@@ -296,6 +311,232 @@ void exhaustive_L2sqr_blas(
|
|
296
311
|
}
|
297
312
|
}
|
298
313
|
|
314
|
+
#ifdef __AVX2__
|
315
|
+
// an override for AVX2 if only a single closest point is needed.
|
316
|
+
template <>
|
317
|
+
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
318
|
+
const float* x,
|
319
|
+
const float* y,
|
320
|
+
size_t d,
|
321
|
+
size_t nx,
|
322
|
+
size_t ny,
|
323
|
+
SingleBestResultHandler<CMax<float, int64_t>>& res,
|
324
|
+
const float* y_norms) {
|
325
|
+
// BLAS does not like empty matrices
|
326
|
+
if (nx == 0 || ny == 0)
|
327
|
+
return;
|
328
|
+
|
329
|
+
/* block sizes */
|
330
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
331
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
332
|
+
// const size_t bs_x = 16, bs_y = 16;
|
333
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
334
|
+
std::unique_ptr<float[]> x_norms(new float[nx]);
|
335
|
+
std::unique_ptr<float[]> del2;
|
336
|
+
|
337
|
+
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
338
|
+
|
339
|
+
if (!y_norms) {
|
340
|
+
float* y_norms2 = new float[ny];
|
341
|
+
del2.reset(y_norms2);
|
342
|
+
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
343
|
+
y_norms = y_norms2;
|
344
|
+
}
|
345
|
+
|
346
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
347
|
+
size_t i1 = i0 + bs_x;
|
348
|
+
if (i1 > nx)
|
349
|
+
i1 = nx;
|
350
|
+
|
351
|
+
res.begin_multiple(i0, i1);
|
352
|
+
|
353
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
354
|
+
size_t j1 = j0 + bs_y;
|
355
|
+
if (j1 > ny)
|
356
|
+
j1 = ny;
|
357
|
+
/* compute the actual dot products */
|
358
|
+
{
|
359
|
+
float one = 1, zero = 0;
|
360
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
361
|
+
sgemm_("Transpose",
|
362
|
+
"Not transpose",
|
363
|
+
&nyi,
|
364
|
+
&nxi,
|
365
|
+
&di,
|
366
|
+
&one,
|
367
|
+
y + j0 * d,
|
368
|
+
&di,
|
369
|
+
x + i0 * d,
|
370
|
+
&di,
|
371
|
+
&zero,
|
372
|
+
ip_block.get(),
|
373
|
+
&nyi);
|
374
|
+
}
|
375
|
+
#pragma omp parallel for
|
376
|
+
for (int64_t i = i0; i < i1; i++) {
|
377
|
+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
378
|
+
|
379
|
+
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
380
|
+
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
381
|
+
|
382
|
+
// constant
|
383
|
+
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
384
|
+
|
385
|
+
// Track 8 min distances + 8 min indices.
|
386
|
+
// All the distances tracked do not take x_norms[i]
|
387
|
+
// into account in order to get rid of extra
|
388
|
+
// _mm256_add_ps(x_norms[i], ...) instructions
|
389
|
+
// is distance computations.
|
390
|
+
__m256 min_distances =
|
391
|
+
_mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
|
392
|
+
|
393
|
+
// these indices are local and are relative to j0.
|
394
|
+
// so, value 0 means j0.
|
395
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
396
|
+
|
397
|
+
__m256i current_indices =
|
398
|
+
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
399
|
+
const __m256i indices_delta = _mm256_set1_epi32(8);
|
400
|
+
|
401
|
+
// current j index
|
402
|
+
size_t idx_j = 0;
|
403
|
+
size_t count = j1 - j0;
|
404
|
+
|
405
|
+
// process 16 elements per loop
|
406
|
+
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
407
|
+
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
408
|
+
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
409
|
+
|
410
|
+
// load values for norms
|
411
|
+
const __m256 y_norm_0 =
|
412
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 0);
|
413
|
+
const __m256 y_norm_1 =
|
414
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 8);
|
415
|
+
|
416
|
+
// load values for dot products
|
417
|
+
const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
|
418
|
+
const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
|
419
|
+
|
420
|
+
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
421
|
+
// x_norm[i] was dropped off because it is a constant for a
|
422
|
+
// given i. We'll deal with it later.
|
423
|
+
__m256 distances_0 =
|
424
|
+
_mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
|
425
|
+
__m256 distances_1 =
|
426
|
+
_mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
|
427
|
+
|
428
|
+
// compare the new distances to the min distances
|
429
|
+
// for each of the first group of 8 AVX2 components.
|
430
|
+
const __m256 comparison_0 = _mm256_cmp_ps(
|
431
|
+
min_distances, distances_0, _CMP_LE_OS);
|
432
|
+
|
433
|
+
// update min distances and indices with closest vectors if
|
434
|
+
// needed.
|
435
|
+
min_distances = _mm256_blendv_ps(
|
436
|
+
distances_0, min_distances, comparison_0);
|
437
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
438
|
+
_mm256_castsi256_ps(current_indices),
|
439
|
+
_mm256_castsi256_ps(min_indices),
|
440
|
+
comparison_0));
|
441
|
+
current_indices =
|
442
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
443
|
+
|
444
|
+
// compare the new distances to the min distances
|
445
|
+
// for each of the second group of 8 AVX2 components.
|
446
|
+
const __m256 comparison_1 = _mm256_cmp_ps(
|
447
|
+
min_distances, distances_1, _CMP_LE_OS);
|
448
|
+
|
449
|
+
// update min distances and indices with closest vectors if
|
450
|
+
// needed.
|
451
|
+
min_distances = _mm256_blendv_ps(
|
452
|
+
distances_1, min_distances, comparison_1);
|
453
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
454
|
+
_mm256_castsi256_ps(current_indices),
|
455
|
+
_mm256_castsi256_ps(min_indices),
|
456
|
+
comparison_1));
|
457
|
+
current_indices =
|
458
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
459
|
+
}
|
460
|
+
|
461
|
+
// dump values and find the minimum distance / minimum index
|
462
|
+
float min_distances_scalar[8];
|
463
|
+
uint32_t min_indices_scalar[8];
|
464
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
465
|
+
_mm256_storeu_si256(
|
466
|
+
(__m256i*)(min_indices_scalar), min_indices);
|
467
|
+
|
468
|
+
float current_min_distance = res.dis_tab[i];
|
469
|
+
uint32_t current_min_index = res.ids_tab[i];
|
470
|
+
|
471
|
+
// This unusual comparison is needed to maintain the behavior
|
472
|
+
// of the original implementation: if two indices are
|
473
|
+
// represented with equal distance values, then
|
474
|
+
// the index with the min value is returned.
|
475
|
+
for (size_t jv = 0; jv < 8; jv++) {
|
476
|
+
// add missing x_norms[i]
|
477
|
+
float distance_candidate =
|
478
|
+
min_distances_scalar[jv] + x_norms[i];
|
479
|
+
|
480
|
+
// negative values can occur for identical vectors
|
481
|
+
// due to roundoff errors.
|
482
|
+
if (distance_candidate < 0)
|
483
|
+
distance_candidate = 0;
|
484
|
+
|
485
|
+
int64_t index_candidate = min_indices_scalar[jv] + j0;
|
486
|
+
|
487
|
+
if (current_min_distance > distance_candidate) {
|
488
|
+
current_min_distance = distance_candidate;
|
489
|
+
current_min_index = index_candidate;
|
490
|
+
} else if (
|
491
|
+
current_min_distance == distance_candidate &&
|
492
|
+
current_min_index > index_candidate) {
|
493
|
+
current_min_index = index_candidate;
|
494
|
+
}
|
495
|
+
}
|
496
|
+
|
497
|
+
// process leftovers
|
498
|
+
for (; idx_j < count; idx_j++, ip_line++) {
|
499
|
+
float ip = *ip_line;
|
500
|
+
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
|
501
|
+
// negative values can occur for identical vectors
|
502
|
+
// due to roundoff errors.
|
503
|
+
if (dis < 0)
|
504
|
+
dis = 0;
|
505
|
+
|
506
|
+
if (current_min_distance > dis) {
|
507
|
+
current_min_distance = dis;
|
508
|
+
current_min_index = idx_j + j0;
|
509
|
+
}
|
510
|
+
}
|
511
|
+
|
512
|
+
//
|
513
|
+
res.add_result(i, current_min_distance, current_min_index);
|
514
|
+
}
|
515
|
+
}
|
516
|
+
InterruptCallback::check();
|
517
|
+
}
|
518
|
+
}
|
519
|
+
#endif
|
520
|
+
|
521
|
+
template <class ResultHandler>
|
522
|
+
void knn_L2sqr_select(
|
523
|
+
const float* x,
|
524
|
+
const float* y,
|
525
|
+
size_t d,
|
526
|
+
size_t nx,
|
527
|
+
size_t ny,
|
528
|
+
ResultHandler& res,
|
529
|
+
const float* y_norm2,
|
530
|
+
const IDSelector* sel) {
|
531
|
+
if (sel) {
|
532
|
+
exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
|
533
|
+
} else if (nx < distance_compute_blas_threshold) {
|
534
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
535
|
+
} else {
|
536
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
537
|
+
}
|
538
|
+
}
|
539
|
+
|
299
540
|
} // anonymous namespace
|
300
541
|
|
301
542
|
/*******************************************************
|
@@ -313,24 +554,63 @@ void knn_inner_product(
|
|
313
554
|
size_t d,
|
314
555
|
size_t nx,
|
315
556
|
size_t ny,
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
557
|
+
size_t k,
|
558
|
+
float* val,
|
559
|
+
int64_t* ids,
|
560
|
+
const IDSelector* sel) {
|
561
|
+
int64_t imin = 0;
|
562
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
563
|
+
imin = std::max(selr->imin, int64_t(0));
|
564
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
565
|
+
ny = imax - imin;
|
566
|
+
y += d * imin;
|
567
|
+
sel = nullptr;
|
568
|
+
}
|
569
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
570
|
+
knn_inner_products_by_idx(
|
571
|
+
x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
|
572
|
+
return;
|
573
|
+
}
|
574
|
+
if (k < distance_compute_min_k_reservoir) {
|
575
|
+
using RH = HeapResultHandler<CMin<float, int64_t>>;
|
576
|
+
RH res(nx, val, ids, k);
|
577
|
+
if (sel) {
|
578
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
579
|
+
} else if (nx < distance_compute_blas_threshold) {
|
321
580
|
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
322
581
|
} else {
|
323
582
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
324
583
|
}
|
325
584
|
} else {
|
326
|
-
ReservoirResultHandler<CMin<float, int64_t
|
327
|
-
|
328
|
-
if (
|
329
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
585
|
+
using RH = ReservoirResultHandler<CMin<float, int64_t>>;
|
586
|
+
RH res(nx, val, ids, k);
|
587
|
+
if (sel) {
|
588
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
589
|
+
} else if (nx < distance_compute_blas_threshold) {
|
590
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
330
591
|
} else {
|
331
592
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
332
593
|
}
|
333
594
|
}
|
595
|
+
if (imin != 0) {
|
596
|
+
for (size_t i = 0; i < nx * k; i++) {
|
597
|
+
if (ids[i] >= 0) {
|
598
|
+
ids[i] += imin;
|
599
|
+
}
|
600
|
+
}
|
601
|
+
}
|
602
|
+
}
|
603
|
+
|
604
|
+
void knn_inner_product(
|
605
|
+
const float* x,
|
606
|
+
const float* y,
|
607
|
+
size_t d,
|
608
|
+
size_t nx,
|
609
|
+
size_t ny,
|
610
|
+
float_minheap_array_t* res,
|
611
|
+
const IDSelector* sel) {
|
612
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
613
|
+
knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
|
334
614
|
}
|
335
615
|
|
336
616
|
void knn_L2sqr(
|
@@ -339,28 +619,55 @@ void knn_L2sqr(
|
|
339
619
|
size_t d,
|
340
620
|
size_t nx,
|
341
621
|
size_t ny,
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
622
|
+
size_t k,
|
623
|
+
float* vals,
|
624
|
+
int64_t* ids,
|
625
|
+
const float* y_norm2,
|
626
|
+
const IDSelector* sel) {
|
627
|
+
int64_t imin = 0;
|
628
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
629
|
+
imin = std::max(selr->imin, int64_t(0));
|
630
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
631
|
+
ny = imax - imin;
|
632
|
+
y += d * imin;
|
633
|
+
sel = nullptr;
|
634
|
+
}
|
635
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
636
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
637
|
+
return;
|
638
|
+
}
|
639
|
+
if (k == 1) {
|
640
|
+
SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
641
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
642
|
+
} else if (k < distance_compute_min_k_reservoir) {
|
643
|
+
HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
644
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
353
645
|
} else {
|
354
|
-
ReservoirResultHandler<CMax<float, int64_t>> res(
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
646
|
+
ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
647
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
648
|
+
}
|
649
|
+
if (imin != 0) {
|
650
|
+
for (size_t i = 0; i < nx * k; i++) {
|
651
|
+
if (ids[i] >= 0) {
|
652
|
+
ids[i] += imin;
|
653
|
+
}
|
360
654
|
}
|
361
655
|
}
|
362
656
|
}
|
363
657
|
|
658
|
+
void knn_L2sqr(
|
659
|
+
const float* x,
|
660
|
+
const float* y,
|
661
|
+
size_t d,
|
662
|
+
size_t nx,
|
663
|
+
size_t ny,
|
664
|
+
float_maxheap_array_t* res,
|
665
|
+
const float* y_norm2,
|
666
|
+
const IDSelector* sel) {
|
667
|
+
FAISS_THROW_IF_NOT(res->nh == nx);
|
668
|
+
knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
|
669
|
+
}
|
670
|
+
|
364
671
|
/***************************************************************************
|
365
672
|
* Range search
|
366
673
|
***************************************************************************/
|
@@ -372,10 +679,14 @@ void range_search_L2sqr(
|
|
372
679
|
size_t nx,
|
373
680
|
size_t ny,
|
374
681
|
float radius,
|
375
|
-
RangeSearchResult* res
|
376
|
-
|
377
|
-
|
378
|
-
|
682
|
+
RangeSearchResult* res,
|
683
|
+
const IDSelector* sel) {
|
684
|
+
using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
|
685
|
+
RH resh(res, radius);
|
686
|
+
if (sel) {
|
687
|
+
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
688
|
+
} else if (nx < distance_compute_blas_threshold) {
|
689
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
|
379
690
|
} else {
|
380
691
|
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
381
692
|
}
|
@@ -388,9 +699,13 @@ void range_search_inner_product(
|
|
388
699
|
size_t nx,
|
389
700
|
size_t ny,
|
390
701
|
float radius,
|
391
|
-
RangeSearchResult* res
|
392
|
-
|
393
|
-
|
702
|
+
RangeSearchResult* res,
|
703
|
+
const IDSelector* sel) {
|
704
|
+
using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
|
705
|
+
RH resh(res, radius);
|
706
|
+
if (sel) {
|
707
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
708
|
+
} else if (nx < distance_compute_blas_threshold) {
|
394
709
|
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
395
710
|
} else {
|
396
711
|
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
@@ -488,16 +803,21 @@ void knn_inner_products_by_idx(
|
|
488
803
|
size_t d,
|
489
804
|
size_t nx,
|
490
805
|
size_t ny,
|
491
|
-
|
492
|
-
|
806
|
+
size_t k,
|
807
|
+
float* res_vals,
|
808
|
+
int64_t* res_ids,
|
809
|
+
int64_t ld_ids) {
|
810
|
+
if (ld_ids < 0) {
|
811
|
+
ld_ids = ny;
|
812
|
+
}
|
493
813
|
|
494
|
-
#pragma omp parallel for
|
814
|
+
#pragma omp parallel for if (nx > 100)
|
495
815
|
for (int64_t i = 0; i < nx; i++) {
|
496
816
|
const float* x_ = x + i * d;
|
497
|
-
const int64_t* idsi = ids + i *
|
817
|
+
const int64_t* idsi = ids + i * ld_ids;
|
498
818
|
size_t j;
|
499
|
-
float* __restrict simi =
|
500
|
-
int64_t* __restrict idxi =
|
819
|
+
float* __restrict simi = res_vals + i * k;
|
820
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
501
821
|
minheap_heapify(k, simi, idxi);
|
502
822
|
|
503
823
|
for (j = 0; j < ny; j++) {
|
@@ -520,16 +840,20 @@ void knn_L2sqr_by_idx(
|
|
520
840
|
size_t d,
|
521
841
|
size_t nx,
|
522
842
|
size_t ny,
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
843
|
+
size_t k,
|
844
|
+
float* res_vals,
|
845
|
+
int64_t* res_ids,
|
846
|
+
int64_t ld_ids) {
|
847
|
+
if (ld_ids < 0) {
|
848
|
+
ld_ids = ny;
|
849
|
+
}
|
850
|
+
#pragma omp parallel for if (nx > 100)
|
527
851
|
for (int64_t i = 0; i < nx; i++) {
|
528
852
|
const float* x_ = x + i * d;
|
529
|
-
const int64_t* __restrict idsi = ids + i *
|
530
|
-
float* __restrict simi =
|
531
|
-
int64_t* __restrict idxi =
|
532
|
-
maxheap_heapify(
|
853
|
+
const int64_t* __restrict idsi = ids + i * ld_ids;
|
854
|
+
float* __restrict simi = res_vals + i * k;
|
855
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
856
|
+
maxheap_heapify(k, simi, idxi);
|
533
857
|
for (size_t j = 0; j < ny; j++) {
|
534
858
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
535
859
|
|
@@ -537,7 +861,7 @@ void knn_L2sqr_by_idx(
|
|
537
861
|
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
538
862
|
}
|
539
863
|
}
|
540
|
-
maxheap_reorder(
|
864
|
+
maxheap_reorder(k, simi, idxi);
|
541
865
|
}
|
542
866
|
}
|
543
867
|
|