faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -17,8 +17,13 @@
|
|
|
17
17
|
|
|
18
18
|
#include <omp.h>
|
|
19
19
|
|
|
20
|
+
#ifdef __AVX2__
|
|
21
|
+
#include <immintrin.h>
|
|
22
|
+
#endif
|
|
23
|
+
|
|
20
24
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
21
25
|
#include <faiss/impl/FaissAssert.h>
|
|
26
|
+
#include <faiss/impl/IDSelector.h>
|
|
22
27
|
#include <faiss/impl/ResultHandler.h>
|
|
23
28
|
|
|
24
29
|
#ifndef FINTEGER
|
|
@@ -96,17 +101,21 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
|
96
101
|
namespace {
|
|
97
102
|
|
|
98
103
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
99
|
-
template <class ResultHandler>
|
|
104
|
+
template <class ResultHandler, bool use_sel = false>
|
|
100
105
|
void exhaustive_inner_product_seq(
|
|
101
106
|
const float* x,
|
|
102
107
|
const float* y,
|
|
103
108
|
size_t d,
|
|
104
109
|
size_t nx,
|
|
105
110
|
size_t ny,
|
|
106
|
-
ResultHandler& res
|
|
111
|
+
ResultHandler& res,
|
|
112
|
+
const IDSelector* sel = nullptr) {
|
|
107
113
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
114
|
+
int nt = std::min(int(nx), omp_get_max_threads());
|
|
108
115
|
|
|
109
|
-
|
|
116
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
117
|
+
|
|
118
|
+
#pragma omp parallel num_threads(nt)
|
|
110
119
|
{
|
|
111
120
|
SingleResultHandler resi(res);
|
|
112
121
|
#pragma omp for
|
|
@@ -116,27 +125,33 @@ void exhaustive_inner_product_seq(
|
|
|
116
125
|
|
|
117
126
|
resi.begin(i);
|
|
118
127
|
|
|
119
|
-
for (size_t j = 0; j < ny; j
|
|
128
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
129
|
+
if (use_sel && !sel->is_member(j)) {
|
|
130
|
+
continue;
|
|
131
|
+
}
|
|
120
132
|
float ip = fvec_inner_product(x_i, y_j, d);
|
|
121
133
|
resi.add_result(ip, j);
|
|
122
|
-
y_j += d;
|
|
123
134
|
}
|
|
124
135
|
resi.end();
|
|
125
136
|
}
|
|
126
137
|
}
|
|
127
138
|
}
|
|
128
139
|
|
|
129
|
-
template <class ResultHandler>
|
|
140
|
+
template <class ResultHandler, bool use_sel = false>
|
|
130
141
|
void exhaustive_L2sqr_seq(
|
|
131
142
|
const float* x,
|
|
132
143
|
const float* y,
|
|
133
144
|
size_t d,
|
|
134
145
|
size_t nx,
|
|
135
146
|
size_t ny,
|
|
136
|
-
ResultHandler& res
|
|
147
|
+
ResultHandler& res,
|
|
148
|
+
const IDSelector* sel = nullptr) {
|
|
137
149
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
150
|
+
int nt = std::min(int(nx), omp_get_max_threads());
|
|
151
|
+
|
|
152
|
+
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
138
153
|
|
|
139
|
-
#pragma omp parallel
|
|
154
|
+
#pragma omp parallel num_threads(nt)
|
|
140
155
|
{
|
|
141
156
|
SingleResultHandler resi(res);
|
|
142
157
|
#pragma omp for
|
|
@@ -144,10 +159,12 @@ void exhaustive_L2sqr_seq(
|
|
|
144
159
|
const float* x_i = x + i * d;
|
|
145
160
|
const float* y_j = y;
|
|
146
161
|
resi.begin(i);
|
|
147
|
-
for (size_t j = 0; j < ny; j
|
|
162
|
+
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
163
|
+
if (use_sel && !sel->is_member(j)) {
|
|
164
|
+
continue;
|
|
165
|
+
}
|
|
148
166
|
float disij = fvec_L2sqr(x_i, y_j, d);
|
|
149
167
|
resi.add_result(disij, j);
|
|
150
|
-
y_j += d;
|
|
151
168
|
}
|
|
152
169
|
resi.end();
|
|
153
170
|
}
|
|
@@ -294,6 +311,232 @@ void exhaustive_L2sqr_blas(
|
|
|
294
311
|
}
|
|
295
312
|
}
|
|
296
313
|
|
|
314
|
+
#ifdef __AVX2__
|
|
315
|
+
// an override for AVX2 if only a single closest point is needed.
|
|
316
|
+
template <>
|
|
317
|
+
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
|
318
|
+
const float* x,
|
|
319
|
+
const float* y,
|
|
320
|
+
size_t d,
|
|
321
|
+
size_t nx,
|
|
322
|
+
size_t ny,
|
|
323
|
+
SingleBestResultHandler<CMax<float, int64_t>>& res,
|
|
324
|
+
const float* y_norms) {
|
|
325
|
+
// BLAS does not like empty matrices
|
|
326
|
+
if (nx == 0 || ny == 0)
|
|
327
|
+
return;
|
|
328
|
+
|
|
329
|
+
/* block sizes */
|
|
330
|
+
const size_t bs_x = distance_compute_blas_query_bs;
|
|
331
|
+
const size_t bs_y = distance_compute_blas_database_bs;
|
|
332
|
+
// const size_t bs_x = 16, bs_y = 16;
|
|
333
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
334
|
+
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
335
|
+
std::unique_ptr<float[]> del2;
|
|
336
|
+
|
|
337
|
+
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
338
|
+
|
|
339
|
+
if (!y_norms) {
|
|
340
|
+
float* y_norms2 = new float[ny];
|
|
341
|
+
del2.reset(y_norms2);
|
|
342
|
+
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
343
|
+
y_norms = y_norms2;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
347
|
+
size_t i1 = i0 + bs_x;
|
|
348
|
+
if (i1 > nx)
|
|
349
|
+
i1 = nx;
|
|
350
|
+
|
|
351
|
+
res.begin_multiple(i0, i1);
|
|
352
|
+
|
|
353
|
+
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
354
|
+
size_t j1 = j0 + bs_y;
|
|
355
|
+
if (j1 > ny)
|
|
356
|
+
j1 = ny;
|
|
357
|
+
/* compute the actual dot products */
|
|
358
|
+
{
|
|
359
|
+
float one = 1, zero = 0;
|
|
360
|
+
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
361
|
+
sgemm_("Transpose",
|
|
362
|
+
"Not transpose",
|
|
363
|
+
&nyi,
|
|
364
|
+
&nxi,
|
|
365
|
+
&di,
|
|
366
|
+
&one,
|
|
367
|
+
y + j0 * d,
|
|
368
|
+
&di,
|
|
369
|
+
x + i0 * d,
|
|
370
|
+
&di,
|
|
371
|
+
&zero,
|
|
372
|
+
ip_block.get(),
|
|
373
|
+
&nyi);
|
|
374
|
+
}
|
|
375
|
+
#pragma omp parallel for
|
|
376
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
377
|
+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
378
|
+
|
|
379
|
+
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
|
380
|
+
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
|
381
|
+
|
|
382
|
+
// constant
|
|
383
|
+
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
384
|
+
|
|
385
|
+
// Track 8 min distances + 8 min indices.
|
|
386
|
+
// All the distances tracked do not take x_norms[i]
|
|
387
|
+
// into account in order to get rid of extra
|
|
388
|
+
// _mm256_add_ps(x_norms[i], ...) instructions
|
|
389
|
+
// is distance computations.
|
|
390
|
+
__m256 min_distances =
|
|
391
|
+
_mm256_set1_ps(res.dis_tab[i] - x_norms[i]);
|
|
392
|
+
|
|
393
|
+
// these indices are local and are relative to j0.
|
|
394
|
+
// so, value 0 means j0.
|
|
395
|
+
__m256i min_indices = _mm256_set1_epi32(0);
|
|
396
|
+
|
|
397
|
+
__m256i current_indices =
|
|
398
|
+
_mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
399
|
+
const __m256i indices_delta = _mm256_set1_epi32(8);
|
|
400
|
+
|
|
401
|
+
// current j index
|
|
402
|
+
size_t idx_j = 0;
|
|
403
|
+
size_t count = j1 - j0;
|
|
404
|
+
|
|
405
|
+
// process 16 elements per loop
|
|
406
|
+
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
407
|
+
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
|
408
|
+
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
|
409
|
+
|
|
410
|
+
// load values for norms
|
|
411
|
+
const __m256 y_norm_0 =
|
|
412
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 0);
|
|
413
|
+
const __m256 y_norm_1 =
|
|
414
|
+
_mm256_loadu_ps(y_norms + idx_j + j0 + 8);
|
|
415
|
+
|
|
416
|
+
// load values for dot products
|
|
417
|
+
const __m256 ip_0 = _mm256_loadu_ps(ip_line + 0);
|
|
418
|
+
const __m256 ip_1 = _mm256_loadu_ps(ip_line + 8);
|
|
419
|
+
|
|
420
|
+
// compute dis = y_norm[j] - 2 * dot(x_norm[i], y_norm[j]).
|
|
421
|
+
// x_norm[i] was dropped off because it is a constant for a
|
|
422
|
+
// given i. We'll deal with it later.
|
|
423
|
+
__m256 distances_0 =
|
|
424
|
+
_mm256_fmadd_ps(ip_0, mul_minus2, y_norm_0);
|
|
425
|
+
__m256 distances_1 =
|
|
426
|
+
_mm256_fmadd_ps(ip_1, mul_minus2, y_norm_1);
|
|
427
|
+
|
|
428
|
+
// compare the new distances to the min distances
|
|
429
|
+
// for each of the first group of 8 AVX2 components.
|
|
430
|
+
const __m256 comparison_0 = _mm256_cmp_ps(
|
|
431
|
+
min_distances, distances_0, _CMP_LE_OS);
|
|
432
|
+
|
|
433
|
+
// update min distances and indices with closest vectors if
|
|
434
|
+
// needed.
|
|
435
|
+
min_distances = _mm256_blendv_ps(
|
|
436
|
+
distances_0, min_distances, comparison_0);
|
|
437
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
438
|
+
_mm256_castsi256_ps(current_indices),
|
|
439
|
+
_mm256_castsi256_ps(min_indices),
|
|
440
|
+
comparison_0));
|
|
441
|
+
current_indices =
|
|
442
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
|
443
|
+
|
|
444
|
+
// compare the new distances to the min distances
|
|
445
|
+
// for each of the second group of 8 AVX2 components.
|
|
446
|
+
const __m256 comparison_1 = _mm256_cmp_ps(
|
|
447
|
+
min_distances, distances_1, _CMP_LE_OS);
|
|
448
|
+
|
|
449
|
+
// update min distances and indices with closest vectors if
|
|
450
|
+
// needed.
|
|
451
|
+
min_distances = _mm256_blendv_ps(
|
|
452
|
+
distances_1, min_distances, comparison_1);
|
|
453
|
+
min_indices = _mm256_castps_si256(_mm256_blendv_ps(
|
|
454
|
+
_mm256_castsi256_ps(current_indices),
|
|
455
|
+
_mm256_castsi256_ps(min_indices),
|
|
456
|
+
comparison_1));
|
|
457
|
+
current_indices =
|
|
458
|
+
_mm256_add_epi32(current_indices, indices_delta);
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
// dump values and find the minimum distance / minimum index
|
|
462
|
+
float min_distances_scalar[8];
|
|
463
|
+
uint32_t min_indices_scalar[8];
|
|
464
|
+
_mm256_storeu_ps(min_distances_scalar, min_distances);
|
|
465
|
+
_mm256_storeu_si256(
|
|
466
|
+
(__m256i*)(min_indices_scalar), min_indices);
|
|
467
|
+
|
|
468
|
+
float current_min_distance = res.dis_tab[i];
|
|
469
|
+
uint32_t current_min_index = res.ids_tab[i];
|
|
470
|
+
|
|
471
|
+
// This unusual comparison is needed to maintain the behavior
|
|
472
|
+
// of the original implementation: if two indices are
|
|
473
|
+
// represented with equal distance values, then
|
|
474
|
+
// the index with the min value is returned.
|
|
475
|
+
for (size_t jv = 0; jv < 8; jv++) {
|
|
476
|
+
// add missing x_norms[i]
|
|
477
|
+
float distance_candidate =
|
|
478
|
+
min_distances_scalar[jv] + x_norms[i];
|
|
479
|
+
|
|
480
|
+
// negative values can occur for identical vectors
|
|
481
|
+
// due to roundoff errors.
|
|
482
|
+
if (distance_candidate < 0)
|
|
483
|
+
distance_candidate = 0;
|
|
484
|
+
|
|
485
|
+
int64_t index_candidate = min_indices_scalar[jv] + j0;
|
|
486
|
+
|
|
487
|
+
if (current_min_distance > distance_candidate) {
|
|
488
|
+
current_min_distance = distance_candidate;
|
|
489
|
+
current_min_index = index_candidate;
|
|
490
|
+
} else if (
|
|
491
|
+
current_min_distance == distance_candidate &&
|
|
492
|
+
current_min_index > index_candidate) {
|
|
493
|
+
current_min_index = index_candidate;
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
// process leftovers
|
|
498
|
+
for (; idx_j < count; idx_j++, ip_line++) {
|
|
499
|
+
float ip = *ip_line;
|
|
500
|
+
float dis = x_norms[i] + y_norms[idx_j + j0] - 2 * ip;
|
|
501
|
+
// negative values can occur for identical vectors
|
|
502
|
+
// due to roundoff errors.
|
|
503
|
+
if (dis < 0)
|
|
504
|
+
dis = 0;
|
|
505
|
+
|
|
506
|
+
if (current_min_distance > dis) {
|
|
507
|
+
current_min_distance = dis;
|
|
508
|
+
current_min_index = idx_j + j0;
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
//
|
|
513
|
+
res.add_result(i, current_min_distance, current_min_index);
|
|
514
|
+
}
|
|
515
|
+
}
|
|
516
|
+
InterruptCallback::check();
|
|
517
|
+
}
|
|
518
|
+
}
|
|
519
|
+
#endif
|
|
520
|
+
|
|
521
|
+
template <class ResultHandler>
|
|
522
|
+
void knn_L2sqr_select(
|
|
523
|
+
const float* x,
|
|
524
|
+
const float* y,
|
|
525
|
+
size_t d,
|
|
526
|
+
size_t nx,
|
|
527
|
+
size_t ny,
|
|
528
|
+
ResultHandler& res,
|
|
529
|
+
const float* y_norm2,
|
|
530
|
+
const IDSelector* sel) {
|
|
531
|
+
if (sel) {
|
|
532
|
+
exhaustive_L2sqr_seq<ResultHandler, true>(x, y, d, nx, ny, res, sel);
|
|
533
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
534
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
535
|
+
} else {
|
|
536
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
297
540
|
} // anonymous namespace
|
|
298
541
|
|
|
299
542
|
/*******************************************************
|
|
@@ -311,24 +554,63 @@ void knn_inner_product(
|
|
|
311
554
|
size_t d,
|
|
312
555
|
size_t nx,
|
|
313
556
|
size_t ny,
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
557
|
+
size_t k,
|
|
558
|
+
float* val,
|
|
559
|
+
int64_t* ids,
|
|
560
|
+
const IDSelector* sel) {
|
|
561
|
+
int64_t imin = 0;
|
|
562
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
|
563
|
+
imin = std::max(selr->imin, int64_t(0));
|
|
564
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
|
565
|
+
ny = imax - imin;
|
|
566
|
+
y += d * imin;
|
|
567
|
+
sel = nullptr;
|
|
568
|
+
}
|
|
569
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
570
|
+
knn_inner_products_by_idx(
|
|
571
|
+
x, y, sela->ids, d, nx, sela->n, k, val, ids, 0);
|
|
572
|
+
return;
|
|
573
|
+
}
|
|
574
|
+
if (k < distance_compute_min_k_reservoir) {
|
|
575
|
+
using RH = HeapResultHandler<CMin<float, int64_t>>;
|
|
576
|
+
RH res(nx, val, ids, k);
|
|
577
|
+
if (sel) {
|
|
578
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
579
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
319
580
|
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
320
581
|
} else {
|
|
321
582
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
322
583
|
}
|
|
323
584
|
} else {
|
|
324
|
-
ReservoirResultHandler<CMin<float, int64_t
|
|
325
|
-
|
|
326
|
-
if (
|
|
327
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
585
|
+
using RH = ReservoirResultHandler<CMin<float, int64_t>>;
|
|
586
|
+
RH res(nx, val, ids, k);
|
|
587
|
+
if (sel) {
|
|
588
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
589
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
590
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
|
328
591
|
} else {
|
|
329
592
|
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
330
593
|
}
|
|
331
594
|
}
|
|
595
|
+
if (imin != 0) {
|
|
596
|
+
for (size_t i = 0; i < nx * k; i++) {
|
|
597
|
+
if (ids[i] >= 0) {
|
|
598
|
+
ids[i] += imin;
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
}
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
void knn_inner_product(
|
|
605
|
+
const float* x,
|
|
606
|
+
const float* y,
|
|
607
|
+
size_t d,
|
|
608
|
+
size_t nx,
|
|
609
|
+
size_t ny,
|
|
610
|
+
float_minheap_array_t* res,
|
|
611
|
+
const IDSelector* sel) {
|
|
612
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
|
613
|
+
knn_inner_product(x, y, d, nx, ny, res->k, res->val, res->ids, sel);
|
|
332
614
|
}
|
|
333
615
|
|
|
334
616
|
void knn_L2sqr(
|
|
@@ -337,28 +619,55 @@ void knn_L2sqr(
|
|
|
337
619
|
size_t d,
|
|
338
620
|
size_t nx,
|
|
339
621
|
size_t ny,
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
622
|
+
size_t k,
|
|
623
|
+
float* vals,
|
|
624
|
+
int64_t* ids,
|
|
625
|
+
const float* y_norm2,
|
|
626
|
+
const IDSelector* sel) {
|
|
627
|
+
int64_t imin = 0;
|
|
628
|
+
if (auto selr = dynamic_cast<const IDSelectorRange*>(sel)) {
|
|
629
|
+
imin = std::max(selr->imin, int64_t(0));
|
|
630
|
+
int64_t imax = std::min(selr->imax, int64_t(ny));
|
|
631
|
+
ny = imax - imin;
|
|
632
|
+
y += d * imin;
|
|
633
|
+
sel = nullptr;
|
|
634
|
+
}
|
|
635
|
+
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
636
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
|
637
|
+
return;
|
|
638
|
+
}
|
|
639
|
+
if (k == 1) {
|
|
640
|
+
SingleBestResultHandler<CMax<float, int64_t>> res(nx, vals, ids);
|
|
641
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
642
|
+
} else if (k < distance_compute_min_k_reservoir) {
|
|
643
|
+
HeapResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
644
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
351
645
|
} else {
|
|
352
|
-
ReservoirResultHandler<CMax<float, int64_t>> res(
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
646
|
+
ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
647
|
+
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
648
|
+
}
|
|
649
|
+
if (imin != 0) {
|
|
650
|
+
for (size_t i = 0; i < nx * k; i++) {
|
|
651
|
+
if (ids[i] >= 0) {
|
|
652
|
+
ids[i] += imin;
|
|
653
|
+
}
|
|
358
654
|
}
|
|
359
655
|
}
|
|
360
656
|
}
|
|
361
657
|
|
|
658
|
+
void knn_L2sqr(
|
|
659
|
+
const float* x,
|
|
660
|
+
const float* y,
|
|
661
|
+
size_t d,
|
|
662
|
+
size_t nx,
|
|
663
|
+
size_t ny,
|
|
664
|
+
float_maxheap_array_t* res,
|
|
665
|
+
const float* y_norm2,
|
|
666
|
+
const IDSelector* sel) {
|
|
667
|
+
FAISS_THROW_IF_NOT(res->nh == nx);
|
|
668
|
+
knn_L2sqr(x, y, d, nx, ny, res->k, res->val, res->ids, y_norm2, sel);
|
|
669
|
+
}
|
|
670
|
+
|
|
362
671
|
/***************************************************************************
|
|
363
672
|
* Range search
|
|
364
673
|
***************************************************************************/
|
|
@@ -370,10 +679,14 @@ void range_search_L2sqr(
|
|
|
370
679
|
size_t nx,
|
|
371
680
|
size_t ny,
|
|
372
681
|
float radius,
|
|
373
|
-
RangeSearchResult* res
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
682
|
+
RangeSearchResult* res,
|
|
683
|
+
const IDSelector* sel) {
|
|
684
|
+
using RH = RangeSearchResultHandler<CMax<float, int64_t>>;
|
|
685
|
+
RH resh(res, radius);
|
|
686
|
+
if (sel) {
|
|
687
|
+
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
688
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
689
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
|
|
377
690
|
} else {
|
|
378
691
|
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
|
379
692
|
}
|
|
@@ -386,9 +699,13 @@ void range_search_inner_product(
|
|
|
386
699
|
size_t nx,
|
|
387
700
|
size_t ny,
|
|
388
701
|
float radius,
|
|
389
|
-
RangeSearchResult* res
|
|
390
|
-
|
|
391
|
-
|
|
702
|
+
RangeSearchResult* res,
|
|
703
|
+
const IDSelector* sel) {
|
|
704
|
+
using RH = RangeSearchResultHandler<CMin<float, int64_t>>;
|
|
705
|
+
RH resh(res, radius);
|
|
706
|
+
if (sel) {
|
|
707
|
+
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
708
|
+
} else if (nx < distance_compute_blas_threshold) {
|
|
392
709
|
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
|
393
710
|
} else {
|
|
394
711
|
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
|
@@ -486,16 +803,21 @@ void knn_inner_products_by_idx(
|
|
|
486
803
|
size_t d,
|
|
487
804
|
size_t nx,
|
|
488
805
|
size_t ny,
|
|
489
|
-
|
|
490
|
-
|
|
806
|
+
size_t k,
|
|
807
|
+
float* res_vals,
|
|
808
|
+
int64_t* res_ids,
|
|
809
|
+
int64_t ld_ids) {
|
|
810
|
+
if (ld_ids < 0) {
|
|
811
|
+
ld_ids = ny;
|
|
812
|
+
}
|
|
491
813
|
|
|
492
|
-
#pragma omp parallel for
|
|
814
|
+
#pragma omp parallel for if (nx > 100)
|
|
493
815
|
for (int64_t i = 0; i < nx; i++) {
|
|
494
816
|
const float* x_ = x + i * d;
|
|
495
|
-
const int64_t* idsi = ids + i *
|
|
817
|
+
const int64_t* idsi = ids + i * ld_ids;
|
|
496
818
|
size_t j;
|
|
497
|
-
float* __restrict simi =
|
|
498
|
-
int64_t* __restrict idxi =
|
|
819
|
+
float* __restrict simi = res_vals + i * k;
|
|
820
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
499
821
|
minheap_heapify(k, simi, idxi);
|
|
500
822
|
|
|
501
823
|
for (j = 0; j < ny; j++) {
|
|
@@ -518,16 +840,20 @@ void knn_L2sqr_by_idx(
|
|
|
518
840
|
size_t d,
|
|
519
841
|
size_t nx,
|
|
520
842
|
size_t ny,
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
843
|
+
size_t k,
|
|
844
|
+
float* res_vals,
|
|
845
|
+
int64_t* res_ids,
|
|
846
|
+
int64_t ld_ids) {
|
|
847
|
+
if (ld_ids < 0) {
|
|
848
|
+
ld_ids = ny;
|
|
849
|
+
}
|
|
850
|
+
#pragma omp parallel for if (nx > 100)
|
|
525
851
|
for (int64_t i = 0; i < nx; i++) {
|
|
526
852
|
const float* x_ = x + i * d;
|
|
527
|
-
const int64_t* __restrict idsi = ids + i *
|
|
528
|
-
float* __restrict simi =
|
|
529
|
-
int64_t* __restrict idxi =
|
|
530
|
-
maxheap_heapify(
|
|
853
|
+
const int64_t* __restrict idsi = ids + i * ld_ids;
|
|
854
|
+
float* __restrict simi = res_vals + i * k;
|
|
855
|
+
int64_t* __restrict idxi = res_ids + i * k;
|
|
856
|
+
maxheap_heapify(k, simi, idxi);
|
|
531
857
|
for (size_t j = 0; j < ny; j++) {
|
|
532
858
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
|
533
859
|
|
|
@@ -535,7 +861,7 @@ void knn_L2sqr_by_idx(
|
|
|
535
861
|
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
|
536
862
|
}
|
|
537
863
|
}
|
|
538
|
-
maxheap_reorder(
|
|
864
|
+
maxheap_reorder(k, simi, idxi);
|
|
539
865
|
}
|
|
540
866
|
}
|
|
541
867
|
|