faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -7,8 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/IndexFastScan.h>
|
|
9
9
|
|
|
10
|
-
#include <limits.h>
|
|
11
10
|
#include <cassert>
|
|
11
|
+
#include <climits>
|
|
12
12
|
#include <memory>
|
|
13
13
|
|
|
14
14
|
#include <omp.h>
|
|
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
|
|
|
37
37
|
|
|
38
38
|
void IndexFastScan::init_fastscan(
|
|
39
39
|
int d,
|
|
40
|
-
size_t
|
|
41
|
-
size_t
|
|
40
|
+
size_t M_2,
|
|
41
|
+
size_t nbits_2,
|
|
42
42
|
MetricType metric,
|
|
43
43
|
int bbs) {
|
|
44
|
-
FAISS_THROW_IF_NOT(
|
|
44
|
+
FAISS_THROW_IF_NOT(nbits_2 == 4);
|
|
45
45
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
46
46
|
this->d = d;
|
|
47
|
-
this->M =
|
|
48
|
-
this->nbits =
|
|
47
|
+
this->M = M_2;
|
|
48
|
+
this->nbits = nbits_2;
|
|
49
49
|
this->metric_type = metric;
|
|
50
50
|
this->bbs = bbs;
|
|
51
|
-
ksub = (1 <<
|
|
51
|
+
ksub = (1 << nbits_2);
|
|
52
52
|
|
|
53
|
-
code_size = (
|
|
53
|
+
code_size = (M_2 * nbits_2 + 7) / 8;
|
|
54
54
|
ntotal = ntotal2 = 0;
|
|
55
|
-
M2 = roundup(
|
|
55
|
+
M2 = roundup(M_2, 2);
|
|
56
56
|
is_trained = false;
|
|
57
57
|
}
|
|
58
58
|
|
|
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
|
|
|
158
158
|
|
|
159
159
|
namespace {
|
|
160
160
|
|
|
161
|
-
template <class C, typename dis_t
|
|
161
|
+
template <class C, typename dis_t>
|
|
162
162
|
void estimators_from_tables_generic(
|
|
163
163
|
const IndexFastScan& index,
|
|
164
164
|
const uint8_t* codes,
|
|
@@ -167,25 +167,28 @@ void estimators_from_tables_generic(
|
|
|
167
167
|
size_t k,
|
|
168
168
|
typename C::T* heap_dis,
|
|
169
169
|
int64_t* heap_ids,
|
|
170
|
-
const
|
|
170
|
+
const NormTableScaler* scaler) {
|
|
171
171
|
using accu_t = typename C::T;
|
|
172
172
|
|
|
173
173
|
for (size_t j = 0; j < ncodes; ++j) {
|
|
174
174
|
BitstringReader bsr(codes + j * index.code_size, index.code_size);
|
|
175
175
|
accu_t dis = 0;
|
|
176
176
|
const dis_t* dt = dis_table;
|
|
177
|
-
|
|
177
|
+
int nscale = scaler ? scaler->nscale : 0;
|
|
178
|
+
|
|
179
|
+
for (size_t m = 0; m < index.M - nscale; m++) {
|
|
178
180
|
uint64_t c = bsr.read(index.nbits);
|
|
179
181
|
dis += dt[c];
|
|
180
182
|
dt += index.ksub;
|
|
181
183
|
}
|
|
182
184
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
185
|
+
if (nscale) {
|
|
186
|
+
for (size_t m = 0; m < nscale; m++) {
|
|
187
|
+
uint64_t c = bsr.read(index.nbits);
|
|
188
|
+
dis += scaler->scale_one(dt[c]);
|
|
189
|
+
dt += index.ksub;
|
|
190
|
+
}
|
|
187
191
|
}
|
|
188
|
-
|
|
189
192
|
if (C::cmp(heap_dis[0], dis)) {
|
|
190
193
|
heap_pop<C>(k, heap_dis, heap_ids);
|
|
191
194
|
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
|
@@ -193,6 +196,27 @@ void estimators_from_tables_generic(
|
|
|
193
196
|
}
|
|
194
197
|
}
|
|
195
198
|
|
|
199
|
+
template <class C>
|
|
200
|
+
ResultHandlerCompare<C, false>* make_knn_handler(
|
|
201
|
+
int impl,
|
|
202
|
+
idx_t n,
|
|
203
|
+
idx_t k,
|
|
204
|
+
size_t ntotal,
|
|
205
|
+
float* distances,
|
|
206
|
+
idx_t* labels) {
|
|
207
|
+
using HeapHC = HeapHandler<C, false>;
|
|
208
|
+
using ReservoirHC = ReservoirHandler<C, false>;
|
|
209
|
+
using SingleResultHC = SingleResultHandler<C, false>;
|
|
210
|
+
|
|
211
|
+
if (k == 1) {
|
|
212
|
+
return new SingleResultHC(n, ntotal, distances, labels);
|
|
213
|
+
} else if (impl % 2 == 0) {
|
|
214
|
+
return new HeapHC(n, ntotal, k, distances, labels);
|
|
215
|
+
} else /* if (impl % 2 == 1) */ {
|
|
216
|
+
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
196
220
|
} // anonymous namespace
|
|
197
221
|
|
|
198
222
|
using namespace quantize_lut;
|
|
@@ -241,22 +265,21 @@ void IndexFastScan::search(
|
|
|
241
265
|
!params, "search params not supported for this index");
|
|
242
266
|
FAISS_THROW_IF_NOT(k > 0);
|
|
243
267
|
|
|
244
|
-
DummyScaler scaler;
|
|
245
268
|
if (metric_type == METRIC_L2) {
|
|
246
|
-
search_dispatch_implem<true>(n, x, k, distances, labels,
|
|
269
|
+
search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
|
|
247
270
|
} else {
|
|
248
|
-
search_dispatch_implem<false>(n, x, k, distances, labels,
|
|
271
|
+
search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
|
|
249
272
|
}
|
|
250
273
|
}
|
|
251
274
|
|
|
252
|
-
template <bool is_max
|
|
275
|
+
template <bool is_max>
|
|
253
276
|
void IndexFastScan::search_dispatch_implem(
|
|
254
277
|
idx_t n,
|
|
255
278
|
const float* x,
|
|
256
279
|
idx_t k,
|
|
257
280
|
float* distances,
|
|
258
281
|
idx_t* labels,
|
|
259
|
-
const
|
|
282
|
+
const NormTableScaler* scaler) const {
|
|
260
283
|
using Cfloat = typename std::conditional<
|
|
261
284
|
is_max,
|
|
262
285
|
CMax<float, int64_t>,
|
|
@@ -319,14 +342,14 @@ void IndexFastScan::search_dispatch_implem(
|
|
|
319
342
|
}
|
|
320
343
|
}
|
|
321
344
|
|
|
322
|
-
template <class Cfloat
|
|
345
|
+
template <class Cfloat>
|
|
323
346
|
void IndexFastScan::search_implem_234(
|
|
324
347
|
idx_t n,
|
|
325
348
|
const float* x,
|
|
326
349
|
idx_t k,
|
|
327
350
|
float* distances,
|
|
328
351
|
idx_t* labels,
|
|
329
|
-
const
|
|
352
|
+
const NormTableScaler* scaler) const {
|
|
330
353
|
FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
|
|
331
354
|
|
|
332
355
|
const size_t dim12 = ksub * M;
|
|
@@ -378,7 +401,7 @@ void IndexFastScan::search_implem_234(
|
|
|
378
401
|
}
|
|
379
402
|
}
|
|
380
403
|
|
|
381
|
-
template <class C
|
|
404
|
+
template <class C>
|
|
382
405
|
void IndexFastScan::search_implem_12(
|
|
383
406
|
idx_t n,
|
|
384
407
|
const float* x,
|
|
@@ -386,7 +409,8 @@ void IndexFastScan::search_implem_12(
|
|
|
386
409
|
float* distances,
|
|
387
410
|
idx_t* labels,
|
|
388
411
|
int impl,
|
|
389
|
-
const
|
|
412
|
+
const NormTableScaler* scaler) const {
|
|
413
|
+
using RH = ResultHandlerCompare<C, false>;
|
|
390
414
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
391
415
|
|
|
392
416
|
// handle qbs2 blocking by recursive call
|
|
@@ -432,63 +456,31 @@ void IndexFastScan::search_implem_12(
|
|
|
432
456
|
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
|
|
433
457
|
FAISS_THROW_IF_NOT(LUT_nq == n);
|
|
434
458
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
} else {
|
|
440
|
-
handler.disable = bool(skip & 2);
|
|
441
|
-
pq4_accumulate_loop_qbs(
|
|
442
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
|
|
443
|
-
}
|
|
459
|
+
std::unique_ptr<RH> handler(
|
|
460
|
+
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
|
|
461
|
+
handler->disable = bool(skip & 2);
|
|
462
|
+
handler->normalizers = normalizers.get();
|
|
444
463
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
} else
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
|
|
460
|
-
|
|
461
|
-
if (!(skip & 8)) {
|
|
462
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
463
|
-
}
|
|
464
|
-
}
|
|
465
|
-
|
|
466
|
-
} else { // impl == 13
|
|
467
|
-
|
|
468
|
-
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
|
469
|
-
handler.disable = bool(skip & 2);
|
|
470
|
-
|
|
471
|
-
if (skip & 4) {
|
|
472
|
-
// skip
|
|
473
|
-
} else {
|
|
474
|
-
pq4_accumulate_loop_qbs(
|
|
475
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
|
|
476
|
-
}
|
|
477
|
-
|
|
478
|
-
if (!(skip & 8)) {
|
|
479
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
480
|
-
}
|
|
481
|
-
|
|
482
|
-
FastScan_stats.t0 += handler.times[0];
|
|
483
|
-
FastScan_stats.t1 += handler.times[1];
|
|
484
|
-
FastScan_stats.t2 += handler.times[2];
|
|
485
|
-
FastScan_stats.t3 += handler.times[3];
|
|
464
|
+
if (skip & 4) {
|
|
465
|
+
// pass
|
|
466
|
+
} else {
|
|
467
|
+
pq4_accumulate_loop_qbs(
|
|
468
|
+
qbs,
|
|
469
|
+
ntotal2,
|
|
470
|
+
M2,
|
|
471
|
+
codes.get(),
|
|
472
|
+
LUT.get(),
|
|
473
|
+
*handler.get(),
|
|
474
|
+
scaler);
|
|
475
|
+
}
|
|
476
|
+
if (!(skip & 8)) {
|
|
477
|
+
handler->end();
|
|
486
478
|
}
|
|
487
479
|
}
|
|
488
480
|
|
|
489
481
|
FastScanStats FastScan_stats;
|
|
490
482
|
|
|
491
|
-
template <class C
|
|
483
|
+
template <class C>
|
|
492
484
|
void IndexFastScan::search_implem_14(
|
|
493
485
|
idx_t n,
|
|
494
486
|
const float* x,
|
|
@@ -496,7 +488,8 @@ void IndexFastScan::search_implem_14(
|
|
|
496
488
|
float* distances,
|
|
497
489
|
idx_t* labels,
|
|
498
490
|
int impl,
|
|
499
|
-
const
|
|
491
|
+
const NormTableScaler* scaler) const {
|
|
492
|
+
using RH = ResultHandlerCompare<C, false>;
|
|
500
493
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
501
494
|
|
|
502
495
|
int qbs2 = qbs == 0 ? 4 : qbs;
|
|
@@ -531,91 +524,29 @@ void IndexFastScan::search_implem_14(
|
|
|
531
524
|
AlignedTable<uint8_t> LUT(n * dim12);
|
|
532
525
|
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
|
|
533
526
|
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
} else {
|
|
539
|
-
handler.disable = bool(skip & 2);
|
|
540
|
-
pq4_accumulate_loop(
|
|
541
|
-
n,
|
|
542
|
-
ntotal2,
|
|
543
|
-
bbs,
|
|
544
|
-
M2,
|
|
545
|
-
codes.get(),
|
|
546
|
-
LUT.get(),
|
|
547
|
-
handler,
|
|
548
|
-
scaler);
|
|
549
|
-
}
|
|
550
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
551
|
-
|
|
552
|
-
} else if (impl == 14) {
|
|
553
|
-
std::vector<uint16_t> tmp_dis(n * k);
|
|
554
|
-
std::vector<int32_t> tmp_ids(n * k);
|
|
555
|
-
|
|
556
|
-
if (skip & 4) {
|
|
557
|
-
// skip
|
|
558
|
-
} else if (k > 1) {
|
|
559
|
-
HeapHandler<C> handler(
|
|
560
|
-
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
|
561
|
-
handler.disable = bool(skip & 2);
|
|
562
|
-
|
|
563
|
-
pq4_accumulate_loop(
|
|
564
|
-
n,
|
|
565
|
-
ntotal2,
|
|
566
|
-
bbs,
|
|
567
|
-
M2,
|
|
568
|
-
codes.get(),
|
|
569
|
-
LUT.get(),
|
|
570
|
-
handler,
|
|
571
|
-
scaler);
|
|
572
|
-
|
|
573
|
-
if (!(skip & 8)) {
|
|
574
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
575
|
-
}
|
|
576
|
-
}
|
|
577
|
-
|
|
578
|
-
} else { // impl == 15
|
|
579
|
-
|
|
580
|
-
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
|
581
|
-
handler.disable = bool(skip & 2);
|
|
582
|
-
|
|
583
|
-
if (skip & 4) {
|
|
584
|
-
// skip
|
|
585
|
-
} else {
|
|
586
|
-
pq4_accumulate_loop(
|
|
587
|
-
n,
|
|
588
|
-
ntotal2,
|
|
589
|
-
bbs,
|
|
590
|
-
M2,
|
|
591
|
-
codes.get(),
|
|
592
|
-
LUT.get(),
|
|
593
|
-
handler,
|
|
594
|
-
scaler);
|
|
595
|
-
}
|
|
527
|
+
std::unique_ptr<RH> handler(
|
|
528
|
+
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
|
|
529
|
+
handler->disable = bool(skip & 2);
|
|
530
|
+
handler->normalizers = normalizers.get();
|
|
596
531
|
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
532
|
+
if (skip & 4) {
|
|
533
|
+
// pass
|
|
534
|
+
} else {
|
|
535
|
+
pq4_accumulate_loop(
|
|
536
|
+
n,
|
|
537
|
+
ntotal2,
|
|
538
|
+
bbs,
|
|
539
|
+
M2,
|
|
540
|
+
codes.get(),
|
|
541
|
+
LUT.get(),
|
|
542
|
+
*handler.get(),
|
|
543
|
+
scaler);
|
|
544
|
+
}
|
|
545
|
+
if (!(skip & 8)) {
|
|
546
|
+
handler->end();
|
|
600
547
|
}
|
|
601
548
|
}
|
|
602
549
|
|
|
603
|
-
template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
|
|
604
|
-
idx_t n,
|
|
605
|
-
const float* x,
|
|
606
|
-
idx_t k,
|
|
607
|
-
float* distances,
|
|
608
|
-
idx_t* labels,
|
|
609
|
-
const NormTableScaler& scaler) const;
|
|
610
|
-
|
|
611
|
-
template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
|
|
612
|
-
idx_t n,
|
|
613
|
-
const float* x,
|
|
614
|
-
idx_t k,
|
|
615
|
-
float* distances,
|
|
616
|
-
idx_t* labels,
|
|
617
|
-
const NormTableScaler& scaler) const;
|
|
618
|
-
|
|
619
550
|
void IndexFastScan::reconstruct(idx_t key, float* recons) const {
|
|
620
551
|
std::vector<uint8_t> code(code_size, 0);
|
|
621
552
|
BitstringWriter bsw(code.data(), code_size);
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
namespace faiss {
|
|
14
14
|
|
|
15
15
|
struct CodePacker;
|
|
16
|
+
struct NormTableScaler;
|
|
16
17
|
|
|
17
18
|
/** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
|
|
18
19
|
*
|
|
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
|
|
|
87
88
|
uint8_t* lut,
|
|
88
89
|
float* normalizers) const;
|
|
89
90
|
|
|
90
|
-
template <bool is_max
|
|
91
|
+
template <bool is_max>
|
|
91
92
|
void search_dispatch_implem(
|
|
92
93
|
idx_t n,
|
|
93
94
|
const float* x,
|
|
94
95
|
idx_t k,
|
|
95
96
|
float* distances,
|
|
96
97
|
idx_t* labels,
|
|
97
|
-
const
|
|
98
|
+
const NormTableScaler* scaler) const;
|
|
98
99
|
|
|
99
|
-
template <class Cfloat
|
|
100
|
+
template <class Cfloat>
|
|
100
101
|
void search_implem_234(
|
|
101
102
|
idx_t n,
|
|
102
103
|
const float* x,
|
|
103
104
|
idx_t k,
|
|
104
105
|
float* distances,
|
|
105
106
|
idx_t* labels,
|
|
106
|
-
const
|
|
107
|
+
const NormTableScaler* scaler) const;
|
|
107
108
|
|
|
108
|
-
template <class C
|
|
109
|
+
template <class C>
|
|
109
110
|
void search_implem_12(
|
|
110
111
|
idx_t n,
|
|
111
112
|
const float* x,
|
|
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
|
|
|
113
114
|
float* distances,
|
|
114
115
|
idx_t* labels,
|
|
115
116
|
int impl,
|
|
116
|
-
const
|
|
117
|
+
const NormTableScaler* scaler) const;
|
|
117
118
|
|
|
118
|
-
template <class C
|
|
119
|
+
template <class C>
|
|
119
120
|
void search_implem_14(
|
|
120
121
|
idx_t n,
|
|
121
122
|
const float* x,
|
|
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
|
|
|
123
124
|
float* distances,
|
|
124
125
|
idx_t* labels,
|
|
125
126
|
int impl,
|
|
126
|
-
const
|
|
127
|
+
const NormTableScaler* scaler) const;
|
|
127
128
|
|
|
128
129
|
void reconstruct(idx_t key, float* recons) const override;
|
|
129
130
|
size_t remove_ids(const IDSelector& sel) override;
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
#include <faiss/utils/Heap.h>
|
|
15
15
|
#include <faiss/utils/distances.h>
|
|
16
16
|
#include <faiss/utils/extra_distances.h>
|
|
17
|
+
#include <faiss/utils/prefetch.h>
|
|
17
18
|
#include <faiss/utils/sorting.h>
|
|
18
19
|
#include <faiss/utils/utils.h>
|
|
19
20
|
#include <cstring>
|
|
@@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
|
|
|
122
123
|
void set_query(const float* x) override {
|
|
123
124
|
q = x;
|
|
124
125
|
}
|
|
126
|
+
|
|
127
|
+
// compute four distances
|
|
128
|
+
void distances_batch_4(
|
|
129
|
+
const idx_t idx0,
|
|
130
|
+
const idx_t idx1,
|
|
131
|
+
const idx_t idx2,
|
|
132
|
+
const idx_t idx3,
|
|
133
|
+
float& dis0,
|
|
134
|
+
float& dis1,
|
|
135
|
+
float& dis2,
|
|
136
|
+
float& dis3) final override {
|
|
137
|
+
ndis += 4;
|
|
138
|
+
|
|
139
|
+
// compute first, assign next
|
|
140
|
+
const float* __restrict y0 =
|
|
141
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
142
|
+
const float* __restrict y1 =
|
|
143
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
144
|
+
const float* __restrict y2 =
|
|
145
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
146
|
+
const float* __restrict y3 =
|
|
147
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
148
|
+
|
|
149
|
+
float dp0 = 0;
|
|
150
|
+
float dp1 = 0;
|
|
151
|
+
float dp2 = 0;
|
|
152
|
+
float dp3 = 0;
|
|
153
|
+
fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
154
|
+
dis0 = dp0;
|
|
155
|
+
dis1 = dp1;
|
|
156
|
+
dis2 = dp2;
|
|
157
|
+
dis3 = dp3;
|
|
158
|
+
}
|
|
125
159
|
};
|
|
126
160
|
|
|
127
161
|
struct FlatIPDis : FlatCodesDistanceComputer {
|
|
@@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
|
|
|
131
165
|
const float* b;
|
|
132
166
|
size_t ndis;
|
|
133
167
|
|
|
134
|
-
float symmetric_dis(idx_t i, idx_t j) override {
|
|
168
|
+
float symmetric_dis(idx_t i, idx_t j) final override {
|
|
135
169
|
return fvec_inner_product(b + j * d, b + i * d, d);
|
|
136
170
|
}
|
|
137
171
|
|
|
138
|
-
float distance_to_code(const uint8_t* code) final {
|
|
172
|
+
float distance_to_code(const uint8_t* code) final override {
|
|
139
173
|
ndis++;
|
|
140
|
-
return fvec_inner_product(q, (float*)code, d);
|
|
174
|
+
return fvec_inner_product(q, (const float*)code, d);
|
|
141
175
|
}
|
|
142
176
|
|
|
143
177
|
explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
|
|
@@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
|
|
|
153
187
|
void set_query(const float* x) override {
|
|
154
188
|
q = x;
|
|
155
189
|
}
|
|
190
|
+
|
|
191
|
+
// compute four distances
|
|
192
|
+
void distances_batch_4(
|
|
193
|
+
const idx_t idx0,
|
|
194
|
+
const idx_t idx1,
|
|
195
|
+
const idx_t idx2,
|
|
196
|
+
const idx_t idx3,
|
|
197
|
+
float& dis0,
|
|
198
|
+
float& dis1,
|
|
199
|
+
float& dis2,
|
|
200
|
+
float& dis3) final override {
|
|
201
|
+
ndis += 4;
|
|
202
|
+
|
|
203
|
+
// compute first, assign next
|
|
204
|
+
const float* __restrict y0 =
|
|
205
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
206
|
+
const float* __restrict y1 =
|
|
207
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
208
|
+
const float* __restrict y2 =
|
|
209
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
210
|
+
const float* __restrict y3 =
|
|
211
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
212
|
+
|
|
213
|
+
float dp0 = 0;
|
|
214
|
+
float dp1 = 0;
|
|
215
|
+
float dp2 = 0;
|
|
216
|
+
float dp3 = 0;
|
|
217
|
+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
218
|
+
dis0 = dp0;
|
|
219
|
+
dis1 = dp1;
|
|
220
|
+
dis2 = dp2;
|
|
221
|
+
dis3 = dp3;
|
|
222
|
+
}
|
|
156
223
|
};
|
|
157
224
|
|
|
158
225
|
} // namespace
|
|
@@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
|
184
251
|
}
|
|
185
252
|
}
|
|
186
253
|
|
|
254
|
+
/***************************************************
|
|
255
|
+
* IndexFlatL2
|
|
256
|
+
***************************************************/
|
|
257
|
+
|
|
258
|
+
namespace {
|
|
259
|
+
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
|
|
260
|
+
size_t d;
|
|
261
|
+
idx_t nb;
|
|
262
|
+
const float* q;
|
|
263
|
+
const float* b;
|
|
264
|
+
size_t ndis;
|
|
265
|
+
|
|
266
|
+
const float* l2norms;
|
|
267
|
+
float query_l2norm;
|
|
268
|
+
|
|
269
|
+
float distance_to_code(const uint8_t* code) final override {
|
|
270
|
+
ndis++;
|
|
271
|
+
return fvec_L2sqr(q, (float*)code, d);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
float operator()(const idx_t i) final override {
|
|
275
|
+
const float* __restrict y =
|
|
276
|
+
reinterpret_cast<const float*>(codes + i * code_size);
|
|
277
|
+
|
|
278
|
+
prefetch_L2(l2norms + i);
|
|
279
|
+
const float dp0 = fvec_inner_product(q, y, d);
|
|
280
|
+
return query_l2norm + l2norms[i] - 2 * dp0;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
float symmetric_dis(idx_t i, idx_t j) final override {
|
|
284
|
+
const float* __restrict yi =
|
|
285
|
+
reinterpret_cast<const float*>(codes + i * code_size);
|
|
286
|
+
const float* __restrict yj =
|
|
287
|
+
reinterpret_cast<const float*>(codes + j * code_size);
|
|
288
|
+
|
|
289
|
+
prefetch_L2(l2norms + i);
|
|
290
|
+
prefetch_L2(l2norms + j);
|
|
291
|
+
const float dp0 = fvec_inner_product(yi, yj, d);
|
|
292
|
+
return l2norms[i] + l2norms[j] - 2 * dp0;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
explicit FlatL2WithNormsDis(
|
|
296
|
+
const IndexFlatL2& storage,
|
|
297
|
+
const float* q = nullptr)
|
|
298
|
+
: FlatCodesDistanceComputer(
|
|
299
|
+
storage.codes.data(),
|
|
300
|
+
storage.code_size),
|
|
301
|
+
d(storage.d),
|
|
302
|
+
nb(storage.ntotal),
|
|
303
|
+
q(q),
|
|
304
|
+
b(storage.get_xb()),
|
|
305
|
+
ndis(0),
|
|
306
|
+
l2norms(storage.cached_l2norms.data()),
|
|
307
|
+
query_l2norm(0) {}
|
|
308
|
+
|
|
309
|
+
void set_query(const float* x) override {
|
|
310
|
+
q = x;
|
|
311
|
+
query_l2norm = fvec_norm_L2sqr(q, d);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
// compute four distances
|
|
315
|
+
void distances_batch_4(
|
|
316
|
+
const idx_t idx0,
|
|
317
|
+
const idx_t idx1,
|
|
318
|
+
const idx_t idx2,
|
|
319
|
+
const idx_t idx3,
|
|
320
|
+
float& dis0,
|
|
321
|
+
float& dis1,
|
|
322
|
+
float& dis2,
|
|
323
|
+
float& dis3) final override {
|
|
324
|
+
ndis += 4;
|
|
325
|
+
|
|
326
|
+
// compute first, assign next
|
|
327
|
+
const float* __restrict y0 =
|
|
328
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
329
|
+
const float* __restrict y1 =
|
|
330
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
331
|
+
const float* __restrict y2 =
|
|
332
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
333
|
+
const float* __restrict y3 =
|
|
334
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
335
|
+
|
|
336
|
+
prefetch_L2(l2norms + idx0);
|
|
337
|
+
prefetch_L2(l2norms + idx1);
|
|
338
|
+
prefetch_L2(l2norms + idx2);
|
|
339
|
+
prefetch_L2(l2norms + idx3);
|
|
340
|
+
|
|
341
|
+
float dp0 = 0;
|
|
342
|
+
float dp1 = 0;
|
|
343
|
+
float dp2 = 0;
|
|
344
|
+
float dp3 = 0;
|
|
345
|
+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
346
|
+
dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
|
|
347
|
+
dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
|
|
348
|
+
dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
|
|
349
|
+
dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
|
|
350
|
+
}
|
|
351
|
+
};
|
|
352
|
+
|
|
353
|
+
} // namespace
|
|
354
|
+
|
|
355
|
+
void IndexFlatL2::sync_l2norms() {
|
|
356
|
+
cached_l2norms.resize(ntotal);
|
|
357
|
+
fvec_norms_L2sqr(
|
|
358
|
+
cached_l2norms.data(),
|
|
359
|
+
reinterpret_cast<const float*>(codes.data()),
|
|
360
|
+
d,
|
|
361
|
+
ntotal);
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
void IndexFlatL2::clear_l2norms() {
|
|
365
|
+
cached_l2norms.clear();
|
|
366
|
+
cached_l2norms.shrink_to_fit();
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
|
|
370
|
+
if (metric_type == METRIC_L2) {
|
|
371
|
+
if (!cached_l2norms.empty()) {
|
|
372
|
+
return new FlatL2WithNormsDis(*this);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
return IndexFlat::get_FlatCodesDistanceComputer();
|
|
377
|
+
}
|
|
378
|
+
|
|
187
379
|
/***************************************************
|
|
188
380
|
* IndexFlat1D
|
|
189
381
|
***************************************************/
|