faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -17,8 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
#include <omp.h>
|
|
19
19
|
|
|
20
|
+
#ifdef __AVX2__
|
|
21
|
+
#include <immintrin.h>
|
|
22
|
+
#endif
|
|
23
|
+
|
|
20
24
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
21
25
|
#include <faiss/impl/FaissAssert.h>
|
|
26
|
+
#include <faiss/impl/IDSelector.h>
|
|
22
27
|
#include <faiss/impl/ResultHandler.h>
|
|
23
28
|
|
|
24
29
|
#ifndef FINTEGER
|
|
@@ -96,17 +101,20 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
|
96
101
|
namespace {
|
|
97
102
|
|
|
98
103
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
99
|
-
template <class ResultHandler>
|
|
104
|
+
template <class ResultHandler, bool use_sel = false>
|
|
100
105
|
void exhaustive_inner_product_seq(
|
|
101
106
|
const float* x,
|
|
102
107
|
const float* y,
|
|
103
108
|
size_t d,
|
|
104
109
|
size_t nx,
|
|
105
110
|
size_t ny,
|
|
106
|
-
ResultHandler& res
|
|
111
|
+
ResultHandler& res,
|
|
112
|
+
const IDSelector* sel = nullptr) {
|
|
107
113
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
108
114
|
int nt = std::min(int(nx), omp_get_max_threads());
|
|
109
115
|
|
|
116
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
117
|
+
|
|
110
118
|
#pragma omp parallel num_threads(nt)
|
|
111
119
|
{
|
|
112
120
|
SingleResultHandler resi(res);
|
|
@@ -117,27 +125,32 @@ void exhaustive_inner_product_seq(
|
|
|
117
125
|
|
|
118
126
|
resi.begin(i);
|
|
119
127
|
|
|
120
|
-
for (size_t j = 0; j < ny; j
|
|
128
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
129
|
+
if (use_sel && !sel->is_member(j)) {
|
|
130
|
+
continue;
|
|
131
|
+
}
|
|
121
132
|
float ip = fvec_inner_product(x_i, y_j, d);
|
|
122
133
|
resi.add_result(ip, j);
|
|
123
|
-
y_j += d;
|
|
124
134
|
}
|
|
125
135
|
resi.end();
|
|
126
136
|
}
|
|
127
137
|
}
|
|
128
138
|
}
|
|
129
139
|
|
|
130
|
-
template <class ResultHandler>
|
|
140
|
+
template <class ResultHandler, bool use_sel = false>
|
|
131
141
|
void exhaustive_L2sqr_seq(
|
|
132
142
|
const float* x,
|
|
133
143
|
const float* y,
|
|
134
144
|
size_t d,
|
|
135
145
|
size_t nx,
|
|
136
146
|
size_t ny,
|
|
137
|
-
ResultHandler& res
|
|
147
|
+
ResultHandler& res,
|
|
148
|
+
const IDSelector* sel = nullptr) {
|
|
138
149
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
139
150
|
int nt = std::min(int(nx), omp_get_max_threads());
|
|
140
151
|
|
|
152
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
153
|
+
|
|
141
154
|
#pragma omp parallel num_threads(nt)
|
|
142
155
|
{
|
|
143
156
|
SingleResultHandler resi(res);
|
|
@@ -146,10 +159,12 @@ void exhaustive_L2sqr_seq(
|
|
|
146
159
|
const float* x_i = x + i * d;
|
|
147
160
|
const float* y_j = y;
|
|
148
161
|
resi.begin(i);
|
|
149
|
-
for (size_t j = 0; j < ny; j
|
|
162
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
163
|
+
if (use_sel && !sel->is_member(j)) {
|
|
164
|
+
continue;
|
|
165
|
+
}
|
|
150
166
|
float disij = fvec_L2sqr(x_i, y_j, d);
|
|
151
167
|
resi.add_result(disij, j);
|
|
152
|
-
y_j += d;
|
|
153
168
|
}
|
|
154
169
|
resi.end();
|
|
155
170
|
}
|
|
@@ -296,6 +311,232 @@ void exhaustive_L2sqr_blas(
|
|
|
296
311
|
}
|
|
297
312
|
}
|
|
298
313
|
|
|
314
|
+
#ifdef __AVX2__
|
|
315
|
+
// an override for AVX2 if only a single closest point is needed.
|
|
316
|
+
template <>
|
|
317
|
+
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
|
318
|
+
const float* x,
|
|
319
|
+
const float* y,
|
|
320
|
+
size_t d,
|
|
321
|
+
size_t nx,
|
|
322
|
+
size_t ny,
|
|
323
|
+
SingleBestResultHandler<CMax<float, int64_t>>& res,
|
|
324
|
+
const float* y_norms) {
|
|
325
|
+
// BLAS does not like empty matrices
|
|
326
|
+
if (nx == 0 || ny == 0)
|
|
327
|
+
return;
|
|
328
|
+
|
|
329
|
+
/* block sizes */
|
|
330
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
|
331
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
|
332
|
+
// const size_t bs_x = 16, bs_y = 16;
|
|
333
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
334
|
+
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
335
|
+
std::unique_ptr<float[]> del2;
|
|
336
|
+
|
|
337
|
+
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
338
|
+
|
|
339
|
+
if (!y_norms) {
|
|
340
|
+
float* y_norms2 = new float[ny];
|
|
341
|
+
del2.reset(y_norms2);
|
|
342
|
+
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
343
|
+
y_norms = y_norms2;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
347
|
+
size_t i1 = i0 + bs_x;
|
|
348
|
+
if (i1 > nx)
|
|
349
|
+
i1 = nx;
|
|
350
|
+
|
|
351
|
+
res.begin_multiple(i0, i1);
|
|
352
|
+
|
|
353
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
354
|
+
size_t j1 = j0 + bs_y;
|
|
355
|
+
if (j1 > ny)
|
|
356
|
+
j1 = ny;
|
|
357
|
+
/* compute the actual dot products */
|
|
358
|
+
{
|
|
359
|
+
float one = 1, zero = 0;
|
|
360
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
361
|
+
sgemm_("Transpose",
|
|
362
|
+
"Not transpose",
|
|
363
|
+
&nyi,
|
|
364
|
+
&nxi,
|
|
365
|
+
&di,
|
|
366
|
+
&one,
|
|
367
|
+
y + j0 * d,
|
|
368
|
+
&di,
|
|
369
|
+
x + i0 * d,
|
|
370
|
+
&di,
|
|
371
|
+
&zero,
|
|
372
|
+
ip_block.get(),
|
|
373
|
+
&nyi);
|
|
374
|
+
}
|
|
375
|
+
#pragma omp parallel for
|
|
376
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
377
|
+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
378
|
+
|
|
379
|
+
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
|
380
|
+
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
|
381
|
+
|
|
382
|
+
// constant
|
|
383
|
+
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
384
|
+
|
|
385
|
+
// Track 8 min distances + 8 min indices.
|
|
386
|
+
// All the distances tracked do not take x_norms[i]
|
|
387
|
+
// into account in order to get rid of extra
|
|
388
|
+
// _mm256_add_ps(x_norms[i], ...) instructions
|
|
389
|
+
// is distance computations.
|
|
390
|
+
__m256 min_distances =
|
|
391
|
+
_mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
|
|
392
|
+
|
|
393
|
+
// these indices are local and are relative to j0.
|
|
394
|
+
// so, value 0 means j0.
|
|
395
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
396
|
+
|
|
397
|
+
__m256i current_indices =
|
|
398
|
+
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
399
|
+
const __m256i indices_delta = _mm256_set1_epi32(8);
|
|
400
|
+
|
|
401
|
+
// current j index
|
|
402
|
+
size_t idx_j = 0;
|
|
403
|
+
size_t count = j1 - j0;
|
|
404
|
+
|
|
405
|
+
// process 16 elements per loop
|
|
406
|
+
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
407
|
+
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
|
408
|
+
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
|
409
|
+
|
|
410
|
+
// load values for norms
|
|
411
|
+
const __m256 y_norm_0 =
|
|
412
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 0);
|
|
413
|
+
const __m256 y_norm_1 =
|
|
414
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 8);
|
|
415
|
+
|
|
416
|
+
// load values for dot products
|
|
417
|
+
const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
|
|
418
|
+
const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
|
|
419
|
+
|
|
420
|
+
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
421
|
+
// x_norm[i] was dropped off because it is a constant for a
|
|
422
|
+
// given i. We'll deal with it later.
|
|
423
|
+
__m256 distances_0 =
|
|
424
|
+
_mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
|
|
425
|
+
__m256 distances_1 =
|
|
426
|
+
_mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
|
|
427
|
+
|
|
428
|
+
// compare the new distances to the min distances
|
|
429
|
+
// for each of the first group of 8 AVX2 components.
|
|
430
|
+
const __m256 comparison_0 = _mm256_cmp_ps(
|
|
431
|
+
min_distances, distances_0, _CMP_LE_OS);
|
|
432
|
+
|
|
433
|
+
// update min distances and indices with closest vectors if
|
|
434
|
+
// needed.
|
|
435
|
+
min_distances = _mm256_blendv_ps(
|
|
436
|
+
distances_0, min_distances, comparison_0);
|
|
437
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
438
|
+
_mm256_castsi256_ps(current_indices),
|
|
439
|
+
_mm256_castsi256_ps(min_indices),
|
|
440
|
+
comparison_0));
|
|
441
|
+
current_indices =
|
|
442
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
|
443
|
+
|
|
444
|
+
// compare the new distances to the min distances
|
|
445
|
+
// for each of the second group of 8 AVX2 components.
|
|
446
|
+
const __m256 comparison_1 = _mm256_cmp_ps(
|
|
447
|
+
min_distances, distances_1, _CMP_LE_OS);
|
|
448
|
+
|
|
449
|
+
// update min distances and indices with closest vectors if
|
|
450
|
+
// needed.
|
|
451
|
+
min_distances = _mm256_blendv_ps(
|
|
452
|
+
distances_1, min_distances, comparison_1);
|
|
453
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
454
|
+
_mm256_castsi256_ps(current_indices),
|
|
455
|
+
_mm256_castsi256_ps(min_indices),
|
|
456
|
+
comparison_1));
|
|
457
|
+
current_indices =
|
|
458
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
// dump values and find the minimum distance / minimum index
|
|
462
|
+
float min_distances_scalar[8];
|
|
463
|
+
uint32_t min_indices_scalar[8];
|
|
464
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
465
|
+
_mm256_storeu_si256(
|
|
466
|
+
(__m256i*)(min_indices_scalar), min_indices);
|
|
467
|
+
|
|
468
|
+
float current_min_distance = res.dis_tab[i];
|
|
469
|
+
uint32_t current_min_index = res.ids_tab[i];
|
|
470
|
+
|
|
471
|
+
// This unusual comparison is needed to maintain the behavior
|
|
472
|
+
// of the original implementation: if two indices are
|
|
473
|
+
// represented with equal distance values, then
|
|
474
|
+
// the index with the min value is returned.
|
|
475
|
+
for (size_t jv = 0; jv < 8; jv++) {
|
|
476
|
+
// add missing x_norms[i]
|
|
477
|
+
float distance_candidate =
|
|
478
|
+
min_distances_scalar[jv] + x_norms[i];
|
|
479
|
+
|
|
480
|
+
// negative values can occur for identical vectors
|
|
481
|
+
// due to roundoff errors.
|
|
482
|
+
if (distance_candidate < 0)
|
|
483
|
+
distance_candidate = 0;
|
|
484
|
+
|
|
485
|
+
int64_t index_candidate = min_indices_scalar[jv] + j0;
|
|
486
|
+
|
|
487
|
+
if (current_min_distance > distance_candidate) {
|
|
488
|
+
current_min_distance = distance_candidate;
|
|
489
|
+
current_min_index = index_candidate;
|
|
490
|
+
} else if (
|
|
491
|
+
current_min_distance == distance_candidate &&
|
|
492
|
+
current_min_index > index_candidate) {
|
|
493
|
+
current_min_index = index_candidate;
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
// process leftovers
|
|
498
|
+
for (; idx_j < count; idx_j++, ip_line++) {
|
|
499
|
+
float ip = *ip_line;
|
|
500
|
+
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
|
|
501
|
+
// negative values can occur for identical vectors
|
|
502
|
+
// due to roundoff errors.
|
|
503
|
+
if (dis < 0)
|
|
504
|
+
dis = 0;
|
|
505
|
+
|
|
506
|
+
if (current_min_distance > dis) {
|
|
507
|
+
current_min_distance = dis;
|
|
508
|
+
current_min_index = idx_j + j0;
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
//
|
|
513
|
+
res.add_result(i, current_min_distance, current_min_index);
|
|
514
|
+
}
|
|
515
|
+
}
|
|
516
|
+
InterruptCallback::check();
|
|
517
|
+
}
|
|
518
|
+
}
|
|
519
|
+
#endif
|
|
520
|
+
|
|
521
|
+
template <class ResultHandler>
|
|
522
|
+
void knn_L2sqr_select(
|
|
523
|
+
const float* x,
|
|
524
|
+
const float* y,
|
|
525
|
+
size_t d,
|
|
526
|
+
size_t nx,
|
|
527
|
+
size_t ny,
|
|
528
|
+
ResultHandler& res,
|
|
529
|
+
const float* y_norm2,
|
|
530
|
+
const IDSelector* sel) {
|
|
531
|
+
if (sel) {
|
|
532
|
+
exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
|
|
533
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
534
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
535
|
+
} else {
|
|
536
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
299
540
|
} // anonymous namespace
|
|
300
541
|
|
|
301
542
|
/*******************************************************
|
|
@@ -313,24 +554,63 @@ void knn_inner_product(
|
|
|
313
554
|
size_t d,
|
|
314
555
|
size_t nx,
|
|
315
556
|
size_t ny,
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
557
|
+
size_t k,
|
|
558
|
+
float* val,
|
|
559
|
+
int64_t* ids,
|
|
560
|
+
const IDSelector* sel) {
|
|
561
|
+
int64_t imin = 0;
|
|
562
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
|
563
|
+
imin = std::max(selr->imin, int64_t(0));
|
|
564
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
|
565
|
+
ny = imax - imin;
|
|
566
|
+
y += d * imin;
|
|
567
|
+
sel = nullptr;
|
|
568
|
+
}
|
|
569
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
570
|
+
knn_inner_products_by_idx(
|
|
571
|
+
x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
|
|
572
|
+
return;
|
|
573
|
+
}
|
|
574
|
+
if (k < distance_compute_min_k_reservoir) {
|
|
575
|
+
using RH = HeapResultHandler<CMin<float, int64_t>>;
|
|
576
|
+
RH res(nx, val, ids, k);
|
|
577
|
+
if (sel) {
|
|
578
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
579
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
321
580
|
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
322
581
|
} else {
|
|
323
582
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
324
583
|
}
|
|
325
584
|
} else {
|
|
326
|
-
ReservoirResultHandler<CMin<float, int64_t
|
|
327
|
-
|
|
328
|
-
if (
|
|
329
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
585
|
+
using RH = ReservoirResultHandler<CMin<float, int64_t>>;
|
|
586
|
+
RH res(nx, val, ids, k);
|
|
587
|
+
if (sel) {
|
|
588
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
589
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
590
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
|
330
591
|
} else {
|
|
331
592
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
332
593
|
}
|
|
333
594
|
}
|
|
595
|
+
if (imin != 0) {
|
|
596
|
+
for (size_t i = 0; i < nx * k; i++) {
|
|
597
|
+
if (ids[i] >= 0) {
|
|
598
|
+
ids[i] += imin;
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
}
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
void knn_inner_product(
|
|
605
|
+
const float* x,
|
|
606
|
+
const float* y,
|
|
607
|
+
size_t d,
|
|
608
|
+
size_t nx,
|
|
609
|
+
size_t ny,
|
|
610
|
+
float_minheap_array_t* res,
|
|
611
|
+
const IDSelector* sel) {
|
|
612
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
|
613
|
+
knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
|
|
334
614
|
}
|
|
335
615
|
|
|
336
616
|
void knn_L2sqr(
|
|
@@ -339,28 +619,55 @@ void knn_L2sqr(
|
|
|
339
619
|
size_t d,
|
|
340
620
|
size_t nx,
|
|
341
621
|
size_t ny,
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
622
|
+
size_t k,
|
|
623
|
+
float* vals,
|
|
624
|
+
int64_t* ids,
|
|
625
|
+
const float* y_norm2,
|
|
626
|
+
const IDSelector* sel) {
|
|
627
|
+
int64_t imin = 0;
|
|
628
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
|
629
|
+
imin = std::max(selr->imin, int64_t(0));
|
|
630
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
|
631
|
+
ny = imax - imin;
|
|
632
|
+
y += d * imin;
|
|
633
|
+
sel = nullptr;
|
|
634
|
+
}
|
|
635
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
636
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
|
637
|
+
return;
|
|
638
|
+
}
|
|
639
|
+
if (k == 1) {
|
|
640
|
+
SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
|
641
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
642
|
+
} else if (k < distance_compute_min_k_reservoir) {
|
|
643
|
+
HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
644
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
353
645
|
} else {
|
|
354
|
-
ReservoirResultHandler<CMax<float, int64_t>> res(
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
646
|
+
ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
647
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
648
|
+
}
|
|
649
|
+
if (imin != 0) {
|
|
650
|
+
for (size_t i = 0; i < nx * k; i++) {
|
|
651
|
+
if (ids[i] >= 0) {
|
|
652
|
+
ids[i] += imin;
|
|
653
|
+
}
|
|
360
654
|
}
|
|
361
655
|
}
|
|
362
656
|
}
|
|
363
657
|
|
|
658
|
+
void knn_L2sqr(
|
|
659
|
+
const float* x,
|
|
660
|
+
const float* y,
|
|
661
|
+
size_t d,
|
|
662
|
+
size_t nx,
|
|
663
|
+
size_t ny,
|
|
664
|
+
float_maxheap_array_t* res,
|
|
665
|
+
const float* y_norm2,
|
|
666
|
+
const IDSelector* sel) {
|
|
667
|
+
FAISS_THROW_IF_NOT(res->nh == nx);
|
|
668
|
+
knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
|
|
669
|
+
}
|
|
670
|
+
|
|
364
671
|
/***************************************************************************
|
|
365
672
|
* Range search
|
|
366
673
|
***************************************************************************/
|
|
@@ -372,10 +679,14 @@ void range_search_L2sqr(
|
|
|
372
679
|
size_t nx,
|
|
373
680
|
size_t ny,
|
|
374
681
|
float radius,
|
|
375
|
-
RangeSearchResult* res
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
682
|
+
RangeSearchResult* res,
|
|
683
|
+
const IDSelector* sel) {
|
|
684
|
+
using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
|
|
685
|
+
RH resh(res, radius);
|
|
686
|
+
if (sel) {
|
|
687
|
+
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
688
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
689
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
|
|
379
690
|
} else {
|
|
380
691
|
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
|
381
692
|
}
|
|
@@ -388,9 +699,13 @@ void range_search_inner_product(
|
|
|
388
699
|
size_t nx,
|
|
389
700
|
size_t ny,
|
|
390
701
|
float radius,
|
|
391
|
-
RangeSearchResult* res
|
|
392
|
-
|
|
393
|
-
|
|
702
|
+
RangeSearchResult* res,
|
|
703
|
+
const IDSelector* sel) {
|
|
704
|
+
using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
|
|
705
|
+
RH resh(res, radius);
|
|
706
|
+
if (sel) {
|
|
707
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
708
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
394
709
|
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
|
395
710
|
} else {
|
|
396
711
|
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
|
@@ -488,16 +803,21 @@ void knn_inner_products_by_idx(
|
|
|
488
803
|
size_t d,
|
|
489
804
|
size_t nx,
|
|
490
805
|
size_t ny,
|
|
491
|
-
|
|
492
|
-
|
|
806
|
+
size_t k,
|
|
807
|
+
float* res_vals,
|
|
808
|
+
int64_t* res_ids,
|
|
809
|
+
int64_t ld_ids) {
|
|
810
|
+
if (ld_ids < 0) {
|
|
811
|
+
ld_ids = ny;
|
|
812
|
+
}
|
|
493
813
|
|
|
494
|
-
#pragma omp parallel for
|
|
814
|
+
#pragma omp parallel for if (nx > 100)
|
|
495
815
|
for (int64_t i = 0; i < nx; i++) {
|
|
496
816
|
const float* x_ = x + i * d;
|
|
497
|
-
const int64_t* idsi = ids + i *
|
|
817
|
+
const int64_t* idsi = ids + i * ld_ids;
|
|
498
818
|
size_t j;
|
|
499
|
-
float* __restrict simi =
|
|
500
|
-
int64_t* __restrict idxi =
|
|
819
|
+
float* __restrict simi = res_vals + i * k;
|
|
820
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
501
821
|
minheap_heapify(k, simi, idxi);
|
|
502
822
|
|
|
503
823
|
for (j = 0; j < ny; j++) {
|
|
@@ -520,16 +840,20 @@ void knn_L2sqr_by_idx(
|
|
|
520
840
|
size_t d,
|
|
521
841
|
size_t nx,
|
|
522
842
|
size_t ny,
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
843
|
+
size_t k,
|
|
844
|
+
float* res_vals,
|
|
845
|
+
int64_t* res_ids,
|
|
846
|
+
int64_t ld_ids) {
|
|
847
|
+
if (ld_ids < 0) {
|
|
848
|
+
ld_ids = ny;
|
|
849
|
+
}
|
|
850
|
+
#pragma omp parallel for if (nx > 100)
|
|
527
851
|
for (int64_t i = 0; i < nx; i++) {
|
|
528
852
|
const float* x_ = x + i * d;
|
|
529
|
-
const int64_t* __restrict idsi = ids + i *
|
|
530
|
-
float* __restrict simi =
|
|
531
|
-
int64_t* __restrict idxi =
|
|
532
|
-
maxheap_heapify(
|
|
853
|
+
const int64_t* __restrict idsi = ids + i * ld_ids;
|
|
854
|
+
float* __restrict simi = res_vals + i * k;
|
|
855
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
856
|
+
maxheap_heapify(k, simi, idxi);
|
|
533
857
|
for (size_t j = 0; j < ny; j++) {
|
|
534
858
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
|
535
859
|
|
|
@@ -537,7 +861,7 @@ void knn_L2sqr_by_idx(
|
|
|
537
861
|
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
|
538
862
|
}
|
|
539
863
|
}
|
|
540
|
-
maxheap_reorder(
|
|
864
|
+
maxheap_reorder(k, simi, idxi);
|
|
541
865
|
}
|
|
542
866
|
}
|
|
543
867
|
|