faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -43,6 +43,8 @@ IndexIVFFastScan::IndexIVFFastScan(
|
|
|
43
43
|
size_t code_size,
|
|
44
44
|
MetricType metric)
|
|
45
45
|
: IndexIVF(quantizer, d, nlist, code_size, metric) {
|
|
46
|
+
// unlike other indexes, we prefer no residuals for performance reasons.
|
|
47
|
+
by_residual = false;
|
|
46
48
|
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
47
49
|
}
|
|
48
50
|
|
|
@@ -50,6 +52,7 @@ IndexIVFFastScan::IndexIVFFastScan() {
|
|
|
50
52
|
bbs = 0;
|
|
51
53
|
M2 = 0;
|
|
52
54
|
is_trained = false;
|
|
55
|
+
by_residual = false;
|
|
53
56
|
}
|
|
54
57
|
|
|
55
58
|
void IndexIVFFastScan::init_fastscan(
|
|
@@ -79,7 +82,7 @@ void IndexIVFFastScan::init_code_packer() {
|
|
|
79
82
|
bil->packer = get_CodePacker();
|
|
80
83
|
}
|
|
81
84
|
|
|
82
|
-
IndexIVFFastScan::~IndexIVFFastScan()
|
|
85
|
+
IndexIVFFastScan::~IndexIVFFastScan() = default;
|
|
83
86
|
|
|
84
87
|
/*********************************************************
|
|
85
88
|
* Code management functions
|
|
@@ -195,7 +198,7 @@ CodePacker* IndexIVFFastScan::get_CodePacker() const {
|
|
|
195
198
|
|
|
196
199
|
namespace {
|
|
197
200
|
|
|
198
|
-
template <class C, typename dis_t
|
|
201
|
+
template <class C, typename dis_t>
|
|
199
202
|
void estimators_from_tables_generic(
|
|
200
203
|
const IndexIVFFastScan& index,
|
|
201
204
|
const uint8_t* codes,
|
|
@@ -206,22 +209,26 @@ void estimators_from_tables_generic(
|
|
|
206
209
|
size_t k,
|
|
207
210
|
typename C::T* heap_dis,
|
|
208
211
|
int64_t* heap_ids,
|
|
209
|
-
const
|
|
212
|
+
const NormTableScaler* scaler) {
|
|
210
213
|
using accu_t = typename C::T;
|
|
214
|
+
int nscale = scaler ? scaler->nscale : 0;
|
|
211
215
|
for (size_t j = 0; j < ncodes; ++j) {
|
|
212
216
|
BitstringReader bsr(codes + j * index.code_size, index.code_size);
|
|
213
217
|
accu_t dis = bias;
|
|
214
218
|
const dis_t* __restrict dt = dis_table;
|
|
215
|
-
|
|
219
|
+
|
|
220
|
+
for (size_t m = 0; m < index.M - nscale; m++) {
|
|
216
221
|
uint64_t c = bsr.read(index.nbits);
|
|
217
222
|
dis += dt[c];
|
|
218
223
|
dt += index.ksub;
|
|
219
224
|
}
|
|
220
225
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
226
|
+
if (scaler) {
|
|
227
|
+
for (size_t m = 0; m < nscale; m++) {
|
|
228
|
+
uint64_t c = bsr.read(index.nbits);
|
|
229
|
+
dis += scaler->scale_one(dt[c]);
|
|
230
|
+
dt += index.ksub;
|
|
231
|
+
}
|
|
225
232
|
}
|
|
226
233
|
|
|
227
234
|
if (C::cmp(heap_dis[0], dis)) {
|
|
@@ -242,18 +249,15 @@ using namespace quantize_lut;
|
|
|
242
249
|
void IndexIVFFastScan::compute_LUT_uint8(
|
|
243
250
|
size_t n,
|
|
244
251
|
const float* x,
|
|
245
|
-
const
|
|
246
|
-
const float* coarse_dis,
|
|
252
|
+
const CoarseQuantized& cq,
|
|
247
253
|
AlignedTable<uint8_t>& dis_tables,
|
|
248
254
|
AlignedTable<uint16_t>& biases,
|
|
249
255
|
float* normalizers) const {
|
|
250
256
|
AlignedTable<float> dis_tables_float;
|
|
251
257
|
AlignedTable<float> biases_float;
|
|
252
258
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
|
|
256
|
-
|
|
259
|
+
compute_LUT(n, x, cq, dis_tables_float, biases_float);
|
|
260
|
+
size_t nprobe = cq.nprobe;
|
|
257
261
|
bool lut_is_3d = lookup_table_is_3d();
|
|
258
262
|
size_t dim123 = ksub * M;
|
|
259
263
|
size_t dim123_2 = ksub * M2;
|
|
@@ -265,7 +269,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
|
|
|
265
269
|
if (biases_float.get()) {
|
|
266
270
|
biases.resize(n * nprobe);
|
|
267
271
|
}
|
|
268
|
-
uint64_t t1 = get_cy();
|
|
269
272
|
|
|
270
273
|
#pragma omp parallel for if (n > 100)
|
|
271
274
|
for (int64_t i = 0; i < n; i++) {
|
|
@@ -291,7 +294,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
|
|
|
291
294
|
normalizers + 2 * i,
|
|
292
295
|
normalizers + 2 * i + 1);
|
|
293
296
|
}
|
|
294
|
-
IVFFastScan_stats.t_round += get_cy() - t1;
|
|
295
297
|
}
|
|
296
298
|
|
|
297
299
|
/*********************************************************
|
|
@@ -305,44 +307,161 @@ void IndexIVFFastScan::search(
|
|
|
305
307
|
float* distances,
|
|
306
308
|
idx_t* labels,
|
|
307
309
|
const SearchParameters* params) const {
|
|
310
|
+
auto paramsi = dynamic_cast<const SearchParametersIVF*>(params);
|
|
311
|
+
FAISS_THROW_IF_NOT_MSG(!params || paramsi, "need IVFSearchParameters");
|
|
312
|
+
search_preassigned(
|
|
313
|
+
n, x, k, nullptr, nullptr, distances, labels, false, paramsi);
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
void IndexIVFFastScan::search_preassigned(
|
|
317
|
+
idx_t n,
|
|
318
|
+
const float* x,
|
|
319
|
+
idx_t k,
|
|
320
|
+
const idx_t* assign,
|
|
321
|
+
const float* centroid_dis,
|
|
322
|
+
float* distances,
|
|
323
|
+
idx_t* labels,
|
|
324
|
+
bool store_pairs,
|
|
325
|
+
const IVFSearchParameters* params,
|
|
326
|
+
IndexIVFStats* stats) const {
|
|
327
|
+
size_t nprobe = this->nprobe;
|
|
328
|
+
if (params) {
|
|
329
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
330
|
+
!params->quantizer_params, "quantizer params not supported");
|
|
331
|
+
FAISS_THROW_IF_NOT(params->max_codes == 0);
|
|
332
|
+
nprobe = params->nprobe;
|
|
333
|
+
}
|
|
308
334
|
FAISS_THROW_IF_NOT_MSG(
|
|
309
|
-
!
|
|
335
|
+
!store_pairs, "store_pairs not supported for this index");
|
|
336
|
+
FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
|
|
310
337
|
FAISS_THROW_IF_NOT(k > 0);
|
|
311
338
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
339
|
+
const CoarseQuantized cq = {nprobe, centroid_dis, assign};
|
|
340
|
+
search_dispatch_implem(n, x, k, distances, labels, cq, nullptr);
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
void IndexIVFFastScan::range_search(
|
|
344
|
+
idx_t n,
|
|
345
|
+
const float* x,
|
|
346
|
+
float radius,
|
|
347
|
+
RangeSearchResult* result,
|
|
348
|
+
const SearchParameters* params) const {
|
|
349
|
+
FAISS_THROW_IF_NOT(!params);
|
|
350
|
+
const CoarseQuantized cq = {nprobe, nullptr, nullptr};
|
|
351
|
+
range_search_dispatch_implem(n, x, radius, *result, cq, nullptr);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
namespace {
|
|
355
|
+
|
|
356
|
+
template <class C>
|
|
357
|
+
ResultHandlerCompare<C, true>* make_knn_handler_fixC(
|
|
358
|
+
int impl,
|
|
359
|
+
idx_t n,
|
|
360
|
+
idx_t k,
|
|
361
|
+
float* distances,
|
|
362
|
+
idx_t* labels) {
|
|
363
|
+
using HeapHC = HeapHandler<C, true>;
|
|
364
|
+
using ReservoirHC = ReservoirHandler<C, true>;
|
|
365
|
+
using SingleResultHC = SingleResultHandler<C, true>;
|
|
366
|
+
|
|
367
|
+
if (k == 1) {
|
|
368
|
+
return new SingleResultHC(n, 0, distances, labels);
|
|
369
|
+
} else if (impl % 2 == 0) {
|
|
370
|
+
return new HeapHC(n, 0, k, distances, labels);
|
|
371
|
+
} else /* if (impl % 2 == 1) */ {
|
|
372
|
+
return new ReservoirHC(n, 0, k, 2 * k, distances, labels);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
SIMDResultHandlerToFloat* make_knn_handler(
|
|
377
|
+
bool is_max,
|
|
378
|
+
int impl,
|
|
379
|
+
idx_t n,
|
|
380
|
+
idx_t k,
|
|
381
|
+
float* distances,
|
|
382
|
+
idx_t* labels) {
|
|
383
|
+
if (is_max) {
|
|
384
|
+
return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
|
|
385
|
+
impl, n, k, distances, labels);
|
|
315
386
|
} else {
|
|
316
|
-
|
|
387
|
+
return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
|
|
388
|
+
impl, n, k, distances, labels);
|
|
317
389
|
}
|
|
318
390
|
}
|
|
319
391
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
392
|
+
using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
|
|
393
|
+
|
|
394
|
+
struct CoarseQuantizedWithBuffer : CoarseQuantized {
|
|
395
|
+
explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq)
|
|
396
|
+
: CoarseQuantized(cq) {}
|
|
397
|
+
|
|
398
|
+
bool done() const {
|
|
399
|
+
return ids != nullptr;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
std::vector<idx_t> ids_buffer;
|
|
403
|
+
std::vector<float> dis_buffer;
|
|
404
|
+
|
|
405
|
+
void quantize(const Index* quantizer, idx_t n, const float* x) {
|
|
406
|
+
dis_buffer.resize(nprobe * n);
|
|
407
|
+
ids_buffer.resize(nprobe * n);
|
|
408
|
+
quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data());
|
|
409
|
+
dis = dis_buffer.data();
|
|
410
|
+
ids = ids_buffer.data();
|
|
411
|
+
}
|
|
412
|
+
};
|
|
413
|
+
|
|
414
|
+
struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
|
|
415
|
+
size_t i0, i1;
|
|
416
|
+
CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
|
|
417
|
+
: CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
|
|
418
|
+
if (done()) {
|
|
419
|
+
dis += nprobe * i0;
|
|
420
|
+
ids += nprobe * i0;
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
void quantize_slice(const Index* quantizer, const float* x) {
|
|
425
|
+
quantize(quantizer, i1 - i0, x + quantizer->d * i0);
|
|
426
|
+
}
|
|
427
|
+
};
|
|
428
|
+
|
|
429
|
+
int compute_search_nslice(
|
|
430
|
+
const IndexIVFFastScan* index,
|
|
431
|
+
size_t n,
|
|
432
|
+
size_t nprobe) {
|
|
433
|
+
int nslice;
|
|
434
|
+
if (n <= omp_get_max_threads()) {
|
|
435
|
+
nslice = n;
|
|
436
|
+
} else if (index->lookup_table_is_3d()) {
|
|
437
|
+
// make sure we don't make too big LUT tables
|
|
438
|
+
size_t lut_size_per_query = index->M * index->ksub * nprobe *
|
|
439
|
+
(sizeof(float) + sizeof(uint8_t));
|
|
440
|
+
|
|
441
|
+
size_t max_lut_size = precomputed_table_max_bytes;
|
|
442
|
+
// how many queries we can handle within mem budget
|
|
443
|
+
size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
444
|
+
nslice = roundup(
|
|
445
|
+
std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
|
|
446
|
+
} else {
|
|
447
|
+
// LUTs unlikely to be a limiting factor
|
|
448
|
+
nslice = omp_get_max_threads();
|
|
449
|
+
}
|
|
450
|
+
return nslice;
|
|
327
451
|
}
|
|
328
452
|
|
|
329
|
-
|
|
453
|
+
} // namespace
|
|
454
|
+
|
|
330
455
|
void IndexIVFFastScan::search_dispatch_implem(
|
|
331
456
|
idx_t n,
|
|
332
457
|
const float* x,
|
|
333
458
|
idx_t k,
|
|
334
459
|
float* distances,
|
|
335
460
|
idx_t* labels,
|
|
336
|
-
const
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
CMin<float, int64_t>>::type;
|
|
341
|
-
|
|
342
|
-
using C = typename std::conditional<
|
|
343
|
-
is_max,
|
|
344
|
-
CMax<uint16_t, int64_t>,
|
|
345
|
-
CMin<uint16_t, int64_t>>::type;
|
|
461
|
+
const CoarseQuantized& cq_in,
|
|
462
|
+
const NormTableScaler* scaler) const {
|
|
463
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
464
|
+
using RH = SIMDResultHandlerToFloat;
|
|
346
465
|
|
|
347
466
|
if (n == 0) {
|
|
348
467
|
return;
|
|
@@ -357,70 +476,74 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
357
476
|
} else {
|
|
358
477
|
impl = 10;
|
|
359
478
|
}
|
|
360
|
-
if (k > 20) {
|
|
479
|
+
if (k > 20) { // use reservoir rather than heap
|
|
361
480
|
impl++;
|
|
362
481
|
}
|
|
363
482
|
}
|
|
364
483
|
|
|
484
|
+
bool multiple_threads =
|
|
485
|
+
n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
|
|
486
|
+
if (impl >= 100) {
|
|
487
|
+
multiple_threads = false;
|
|
488
|
+
impl -= 100;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
CoarseQuantizedWithBuffer cq(cq_in);
|
|
492
|
+
|
|
493
|
+
if (!cq.done() && !multiple_threads) {
|
|
494
|
+
// we do the coarse quantization here execpt when search is
|
|
495
|
+
// sliced over threads (then it is more efficient to have each thread do
|
|
496
|
+
// its own coarse quantization)
|
|
497
|
+
cq.quantize(quantizer, n, x);
|
|
498
|
+
}
|
|
499
|
+
|
|
365
500
|
if (impl == 1) {
|
|
366
|
-
|
|
501
|
+
if (is_max) {
|
|
502
|
+
search_implem_1<CMax<float, int64_t>>(
|
|
503
|
+
n, x, k, distances, labels, cq, scaler);
|
|
504
|
+
} else {
|
|
505
|
+
search_implem_1<CMin<float, int64_t>>(
|
|
506
|
+
n, x, k, distances, labels, cq, scaler);
|
|
507
|
+
}
|
|
367
508
|
} else if (impl == 2) {
|
|
368
|
-
|
|
509
|
+
if (is_max) {
|
|
510
|
+
search_implem_2<CMax<uint16_t, int64_t>>(
|
|
511
|
+
n, x, k, distances, labels, cq, scaler);
|
|
512
|
+
} else {
|
|
513
|
+
search_implem_2<CMin<uint16_t, int64_t>>(
|
|
514
|
+
n, x, k, distances, labels, cq, scaler);
|
|
515
|
+
}
|
|
369
516
|
|
|
370
517
|
} else if (impl >= 10 && impl <= 15) {
|
|
371
518
|
size_t ndis = 0, nlist_visited = 0;
|
|
372
519
|
|
|
373
|
-
if (
|
|
520
|
+
if (!multiple_threads) {
|
|
521
|
+
// clang-format off
|
|
374
522
|
if (impl == 12 || impl == 13) {
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
x,
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
labels,
|
|
381
|
-
impl,
|
|
382
|
-
&ndis,
|
|
383
|
-
&nlist_visited,
|
|
384
|
-
scaler);
|
|
523
|
+
std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
|
|
524
|
+
search_implem_12(
|
|
525
|
+
n, x, *handler.get(),
|
|
526
|
+
cq, &ndis, &nlist_visited, scaler);
|
|
527
|
+
|
|
385
528
|
} else if (impl == 14 || impl == 15) {
|
|
386
|
-
|
|
529
|
+
|
|
530
|
+
search_implem_14(
|
|
531
|
+
n, x, k, distances, labels,
|
|
532
|
+
cq, impl, scaler);
|
|
387
533
|
} else {
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
x,
|
|
391
|
-
|
|
392
|
-
distances,
|
|
393
|
-
labels,
|
|
394
|
-
impl,
|
|
395
|
-
&ndis,
|
|
396
|
-
&nlist_visited,
|
|
397
|
-
scaler);
|
|
534
|
+
std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
|
|
535
|
+
search_implem_10(
|
|
536
|
+
n, x, *handler.get(), cq,
|
|
537
|
+
&ndis, &nlist_visited, scaler);
|
|
398
538
|
}
|
|
539
|
+
// clang-format on
|
|
399
540
|
} else {
|
|
400
541
|
// explicitly slice over threads
|
|
401
|
-
int nslice;
|
|
402
|
-
if (
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
size_t lut_size_per_query =
|
|
407
|
-
M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
|
|
408
|
-
|
|
409
|
-
size_t max_lut_size = precomputed_table_max_bytes;
|
|
410
|
-
// how many queries we can handle within mem budget
|
|
411
|
-
size_t nq_ok =
|
|
412
|
-
std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
413
|
-
nslice =
|
|
414
|
-
roundup(std::max(size_t(n / nq_ok), size_t(1)),
|
|
415
|
-
omp_get_max_threads());
|
|
416
|
-
} else {
|
|
417
|
-
// LUTs unlikely to be a limiting factor
|
|
418
|
-
nslice = omp_get_max_threads();
|
|
419
|
-
}
|
|
420
|
-
if (impl == 14 ||
|
|
421
|
-
impl == 15) { // this might require slicing if there are too
|
|
422
|
-
// many queries (for now we keep this simple)
|
|
423
|
-
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
|
|
542
|
+
int nslice = compute_search_nslice(this, n, cq.nprobe);
|
|
543
|
+
if (impl == 14 || impl == 15) {
|
|
544
|
+
// this might require slicing if there are too
|
|
545
|
+
// many queries (for now we keep this simple)
|
|
546
|
+
search_implem_14(n, x, k, distances, labels, cq, impl, scaler);
|
|
424
547
|
} else {
|
|
425
548
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
426
549
|
for (int slice = 0; slice < nslice; slice++) {
|
|
@@ -428,29 +551,23 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
428
551
|
idx_t i1 = n * (slice + 1) / nslice;
|
|
429
552
|
float* dis_i = distances + i0 * k;
|
|
430
553
|
idx_t* lab_i = labels + i0 * k;
|
|
554
|
+
CoarseQuantizedSlice cq_i(cq, i0, i1);
|
|
555
|
+
if (!cq_i.done()) {
|
|
556
|
+
cq_i.quantize_slice(quantizer, x);
|
|
557
|
+
}
|
|
558
|
+
std::unique_ptr<RH> handler(make_knn_handler(
|
|
559
|
+
is_max, impl, i1 - i0, k, dis_i, lab_i));
|
|
560
|
+
// clang-format off
|
|
431
561
|
if (impl == 12 || impl == 13) {
|
|
432
|
-
search_implem_12
|
|
433
|
-
i1 - i0,
|
|
434
|
-
|
|
435
|
-
k,
|
|
436
|
-
dis_i,
|
|
437
|
-
lab_i,
|
|
438
|
-
impl,
|
|
439
|
-
&ndis,
|
|
440
|
-
&nlist_visited,
|
|
441
|
-
scaler);
|
|
562
|
+
search_implem_12(
|
|
563
|
+
i1 - i0, x + i0 * d, *handler.get(),
|
|
564
|
+
cq_i, &ndis, &nlist_visited, scaler);
|
|
442
565
|
} else {
|
|
443
|
-
search_implem_10
|
|
444
|
-
i1 - i0,
|
|
445
|
-
|
|
446
|
-
k,
|
|
447
|
-
dis_i,
|
|
448
|
-
lab_i,
|
|
449
|
-
impl,
|
|
450
|
-
&ndis,
|
|
451
|
-
&nlist_visited,
|
|
452
|
-
scaler);
|
|
566
|
+
search_implem_10(
|
|
567
|
+
i1 - i0, x + i0 * d, *handler.get(),
|
|
568
|
+
cq_i, &ndis, &nlist_visited, scaler);
|
|
453
569
|
}
|
|
570
|
+
// clang-format on
|
|
454
571
|
}
|
|
455
572
|
}
|
|
456
573
|
}
|
|
@@ -462,31 +579,139 @@ void IndexIVFFastScan::search_dispatch_implem(
|
|
|
462
579
|
}
|
|
463
580
|
}
|
|
464
581
|
|
|
465
|
-
|
|
582
|
+
void IndexIVFFastScan::range_search_dispatch_implem(
|
|
583
|
+
idx_t n,
|
|
584
|
+
const float* x,
|
|
585
|
+
float radius,
|
|
586
|
+
RangeSearchResult& rres,
|
|
587
|
+
const CoarseQuantized& cq_in,
|
|
588
|
+
const NormTableScaler* scaler) const {
|
|
589
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
590
|
+
|
|
591
|
+
if (n == 0) {
|
|
592
|
+
return;
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
// actual implementation used
|
|
596
|
+
int impl = implem;
|
|
597
|
+
|
|
598
|
+
if (impl == 0) {
|
|
599
|
+
if (bbs == 32) {
|
|
600
|
+
impl = 12;
|
|
601
|
+
} else {
|
|
602
|
+
impl = 10;
|
|
603
|
+
}
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
CoarseQuantizedWithBuffer cq(cq_in);
|
|
607
|
+
|
|
608
|
+
bool multiple_threads =
|
|
609
|
+
n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
|
|
610
|
+
if (impl >= 100) {
|
|
611
|
+
multiple_threads = false;
|
|
612
|
+
impl -= 100;
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
if (!multiple_threads && !cq.done()) {
|
|
616
|
+
cq.quantize(quantizer, n, x);
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
620
|
+
|
|
621
|
+
if (!multiple_threads) { // single thread
|
|
622
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler;
|
|
623
|
+
if (is_max) {
|
|
624
|
+
handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
|
|
625
|
+
rres, radius, 0));
|
|
626
|
+
} else {
|
|
627
|
+
handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
|
|
628
|
+
rres, radius, 0));
|
|
629
|
+
}
|
|
630
|
+
if (impl == 12) {
|
|
631
|
+
search_implem_12(
|
|
632
|
+
n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
|
|
633
|
+
} else if (impl == 10) {
|
|
634
|
+
search_implem_10(
|
|
635
|
+
n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
|
|
636
|
+
} else {
|
|
637
|
+
FAISS_THROW_FMT("Range search implem %d not impemented", impl);
|
|
638
|
+
}
|
|
639
|
+
} else {
|
|
640
|
+
// explicitly slice over threads
|
|
641
|
+
int nslice = compute_search_nslice(this, n, cq.nprobe);
|
|
642
|
+
#pragma omp parallel
|
|
643
|
+
{
|
|
644
|
+
RangeSearchPartialResult pres(&rres);
|
|
645
|
+
|
|
646
|
+
#pragma omp for reduction(+ : ndis, nlist_visited)
|
|
647
|
+
for (int slice = 0; slice < nslice; slice++) {
|
|
648
|
+
idx_t i0 = n * slice / nslice;
|
|
649
|
+
idx_t i1 = n * (slice + 1) / nslice;
|
|
650
|
+
CoarseQuantizedSlice cq_i(cq, i0, i1);
|
|
651
|
+
if (!cq_i.done()) {
|
|
652
|
+
cq_i.quantize_slice(quantizer, x);
|
|
653
|
+
}
|
|
654
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler;
|
|
655
|
+
if (is_max) {
|
|
656
|
+
handler.reset(new PartialRangeHandler<
|
|
657
|
+
CMax<uint16_t, int64_t>,
|
|
658
|
+
true>(pres, radius, 0, i0, i1));
|
|
659
|
+
} else {
|
|
660
|
+
handler.reset(new PartialRangeHandler<
|
|
661
|
+
CMin<uint16_t, int64_t>,
|
|
662
|
+
true>(pres, radius, 0, i0, i1));
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
if (impl == 12 || impl == 13) {
|
|
666
|
+
search_implem_12(
|
|
667
|
+
i1 - i0,
|
|
668
|
+
x + i0 * d,
|
|
669
|
+
*handler.get(),
|
|
670
|
+
cq_i,
|
|
671
|
+
&ndis,
|
|
672
|
+
&nlist_visited,
|
|
673
|
+
scaler);
|
|
674
|
+
} else {
|
|
675
|
+
search_implem_10(
|
|
676
|
+
i1 - i0,
|
|
677
|
+
x + i0 * d,
|
|
678
|
+
*handler.get(),
|
|
679
|
+
cq_i,
|
|
680
|
+
&ndis,
|
|
681
|
+
&nlist_visited,
|
|
682
|
+
scaler);
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
pres.finalize();
|
|
686
|
+
}
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
indexIVF_stats.nq += n;
|
|
690
|
+
indexIVF_stats.ndis += ndis;
|
|
691
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
template <class C>
|
|
466
695
|
void IndexIVFFastScan::search_implem_1(
|
|
467
696
|
idx_t n,
|
|
468
697
|
const float* x,
|
|
469
698
|
idx_t k,
|
|
470
699
|
float* distances,
|
|
471
700
|
idx_t* labels,
|
|
472
|
-
const
|
|
701
|
+
const CoarseQuantized& cq,
|
|
702
|
+
const NormTableScaler* scaler) const {
|
|
473
703
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
474
704
|
|
|
475
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
476
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
477
|
-
|
|
478
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
479
|
-
|
|
480
705
|
size_t dim12 = ksub * M;
|
|
481
706
|
AlignedTable<float> dis_tables;
|
|
482
707
|
AlignedTable<float> biases;
|
|
483
708
|
|
|
484
|
-
compute_LUT(n, x,
|
|
709
|
+
compute_LUT(n, x, cq, dis_tables, biases);
|
|
485
710
|
|
|
486
711
|
bool single_LUT = !lookup_table_is_3d();
|
|
487
712
|
|
|
488
713
|
size_t ndis = 0, nlist_visited = 0;
|
|
489
|
-
|
|
714
|
+
size_t nprobe = cq.nprobe;
|
|
490
715
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
491
716
|
for (idx_t i = 0; i < n; i++) {
|
|
492
717
|
int64_t* heap_ids = labels + i * k;
|
|
@@ -501,7 +726,7 @@ void IndexIVFFastScan::search_implem_1(
|
|
|
501
726
|
if (!single_LUT) {
|
|
502
727
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
503
728
|
}
|
|
504
|
-
idx_t list_no =
|
|
729
|
+
idx_t list_no = cq.ids[i * nprobe + j];
|
|
505
730
|
if (list_no < 0)
|
|
506
731
|
continue;
|
|
507
732
|
size_t ls = orig_invlists->list_size(list_no);
|
|
@@ -533,38 +758,28 @@ void IndexIVFFastScan::search_implem_1(
|
|
|
533
758
|
indexIVF_stats.nlist += nlist_visited;
|
|
534
759
|
}
|
|
535
760
|
|
|
536
|
-
template <class C
|
|
761
|
+
template <class C>
|
|
537
762
|
void IndexIVFFastScan::search_implem_2(
|
|
538
763
|
idx_t n,
|
|
539
764
|
const float* x,
|
|
540
765
|
idx_t k,
|
|
541
766
|
float* distances,
|
|
542
767
|
idx_t* labels,
|
|
543
|
-
const
|
|
768
|
+
const CoarseQuantized& cq,
|
|
769
|
+
const NormTableScaler* scaler) const {
|
|
544
770
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
545
771
|
|
|
546
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
547
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
548
|
-
|
|
549
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
550
|
-
|
|
551
772
|
size_t dim12 = ksub * M2;
|
|
552
773
|
AlignedTable<uint8_t> dis_tables;
|
|
553
774
|
AlignedTable<uint16_t> biases;
|
|
554
775
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
555
776
|
|
|
556
|
-
compute_LUT_uint8(
|
|
557
|
-
n,
|
|
558
|
-
x,
|
|
559
|
-
coarse_ids.get(),
|
|
560
|
-
coarse_dis.get(),
|
|
561
|
-
dis_tables,
|
|
562
|
-
biases,
|
|
563
|
-
normalizers.get());
|
|
777
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
564
778
|
|
|
565
779
|
bool single_LUT = !lookup_table_is_3d();
|
|
566
780
|
|
|
567
781
|
size_t ndis = 0, nlist_visited = 0;
|
|
782
|
+
size_t nprobe = cq.nprobe;
|
|
568
783
|
|
|
569
784
|
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
570
785
|
for (idx_t i = 0; i < n; i++) {
|
|
@@ -581,7 +796,7 @@ void IndexIVFFastScan::search_implem_2(
|
|
|
581
796
|
if (!single_LUT) {
|
|
582
797
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
583
798
|
}
|
|
584
|
-
idx_t list_no =
|
|
799
|
+
idx_t list_no = cq.ids[i * nprobe + j];
|
|
585
800
|
if (list_no < 0)
|
|
586
801
|
continue;
|
|
587
802
|
size_t ls = orig_invlists->list_size(list_no);
|
|
@@ -626,171 +841,99 @@ void IndexIVFFastScan::search_implem_2(
|
|
|
626
841
|
indexIVF_stats.nlist += nlist_visited;
|
|
627
842
|
}
|
|
628
843
|
|
|
629
|
-
template <class C, class Scaler>
|
|
630
844
|
void IndexIVFFastScan::search_implem_10(
|
|
631
845
|
idx_t n,
|
|
632
846
|
const float* x,
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
idx_t* labels,
|
|
636
|
-
int impl,
|
|
847
|
+
SIMDResultHandlerToFloat& handler,
|
|
848
|
+
const CoarseQuantized& cq,
|
|
637
849
|
size_t* ndis_out,
|
|
638
850
|
size_t* nlist_out,
|
|
639
|
-
const
|
|
640
|
-
memset(distances, -1, sizeof(float) * k * n);
|
|
641
|
-
memset(labels, -1, sizeof(idx_t) * k * n);
|
|
642
|
-
|
|
643
|
-
using HeapHC = HeapHandler<C, true>;
|
|
644
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
645
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
646
|
-
|
|
647
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
648
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
649
|
-
|
|
650
|
-
uint64_t times[10];
|
|
651
|
-
memset(times, 0, sizeof(times));
|
|
652
|
-
int ti = 0;
|
|
653
|
-
#define TIC times[ti++] = get_cy()
|
|
654
|
-
TIC;
|
|
655
|
-
|
|
656
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
657
|
-
|
|
658
|
-
TIC;
|
|
659
|
-
|
|
851
|
+
const NormTableScaler* scaler) const {
|
|
660
852
|
size_t dim12 = ksub * M2;
|
|
661
853
|
AlignedTable<uint8_t> dis_tables;
|
|
662
854
|
AlignedTable<uint16_t> biases;
|
|
663
855
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
664
856
|
|
|
665
|
-
compute_LUT_uint8(
|
|
666
|
-
n,
|
|
667
|
-
x,
|
|
668
|
-
coarse_ids.get(),
|
|
669
|
-
coarse_dis.get(),
|
|
670
|
-
dis_tables,
|
|
671
|
-
biases,
|
|
672
|
-
normalizers.get());
|
|
673
|
-
|
|
674
|
-
TIC;
|
|
857
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
675
858
|
|
|
676
859
|
bool single_LUT = !lookup_table_is_3d();
|
|
677
860
|
|
|
678
|
-
|
|
679
|
-
|
|
861
|
+
size_t ndis = 0;
|
|
862
|
+
int qmap1[1];
|
|
680
863
|
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
const uint8_t* LUT = nullptr;
|
|
685
|
-
int qmap1[1] = {0};
|
|
686
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
687
|
-
|
|
688
|
-
if (k == 1) {
|
|
689
|
-
handler.reset(new SingleResultHC(1, 0));
|
|
690
|
-
} else if (impl == 10) {
|
|
691
|
-
handler.reset(new HeapHC(
|
|
692
|
-
1, tmp_distances.get(), labels + i * k, k, 0));
|
|
693
|
-
} else if (impl == 11) {
|
|
694
|
-
handler.reset(new ReservoirHC(1, 0, k, 2 * k));
|
|
695
|
-
} else {
|
|
696
|
-
FAISS_THROW_MSG("invalid");
|
|
697
|
-
}
|
|
864
|
+
handler.q_map = qmap1;
|
|
865
|
+
handler.begin(skip & 16 ? nullptr : normalizers.get());
|
|
866
|
+
size_t nprobe = cq.nprobe;
|
|
698
867
|
|
|
699
|
-
|
|
868
|
+
for (idx_t i = 0; i < n; i++) {
|
|
869
|
+
const uint8_t* LUT = nullptr;
|
|
870
|
+
qmap1[0] = i;
|
|
700
871
|
|
|
701
|
-
|
|
702
|
-
|
|
872
|
+
if (single_LUT) {
|
|
873
|
+
LUT = dis_tables.get() + i * dim12;
|
|
874
|
+
}
|
|
875
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
876
|
+
size_t ij = i * nprobe + j;
|
|
877
|
+
if (!single_LUT) {
|
|
878
|
+
LUT = dis_tables.get() + ij * dim12;
|
|
879
|
+
}
|
|
880
|
+
if (biases.get()) {
|
|
881
|
+
handler.dbias = biases.get() + ij;
|
|
703
882
|
}
|
|
704
|
-
for (idx_t j = 0; j < nprobe; j++) {
|
|
705
|
-
size_t ij = i * nprobe + j;
|
|
706
|
-
if (!single_LUT) {
|
|
707
|
-
LUT = dis_tables.get() + ij * dim12;
|
|
708
|
-
}
|
|
709
|
-
if (biases.get()) {
|
|
710
|
-
handler->dbias = biases.get() + ij;
|
|
711
|
-
}
|
|
712
|
-
|
|
713
|
-
idx_t list_no = coarse_ids[ij];
|
|
714
|
-
if (list_no < 0)
|
|
715
|
-
continue;
|
|
716
|
-
size_t ls = invlists->list_size(list_no);
|
|
717
|
-
if (ls == 0)
|
|
718
|
-
continue;
|
|
719
883
|
|
|
720
|
-
|
|
721
|
-
|
|
884
|
+
idx_t list_no = cq.ids[ij];
|
|
885
|
+
if (list_no < 0) {
|
|
886
|
+
continue;
|
|
887
|
+
}
|
|
888
|
+
size_t ls = invlists->list_size(list_no);
|
|
889
|
+
if (ls == 0) {
|
|
890
|
+
continue;
|
|
891
|
+
}
|
|
722
892
|
|
|
723
|
-
|
|
724
|
-
|
|
893
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
894
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
725
895
|
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
729
|
-
pq4_accumulate_loop( \
|
|
730
|
-
1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
|
|
731
|
-
}
|
|
732
|
-
DISPATCH(HeapHC)
|
|
733
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
734
|
-
#undef DISPATCH
|
|
896
|
+
handler.ntotal = ls;
|
|
897
|
+
handler.id_map = ids.get();
|
|
735
898
|
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
899
|
+
pq4_accumulate_loop(
|
|
900
|
+
1,
|
|
901
|
+
roundup(ls, bbs),
|
|
902
|
+
bbs,
|
|
903
|
+
M2,
|
|
904
|
+
codes.get(),
|
|
905
|
+
LUT,
|
|
906
|
+
handler,
|
|
907
|
+
scaler);
|
|
739
908
|
|
|
740
|
-
|
|
741
|
-
distances + i * k,
|
|
742
|
-
labels + i * k,
|
|
743
|
-
skip & 16 ? nullptr : normalizers.get() + i * 2);
|
|
909
|
+
ndis++;
|
|
744
910
|
}
|
|
745
911
|
}
|
|
912
|
+
handler.end();
|
|
746
913
|
*ndis_out = ndis;
|
|
747
914
|
*nlist_out = nlist;
|
|
748
915
|
}
|
|
749
916
|
|
|
750
|
-
template <class C, class Scaler>
|
|
751
917
|
void IndexIVFFastScan::search_implem_12(
|
|
752
918
|
idx_t n,
|
|
753
919
|
const float* x,
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
idx_t* labels,
|
|
757
|
-
int impl,
|
|
920
|
+
SIMDResultHandlerToFloat& handler,
|
|
921
|
+
const CoarseQuantized& cq,
|
|
758
922
|
size_t* ndis_out,
|
|
759
923
|
size_t* nlist_out,
|
|
760
|
-
const
|
|
924
|
+
const NormTableScaler* scaler) const {
|
|
761
925
|
if (n == 0) { // does not work well with reservoir
|
|
762
926
|
return;
|
|
763
927
|
}
|
|
764
928
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
765
929
|
|
|
766
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
767
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
768
|
-
|
|
769
|
-
uint64_t times[10];
|
|
770
|
-
memset(times, 0, sizeof(times));
|
|
771
|
-
int ti = 0;
|
|
772
|
-
#define TIC times[ti++] = get_cy()
|
|
773
|
-
TIC;
|
|
774
|
-
|
|
775
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
776
|
-
|
|
777
|
-
TIC;
|
|
778
|
-
|
|
779
930
|
size_t dim12 = ksub * M2;
|
|
780
931
|
AlignedTable<uint8_t> dis_tables;
|
|
781
932
|
AlignedTable<uint16_t> biases;
|
|
782
933
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
783
934
|
|
|
784
|
-
compute_LUT_uint8(
|
|
785
|
-
|
|
786
|
-
x,
|
|
787
|
-
coarse_ids.get(),
|
|
788
|
-
coarse_dis.get(),
|
|
789
|
-
dis_tables,
|
|
790
|
-
biases,
|
|
791
|
-
normalizers.get());
|
|
792
|
-
|
|
793
|
-
TIC;
|
|
935
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
936
|
+
handler.begin(skip & 16 ? nullptr : normalizers.get());
|
|
794
937
|
|
|
795
938
|
struct QC {
|
|
796
939
|
int qno; // sequence number of the query
|
|
@@ -798,14 +941,15 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
798
941
|
int rank; // this is the rank'th result of the coarse quantizer
|
|
799
942
|
};
|
|
800
943
|
bool single_LUT = !lookup_table_is_3d();
|
|
944
|
+
size_t nprobe = cq.nprobe;
|
|
801
945
|
|
|
802
946
|
std::vector<QC> qcs;
|
|
803
947
|
{
|
|
804
948
|
int ij = 0;
|
|
805
949
|
for (int i = 0; i < n; i++) {
|
|
806
950
|
for (int j = 0; j < nprobe; j++) {
|
|
807
|
-
if (
|
|
808
|
-
qcs.push_back(QC{i, int(
|
|
951
|
+
if (cq.ids[ij] >= 0) {
|
|
952
|
+
qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
|
|
809
953
|
}
|
|
810
954
|
ij++;
|
|
811
955
|
}
|
|
@@ -814,42 +958,21 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
814
958
|
return a.list_no < b.list_no;
|
|
815
959
|
});
|
|
816
960
|
}
|
|
817
|
-
TIC;
|
|
818
|
-
|
|
819
961
|
// prepare the result handlers
|
|
820
962
|
|
|
821
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
822
|
-
AlignedTable<uint16_t> tmp_distances;
|
|
823
|
-
|
|
824
|
-
using HeapHC = HeapHandler<C, true>;
|
|
825
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
826
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
827
|
-
|
|
828
|
-
if (k == 1) {
|
|
829
|
-
handler.reset(new SingleResultHC(n, 0));
|
|
830
|
-
} else if (impl == 12) {
|
|
831
|
-
tmp_distances.resize(n * k);
|
|
832
|
-
handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
|
|
833
|
-
} else if (impl == 13) {
|
|
834
|
-
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
835
|
-
}
|
|
836
|
-
|
|
837
963
|
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
838
964
|
|
|
839
965
|
std::vector<uint16_t> tmp_bias;
|
|
840
966
|
if (biases.get()) {
|
|
841
967
|
tmp_bias.resize(qbs2);
|
|
842
|
-
handler
|
|
968
|
+
handler.dbias = tmp_bias.data();
|
|
843
969
|
}
|
|
844
|
-
TIC;
|
|
845
970
|
|
|
846
971
|
size_t ndis = 0;
|
|
847
972
|
|
|
848
973
|
size_t i0 = 0;
|
|
849
974
|
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
850
975
|
while (i0 < qcs.size()) {
|
|
851
|
-
uint64_t tt0 = get_cy();
|
|
852
|
-
|
|
853
976
|
// find all queries that access this inverted list
|
|
854
977
|
int list_no = qcs[i0].list_no;
|
|
855
978
|
size_t i1 = i0 + 1;
|
|
@@ -897,93 +1020,47 @@ void IndexIVFFastScan::search_implem_12(
|
|
|
897
1020
|
|
|
898
1021
|
// prepare the handler
|
|
899
1022
|
|
|
900
|
-
handler
|
|
901
|
-
handler
|
|
902
|
-
handler
|
|
903
|
-
uint64_t tt1 = get_cy();
|
|
904
|
-
|
|
905
|
-
#define DISPATCH(classHC) \
|
|
906
|
-
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
907
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
908
|
-
pq4_accumulate_loop_qbs( \
|
|
909
|
-
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
910
|
-
}
|
|
911
|
-
DISPATCH(HeapHC)
|
|
912
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
913
|
-
|
|
914
|
-
// prepare for next loop
|
|
915
|
-
i0 = i1;
|
|
1023
|
+
handler.ntotal = list_size;
|
|
1024
|
+
handler.q_map = q_map.data();
|
|
1025
|
+
handler.id_map = ids.get();
|
|
916
1026
|
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
1027
|
+
pq4_accumulate_loop_qbs(
|
|
1028
|
+
qbs, list_size, M2, codes.get(), LUT.get(), handler, scaler);
|
|
1029
|
+
// prepare for next loop
|
|
1030
|
+
i0 = i1;
|
|
920
1031
|
}
|
|
921
|
-
TIC;
|
|
922
|
-
|
|
923
|
-
// labels is in-place for HeapHC
|
|
924
|
-
handler->to_flat_arrays(
|
|
925
|
-
distances, labels, skip & 16 ? nullptr : normalizers.get());
|
|
926
1032
|
|
|
927
|
-
|
|
1033
|
+
handler.end();
|
|
928
1034
|
|
|
929
1035
|
// these stats are not thread-safe
|
|
930
1036
|
|
|
931
|
-
for (int i = 1; i < ti; i++) {
|
|
932
|
-
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
|
|
933
|
-
}
|
|
934
1037
|
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
935
1038
|
IVFFastScan_stats.t_scan += t_scan;
|
|
936
1039
|
|
|
937
|
-
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
938
|
-
for (int i = 0; i < 4; i++) {
|
|
939
|
-
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
940
|
-
}
|
|
941
|
-
}
|
|
942
|
-
|
|
943
1040
|
*ndis_out = ndis;
|
|
944
1041
|
*nlist_out = nlist;
|
|
945
1042
|
}
|
|
946
1043
|
|
|
947
|
-
template <class C, class Scaler>
|
|
948
1044
|
void IndexIVFFastScan::search_implem_14(
|
|
949
1045
|
idx_t n,
|
|
950
1046
|
const float* x,
|
|
951
1047
|
idx_t k,
|
|
952
1048
|
float* distances,
|
|
953
1049
|
idx_t* labels,
|
|
1050
|
+
const CoarseQuantized& cq,
|
|
954
1051
|
int impl,
|
|
955
|
-
const
|
|
1052
|
+
const NormTableScaler* scaler) const {
|
|
956
1053
|
if (n == 0) { // does not work well with reservoir
|
|
957
1054
|
return;
|
|
958
1055
|
}
|
|
959
1056
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
960
1057
|
|
|
961
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
962
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
963
|
-
|
|
964
|
-
uint64_t ttg0 = get_cy();
|
|
965
|
-
|
|
966
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
967
|
-
|
|
968
|
-
uint64_t ttg1 = get_cy();
|
|
969
|
-
uint64_t coarse_search_tt = ttg1 - ttg0;
|
|
970
|
-
|
|
971
1058
|
size_t dim12 = ksub * M2;
|
|
972
1059
|
AlignedTable<uint8_t> dis_tables;
|
|
973
1060
|
AlignedTable<uint16_t> biases;
|
|
974
1061
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
975
1062
|
|
|
976
|
-
compute_LUT_uint8(
|
|
977
|
-
n,
|
|
978
|
-
x,
|
|
979
|
-
coarse_ids.get(),
|
|
980
|
-
coarse_dis.get(),
|
|
981
|
-
dis_tables,
|
|
982
|
-
biases,
|
|
983
|
-
normalizers.get());
|
|
984
|
-
|
|
985
|
-
uint64_t ttg2 = get_cy();
|
|
986
|
-
uint64_t lut_compute_tt = ttg2 - ttg1;
|
|
1063
|
+
compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
|
|
987
1064
|
|
|
988
1065
|
struct QC {
|
|
989
1066
|
int qno; // sequence number of the query
|
|
@@ -991,14 +1068,15 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
991
1068
|
int rank; // this is the rank'th result of the coarse quantizer
|
|
992
1069
|
};
|
|
993
1070
|
bool single_LUT = !lookup_table_is_3d();
|
|
1071
|
+
size_t nprobe = cq.nprobe;
|
|
994
1072
|
|
|
995
1073
|
std::vector<QC> qcs;
|
|
996
1074
|
{
|
|
997
1075
|
int ij = 0;
|
|
998
1076
|
for (int i = 0; i < n; i++) {
|
|
999
1077
|
for (int j = 0; j < nprobe; j++) {
|
|
1000
|
-
if (
|
|
1001
|
-
qcs.push_back(QC{i, int(
|
|
1078
|
+
if (cq.ids[ij] >= 0) {
|
|
1079
|
+
qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
|
|
1002
1080
|
}
|
|
1003
1081
|
ij++;
|
|
1004
1082
|
}
|
|
@@ -1036,14 +1114,13 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1036
1114
|
ses.push_back(SE{i0_l, i1, list_size});
|
|
1037
1115
|
i0_l = i1;
|
|
1038
1116
|
}
|
|
1039
|
-
uint64_t ttg3 = get_cy();
|
|
1040
|
-
uint64_t compute_clusters_tt = ttg3 - ttg2;
|
|
1041
1117
|
|
|
1042
1118
|
// function to handle the global heap
|
|
1119
|
+
bool is_max = !is_similarity_metric(metric_type);
|
|
1043
1120
|
using HeapForIP = CMin<float, idx_t>;
|
|
1044
1121
|
using HeapForL2 = CMax<float, idx_t>;
|
|
1045
1122
|
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
1046
|
-
if (
|
|
1123
|
+
if (!is_max) {
|
|
1047
1124
|
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
1048
1125
|
} else {
|
|
1049
1126
|
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
@@ -1054,7 +1131,7 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1054
1131
|
const idx_t* local_idx,
|
|
1055
1132
|
float* simi,
|
|
1056
1133
|
idx_t* idxi) {
|
|
1057
|
-
if (
|
|
1134
|
+
if (!is_max) {
|
|
1058
1135
|
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
|
|
1059
1136
|
} else {
|
|
1060
1137
|
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
|
|
@@ -1062,14 +1139,12 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1062
1139
|
};
|
|
1063
1140
|
|
|
1064
1141
|
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
1065
|
-
if (
|
|
1142
|
+
if (!is_max) {
|
|
1066
1143
|
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
1067
1144
|
} else {
|
|
1068
1145
|
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
1069
1146
|
}
|
|
1070
1147
|
};
|
|
1071
|
-
uint64_t ttg4 = get_cy();
|
|
1072
|
-
uint64_t fn_tt = ttg4 - ttg3;
|
|
1073
1148
|
|
|
1074
1149
|
size_t ndis = 0;
|
|
1075
1150
|
size_t nlist_visited = 0;
|
|
@@ -1081,22 +1156,9 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1081
1156
|
std::vector<float> local_dis(k * n);
|
|
1082
1157
|
|
|
1083
1158
|
// prepare the result handlers
|
|
1084
|
-
std::unique_ptr<
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
using HeapHC = HeapHandler<C, true>;
|
|
1088
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
1089
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
1090
|
-
|
|
1091
|
-
if (k == 1) {
|
|
1092
|
-
handler.reset(new SingleResultHC(n, 0));
|
|
1093
|
-
} else if (impl == 14) {
|
|
1094
|
-
tmp_distances.resize(n * k);
|
|
1095
|
-
handler.reset(
|
|
1096
|
-
new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
|
|
1097
|
-
} else if (impl == 15) {
|
|
1098
|
-
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
1099
|
-
}
|
|
1159
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
|
|
1160
|
+
is_max, impl, n, k, local_dis.data(), local_idx.data()));
|
|
1161
|
+
handler->begin(normalizers.get());
|
|
1100
1162
|
|
|
1101
1163
|
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
1102
1164
|
|
|
@@ -1105,15 +1167,10 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1105
1167
|
tmp_bias.resize(qbs2);
|
|
1106
1168
|
handler->dbias = tmp_bias.data();
|
|
1107
1169
|
}
|
|
1108
|
-
|
|
1109
|
-
uint64_t ttg5 = get_cy();
|
|
1110
|
-
uint64_t handler_tt = ttg5 - ttg4;
|
|
1111
|
-
|
|
1112
1170
|
std::set<int> q_set;
|
|
1113
1171
|
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
1114
1172
|
#pragma omp for schedule(dynamic)
|
|
1115
1173
|
for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
|
|
1116
|
-
uint64_t tt0 = get_cy();
|
|
1117
1174
|
size_t i0 = ses[cluster].start;
|
|
1118
1175
|
size_t i1 = ses[cluster].end;
|
|
1119
1176
|
size_t list_size = ses[cluster].list_size;
|
|
@@ -1153,28 +1210,21 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1153
1210
|
handler->ntotal = list_size;
|
|
1154
1211
|
handler->q_map = q_map.data();
|
|
1155
1212
|
handler->id_map = ids.get();
|
|
1156
|
-
uint64_t tt1 = get_cy();
|
|
1157
|
-
|
|
1158
|
-
#define DISPATCH(classHC) \
|
|
1159
|
-
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
1160
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
1161
|
-
pq4_accumulate_loop_qbs( \
|
|
1162
|
-
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
1163
|
-
}
|
|
1164
|
-
DISPATCH(HeapHC)
|
|
1165
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
1166
1213
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1214
|
+
pq4_accumulate_loop_qbs(
|
|
1215
|
+
qbs,
|
|
1216
|
+
list_size,
|
|
1217
|
+
M2,
|
|
1218
|
+
codes.get(),
|
|
1219
|
+
LUT.get(),
|
|
1220
|
+
*handler.get(),
|
|
1221
|
+
scaler);
|
|
1170
1222
|
}
|
|
1171
1223
|
|
|
1172
1224
|
// labels is in-place for HeapHC
|
|
1173
|
-
handler->
|
|
1174
|
-
local_dis.data(),
|
|
1175
|
-
local_idx.data(),
|
|
1176
|
-
skip & 16 ? nullptr : normalizers.get());
|
|
1225
|
+
handler->end();
|
|
1177
1226
|
|
|
1227
|
+
// merge per-thread results
|
|
1178
1228
|
#pragma omp single
|
|
1179
1229
|
{
|
|
1180
1230
|
// we init the results as a heap
|
|
@@ -1197,12 +1247,6 @@ void IndexIVFFastScan::search_implem_14(
|
|
|
1197
1247
|
|
|
1198
1248
|
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
1199
1249
|
IVFFastScan_stats.t_scan += t_scan;
|
|
1200
|
-
|
|
1201
|
-
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
1202
|
-
for (int i = 0; i < 4; i++) {
|
|
1203
|
-
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
1204
|
-
}
|
|
1205
|
-
}
|
|
1206
1250
|
}
|
|
1207
1251
|
#pragma omp barrier
|
|
1208
1252
|
#pragma omp single
|
|
@@ -1272,20 +1316,4 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
|
|
|
1272
1316
|
|
|
1273
1317
|
IVFFastScanStats IVFFastScan_stats;
|
|
1274
1318
|
|
|
1275
|
-
template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
|
|
1276
|
-
idx_t n,
|
|
1277
|
-
const float* x,
|
|
1278
|
-
idx_t k,
|
|
1279
|
-
float* distances,
|
|
1280
|
-
idx_t* labels,
|
|
1281
|
-
const NormTableScaler& scaler) const;
|
|
1282
|
-
|
|
1283
|
-
template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
|
|
1284
|
-
idx_t n,
|
|
1285
|
-
const float* x,
|
|
1286
|
-
idx_t k,
|
|
1287
|
-
float* distances,
|
|
1288
|
-
idx_t* labels,
|
|
1289
|
-
const NormTableScaler& scaler) const;
|
|
1290
|
-
|
|
1291
1319
|
} // namespace faiss
|