faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -12,28 +12,196 @@
|
|
|
12
12
|
#pragma once
|
|
13
13
|
|
|
14
14
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
15
|
+
#include <faiss/impl/FaissException.h>
|
|
16
|
+
#include <faiss/impl/IDSelector.h>
|
|
15
17
|
#include <faiss/utils/Heap.h>
|
|
16
18
|
#include <faiss/utils/partitioning.h>
|
|
17
19
|
|
|
20
|
+
#include <algorithm>
|
|
21
|
+
#include <iostream>
|
|
22
|
+
|
|
18
23
|
namespace faiss {
|
|
19
24
|
|
|
20
25
|
/*****************************************************************
|
|
21
|
-
*
|
|
26
|
+
* The classes below are intended to be used as template arguments
|
|
27
|
+
* they handle results for batches of queries (size nq).
|
|
28
|
+
* They can be called in two ways:
|
|
29
|
+
* - by instanciating a SingleResultHandler that tracks results for a single
|
|
30
|
+
* query
|
|
31
|
+
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
|
32
|
+
* results is submitted
|
|
33
|
+
* All classes are templated on C which to define wheter the min or the max of
|
|
34
|
+
* results is to be kept, and on sel, so that the codepaths for with / without
|
|
35
|
+
* selector can be separated at compile time.
|
|
22
36
|
*****************************************************************/
|
|
23
37
|
|
|
38
|
+
template <class C, bool use_sel = false>
|
|
39
|
+
struct BlockResultHandler {
|
|
40
|
+
size_t nq; // number of queries for which we search
|
|
41
|
+
const IDSelector* sel;
|
|
42
|
+
|
|
43
|
+
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
|
|
44
|
+
: nq(nq), sel(sel) {
|
|
45
|
+
assert(!use_sel || sel);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
// currently handled query range
|
|
49
|
+
size_t i0 = 0, i1 = 0;
|
|
50
|
+
|
|
51
|
+
// start collecting results for queries [i0, i1)
|
|
52
|
+
virtual void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
53
|
+
this->i0 = i0_2;
|
|
54
|
+
this->i1 = i1_2;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// add results for queries [i0, i1) and database [j0, j1)
|
|
58
|
+
virtual void add_results(size_t, size_t, const typename C::T*) {}
|
|
59
|
+
|
|
60
|
+
// series of results for queries i0..i1 is done
|
|
61
|
+
virtual void end_multiple() {}
|
|
62
|
+
|
|
63
|
+
virtual ~BlockResultHandler() {}
|
|
64
|
+
|
|
65
|
+
bool is_in_selection(idx_t i) const {
|
|
66
|
+
return !use_sel || sel->is_member(i);
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
// handler for a single query
|
|
24
71
|
template <class C>
|
|
25
|
-
struct
|
|
72
|
+
struct ResultHandler {
|
|
73
|
+
// if not better than threshold, then not necessary to call add_result
|
|
74
|
+
typename C::T threshold = C::neutral();
|
|
75
|
+
|
|
76
|
+
// return whether threshold was updated
|
|
77
|
+
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
|
|
78
|
+
|
|
79
|
+
virtual ~ResultHandler() {}
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
/*****************************************************************
|
|
83
|
+
* Single best result handler.
|
|
84
|
+
* Tracks the only best result, thus avoiding storing
|
|
85
|
+
* some temporary data in memory.
|
|
86
|
+
*****************************************************************/
|
|
87
|
+
|
|
88
|
+
template <class C, bool use_sel = false>
|
|
89
|
+
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
90
|
+
using T = typename C::T;
|
|
91
|
+
using TI = typename C::TI;
|
|
92
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
93
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
94
|
+
|
|
95
|
+
// contains exactly nq elements
|
|
96
|
+
T* dis_tab;
|
|
97
|
+
// contains exactly nq elements
|
|
98
|
+
TI* ids_tab;
|
|
99
|
+
|
|
100
|
+
Top1BlockResultHandler(
|
|
101
|
+
size_t nq,
|
|
102
|
+
T* dis_tab,
|
|
103
|
+
TI* ids_tab,
|
|
104
|
+
const IDSelector* sel = nullptr)
|
|
105
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
106
|
+
dis_tab(dis_tab),
|
|
107
|
+
ids_tab(ids_tab) {}
|
|
108
|
+
|
|
109
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
110
|
+
Top1BlockResultHandler& hr;
|
|
111
|
+
using ResultHandler<C>::threshold;
|
|
112
|
+
|
|
113
|
+
TI min_idx;
|
|
114
|
+
size_t current_idx = 0;
|
|
115
|
+
|
|
116
|
+
explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
|
|
117
|
+
|
|
118
|
+
/// begin results for query # i
|
|
119
|
+
void begin(const size_t current_idx_2) {
|
|
120
|
+
this->current_idx = current_idx_2;
|
|
121
|
+
threshold = C::neutral();
|
|
122
|
+
min_idx = -1;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
/// add one result for query i
|
|
126
|
+
bool add_result(T dis, TI idx) final {
|
|
127
|
+
if (C::cmp(this->threshold, dis)) {
|
|
128
|
+
threshold = dis;
|
|
129
|
+
min_idx = idx;
|
|
130
|
+
return true;
|
|
131
|
+
}
|
|
132
|
+
return false;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/// series of results for query i is done
|
|
136
|
+
void end() {
|
|
137
|
+
hr.dis_tab[current_idx] = threshold;
|
|
138
|
+
hr.ids_tab[current_idx] = min_idx;
|
|
139
|
+
}
|
|
140
|
+
};
|
|
141
|
+
|
|
142
|
+
/// begin
|
|
143
|
+
void begin_multiple(size_t i0, size_t i1) final {
|
|
144
|
+
this->i0 = i0;
|
|
145
|
+
this->i1 = i1;
|
|
146
|
+
|
|
147
|
+
for (size_t i = i0; i < i1; i++) {
|
|
148
|
+
this->dis_tab[i] = C::neutral();
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
/// add results for query i0..i1 and j0..j1
|
|
153
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
|
|
154
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
155
|
+
const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
|
|
156
|
+
|
|
157
|
+
auto& min_distance = this->dis_tab[i];
|
|
158
|
+
auto& min_index = this->ids_tab[i];
|
|
159
|
+
|
|
160
|
+
for (size_t j = j0; j < j1; j++) {
|
|
161
|
+
const T distance = dis_tab_i[j];
|
|
162
|
+
|
|
163
|
+
if (C::cmp(min_distance, distance)) {
|
|
164
|
+
min_distance = distance;
|
|
165
|
+
min_index = j;
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
void add_result(const size_t i, const T dis, const TI idx) {
|
|
172
|
+
auto& min_distance = this->dis_tab[i];
|
|
173
|
+
auto& min_index = this->ids_tab[i];
|
|
174
|
+
|
|
175
|
+
if (C::cmp(min_distance, dis)) {
|
|
176
|
+
min_distance = dis;
|
|
177
|
+
min_index = idx;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
};
|
|
181
|
+
|
|
182
|
+
/*****************************************************************
|
|
183
|
+
* Heap based result handler
|
|
184
|
+
*****************************************************************/
|
|
185
|
+
|
|
186
|
+
template <class C, bool use_sel = false>
|
|
187
|
+
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
26
188
|
using T = typename C::T;
|
|
27
189
|
using TI = typename C::TI;
|
|
190
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
191
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
28
192
|
|
|
29
|
-
int nq;
|
|
30
193
|
T* heap_dis_tab;
|
|
31
194
|
TI* heap_ids_tab;
|
|
32
195
|
|
|
33
196
|
int64_t k; // number of results to keep
|
|
34
197
|
|
|
35
|
-
|
|
36
|
-
|
|
198
|
+
HeapBlockResultHandler(
|
|
199
|
+
size_t nq,
|
|
200
|
+
T* heap_dis_tab,
|
|
201
|
+
TI* heap_ids_tab,
|
|
202
|
+
size_t k,
|
|
203
|
+
const IDSelector* sel = nullptr)
|
|
204
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
37
205
|
heap_dis_tab(heap_dis_tab),
|
|
38
206
|
heap_ids_tab(heap_ids_tab),
|
|
39
207
|
k(k) {}
|
|
@@ -43,30 +211,33 @@ struct HeapResultHandler {
|
|
|
43
211
|
* called from 1 thread)
|
|
44
212
|
*/
|
|
45
213
|
|
|
46
|
-
struct SingleResultHandler {
|
|
47
|
-
|
|
214
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
215
|
+
HeapBlockResultHandler& hr;
|
|
216
|
+
using ResultHandler<C>::threshold;
|
|
48
217
|
size_t k;
|
|
49
218
|
|
|
50
219
|
T* heap_dis;
|
|
51
220
|
TI* heap_ids;
|
|
52
|
-
T thresh;
|
|
53
221
|
|
|
54
|
-
SingleResultHandler(
|
|
222
|
+
explicit SingleResultHandler(HeapBlockResultHandler& hr)
|
|
223
|
+
: hr(hr), k(hr.k) {}
|
|
55
224
|
|
|
56
225
|
/// begin results for query # i
|
|
57
226
|
void begin(size_t i) {
|
|
58
227
|
heap_dis = hr.heap_dis_tab + i * k;
|
|
59
228
|
heap_ids = hr.heap_ids_tab + i * k;
|
|
60
229
|
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
61
|
-
|
|
230
|
+
threshold = heap_dis[0];
|
|
62
231
|
}
|
|
63
232
|
|
|
64
233
|
/// add one result for query i
|
|
65
|
-
|
|
66
|
-
if (C::cmp(
|
|
234
|
+
bool add_result(T dis, TI idx) final {
|
|
235
|
+
if (C::cmp(threshold, dis)) {
|
|
67
236
|
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
|
68
|
-
|
|
237
|
+
threshold = heap_dis[0];
|
|
238
|
+
return true;
|
|
69
239
|
}
|
|
240
|
+
return false;
|
|
70
241
|
}
|
|
71
242
|
|
|
72
243
|
/// series of results for query i is done
|
|
@@ -79,19 +250,17 @@ struct HeapResultHandler {
|
|
|
79
250
|
* API for multiple results (called from 1 thread)
|
|
80
251
|
*/
|
|
81
252
|
|
|
82
|
-
size_t i0, i1;
|
|
83
|
-
|
|
84
253
|
/// begin
|
|
85
|
-
void begin_multiple(size_t
|
|
86
|
-
this->i0 =
|
|
87
|
-
this->i1 =
|
|
254
|
+
void begin_multiple(size_t i0_2, size_t i1_2) final {
|
|
255
|
+
this->i0 = i0_2;
|
|
256
|
+
this->i1 = i1_2;
|
|
88
257
|
for (size_t i = i0; i < i1; i++) {
|
|
89
258
|
heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
90
259
|
}
|
|
91
260
|
}
|
|
92
261
|
|
|
93
262
|
/// add results for query i0..i1 and j0..j1
|
|
94
|
-
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
263
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) final {
|
|
95
264
|
#pragma omp parallel for
|
|
96
265
|
for (int64_t i = i0; i < i1; i++) {
|
|
97
266
|
T* heap_dis = heap_dis_tab + i * k;
|
|
@@ -109,7 +278,7 @@ struct HeapResultHandler {
|
|
|
109
278
|
}
|
|
110
279
|
|
|
111
280
|
/// series of results for queries i0..i1 is done
|
|
112
|
-
void end_multiple() {
|
|
281
|
+
void end_multiple() final {
|
|
113
282
|
// maybe parallel for
|
|
114
283
|
for (size_t i = i0; i < i1; i++) {
|
|
115
284
|
heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
@@ -128,9 +297,10 @@ struct HeapResultHandler {
|
|
|
128
297
|
|
|
129
298
|
/// Reservoir for a single query
|
|
130
299
|
template <class C>
|
|
131
|
-
struct ReservoirTopN {
|
|
300
|
+
struct ReservoirTopN : ResultHandler<C> {
|
|
132
301
|
using T = typename C::T;
|
|
133
302
|
using TI = typename C::TI;
|
|
303
|
+
using ResultHandler<C>::threshold;
|
|
134
304
|
|
|
135
305
|
T* vals;
|
|
136
306
|
TI* ids;
|
|
@@ -139,8 +309,6 @@ struct ReservoirTopN {
|
|
|
139
309
|
size_t n; // number of requested elements
|
|
140
310
|
size_t capacity; // size of storage
|
|
141
311
|
|
|
142
|
-
T threshold; // current threshold
|
|
143
|
-
|
|
144
312
|
ReservoirTopN() {}
|
|
145
313
|
|
|
146
314
|
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
|
@@ -149,15 +317,22 @@ struct ReservoirTopN {
|
|
|
149
317
|
threshold = C::neutral();
|
|
150
318
|
}
|
|
151
319
|
|
|
152
|
-
|
|
320
|
+
bool add_result(T val, TI id) final {
|
|
321
|
+
bool updated_threshold = false;
|
|
153
322
|
if (C::cmp(threshold, val)) {
|
|
154
323
|
if (i == capacity) {
|
|
155
324
|
shrink_fuzzy();
|
|
325
|
+
updated_threshold = true;
|
|
156
326
|
}
|
|
157
327
|
vals[i] = val;
|
|
158
328
|
ids[i] = id;
|
|
159
329
|
i++;
|
|
160
330
|
}
|
|
331
|
+
return updated_threshold;
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
void add(T val, TI id) {
|
|
335
|
+
add_result(val, id);
|
|
161
336
|
}
|
|
162
337
|
|
|
163
338
|
// reduce storage from capacity to anything
|
|
@@ -169,6 +344,11 @@ struct ReservoirTopN {
|
|
|
169
344
|
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
|
170
345
|
}
|
|
171
346
|
|
|
347
|
+
void shrink() {
|
|
348
|
+
threshold = partition<C>(vals, ids, i, n);
|
|
349
|
+
i = n;
|
|
350
|
+
}
|
|
351
|
+
|
|
172
352
|
void to_result(T* heap_dis, TI* heap_ids) const {
|
|
173
353
|
for (int j = 0; j < std::min(i, n); j++) {
|
|
174
354
|
heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
|
|
@@ -186,24 +366,26 @@ struct ReservoirTopN {
|
|
|
186
366
|
}
|
|
187
367
|
};
|
|
188
368
|
|
|
189
|
-
template <class C>
|
|
190
|
-
struct
|
|
369
|
+
template <class C, bool use_sel = false>
|
|
370
|
+
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
191
371
|
using T = typename C::T;
|
|
192
372
|
using TI = typename C::TI;
|
|
373
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
374
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
193
375
|
|
|
194
|
-
int nq;
|
|
195
376
|
T* heap_dis_tab;
|
|
196
377
|
TI* heap_ids_tab;
|
|
197
378
|
|
|
198
379
|
int64_t k; // number of results to keep
|
|
199
380
|
size_t capacity; // capacity of the reservoirs
|
|
200
381
|
|
|
201
|
-
|
|
382
|
+
ReservoirBlockResultHandler(
|
|
202
383
|
size_t nq,
|
|
203
384
|
T* heap_dis_tab,
|
|
204
385
|
TI* heap_ids_tab,
|
|
205
|
-
size_t k
|
|
206
|
-
|
|
386
|
+
size_t k,
|
|
387
|
+
const IDSelector* sel = nullptr)
|
|
388
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
207
389
|
heap_dis_tab(heap_dis_tab),
|
|
208
390
|
heap_ids_tab(heap_ids_tab),
|
|
209
391
|
k(k) {
|
|
@@ -216,40 +398,34 @@ struct ReservoirResultHandler {
|
|
|
216
398
|
* called from 1 thread)
|
|
217
399
|
*/
|
|
218
400
|
|
|
219
|
-
struct SingleResultHandler {
|
|
220
|
-
|
|
401
|
+
struct SingleResultHandler : ReservoirTopN<C> {
|
|
402
|
+
ReservoirBlockResultHandler& hr;
|
|
221
403
|
|
|
222
404
|
std::vector<T> reservoir_dis;
|
|
223
405
|
std::vector<TI> reservoir_ids;
|
|
224
|
-
ReservoirTopN<C> res1;
|
|
225
406
|
|
|
226
|
-
SingleResultHandler(
|
|
227
|
-
:
|
|
228
|
-
|
|
229
|
-
reservoir_ids(hr.capacity) {}
|
|
407
|
+
explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
|
|
408
|
+
: ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
|
|
409
|
+
hr(hr) {}
|
|
230
410
|
|
|
231
|
-
size_t
|
|
411
|
+
size_t qno;
|
|
232
412
|
|
|
233
413
|
/// begin results for query # i
|
|
234
|
-
void begin(size_t
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
this->
|
|
414
|
+
void begin(size_t qno_2) {
|
|
415
|
+
reservoir_dis.resize(hr.capacity);
|
|
416
|
+
reservoir_ids.resize(hr.capacity);
|
|
417
|
+
this->vals = reservoir_dis.data();
|
|
418
|
+
this->ids = reservoir_ids.data();
|
|
419
|
+
this->i = 0; // size of reservoir
|
|
420
|
+
this->threshold = C::neutral();
|
|
421
|
+
this->qno = qno_2;
|
|
241
422
|
}
|
|
242
423
|
|
|
243
|
-
///
|
|
244
|
-
void add_result(T dis, TI idx) {
|
|
245
|
-
res1.add(dis, idx);
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
/// series of results for query i is done
|
|
424
|
+
/// series of results for query qno is done
|
|
249
425
|
void end() {
|
|
250
|
-
T* heap_dis = hr.heap_dis_tab +
|
|
251
|
-
TI* heap_ids = hr.heap_ids_tab +
|
|
252
|
-
|
|
426
|
+
T* heap_dis = hr.heap_dis_tab + qno * hr.k;
|
|
427
|
+
TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
|
|
428
|
+
this->to_result(heap_dis, heap_ids);
|
|
253
429
|
}
|
|
254
430
|
};
|
|
255
431
|
|
|
@@ -257,44 +433,41 @@ struct ReservoirResultHandler {
|
|
|
257
433
|
* API for multiple results (called from 1 thread)
|
|
258
434
|
*/
|
|
259
435
|
|
|
260
|
-
size_t i0, i1;
|
|
261
|
-
|
|
262
436
|
std::vector<T> reservoir_dis;
|
|
263
437
|
std::vector<TI> reservoir_ids;
|
|
264
438
|
std::vector<ReservoirTopN<C>> reservoirs;
|
|
265
439
|
|
|
266
440
|
/// begin
|
|
267
|
-
void begin_multiple(size_t
|
|
268
|
-
this->i0 =
|
|
269
|
-
this->i1 =
|
|
441
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
442
|
+
this->i0 = i0_2;
|
|
443
|
+
this->i1 = i1_2;
|
|
270
444
|
reservoir_dis.resize((i1 - i0) * capacity);
|
|
271
445
|
reservoir_ids.resize((i1 - i0) * capacity);
|
|
272
446
|
reservoirs.clear();
|
|
273
|
-
for (size_t i =
|
|
447
|
+
for (size_t i = i0_2; i < i1_2; i++) {
|
|
274
448
|
reservoirs.emplace_back(
|
|
275
449
|
k,
|
|
276
450
|
capacity,
|
|
277
|
-
reservoir_dis.data() + (i -
|
|
278
|
-
reservoir_ids.data() + (i -
|
|
451
|
+
reservoir_dis.data() + (i - i0_2) * capacity,
|
|
452
|
+
reservoir_ids.data() + (i - i0_2) * capacity);
|
|
279
453
|
}
|
|
280
454
|
}
|
|
281
455
|
|
|
282
456
|
/// add results for query i0..i1 and j0..j1
|
|
283
457
|
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
284
|
-
// maybe parallel for
|
|
285
458
|
#pragma omp parallel for
|
|
286
459
|
for (int64_t i = i0; i < i1; i++) {
|
|
287
460
|
ReservoirTopN<C>& reservoir = reservoirs[i - i0];
|
|
288
461
|
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
289
462
|
for (size_t j = j0; j < j1; j++) {
|
|
290
463
|
T dis = dis_tab_i[j];
|
|
291
|
-
reservoir.
|
|
464
|
+
reservoir.add_result(dis, j);
|
|
292
465
|
}
|
|
293
466
|
}
|
|
294
467
|
}
|
|
295
468
|
|
|
296
469
|
/// series of results for queries i0..i1 is done
|
|
297
|
-
void end_multiple() {
|
|
470
|
+
void end_multiple() final {
|
|
298
471
|
// maybe parallel for
|
|
299
472
|
for (size_t i = i0; i < i1; i++) {
|
|
300
473
|
reservoirs[i - i0].to_result(
|
|
@@ -307,30 +480,39 @@ struct ReservoirResultHandler {
|
|
|
307
480
|
* Result handler for range searches
|
|
308
481
|
*****************************************************************/
|
|
309
482
|
|
|
310
|
-
template <class C>
|
|
311
|
-
struct
|
|
483
|
+
template <class C, bool use_sel = false>
|
|
484
|
+
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
312
485
|
using T = typename C::T;
|
|
313
486
|
using TI = typename C::TI;
|
|
487
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
488
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
314
489
|
|
|
315
490
|
RangeSearchResult* res;
|
|
316
|
-
|
|
491
|
+
T radius;
|
|
317
492
|
|
|
318
|
-
|
|
319
|
-
|
|
493
|
+
RangeSearchBlockResultHandler(
|
|
494
|
+
RangeSearchResult* res,
|
|
495
|
+
float radius,
|
|
496
|
+
const IDSelector* sel = nullptr)
|
|
497
|
+
: BlockResultHandler<C, use_sel>(res->nq, sel),
|
|
498
|
+
res(res),
|
|
499
|
+
radius(radius) {}
|
|
320
500
|
|
|
321
501
|
/******************************************************
|
|
322
502
|
* API for 1 result at a time (each SingleResultHandler is
|
|
323
503
|
* called from 1 thread)
|
|
324
504
|
******************************************************/
|
|
325
505
|
|
|
326
|
-
struct SingleResultHandler {
|
|
506
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
327
507
|
// almost the same interface as RangeSearchResultHandler
|
|
508
|
+
using ResultHandler<C>::threshold;
|
|
328
509
|
RangeSearchPartialResult pres;
|
|
329
|
-
float radius;
|
|
330
510
|
RangeQueryResult* qr = nullptr;
|
|
331
511
|
|
|
332
|
-
SingleResultHandler(
|
|
333
|
-
: pres(rh.res)
|
|
512
|
+
explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
|
|
513
|
+
: pres(rh.res) {
|
|
514
|
+
threshold = rh.radius;
|
|
515
|
+
}
|
|
334
516
|
|
|
335
517
|
/// begin results for query # i
|
|
336
518
|
void begin(size_t i) {
|
|
@@ -338,17 +520,26 @@ struct RangeSearchResultHandler {
|
|
|
338
520
|
}
|
|
339
521
|
|
|
340
522
|
/// add one result for query i
|
|
341
|
-
|
|
342
|
-
if (C::cmp(
|
|
523
|
+
bool add_result(T dis, TI idx) final {
|
|
524
|
+
if (C::cmp(threshold, dis)) {
|
|
343
525
|
qr->add(dis, idx);
|
|
344
526
|
}
|
|
527
|
+
return false;
|
|
345
528
|
}
|
|
346
529
|
|
|
347
530
|
/// series of results for query i is done
|
|
348
531
|
void end() {}
|
|
349
532
|
|
|
350
533
|
~SingleResultHandler() {
|
|
351
|
-
|
|
534
|
+
try {
|
|
535
|
+
// finalize the partial result
|
|
536
|
+
pres.finalize();
|
|
537
|
+
} catch (const faiss::FaissException& e) {
|
|
538
|
+
// Do nothing if allocation fails in finalizing partial results.
|
|
539
|
+
#ifndef NDEBUG
|
|
540
|
+
std::cerr << e.what() << std::endl;
|
|
541
|
+
#endif
|
|
542
|
+
}
|
|
352
543
|
}
|
|
353
544
|
};
|
|
354
545
|
|
|
@@ -356,16 +547,14 @@ struct RangeSearchResultHandler {
|
|
|
356
547
|
* API for multiple results (called from 1 thread)
|
|
357
548
|
******************************************************/
|
|
358
549
|
|
|
359
|
-
size_t i0, i1;
|
|
360
|
-
|
|
361
550
|
std::vector<RangeSearchPartialResult*> partial_results;
|
|
362
551
|
std::vector<size_t> j0s;
|
|
363
552
|
int pr = 0;
|
|
364
553
|
|
|
365
554
|
/// begin
|
|
366
|
-
void begin_multiple(size_t
|
|
367
|
-
this->i0 =
|
|
368
|
-
this->i1 =
|
|
555
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
556
|
+
this->i0 = i0_2;
|
|
557
|
+
this->i1 = i1_2;
|
|
369
558
|
}
|
|
370
559
|
|
|
371
560
|
/// add results for query i0..i1 and j0..j1
|
|
@@ -404,109 +593,95 @@ struct RangeSearchResultHandler {
|
|
|
404
593
|
}
|
|
405
594
|
}
|
|
406
595
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
596
|
+
~RangeSearchBlockResultHandler() {
|
|
597
|
+
try {
|
|
598
|
+
if (partial_results.size() > 0) {
|
|
599
|
+
RangeSearchPartialResult::merge(partial_results);
|
|
600
|
+
}
|
|
601
|
+
} catch (const faiss::FaissException& e) {
|
|
602
|
+
// Do nothing if allocation fails in merge.
|
|
603
|
+
#ifndef NDEBUG
|
|
604
|
+
std::cerr << e.what() << std::endl;
|
|
605
|
+
#endif
|
|
412
606
|
}
|
|
413
607
|
}
|
|
414
608
|
};
|
|
415
609
|
|
|
416
610
|
/*****************************************************************
|
|
417
|
-
*
|
|
418
|
-
* Tracks the only best result, thus avoiding storing
|
|
419
|
-
* some temporary data in memory.
|
|
611
|
+
* Dispatcher function to choose the right knn result handler depending on k
|
|
420
612
|
*****************************************************************/
|
|
421
613
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
void begin(const size_t current_idx) {
|
|
447
|
-
this->current_idx = current_idx;
|
|
448
|
-
min_dis = HUGE_VALF;
|
|
449
|
-
min_idx = 0;
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
/// add one result for query i
|
|
453
|
-
void add_result(T dis, TI idx) {
|
|
454
|
-
if (C::cmp(min_dis, dis)) {
|
|
455
|
-
min_dis = dis;
|
|
456
|
-
min_idx = idx;
|
|
457
|
-
}
|
|
458
|
-
}
|
|
614
|
+
// declared in distances.cpp
|
|
615
|
+
FAISS_API extern int distance_compute_min_k_reservoir;
|
|
616
|
+
|
|
617
|
+
template <class Consumer, class... Types>
|
|
618
|
+
typename Consumer::T dispatch_knn_ResultHandler(
|
|
619
|
+
size_t nx,
|
|
620
|
+
float* vals,
|
|
621
|
+
int64_t* ids,
|
|
622
|
+
size_t k,
|
|
623
|
+
MetricType metric,
|
|
624
|
+
const IDSelector* sel,
|
|
625
|
+
Consumer& consumer,
|
|
626
|
+
Types... args) {
|
|
627
|
+
#define DISPATCH_C_SEL(C, use_sel) \
|
|
628
|
+
if (k == 1) { \
|
|
629
|
+
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
|
|
630
|
+
return consumer.template f<>(res, args...); \
|
|
631
|
+
} else if (k < distance_compute_min_k_reservoir) { \
|
|
632
|
+
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
|
633
|
+
return consumer.template f<>(res, args...); \
|
|
634
|
+
} else { \
|
|
635
|
+
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
|
636
|
+
return consumer.template f<>(res, args...); \
|
|
637
|
+
}
|
|
459
638
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
639
|
+
if (is_similarity_metric(metric)) {
|
|
640
|
+
using C = CMin<float, int64_t>;
|
|
641
|
+
if (sel) {
|
|
642
|
+
DISPATCH_C_SEL(C, true);
|
|
643
|
+
} else {
|
|
644
|
+
DISPATCH_C_SEL(C, false);
|
|
464
645
|
}
|
|
465
|
-
}
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
this->i0 = i0;
|
|
472
|
-
this->i1 = i1;
|
|
473
|
-
|
|
474
|
-
for (size_t i = i0; i < i1; i++) {
|
|
475
|
-
this->dis_tab[i] = HUGE_VALF;
|
|
646
|
+
} else {
|
|
647
|
+
using C = CMax<float, int64_t>;
|
|
648
|
+
if (sel) {
|
|
649
|
+
DISPATCH_C_SEL(C, true);
|
|
650
|
+
} else {
|
|
651
|
+
DISPATCH_C_SEL(C, false);
|
|
476
652
|
}
|
|
477
653
|
}
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
654
|
+
#undef DISPATCH_C_SEL
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
template <class Consumer, class... Types>
|
|
658
|
+
typename Consumer::T dispatch_range_ResultHandler(
|
|
659
|
+
RangeSearchResult* res,
|
|
660
|
+
float radius,
|
|
661
|
+
MetricType metric,
|
|
662
|
+
const IDSelector* sel,
|
|
663
|
+
Consumer& consumer,
|
|
664
|
+
Types... args) {
|
|
665
|
+
#define DISPATCH_C_SEL(C, use_sel) \
|
|
666
|
+
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
|
|
667
|
+
return consumer.template f<>(resb, args...);
|
|
668
|
+
|
|
669
|
+
if (is_similarity_metric(metric)) {
|
|
670
|
+
using C = CMin<float, int64_t>;
|
|
671
|
+
if (sel) {
|
|
672
|
+
DISPATCH_C_SEL(C, true);
|
|
673
|
+
} else {
|
|
674
|
+
DISPATCH_C_SEL(C, false);
|
|
495
675
|
}
|
|
496
|
-
}
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
if (C::cmp(min_distance, dis)) {
|
|
503
|
-
min_distance = dis;
|
|
504
|
-
min_index = idx;
|
|
676
|
+
} else {
|
|
677
|
+
using C = CMax<float, int64_t>;
|
|
678
|
+
if (sel) {
|
|
679
|
+
DISPATCH_C_SEL(C, true);
|
|
680
|
+
} else {
|
|
681
|
+
DISPATCH_C_SEL(C, false);
|
|
505
682
|
}
|
|
506
683
|
}
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
void end_multiple() {}
|
|
510
|
-
};
|
|
684
|
+
#undef DISPATCH_C_SEL
|
|
685
|
+
}
|
|
511
686
|
|
|
512
687
|
} // namespace faiss
|