faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -43,6 +43,8 @@ IndexIVFFastScan::IndexIVFFastScan(
|
|
|
43
43
|
size_t code_size,
|
|
44
44
|
MetricType metric)
|
|
45
45
|
: IndexIVF(quantizer, d, nlist, code_size, metric) {
|
|
46
|
+
// unlike other indexes, we prefer no residuals for performance reasons.
|
|
47
|
+
by_residual = false;
|
|
46
48
|
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
47
49
|
}
|
|
48
50
|
|
|
@@ -50,6 +52,7 @@ IndexIVFFastScan::IndexIVFFastScan() {
|
|
|
50
52
|
bbs = 0;
|
|
51
53
|
M2 = 0;
|
|
52
54
|
is_trained = false;
|
|
55
|
+
by_residual = false;
|
|
53
56
|
}
|
|
54
57
|
|
|
55
58
|
void IndexIVFFastScan::init_fastscan(
|
|
@@ -79,7 +82,7 @@ void IndexIVFFastScan::init_code_packer() {
|
|
|
79
82
|
bil->packer = get_CodePacker();
|
|
80
83
|
}
|
|
81
84
|
|
|
82
|
-
IndexIVFFastScan::~IndexIVFFastScan()
|
|
85
|
+
IndexIVFFastScan::~IndexIVFFastScan() = default;
|
|
83
86
|
|
|
84
87
|
/*********************************************************
|
|
85
88
|
* Code management functions
|
|
@@ -195,7 +198,7 @@ CodePacker* IndexIVFFastScan::get_CodePacker() const {
|
|
|
195
198
|
|
|
196
199
|
namespace {
|
|
197
200
|
|
|
198
|
-
template <class C, typename dis_t
|
|
201
|
+
template <class C, typename dis_t>
|
|
199
202
|
void estimators_from_tables_generic(
|
|
200
203
|
const IndexIVFFastScan& index,
|
|
201
204
|
const uint8_t* codes,
|
|
@@ -206,22 +209,26 @@ void estimators_from_tables_generic(
|
|
|
206
209
|
size_t k,
|
|
207
210
|
typename C::T* heap_dis,
|
|
208
211
|
int64_t* heap_ids,
|
|
209
|
-
const
|
|
212
|
+
const NormTableScaler* scaler) {
|
|
210
213
|
using accu_t = typename C::T;
|
|
214
|
+
size_t nscale = scaler ? scaler->nscale : 0;
|
|
211
215
|
for (size_t j = 0; j < ncodes; ++j) {
|
|
212
216
|
BitstringReader bsr(codes + j * index.code_size, index.code_size);
|
|
213
217
|
accu_t dis = bias;
|
|
214
218
|
const dis_t* __restrict dt = dis_table;
|
|
215
|
-
|
|
219
|
+
|
|
220
|
+
for (size_t m = 0; m < index.M - nscale; m++) {
|
|
216
221
|
uint64_t c = bsr.read(index.nbits);
|
|
217
222
|
dis += dt[c];
|
|
218
223
|
dt += index.ksub;
|
|
219
224
|
}
|
|
220
225
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
226
|
+
if (scaler) {
|
|
227
|
+
for (size_t m = 0; m < nscale; m++) {
|
|
228
|
+
uint64_t c = bsr.read(index.nbits);
|
|
229
|
+
dis += scaler->scale_one(dt[c]);
|
|
230
|
+
dt += index.ksub;
|
|
231
|
+
}
|
|
225
232
|
}
|
|
226
233
|
|
|
227
234
|
if (C::cmp(heap_dis[0], dis)) {
|
|
@@ -242,18 +249,15 @@ using namespace quantize_lut;
|
|
|
242
249
|
void IndexIVFFastScan::compute_LUT_uint8(
|
|
243
250
|
size_t n,
|
|
244
251
|
const float* x,
|
|
245
|
-
const
|
|
246
|
-
const float* coarse_dis,
|
|
252
|
+
const CoarseQuantized& cq,
|
|
247
253
|
AlignedTable<uint8_t>& dis_tables,
|
|
248
254
|
AlignedTable<uint16_t>& biases,
|
|
249
255
|
float* normalizers) const {
|
|
250
256
|
AlignedTable<float> dis_tables_float;
|
|
251
257
|
AlignedTable<float> biases_float;
|
|
252
258
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
|
|
256
|
-
|
|
259
|
+
compute_LUT(n, x, cq, dis_tables_float, biases_float);
|
|
260
|
+
size_t nprobe = cq.nprobe;
|
|
257
261
|
bool lut_is_3d = lookup_table_is_3d();
|
|
258
262
|
size_t dim123 = ksub * M;
|
|
259
263
|
size_t dim123_2 = ksub * M2;
|
|
@@ -265,8 +269,8 @@ void IndexIVFFastScan::compute_LUT_uint8(
|
|
|
265
269
|
if (biases_float.get()) {
|
|
266
270
|
biases.resize(n * nprobe);
|
|
267
271
|
}
|
|
268
|
-
uint64_t t1 = get_cy();
|
|
269
272
|
|
|
273
|
+
// OMP for MSVC requires i to have signed integral type
|
|
270
274
|
#pragma omp parallel for if (n > 100)
|
|
271
275
|
for (int64_t i = 0; i < n; i++) {
|
|
272
276
|
const float* t_in = dis_tables_float.get() + i * dim123;
|
|
@@ -291,7 +295,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
|
|
|
291
295
|
normalizers + 2 * i,
|
|
292
296
|
normalizers + 2 * i + 1);
|
|
293
297
|
}
|
|
294
|
-
IVFFastScan_stats.t_round += get_cy() - t1;
|
|
295
298
|
}
|
|
296
299
|
|
|
297
300
|
/*********************************************************
|
|
@@ -304,45 +307,195 @@ void IndexIVFFastScan::search(
|
|
|
304
307
|
idx_t k,
|
|
305
308
|
float* distances,
|
|
306
309
|
idx_t* labels,
|
|
307
|
-
const SearchParameters*
|
|
310
|
+
const SearchParameters* params_in) const {
|
|
311
|
+
const IVFSearchParameters* params = nullptr;
|
|
312
|
+
if (params_in) {
|
|
313
|
+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
314
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
315
|
+
params, "IndexIVFFastScan params have incorrect type");
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
search_preassigned(
|
|
319
|
+
n, x, k, nullptr, nullptr, distances, labels, false, params);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
void IndexIVFFastScan::search_preassigned(
|
|
323
|
+
idx_t n,
|
|
324
|
+
const float* x,
|
|
325
|
+
idx_t k,
|
|
326
|
+
const idx_t* assign,
|
|
327
|
+
const float* centroid_dis,
|
|
328
|
+
float* distances,
|
|
329
|
+
idx_t* labels,
|
|
330
|
+
bool store_pairs,
|
|
331
|
+
const IVFSearchParameters* params,
|
|
332
|
+
IndexIVFStats* stats) const {
|
|
333
|
+
size_t nprobe = this->nprobe;
|
|
334
|
+
if (params) {
|
|
335
|
+
FAISS_THROW_IF_NOT(params->max_codes == 0);
|
|
336
|
+
nprobe = params->nprobe;
|
|
337
|
+
}
|
|
338
|
+
|
|
308
339
|
FAISS_THROW_IF_NOT_MSG(
|
|
309
|
-
!
|
|
340
|
+
!store_pairs, "store_pairs not supported for this index");
|
|
341
|
+
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
|
|
310
342
|
FAISS_THROW_IF_NOT(k > 0);
|
|
311
343
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
344
|
+
const CoarseQuantized cq = {nprobe, centroid_dis, assign};
|
|
345
|
+
search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
void IndexIVFFastScan::range_search(
|
|
349
|
+
idx_t n,
|
|
350
|
+
const float* x,
|
|
351
|
+
float radius,
|
|
352
|
+
RangeSearchResult* result,
|
|
353
|
+
const SearchParameters* params_in) const {
|
|
354
|
+
size_t nprobe = this->nprobe;
|
|
355
|
+
const IVFSearchParameters* params = nullptr;
|
|
356
|
+
if (params_in) {
|
|
357
|
+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
358
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
359
|
+
params, "IndexIVFFastScan params have incorrect type");
|
|
360
|
+
nprobe = params->nprobe;
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
const CoarseQuantized cq = {nprobe, nullptr, nullptr};
|
|
364
|
+
range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
namespace {
|
|
368
|
+
|
|
369
|
+
template <class C>
|
|
370
|
+
ResultHandlerCompare<C, true>* make_knn_handler_fixC(
|
|
371
|
+
int impl,
|
|
372
|
+
idx_t n,
|
|
373
|
+
idx_t k,
|
|
374
|
+
float* distances,
|
|
375
|
+
idx_t* labels,
|
|
376
|
+
const IDSelector* sel) {
|
|
377
|
+
using HeapHC = HeapHandler<C, true>;
|
|
378
|
+
using ReservoirHC = ReservoirHandler<C, true>;
|
|
379
|
+
using SingleResultHC = SingleResultHandler<C, true>;
|
|
380
|
+
|
|
381
|
+
if (k == 1) {
|
|
382
|
+
return new SingleResultHC(n, 0, distances, labels, sel);
|
|
383
|
+
} else if (impl % 2 == 0) {
|
|
384
|
+
return new HeapHC(n, 0, k, distances, labels, sel);
|
|
385
|
+
} else /* if (impl % 2 == 1) */ {
|
|
386
|
+
return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel);
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
SIMDResultHandlerToFloat* make_knn_handler(
|
|
391
|
+
bool is_max,
|
|
392
|
+
int impl,
|
|
393
|
+
idx_t n,
|
|
394
|
+
idx_t k,
|
|
395
|
+
float* distances,
|
|
396
|
+
idx_t* labels,
|
|
397
|
+
const IDSelector* sel) {
|
|
398
|
+
if (is_max) {
|
|
399
|
+
return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
|
|
400
|
+
impl, n, k, distances, labels, sel);
|
|
315
401
|
} else {
|
|
316
|
-
|
|
402
|
+
return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
|
|
403
|
+
impl, n, k, distances, labels, sel);
|
|
317
404
|
}
|
|
318
405
|
}
|
|
319
406
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
407
|
+
using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
|
|
408
|
+
|
|
409
|
+
struct CoarseQuantizedWithBuffer : CoarseQuantized {
|
|
410
|
+
explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq)
|
|
411
|
+
: CoarseQuantized(cq) {}
|
|
412
|
+
|
|
413
|
+
bool done() const {
|
|
414
|
+
return ids != nullptr;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
std::vector<idx_t> ids_buffer;
|
|
418
|
+
std::vector<float> dis_buffer;
|
|
419
|
+
|
|
420
|
+
void quantize(
|
|
421
|
+
const Index* quantizer,
|
|
422
|
+
idx_t n,
|
|
423
|
+
const float* x,
|
|
424
|
+
const SearchParameters* quantizer_params) {
|
|
425
|
+
dis_buffer.resize(nprobe * n);
|
|
426
|
+
ids_buffer.resize(nprobe * n);
|
|
427
|
+
quantizer->search(
|
|
428
|
+
n,
|
|
429
|
+
x,
|
|
430
|
+
nprobe,
|
|
431
|
+
dis_buffer.data(),
|
|
432
|
+
ids_buffer.data(),
|
|
433
|
+
quantizer_params);
|
|
434
|
+
dis = dis_buffer.data();
|
|
435
|
+
ids = ids_buffer.data();
|
|
436
|
+
}
|
|
437
|
+
};
|
|
438
|
+
|
|
439
|
+
struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
|
|
440
|
+
size_t i0, i1;
|
|
441
|
+
CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
|
|
442
|
+
: CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
|
|
443
|
+
if (done()) {
|
|
444
|
+
dis += nprobe * i0;
|
|
445
|
+
ids += nprobe * i0;
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
void quantize_slice(
|
|
450
|
+
const Index* quantizer,
|
|
451
|
+
const float* x,
|
|
452
|
+
const SearchParameters* quantizer_params) {
|
|
453
|
+
quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params);
|
|
454
|
+
}
|
|
455
|
+
};
|
|
456
|
+
|
|
457
|
+
int compute_search_nslice(
|
|
458
|
+
const IndexIVFFastScan* index,
|
|
459
|
+
size_t n,
|
|
460
|
+
size_t nprobe) {
|
|
461
|
+
int nslice;
|
|
462
|
+
if (n <= omp_get_max_threads()) {
|
|
463
|
+
nslice = n;
|
|
464
|
+
} else if (index->lookup_table_is_3d()) {
|
|
465
|
+
// make sure we don't make too big LUT tables
|
|
466
|
+
size_t lut_size_per_query = index->M * index->ksub * nprobe *
|
|
467
|
+
(sizeof(float) + sizeof(uint8_t));
|
|
468
|
+
|
|
469
|
+
size_t max_lut_size = precomputed_table_max_bytes;
|
|
470
|
+
// how many queries we can handle within mem budget
|
|
471
|
+
size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
472
|
+
nslice = roundup(
|
|
473
|
+
std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
|
|
474
|
+
} else {
|
|
475
|
+
// LUTs unlikely to be a limiting factor
|
|
476
|
+
nslice = omp_get_max_threads();
|
|
477
|
+
}
|
|
478
|
+
return nslice;
|
|
327
479
|
}
|
|
328
480
|
|
|
329
|
-
|
|
481
|
+
} // namespace
|
|
482
|
+
|
|
330
483
|
void IndexIVFFastScan::search_dispatch_implem(
|
|
331
484
|
idx_t n,
|
|
332
485
|
const float* x,
|
|
333
486
|
idx_t k,
|
|
334
487
|
float* distances,
|
|
335
488
|
idx_t* labels,
|
|
336
|
-
const
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
489
|
+
const CoarseQuantized& cq_in,
|
|
490
|
+
const NormTableScaler* scaler,
|
|
491
|
+
const IVFSearchParameters* params) const {
|
|
492
|
+
const idx_t nprobe = params ? params->nprobe : this->nprobe;
|
|
493
|
+
const IDSelector* sel = (params) ? params->sel : nullptr;
|
|
494
|
+
const SearchParameters* quantizer_params =
|
|
495
|
+
params ? params->quantizer_params : nullptr;
|
|
341
496
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
CMax<uint16_t, int64_t>,
|
|
345
|
-
CMin<uint16_t, int64_t>>::type;
|
|
497
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
498
|
+
using RH = SIMDResultHandlerToFloat;
|
|
346
499
|
|
|
347
500
|
if (n == 0) {
|
|
348
501
|
return;
|
|
@@ -357,70 +510,93 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
357
510
|
} else {
|
|
358
511
|
impl = 10;
|
|
359
512
|
}
|
|
360
|
-
if (k > 20) {
|
|
513
|
+
if (k > 20) { // use reservoir rather than heap
|
|
361
514
|
impl++;
|
|
362
515
|
}
|
|
363
516
|
}
|
|
364
517
|
|
|
518
|
+
bool multiple_threads =
|
|
519
|
+
n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
|
|
520
|
+
if (impl >= 100) {
|
|
521
|
+
multiple_threads = false;
|
|
522
|
+
impl -= 100;
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
CoarseQuantizedWithBuffer cq(cq_in);
|
|
526
|
+
cq.nprobe = nprobe;
|
|
527
|
+
|
|
528
|
+
if (!cq.done() && !multiple_threads) {
|
|
529
|
+
// we do the coarse quantization here execpt when search is
|
|
530
|
+
// sliced over threads (then it is more efficient to have each thread do
|
|
531
|
+
// its own coarse quantization)
|
|
532
|
+
cq.quantize(quantizer, n, x, quantizer_params);
|
|
533
|
+
invlists->prefetch_lists(cq.ids, n * cq.nprobe);
|
|
534
|
+
}
|
|
535
|
+
|
|
365
536
|
if (impl == 1) {
|
|
366
|
-
|
|
537
|
+
if (is_max) {
|
|
538
|
+
search_implem_1<CMax<float, int64_t>>(
|
|
539
|
+
n, x, k, distances, labels, cq, scaler, params);
|
|
540
|
+
} else {
|
|
541
|
+
search_implem_1<CMin<float, int64_t>>(
|
|
542
|
+
n, x, k, distances, labels, cq, scaler, params);
|
|
543
|
+
}
|
|
367
544
|
} else if (impl == 2) {
|
|
368
|
-
|
|
369
|
-
|
|
545
|
+
if (is_max) {
|
|
546
|
+
search_implem_2<CMax<uint16_t, int64_t>>(
|
|
547
|
+
n, x, k, distances, labels, cq, scaler, params);
|
|
548
|
+
} else {
|
|
549
|
+
search_implem_2<CMin<uint16_t, int64_t>>(
|
|
550
|
+
n, x, k, distances, labels, cq, scaler, params);
|
|
551
|
+
}
|
|
370
552
|
} else if (impl >= 10 && impl <= 15) {
|
|
371
553
|
size_t ndis = 0, nlist_visited = 0;
|
|
372
554
|
|
|
373
|
-
if (
|
|
555
|
+
if (!multiple_threads) {
|
|
556
|
+
// clang-format off
|
|
374
557
|
if (impl == 12 || impl == 13) {
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
558
|
+
std::unique_ptr<RH> handler(
|
|
559
|
+
make_knn_handler(
|
|
560
|
+
is_max,
|
|
561
|
+
impl,
|
|
562
|
+
n,
|
|
563
|
+
k,
|
|
564
|
+
distances,
|
|
565
|
+
labels, sel
|
|
566
|
+
)
|
|
567
|
+
);
|
|
568
|
+
search_implem_12(
|
|
569
|
+
n, x, *handler.get(),
|
|
570
|
+
cq, &ndis, &nlist_visited, scaler, params);
|
|
385
571
|
} else if (impl == 14 || impl == 15) {
|
|
386
|
-
search_implem_14
|
|
572
|
+
search_implem_14(
|
|
573
|
+
n, x, k, distances, labels,
|
|
574
|
+
cq, impl, scaler, params);
|
|
387
575
|
} else {
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
576
|
+
std::unique_ptr<RH> handler(
|
|
577
|
+
make_knn_handler(
|
|
578
|
+
is_max,
|
|
579
|
+
impl,
|
|
580
|
+
n,
|
|
581
|
+
k,
|
|
582
|
+
distances,
|
|
393
583
|
labels,
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
584
|
+
sel
|
|
585
|
+
)
|
|
586
|
+
);
|
|
587
|
+
search_implem_10(
|
|
588
|
+
n, x, *handler.get(), cq,
|
|
589
|
+
&ndis, &nlist_visited, scaler, params);
|
|
398
590
|
}
|
|
591
|
+
// clang-format on
|
|
399
592
|
} else {
|
|
400
593
|
// explicitly slice over threads
|
|
401
|
-
int nslice;
|
|
402
|
-
if (
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
|
|
408
|
-
|
|
409
|
-
size_t max_lut_size = precomputed_table_max_bytes;
|
|
410
|
-
// how many queries we can handle within mem budget
|
|
411
|
-
size_t nq_ok =
|
|
412
|
-
std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
413
|
-
nslice =
|
|
414
|
-
roundup(std::max(size_t(n / nq_ok), size_t(1)),
|
|
415
|
-
omp_get_max_threads());
|
|
416
|
-
} else {
|
|
417
|
-
// LUTs unlikely to be a limiting factor
|
|
418
|
-
nslice = omp_get_max_threads();
|
|
419
|
-
}
|
|
420
|
-
if (impl == 14 ||
|
|
421
|
-
impl == 15) { // this might require slicing if there are too
|
|
422
|
-
// many queries (for now we keep this simple)
|
|
423
|
-
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
|
|
594
|
+
int nslice = compute_search_nslice(this, n, cq.nprobe);
|
|
595
|
+
if (impl == 14 || impl == 15) {
|
|
596
|
+
// this might require slicing if there are too
|
|
597
|
+
// many queries (for now we keep this simple)
|
|
598
|
+
search_implem_14(
|
|
599
|
+
n, x, k, distances, labels, cq, impl, scaler, params);
|
|
424
600
|
} else {
|
|
425
601
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
426
602
|
for (int slice = 0; slice < nslice; slice++) {
|
|
@@ -428,29 +604,23 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
428
604
|
idx_t i1 = n * (slice + 1) / nslice;
|
|
429
605
|
float* dis_i = distances + i0 * k;
|
|
430
606
|
idx_t* lab_i = labels + i0 * k;
|
|
607
|
+
CoarseQuantizedSlice cq_i(cq, i0, i1);
|
|
608
|
+
if (!cq_i.done()) {
|
|
609
|
+
cq_i.quantize_slice(quantizer, x, quantizer_params);
|
|
610
|
+
}
|
|
611
|
+
std::unique_ptr<RH> handler(make_knn_handler(
|
|
612
|
+
is_max, impl, i1 - i0, k, dis_i, lab_i, sel));
|
|
613
|
+
// clang-format off
|
|
431
614
|
if (impl == 12 || impl == 13) {
|
|
432
|
-
search_implem_12
|
|
433
|
-
i1 - i0,
|
|
434
|
-
|
|
435
|
-
k,
|
|
436
|
-
dis_i,
|
|
437
|
-
lab_i,
|
|
438
|
-
impl,
|
|
439
|
-
&ndis,
|
|
440
|
-
&nlist_visited,
|
|
441
|
-
scaler);
|
|
615
|
+
search_implem_12(
|
|
616
|
+
i1 - i0, x + i0 * d, *handler.get(),
|
|
617
|
+
cq_i, &ndis, &nlist_visited, scaler, params);
|
|
442
618
|
} else {
|
|
443
|
-
search_implem_10
|
|
444
|
-
i1 - i0,
|
|
445
|
-
|
|
446
|
-
k,
|
|
447
|
-
dis_i,
|
|
448
|
-
lab_i,
|
|
449
|
-
impl,
|
|
450
|
-
&ndis,
|
|
451
|
-
&nlist_visited,
|
|
452
|
-
scaler);
|
|
619
|
+
search_implem_10(
|
|
620
|
+
i1 - i0, x + i0 * d, *handler.get(),
|
|
621
|
+
cq_i, &ndis, &nlist_visited, scaler, params);
|
|
453
622
|
}
|
|
623
|
+
// clang-format on
|
|
454
624
|
}
|
|
455
625
|
}
|
|
456
626
|
}
|
|
@@ -462,31 +632,149 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
462
632
|
}
|
|
463
633
|
}
|
|
464
634
|
|
|
465
|
-
|
|
635
|
+
void IndexIVFFastScan::range_search_dispatch_implem(
|
|
636
|
+
idx_t n,
|
|
637
|
+
const float* x,
|
|
638
|
+
float radius,
|
|
639
|
+
RangeSearchResult& rres,
|
|
640
|
+
const CoarseQuantized& cq_in,
|
|
641
|
+
const NormTableScaler* scaler,
|
|
642
|
+
const IVFSearchParameters* params) const {
|
|
643
|
+
// const idx_t nprobe = params ? params->nprobe : this->nprobe;
|
|
644
|
+
const IDSelector* sel = (params) ? params->sel : nullptr;
|
|
645
|
+
const SearchParameters* quantizer_params =
|
|
646
|
+
params ? params->quantizer_params : nullptr;
|
|
647
|
+
|
|
648
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
649
|
+
|
|
650
|
+
if (n == 0) {
|
|
651
|
+
return;
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
// actual implementation used
|
|
655
|
+
int impl = implem;
|
|
656
|
+
|
|
657
|
+
if (impl == 0) {
|
|
658
|
+
if (bbs == 32) {
|
|
659
|
+
impl = 12;
|
|
660
|
+
} else {
|
|
661
|
+
impl = 10;
|
|
662
|
+
}
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
CoarseQuantizedWithBuffer cq(cq_in);
|
|
666
|
+
|
|
667
|
+
bool multiple_threads =
|
|
668
|
+
n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
|
|
669
|
+
if (impl >= 100) {
|
|
670
|
+
multiple_threads = false;
|
|
671
|
+
impl -= 100;
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
if (!multiple_threads && !cq.done()) {
|
|
675
|
+
cq.quantize(quantizer, n, x, quantizer_params);
|
|
676
|
+
invlists->prefetch_lists(cq.ids, n * cq.nprobe);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
680
|
+
|
|
681
|
+
if (!multiple_threads) { // single thread
|
|
682
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler;
|
|
683
|
+
if (is_max) {
|
|
684
|
+
handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
|
|
685
|
+
rres, radius, 0, sel));
|
|
686
|
+
} else {
|
|
687
|
+
handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
|
|
688
|
+
rres, radius, 0, sel));
|
|
689
|
+
}
|
|
690
|
+
if (impl == 12) {
|
|
691
|
+
search_implem_12(
|
|
692
|
+
n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
|
|
693
|
+
} else if (impl == 10) {
|
|
694
|
+
search_implem_10(
|
|
695
|
+
n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
|
|
696
|
+
} else {
|
|
697
|
+
FAISS_THROW_FMT("Range search implem %d not implemented", impl);
|
|
698
|
+
}
|
|
699
|
+
} else {
|
|
700
|
+
// explicitly slice over threads
|
|
701
|
+
int nslice = compute_search_nslice(this, n, cq.nprobe);
|
|
702
|
+
#pragma omp parallel
|
|
703
|
+
{
|
|
704
|
+
RangeSearchPartialResult pres(&rres);
|
|
705
|
+
|
|
706
|
+
#pragma omp for reduction(+ : ndis, nlist_visited)
|
|
707
|
+
for (int slice = 0; slice < nslice; slice++) {
|
|
708
|
+
idx_t i0 = n * slice / nslice;
|
|
709
|
+
idx_t i1 = n * (slice + 1) / nslice;
|
|
710
|
+
CoarseQuantizedSlice cq_i(cq, i0, i1);
|
|
711
|
+
if (!cq_i.done()) {
|
|
712
|
+
cq_i.quantize_slice(quantizer, x, quantizer_params);
|
|
713
|
+
}
|
|
714
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler;
|
|
715
|
+
if (is_max) {
|
|
716
|
+
handler.reset(new PartialRangeHandler<
|
|
717
|
+
CMax<uint16_t, int64_t>,
|
|
718
|
+
true>(pres, radius, 0, i0, i1, sel));
|
|
719
|
+
} else {
|
|
720
|
+
handler.reset(new PartialRangeHandler<
|
|
721
|
+
CMin<uint16_t, int64_t>,
|
|
722
|
+
true>(pres, radius, 0, i0, i1, sel));
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
if (impl == 12 || impl == 13) {
|
|
726
|
+
search_implem_12(
|
|
727
|
+
i1 - i0,
|
|
728
|
+
x + i0 * d,
|
|
729
|
+
*handler.get(),
|
|
730
|
+
cq_i,
|
|
731
|
+
&ndis,
|
|
732
|
+
&nlist_visited,
|
|
733
|
+
scaler,
|
|
734
|
+
params);
|
|
735
|
+
} else {
|
|
736
|
+
search_implem_10(
|
|
737
|
+
i1 - i0,
|
|
738
|
+
x + i0 * d,
|
|
739
|
+
*handler.get(),
|
|
740
|
+
cq_i,
|
|
741
|
+
&ndis,
|
|
742
|
+
&nlist_visited,
|
|
743
|
+
scaler,
|
|
744
|
+
params);
|
|
745
|
+
}
|
|
746
|
+
}
|
|
747
|
+
pres.finalize();
|
|
748
|
+
}
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
indexIVF_stats.nq += n;
|
|
752
|
+
indexIVF_stats.ndis += ndis;
|
|
753
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
template <class C>
|
|
466
757
|
void IndexIVFFastScan::search_implem_1(
|
|
467
758
|
idx_t n,
|
|
468
759
|
const float* x,
|
|
469
760
|
idx_t k,
|
|
470
761
|
float* distances,
|
|
471
762
|
idx_t* labels,
|
|
472
|
-
const
|
|
763
|
+
const CoarseQuantized& cq,
|
|
764
|
+
const NormTableScaler* scaler,
|
|
765
|
+
const IVFSearchParameters* params) const {
|
|
473
766
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
474
767
|
|
|
475
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
476
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
477
|
-
|
|
478
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
479
|
-
|
|
480
768
|
size_t dim12 = ksub * M;
|
|
481
769
|
AlignedTable<float> dis_tables;
|
|
482
770
|
AlignedTable<float> biases;
|
|
483
771
|
|
|
484
|
-
compute_LUT(n, x,
|
|
772
|
+
compute_LUT(n, x, cq, dis_tables, biases);
|
|
485
773
|
|
|
486
774
|
bool single_LUT = !lookup_table_is_3d();
|
|
487
775
|
|
|
488
776
|
size_t ndis = 0, nlist_visited = 0;
|
|
489
|
-
|
|
777
|
+
size_t nprobe = cq.nprobe;
|
|
490
778
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
491
779
|
for (idx_t i = 0; i < n; i++) {
|
|
492
780
|
int64_t* heap_ids = labels + i * k;
|
|
@@ -501,7 +789,7 @@ void IndexIVFFastScan::search_implem_1(
|
|
|
501
789
|
if (!single_LUT) {
|
|
502
790
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
503
791
|
}
|
|
504
|
-
idx_t list_no =
|
|
792
|
+
idx_t list_no = cq.ids[i * nprobe + j];
|
|
505
793
|
if (list_no < 0)
|
|
506
794
|
continue;
|
|
507
795
|
size_t ls = orig_invlists->list_size(list_no);
|
|
@@ -533,38 +821,29 @@ void IndexIVFFastScan::search_implem_1(
|
|
|
533
821
|
indexIVF_stats.nlist += nlist_visited;
|
|
534
822
|
}
|
|
535
823
|
|
|
536
|
-
template <class C
|
|
824
|
+
template <class C>
|
|
537
825
|
void IndexIVFFastScan::search_implem_2(
|
|
538
826
|
idx_t n,
|
|
539
827
|
const float* x,
|
|
540
828
|
idx_t k,
|
|
541
829
|
float* distances,
|
|
542
830
|
idx_t* labels,
|
|
543
|
-
const
|
|
831
|
+
const CoarseQuantized& cq,
|
|
832
|
+
const NormTableScaler* scaler,
|
|
833
|
+
const IVFSearchParameters* params) const {
|
|
544
834
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
545
835
|
|
|
546
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
547
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
548
|
-
|
|
549
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
550
|
-
|
|
551
836
|
size_t dim12 = ksub * M2;
|
|
552
837
|
AlignedTable<uint8_t> dis_tables;
|
|
553
838
|
AlignedTable<uint16_t> biases;
|
|
554
839
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
555
840
|
|
|
556
|
-
compute_LUT_uint8(
|
|
557
|
-
n,
|
|
558
|
-
x,
|
|
559
|
-
coarse_ids.get(),
|
|
560
|
-
coarse_dis.get(),
|
|
561
|
-
dis_tables,
|
|
562
|
-
biases,
|
|
563
|
-
normalizers.get());
|
|
841
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
564
842
|
|
|
565
843
|
bool single_LUT = !lookup_table_is_3d();
|
|
566
844
|
|
|
567
845
|
size_t ndis = 0, nlist_visited = 0;
|
|
846
|
+
size_t nprobe = cq.nprobe;
|
|
568
847
|
|
|
569
848
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
570
849
|
for (idx_t i = 0; i < n; i++) {
|
|
@@ -581,7 +860,7 @@ void IndexIVFFastScan::search_implem_2(
|
|
|
581
860
|
if (!single_LUT) {
|
|
582
861
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
583
862
|
}
|
|
584
|
-
idx_t list_no =
|
|
863
|
+
idx_t list_no = cq.ids[i * nprobe + j];
|
|
585
864
|
if (list_no < 0)
|
|
586
865
|
continue;
|
|
587
866
|
size_t ls = orig_invlists->list_size(list_no);
|
|
@@ -626,171 +905,103 @@ void IndexIVFFastScan::search_implem_2(
|
|
|
626
905
|
indexIVF_stats.nlist += nlist_visited;
|
|
627
906
|
}
|
|
628
907
|
|
|
629
|
-
template <class C, class Scaler>
|
|
630
908
|
void IndexIVFFastScan::search_implem_10(
|
|
631
909
|
idx_t n,
|
|
632
910
|
const float* x,
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
idx_t* labels,
|
|
636
|
-
int impl,
|
|
911
|
+
SIMDResultHandlerToFloat& handler,
|
|
912
|
+
const CoarseQuantized& cq,
|
|
637
913
|
size_t* ndis_out,
|
|
638
914
|
size_t* nlist_out,
|
|
639
|
-
const
|
|
640
|
-
|
|
641
|
-
memset(labels, -1, sizeof(idx_t) * k * n);
|
|
642
|
-
|
|
643
|
-
using HeapHC = HeapHandler<C, true>;
|
|
644
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
645
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
646
|
-
|
|
647
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
648
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
649
|
-
|
|
650
|
-
uint64_t times[10];
|
|
651
|
-
memset(times, 0, sizeof(times));
|
|
652
|
-
int ti = 0;
|
|
653
|
-
#define TIC times[ti++] = get_cy()
|
|
654
|
-
TIC;
|
|
655
|
-
|
|
656
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
657
|
-
|
|
658
|
-
TIC;
|
|
659
|
-
|
|
915
|
+
const NormTableScaler* scaler,
|
|
916
|
+
const IVFSearchParameters* params) const {
|
|
660
917
|
size_t dim12 = ksub * M2;
|
|
661
918
|
AlignedTable<uint8_t> dis_tables;
|
|
662
919
|
AlignedTable<uint16_t> biases;
|
|
663
920
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
664
921
|
|
|
665
|
-
compute_LUT_uint8(
|
|
666
|
-
n,
|
|
667
|
-
x,
|
|
668
|
-
coarse_ids.get(),
|
|
669
|
-
coarse_dis.get(),
|
|
670
|
-
dis_tables,
|
|
671
|
-
biases,
|
|
672
|
-
normalizers.get());
|
|
673
|
-
|
|
674
|
-
TIC;
|
|
922
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
675
923
|
|
|
676
924
|
bool single_LUT = !lookup_table_is_3d();
|
|
677
925
|
|
|
678
|
-
|
|
679
|
-
|
|
926
|
+
size_t ndis = 0;
|
|
927
|
+
int qmap1[1];
|
|
680
928
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
const uint8_t* LUT = nullptr;
|
|
685
|
-
int qmap1[1] = {0};
|
|
686
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
687
|
-
|
|
688
|
-
if (k == 1) {
|
|
689
|
-
handler.reset(new SingleResultHC(1, 0));
|
|
690
|
-
} else if (impl == 10) {
|
|
691
|
-
handler.reset(new HeapHC(
|
|
692
|
-
1, tmp_distances.get(), labels + i * k, k, 0));
|
|
693
|
-
} else if (impl == 11) {
|
|
694
|
-
handler.reset(new ReservoirHC(1, 0, k, 2 * k));
|
|
695
|
-
} else {
|
|
696
|
-
FAISS_THROW_MSG("invalid");
|
|
697
|
-
}
|
|
929
|
+
handler.q_map = qmap1;
|
|
930
|
+
handler.begin(skip & 16 ? nullptr : normalizers.get());
|
|
931
|
+
size_t nprobe = cq.nprobe;
|
|
698
932
|
|
|
699
|
-
|
|
933
|
+
for (idx_t i = 0; i < n; i++) {
|
|
934
|
+
const uint8_t* LUT = nullptr;
|
|
935
|
+
qmap1[0] = i;
|
|
700
936
|
|
|
701
|
-
|
|
702
|
-
|
|
937
|
+
if (single_LUT) {
|
|
938
|
+
LUT = dis_tables.get() + i * dim12;
|
|
939
|
+
}
|
|
940
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
941
|
+
size_t ij = i * nprobe + j;
|
|
942
|
+
if (!single_LUT) {
|
|
943
|
+
LUT = dis_tables.get() + ij * dim12;
|
|
944
|
+
}
|
|
945
|
+
if (biases.get()) {
|
|
946
|
+
handler.dbias = biases.get() + ij;
|
|
703
947
|
}
|
|
704
|
-
for (idx_t j = 0; j < nprobe; j++) {
|
|
705
|
-
size_t ij = i * nprobe + j;
|
|
706
|
-
if (!single_LUT) {
|
|
707
|
-
LUT = dis_tables.get() + ij * dim12;
|
|
708
|
-
}
|
|
709
|
-
if (biases.get()) {
|
|
710
|
-
handler->dbias = biases.get() + ij;
|
|
711
|
-
}
|
|
712
|
-
|
|
713
|
-
idx_t list_no = coarse_ids[ij];
|
|
714
|
-
if (list_no < 0)
|
|
715
|
-
continue;
|
|
716
|
-
size_t ls = invlists->list_size(list_no);
|
|
717
|
-
if (ls == 0)
|
|
718
|
-
continue;
|
|
719
948
|
|
|
720
|
-
|
|
721
|
-
|
|
949
|
+
idx_t list_no = cq.ids[ij];
|
|
950
|
+
if (list_no < 0) {
|
|
951
|
+
continue;
|
|
952
|
+
}
|
|
953
|
+
size_t ls = invlists->list_size(list_no);
|
|
954
|
+
if (ls == 0) {
|
|
955
|
+
continue;
|
|
956
|
+
}
|
|
722
957
|
|
|
723
|
-
|
|
724
|
-
|
|
958
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
959
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
725
960
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
729
|
-
pq4_accumulate_loop( \
|
|
730
|
-
1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
|
|
731
|
-
}
|
|
732
|
-
DISPATCH(HeapHC)
|
|
733
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
734
|
-
#undef DISPATCH
|
|
961
|
+
handler.ntotal = ls;
|
|
962
|
+
handler.id_map = ids.get();
|
|
735
963
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
964
|
+
pq4_accumulate_loop(
|
|
965
|
+
1,
|
|
966
|
+
roundup(ls, bbs),
|
|
967
|
+
bbs,
|
|
968
|
+
M2,
|
|
969
|
+
codes.get(),
|
|
970
|
+
LUT,
|
|
971
|
+
handler,
|
|
972
|
+
scaler);
|
|
739
973
|
|
|
740
|
-
|
|
741
|
-
distances + i * k,
|
|
742
|
-
labels + i * k,
|
|
743
|
-
skip & 16 ? nullptr : normalizers.get() + i * 2);
|
|
974
|
+
ndis++;
|
|
744
975
|
}
|
|
745
976
|
}
|
|
977
|
+
|
|
978
|
+
handler.end();
|
|
746
979
|
*ndis_out = ndis;
|
|
747
980
|
*nlist_out = nlist;
|
|
748
981
|
}
|
|
749
982
|
|
|
750
|
-
template <class C, class Scaler>
|
|
751
983
|
void IndexIVFFastScan::search_implem_12(
|
|
752
984
|
idx_t n,
|
|
753
985
|
const float* x,
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
idx_t* labels,
|
|
757
|
-
int impl,
|
|
986
|
+
SIMDResultHandlerToFloat& handler,
|
|
987
|
+
const CoarseQuantized& cq,
|
|
758
988
|
size_t* ndis_out,
|
|
759
989
|
size_t* nlist_out,
|
|
760
|
-
const
|
|
990
|
+
const NormTableScaler* scaler,
|
|
991
|
+
const IVFSearchParameters* params) const {
|
|
761
992
|
if (n == 0) { // does not work well with reservoir
|
|
762
993
|
return;
|
|
763
994
|
}
|
|
764
995
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
765
996
|
|
|
766
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
767
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
768
|
-
|
|
769
|
-
uint64_t times[10];
|
|
770
|
-
memset(times, 0, sizeof(times));
|
|
771
|
-
int ti = 0;
|
|
772
|
-
#define TIC times[ti++] = get_cy()
|
|
773
|
-
TIC;
|
|
774
|
-
|
|
775
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
776
|
-
|
|
777
|
-
TIC;
|
|
778
|
-
|
|
779
997
|
size_t dim12 = ksub * M2;
|
|
780
998
|
AlignedTable<uint8_t> dis_tables;
|
|
781
999
|
AlignedTable<uint16_t> biases;
|
|
782
1000
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
783
1001
|
|
|
784
|
-
compute_LUT_uint8(
|
|
785
|
-
n,
|
|
786
|
-
x,
|
|
787
|
-
coarse_ids.get(),
|
|
788
|
-
coarse_dis.get(),
|
|
789
|
-
dis_tables,
|
|
790
|
-
biases,
|
|
791
|
-
normalizers.get());
|
|
1002
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
792
1003
|
|
|
793
|
-
|
|
1004
|
+
handler.begin(skip & 16 ? nullptr : normalizers.get());
|
|
794
1005
|
|
|
795
1006
|
struct QC {
|
|
796
1007
|
int qno; // sequence number of the query
|
|
@@ -798,14 +1009,15 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
798
1009
|
int rank; // this is the rank'th result of the coarse quantizer
|
|
799
1010
|
};
|
|
800
1011
|
bool single_LUT = !lookup_table_is_3d();
|
|
1012
|
+
size_t nprobe = cq.nprobe;
|
|
801
1013
|
|
|
802
1014
|
std::vector<QC> qcs;
|
|
803
1015
|
{
|
|
804
1016
|
int ij = 0;
|
|
805
1017
|
for (int i = 0; i < n; i++) {
|
|
806
1018
|
for (int j = 0; j < nprobe; j++) {
|
|
807
|
-
if (
|
|
808
|
-
qcs.push_back(QC{i, int(
|
|
1019
|
+
if (cq.ids[ij] >= 0) {
|
|
1020
|
+
qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
|
|
809
1021
|
}
|
|
810
1022
|
ij++;
|
|
811
1023
|
}
|
|
@@ -814,42 +1026,22 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
814
1026
|
return a.list_no < b.list_no;
|
|
815
1027
|
});
|
|
816
1028
|
}
|
|
817
|
-
TIC;
|
|
818
1029
|
|
|
819
1030
|
// prepare the result handlers
|
|
820
1031
|
|
|
821
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
822
|
-
AlignedTable<uint16_t> tmp_distances;
|
|
823
|
-
|
|
824
|
-
using HeapHC = HeapHandler<C, true>;
|
|
825
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
826
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
827
|
-
|
|
828
|
-
if (k == 1) {
|
|
829
|
-
handler.reset(new SingleResultHC(n, 0));
|
|
830
|
-
} else if (impl == 12) {
|
|
831
|
-
tmp_distances.resize(n * k);
|
|
832
|
-
handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
|
|
833
|
-
} else if (impl == 13) {
|
|
834
|
-
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
835
|
-
}
|
|
836
|
-
|
|
837
1032
|
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
838
1033
|
|
|
839
1034
|
std::vector<uint16_t> tmp_bias;
|
|
840
1035
|
if (biases.get()) {
|
|
841
1036
|
tmp_bias.resize(qbs2);
|
|
842
|
-
handler
|
|
1037
|
+
handler.dbias = tmp_bias.data();
|
|
843
1038
|
}
|
|
844
|
-
TIC;
|
|
845
1039
|
|
|
846
1040
|
size_t ndis = 0;
|
|
847
1041
|
|
|
848
1042
|
size_t i0 = 0;
|
|
849
1043
|
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
850
1044
|
while (i0 < qcs.size()) {
|
|
851
|
-
uint64_t tt0 = get_cy();
|
|
852
|
-
|
|
853
1045
|
// find all queries that access this inverted list
|
|
854
1046
|
int list_no = qcs[i0].list_no;
|
|
855
1047
|
size_t i1 = i0 + 1;
|
|
@@ -897,93 +1089,50 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
897
1089
|
|
|
898
1090
|
// prepare the handler
|
|
899
1091
|
|
|
900
|
-
handler
|
|
901
|
-
handler
|
|
902
|
-
handler
|
|
903
|
-
uint64_t tt1 = get_cy();
|
|
1092
|
+
handler.ntotal = list_size;
|
|
1093
|
+
handler.q_map = q_map.data();
|
|
1094
|
+
handler.id_map = ids.get();
|
|
904
1095
|
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
910
|
-
}
|
|
911
|
-
DISPATCH(HeapHC)
|
|
912
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
913
|
-
|
|
914
|
-
// prepare for next loop
|
|
915
|
-
i0 = i1;
|
|
916
|
-
|
|
917
|
-
uint64_t tt2 = get_cy();
|
|
918
|
-
t_copy_pack += tt1 - tt0;
|
|
919
|
-
t_scan += tt2 - tt1;
|
|
1096
|
+
pq4_accumulate_loop_qbs(
|
|
1097
|
+
qbs, list_size, M2, codes.get(), LUT.get(), handler, scaler);
|
|
1098
|
+
// prepare for next loop
|
|
1099
|
+
i0 = i1;
|
|
920
1100
|
}
|
|
921
|
-
TIC;
|
|
922
1101
|
|
|
923
|
-
|
|
924
|
-
handler->to_flat_arrays(
|
|
925
|
-
distances, labels, skip & 16 ? nullptr : normalizers.get());
|
|
926
|
-
|
|
927
|
-
TIC;
|
|
1102
|
+
handler.end();
|
|
928
1103
|
|
|
929
1104
|
// these stats are not thread-safe
|
|
930
1105
|
|
|
931
|
-
for (int i = 1; i < ti; i++) {
|
|
932
|
-
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
|
|
933
|
-
}
|
|
934
1106
|
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
935
1107
|
IVFFastScan_stats.t_scan += t_scan;
|
|
936
1108
|
|
|
937
|
-
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
938
|
-
for (int i = 0; i < 4; i++) {
|
|
939
|
-
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
940
|
-
}
|
|
941
|
-
}
|
|
942
|
-
|
|
943
1109
|
*ndis_out = ndis;
|
|
944
1110
|
*nlist_out = nlist;
|
|
945
1111
|
}
|
|
946
1112
|
|
|
947
|
-
template <class C, class Scaler>
|
|
948
1113
|
void IndexIVFFastScan::search_implem_14(
|
|
949
1114
|
idx_t n,
|
|
950
1115
|
const float* x,
|
|
951
1116
|
idx_t k,
|
|
952
1117
|
float* distances,
|
|
953
1118
|
idx_t* labels,
|
|
1119
|
+
const CoarseQuantized& cq,
|
|
954
1120
|
int impl,
|
|
955
|
-
const
|
|
1121
|
+
const NormTableScaler* scaler,
|
|
1122
|
+
const IVFSearchParameters* params) const {
|
|
956
1123
|
if (n == 0) { // does not work well with reservoir
|
|
957
1124
|
return;
|
|
958
1125
|
}
|
|
959
1126
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
960
1127
|
|
|
961
|
-
|
|
962
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
963
|
-
|
|
964
|
-
uint64_t ttg0 = get_cy();
|
|
965
|
-
|
|
966
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
967
|
-
|
|
968
|
-
uint64_t ttg1 = get_cy();
|
|
969
|
-
uint64_t coarse_search_tt = ttg1 - ttg0;
|
|
1128
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
|
970
1129
|
|
|
971
1130
|
size_t dim12 = ksub * M2;
|
|
972
1131
|
AlignedTable<uint8_t> dis_tables;
|
|
973
1132
|
AlignedTable<uint16_t> biases;
|
|
974
1133
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
975
1134
|
|
|
976
|
-
compute_LUT_uint8(
|
|
977
|
-
n,
|
|
978
|
-
x,
|
|
979
|
-
coarse_ids.get(),
|
|
980
|
-
coarse_dis.get(),
|
|
981
|
-
dis_tables,
|
|
982
|
-
biases,
|
|
983
|
-
normalizers.get());
|
|
984
|
-
|
|
985
|
-
uint64_t ttg2 = get_cy();
|
|
986
|
-
uint64_t lut_compute_tt = ttg2 - ttg1;
|
|
1135
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
987
1136
|
|
|
988
1137
|
struct QC {
|
|
989
1138
|
int qno; // sequence number of the query
|
|
@@ -991,14 +1140,15 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
991
1140
|
int rank; // this is the rank'th result of the coarse quantizer
|
|
992
1141
|
};
|
|
993
1142
|
bool single_LUT = !lookup_table_is_3d();
|
|
1143
|
+
size_t nprobe = cq.nprobe;
|
|
994
1144
|
|
|
995
1145
|
std::vector<QC> qcs;
|
|
996
1146
|
{
|
|
997
1147
|
int ij = 0;
|
|
998
1148
|
for (int i = 0; i < n; i++) {
|
|
999
1149
|
for (int j = 0; j < nprobe; j++) {
|
|
1000
|
-
if (
|
|
1001
|
-
qcs.push_back(QC{i, int(
|
|
1150
|
+
if (cq.ids[ij] >= 0) {
|
|
1151
|
+
qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
|
|
1002
1152
|
}
|
|
1003
1153
|
ij++;
|
|
1004
1154
|
}
|
|
@@ -1036,14 +1186,13 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1036
1186
|
ses.push_back(SE{i0_l, i1, list_size});
|
|
1037
1187
|
i0_l = i1;
|
|
1038
1188
|
}
|
|
1039
|
-
uint64_t ttg3 = get_cy();
|
|
1040
|
-
uint64_t compute_clusters_tt = ttg3 - ttg2;
|
|
1041
1189
|
|
|
1042
1190
|
// function to handle the global heap
|
|
1191
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
1043
1192
|
using HeapForIP = CMin<float, idx_t>;
|
|
1044
1193
|
using HeapForL2 = CMax<float, idx_t>;
|
|
1045
1194
|
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
1046
|
-
if (
|
|
1195
|
+
if (!is_max) {
|
|
1047
1196
|
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
1048
1197
|
} else {
|
|
1049
1198
|
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
@@ -1054,7 +1203,7 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1054
1203
|
const idx_t* local_idx,
|
|
1055
1204
|
float* simi,
|
|
1056
1205
|
idx_t* idxi) {
|
|
1057
|
-
if (
|
|
1206
|
+
if (!is_max) {
|
|
1058
1207
|
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
|
|
1059
1208
|
} else {
|
|
1060
1209
|
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
|
|
@@ -1062,14 +1211,12 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1062
1211
|
};
|
|
1063
1212
|
|
|
1064
1213
|
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
1065
|
-
if (
|
|
1214
|
+
if (!is_max) {
|
|
1066
1215
|
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
1067
1216
|
} else {
|
|
1068
1217
|
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
1069
1218
|
}
|
|
1070
1219
|
};
|
|
1071
|
-
uint64_t ttg4 = get_cy();
|
|
1072
|
-
uint64_t fn_tt = ttg4 - ttg3;
|
|
1073
1220
|
|
|
1074
1221
|
size_t ndis = 0;
|
|
1075
1222
|
size_t nlist_visited = 0;
|
|
@@ -1081,22 +1228,9 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1081
1228
|
std::vector<float> local_dis(k * n);
|
|
1082
1229
|
|
|
1083
1230
|
// prepare the result handlers
|
|
1084
|
-
std::unique_ptr<
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
using HeapHC = HeapHandler<C, true>;
|
|
1088
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
1089
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
1090
|
-
|
|
1091
|
-
if (k == 1) {
|
|
1092
|
-
handler.reset(new SingleResultHC(n, 0));
|
|
1093
|
-
} else if (impl == 14) {
|
|
1094
|
-
tmp_distances.resize(n * k);
|
|
1095
|
-
handler.reset(
|
|
1096
|
-
new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
|
|
1097
|
-
} else if (impl == 15) {
|
|
1098
|
-
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
1099
|
-
}
|
|
1231
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
|
|
1232
|
+
is_max, impl, n, k, local_dis.data(), local_idx.data(), sel));
|
|
1233
|
+
handler->begin(normalizers.get());
|
|
1100
1234
|
|
|
1101
1235
|
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
1102
1236
|
|
|
@@ -1106,14 +1240,10 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1106
1240
|
handler->dbias = tmp_bias.data();
|
|
1107
1241
|
}
|
|
1108
1242
|
|
|
1109
|
-
uint64_t ttg5 = get_cy();
|
|
1110
|
-
uint64_t handler_tt = ttg5 - ttg4;
|
|
1111
|
-
|
|
1112
1243
|
std::set<int> q_set;
|
|
1113
1244
|
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
1114
1245
|
#pragma omp for schedule(dynamic)
|
|
1115
1246
|
for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
|
|
1116
|
-
uint64_t tt0 = get_cy();
|
|
1117
1247
|
size_t i0 = ses[cluster].start;
|
|
1118
1248
|
size_t i1 = ses[cluster].end;
|
|
1119
1249
|
size_t list_size = ses[cluster].list_size;
|
|
@@ -1153,28 +1283,21 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1153
1283
|
handler->ntotal = list_size;
|
|
1154
1284
|
handler->q_map = q_map.data();
|
|
1155
1285
|
handler->id_map = ids.get();
|
|
1156
|
-
uint64_t tt1 = get_cy();
|
|
1157
|
-
|
|
1158
|
-
#define DISPATCH(classHC) \
|
|
1159
|
-
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
1160
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
1161
|
-
pq4_accumulate_loop_qbs( \
|
|
1162
|
-
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
1163
|
-
}
|
|
1164
|
-
DISPATCH(HeapHC)
|
|
1165
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
1166
1286
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1287
|
+
pq4_accumulate_loop_qbs(
|
|
1288
|
+
qbs,
|
|
1289
|
+
list_size,
|
|
1290
|
+
M2,
|
|
1291
|
+
codes.get(),
|
|
1292
|
+
LUT.get(),
|
|
1293
|
+
*handler.get(),
|
|
1294
|
+
scaler);
|
|
1170
1295
|
}
|
|
1171
1296
|
|
|
1172
1297
|
// labels is in-place for HeapHC
|
|
1173
|
-
handler->
|
|
1174
|
-
local_dis.data(),
|
|
1175
|
-
local_idx.data(),
|
|
1176
|
-
skip & 16 ? nullptr : normalizers.get());
|
|
1298
|
+
handler->end();
|
|
1177
1299
|
|
|
1300
|
+
// merge per-thread results
|
|
1178
1301
|
#pragma omp single
|
|
1179
1302
|
{
|
|
1180
1303
|
// we init the results as a heap
|
|
@@ -1197,12 +1320,6 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1197
1320
|
|
|
1198
1321
|
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
1199
1322
|
IVFFastScan_stats.t_scan += t_scan;
|
|
1200
|
-
|
|
1201
|
-
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
1202
|
-
for (int i = 0; i < 4; i++) {
|
|
1203
|
-
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
1204
|
-
}
|
|
1205
|
-
}
|
|
1206
1323
|
}
|
|
1207
1324
|
#pragma omp barrier
|
|
1208
1325
|
#pragma omp single
|
|
@@ -1272,20 +1389,4 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
|
|
|
1272
1389
|
|
|
1273
1390
|
IVFFastScanStats IVFFastScan_stats;
|
|
1274
1391
|
|
|
1275
|
-
template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
|
|
1276
|
-
idx_t n,
|
|
1277
|
-
const float* x,
|
|
1278
|
-
idx_t k,
|
|
1279
|
-
float* distances,
|
|
1280
|
-
idx_t* labels,
|
|
1281
|
-
const NormTableScaler& scaler) const;
|
|
1282
|
-
|
|
1283
|
-
template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
|
|
1284
|
-
idx_t n,
|
|
1285
|
-
const float* x,
|
|
1286
|
-
idx_t k,
|
|
1287
|
-
float* distances,
|
|
1288
|
-
idx_t* labels,
|
|
1289
|
-
const NormTableScaler& scaler) const;
|
|
1290
|
-
|
|
1291
1392
|
} // namespace faiss
|