faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -17,23 +17,170 @@
|
|
17
17
|
|
18
18
|
namespace faiss {
|
19
19
|
|
20
|
+
/*****************************************************************
|
21
|
+
* The classes below are intended to be used as template arguments
|
22
|
+
* they handle results for batches of queries (size nq).
|
23
|
+
* They can be called in two ways:
|
24
|
+
* - by instanciating a SingleResultHandler that tracks results for a single
|
25
|
+
* query
|
26
|
+
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
27
|
+
* resutls is submitted
|
28
|
+
* All classes are templated on C which to define wheter the min or the max of
|
29
|
+
* results is to be kept.
|
30
|
+
*****************************************************************/
|
31
|
+
|
32
|
+
template <class C>
|
33
|
+
struct BlockResultHandler {
|
34
|
+
size_t nq; // number of queries for which we search
|
35
|
+
|
36
|
+
explicit BlockResultHandler(size_t nq) : nq(nq) {}
|
37
|
+
|
38
|
+
// currently handled query range
|
39
|
+
size_t i0 = 0, i1 = 0;
|
40
|
+
|
41
|
+
// start collecting results for queries [i0, i1)
|
42
|
+
virtual void begin_multiple(size_t i0_2, size_t i1_2) {
|
43
|
+
this->i0 = i0_2;
|
44
|
+
this->i1 = i1_2;
|
45
|
+
}
|
46
|
+
|
47
|
+
// add results for queries [i0, i1) and database [j0, j1)
|
48
|
+
virtual void add_results(size_t, size_t, const typename C::T*) {}
|
49
|
+
|
50
|
+
// series of results for queries i0..i1 is done
|
51
|
+
virtual void end_multiple() {}
|
52
|
+
|
53
|
+
virtual ~BlockResultHandler() {}
|
54
|
+
};
|
55
|
+
|
56
|
+
// handler for a single query
|
57
|
+
template <class C>
|
58
|
+
struct ResultHandler {
|
59
|
+
// if not better than threshold, then not necessary to call add_result
|
60
|
+
typename C::T threshold = 0;
|
61
|
+
|
62
|
+
// return whether threshold was updated
|
63
|
+
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
|
64
|
+
|
65
|
+
virtual ~ResultHandler() {}
|
66
|
+
};
|
67
|
+
|
68
|
+
/*****************************************************************
|
69
|
+
* Single best result handler.
|
70
|
+
* Tracks the only best result, thus avoiding storing
|
71
|
+
* some temporary data in memory.
|
72
|
+
*****************************************************************/
|
73
|
+
|
74
|
+
template <class C>
|
75
|
+
struct Top1BlockResultHandler : BlockResultHandler<C> {
|
76
|
+
using T = typename C::T;
|
77
|
+
using TI = typename C::TI;
|
78
|
+
using BlockResultHandler<C>::i0;
|
79
|
+
using BlockResultHandler<C>::i1;
|
80
|
+
|
81
|
+
// contains exactly nq elements
|
82
|
+
T* dis_tab;
|
83
|
+
// contains exactly nq elements
|
84
|
+
TI* ids_tab;
|
85
|
+
|
86
|
+
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
87
|
+
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
88
|
+
|
89
|
+
struct SingleResultHandler : ResultHandler<C> {
|
90
|
+
Top1BlockResultHandler& hr;
|
91
|
+
using ResultHandler<C>::threshold;
|
92
|
+
|
93
|
+
TI min_idx;
|
94
|
+
size_t current_idx = 0;
|
95
|
+
|
96
|
+
explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
|
97
|
+
|
98
|
+
/// begin results for query # i
|
99
|
+
void begin(const size_t current_idx_2) {
|
100
|
+
this->current_idx = current_idx_2;
|
101
|
+
threshold = C::neutral();
|
102
|
+
min_idx = -1;
|
103
|
+
}
|
104
|
+
|
105
|
+
/// add one result for query i
|
106
|
+
bool add_result(T dis, TI idx) final {
|
107
|
+
if (C::cmp(this->threshold, dis)) {
|
108
|
+
threshold = dis;
|
109
|
+
min_idx = idx;
|
110
|
+
return true;
|
111
|
+
}
|
112
|
+
return false;
|
113
|
+
}
|
114
|
+
|
115
|
+
/// series of results for query i is done
|
116
|
+
void end() {
|
117
|
+
hr.dis_tab[current_idx] = threshold;
|
118
|
+
hr.ids_tab[current_idx] = min_idx;
|
119
|
+
}
|
120
|
+
};
|
121
|
+
|
122
|
+
/// begin
|
123
|
+
void begin_multiple(size_t i0, size_t i1) final {
|
124
|
+
this->i0 = i0;
|
125
|
+
this->i1 = i1;
|
126
|
+
|
127
|
+
for (size_t i = i0; i < i1; i++) {
|
128
|
+
this->dis_tab[i] = C::neutral();
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
/// add results for query i0..i1 and j0..j1
|
133
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
|
134
|
+
for (int64_t i = i0; i < i1; i++) {
|
135
|
+
const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
|
136
|
+
|
137
|
+
auto& min_distance = this->dis_tab[i];
|
138
|
+
auto& min_index = this->ids_tab[i];
|
139
|
+
|
140
|
+
for (size_t j = j0; j < j1; j++) {
|
141
|
+
const T distance = dis_tab_i[j];
|
142
|
+
|
143
|
+
if (C::cmp(min_distance, distance)) {
|
144
|
+
min_distance = distance;
|
145
|
+
min_index = j;
|
146
|
+
}
|
147
|
+
}
|
148
|
+
}
|
149
|
+
}
|
150
|
+
|
151
|
+
void add_result(const size_t i, const T dis, const TI idx) {
|
152
|
+
auto& min_distance = this->dis_tab[i];
|
153
|
+
auto& min_index = this->ids_tab[i];
|
154
|
+
|
155
|
+
if (C::cmp(min_distance, dis)) {
|
156
|
+
min_distance = dis;
|
157
|
+
min_index = idx;
|
158
|
+
}
|
159
|
+
}
|
160
|
+
};
|
161
|
+
|
20
162
|
/*****************************************************************
|
21
163
|
* Heap based result handler
|
22
164
|
*****************************************************************/
|
23
165
|
|
24
166
|
template <class C>
|
25
|
-
struct
|
167
|
+
struct HeapBlockResultHandler : BlockResultHandler<C> {
|
26
168
|
using T = typename C::T;
|
27
169
|
using TI = typename C::TI;
|
170
|
+
using BlockResultHandler<C>::i0;
|
171
|
+
using BlockResultHandler<C>::i1;
|
28
172
|
|
29
|
-
int nq;
|
30
173
|
T* heap_dis_tab;
|
31
174
|
TI* heap_ids_tab;
|
32
175
|
|
33
176
|
int64_t k; // number of results to keep
|
34
177
|
|
35
|
-
|
36
|
-
|
178
|
+
HeapBlockResultHandler(
|
179
|
+
size_t nq,
|
180
|
+
T* heap_dis_tab,
|
181
|
+
TI* heap_ids_tab,
|
182
|
+
size_t k)
|
183
|
+
: BlockResultHandler<C>(nq),
|
37
184
|
heap_dis_tab(heap_dis_tab),
|
38
185
|
heap_ids_tab(heap_ids_tab),
|
39
186
|
k(k) {}
|
@@ -43,30 +190,33 @@ struct HeapResultHandler {
|
|
43
190
|
* called from 1 thread)
|
44
191
|
*/
|
45
192
|
|
46
|
-
struct SingleResultHandler {
|
47
|
-
|
193
|
+
struct SingleResultHandler : ResultHandler<C> {
|
194
|
+
HeapBlockResultHandler& hr;
|
195
|
+
using ResultHandler<C>::threshold;
|
48
196
|
size_t k;
|
49
197
|
|
50
198
|
T* heap_dis;
|
51
199
|
TI* heap_ids;
|
52
|
-
T thresh;
|
53
200
|
|
54
|
-
SingleResultHandler(
|
201
|
+
explicit SingleResultHandler(HeapBlockResultHandler& hr)
|
202
|
+
: hr(hr), k(hr.k) {}
|
55
203
|
|
56
204
|
/// begin results for query # i
|
57
205
|
void begin(size_t i) {
|
58
206
|
heap_dis = hr.heap_dis_tab + i * k;
|
59
207
|
heap_ids = hr.heap_ids_tab + i * k;
|
60
208
|
heap_heapify<C>(k, heap_dis, heap_ids);
|
61
|
-
|
209
|
+
threshold = heap_dis[0];
|
62
210
|
}
|
63
211
|
|
64
212
|
/// add one result for query i
|
65
|
-
|
66
|
-
if (C::cmp(
|
213
|
+
bool add_result(T dis, TI idx) final {
|
214
|
+
if (C::cmp(threshold, dis)) {
|
67
215
|
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
68
|
-
|
216
|
+
threshold = heap_dis[0];
|
217
|
+
return true;
|
69
218
|
}
|
219
|
+
return false;
|
70
220
|
}
|
71
221
|
|
72
222
|
/// series of results for query i is done
|
@@ -79,19 +229,17 @@ struct HeapResultHandler {
|
|
79
229
|
* API for multiple results (called from 1 thread)
|
80
230
|
*/
|
81
231
|
|
82
|
-
size_t i0, i1;
|
83
|
-
|
84
232
|
/// begin
|
85
|
-
void begin_multiple(size_t
|
86
|
-
this->i0 =
|
87
|
-
this->i1 =
|
233
|
+
void begin_multiple(size_t i0_2, size_t i1_2) final {
|
234
|
+
this->i0 = i0_2;
|
235
|
+
this->i1 = i1_2;
|
88
236
|
for (size_t i = i0; i < i1; i++) {
|
89
237
|
heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
90
238
|
}
|
91
239
|
}
|
92
240
|
|
93
241
|
/// add results for query i0..i1 and j0..j1
|
94
|
-
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
242
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) final {
|
95
243
|
#pragma omp parallel for
|
96
244
|
for (int64_t i = i0; i < i1; i++) {
|
97
245
|
T* heap_dis = heap_dis_tab + i * k;
|
@@ -109,7 +257,7 @@ struct HeapResultHandler {
|
|
109
257
|
}
|
110
258
|
|
111
259
|
/// series of results for queries i0..i1 is done
|
112
|
-
void end_multiple() {
|
260
|
+
void end_multiple() final {
|
113
261
|
// maybe parallel for
|
114
262
|
for (size_t i = i0; i < i1; i++) {
|
115
263
|
heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
@@ -128,9 +276,10 @@ struct HeapResultHandler {
|
|
128
276
|
|
129
277
|
/// Reservoir for a single query
|
130
278
|
template <class C>
|
131
|
-
struct ReservoirTopN {
|
279
|
+
struct ReservoirTopN : ResultHandler<C> {
|
132
280
|
using T = typename C::T;
|
133
281
|
using TI = typename C::TI;
|
282
|
+
using ResultHandler<C>::threshold;
|
134
283
|
|
135
284
|
T* vals;
|
136
285
|
TI* ids;
|
@@ -139,8 +288,6 @@ struct ReservoirTopN {
|
|
139
288
|
size_t n; // number of requested elements
|
140
289
|
size_t capacity; // size of storage
|
141
290
|
|
142
|
-
T threshold; // current threshold
|
143
|
-
|
144
291
|
ReservoirTopN() {}
|
145
292
|
|
146
293
|
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
@@ -149,15 +296,22 @@ struct ReservoirTopN {
|
|
149
296
|
threshold = C::neutral();
|
150
297
|
}
|
151
298
|
|
152
|
-
|
299
|
+
bool add_result(T val, TI id) final {
|
300
|
+
bool updated_threshold = false;
|
153
301
|
if (C::cmp(threshold, val)) {
|
154
302
|
if (i == capacity) {
|
155
303
|
shrink_fuzzy();
|
304
|
+
updated_threshold = true;
|
156
305
|
}
|
157
306
|
vals[i] = val;
|
158
307
|
ids[i] = id;
|
159
308
|
i++;
|
160
309
|
}
|
310
|
+
return updated_threshold;
|
311
|
+
}
|
312
|
+
|
313
|
+
void add(T val, TI id) {
|
314
|
+
add_result(val, id);
|
161
315
|
}
|
162
316
|
|
163
317
|
// reduce storage from capacity to anything
|
@@ -169,6 +323,11 @@ struct ReservoirTopN {
|
|
169
323
|
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
170
324
|
}
|
171
325
|
|
326
|
+
void shrink() {
|
327
|
+
threshold = partition<C>(vals, ids, i, n);
|
328
|
+
i = n;
|
329
|
+
}
|
330
|
+
|
172
331
|
void to_result(T* heap_dis, TI* heap_ids) const {
|
173
332
|
for (int j = 0; j < std::min(i, n); j++) {
|
174
333
|
heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
|
@@ -187,23 +346,24 @@ struct ReservoirTopN {
|
|
187
346
|
};
|
188
347
|
|
189
348
|
template <class C>
|
190
|
-
struct
|
349
|
+
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
191
350
|
using T = typename C::T;
|
192
351
|
using TI = typename C::TI;
|
352
|
+
using BlockResultHandler<C>::i0;
|
353
|
+
using BlockResultHandler<C>::i1;
|
193
354
|
|
194
|
-
int nq;
|
195
355
|
T* heap_dis_tab;
|
196
356
|
TI* heap_ids_tab;
|
197
357
|
|
198
358
|
int64_t k; // number of results to keep
|
199
359
|
size_t capacity; // capacity of the reservoirs
|
200
360
|
|
201
|
-
|
361
|
+
ReservoirBlockResultHandler(
|
202
362
|
size_t nq,
|
203
363
|
T* heap_dis_tab,
|
204
364
|
TI* heap_ids_tab,
|
205
365
|
size_t k)
|
206
|
-
:
|
366
|
+
: BlockResultHandler<C>(nq),
|
207
367
|
heap_dis_tab(heap_dis_tab),
|
208
368
|
heap_ids_tab(heap_ids_tab),
|
209
369
|
k(k) {
|
@@ -216,40 +376,34 @@ struct ReservoirResultHandler {
|
|
216
376
|
* called from 1 thread)
|
217
377
|
*/
|
218
378
|
|
219
|
-
struct SingleResultHandler {
|
220
|
-
|
379
|
+
struct SingleResultHandler : ReservoirTopN<C> {
|
380
|
+
ReservoirBlockResultHandler& hr;
|
221
381
|
|
222
382
|
std::vector<T> reservoir_dis;
|
223
383
|
std::vector<TI> reservoir_ids;
|
224
|
-
ReservoirTopN<C> res1;
|
225
384
|
|
226
|
-
SingleResultHandler(
|
227
|
-
:
|
228
|
-
|
229
|
-
reservoir_ids(hr.capacity) {}
|
385
|
+
explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
|
386
|
+
: ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
|
387
|
+
hr(hr) {}
|
230
388
|
|
231
|
-
size_t
|
389
|
+
size_t qno;
|
232
390
|
|
233
391
|
/// begin results for query # i
|
234
|
-
void begin(size_t
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
this->
|
392
|
+
void begin(size_t qno_2) {
|
393
|
+
reservoir_dis.resize(hr.capacity);
|
394
|
+
reservoir_ids.resize(hr.capacity);
|
395
|
+
this->vals = reservoir_dis.data();
|
396
|
+
this->ids = reservoir_ids.data();
|
397
|
+
this->i = 0; // size of reservoir
|
398
|
+
this->threshold = C::neutral();
|
399
|
+
this->qno = qno_2;
|
241
400
|
}
|
242
401
|
|
243
|
-
///
|
244
|
-
void add_result(T dis, TI idx) {
|
245
|
-
res1.add(dis, idx);
|
246
|
-
}
|
247
|
-
|
248
|
-
/// series of results for query i is done
|
402
|
+
/// series of results for query qno is done
|
249
403
|
void end() {
|
250
|
-
T* heap_dis = hr.heap_dis_tab +
|
251
|
-
TI* heap_ids = hr.heap_ids_tab +
|
252
|
-
|
404
|
+
T* heap_dis = hr.heap_dis_tab + qno * hr.k;
|
405
|
+
TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
|
406
|
+
this->to_result(heap_dis, heap_ids);
|
253
407
|
}
|
254
408
|
};
|
255
409
|
|
@@ -257,44 +411,41 @@ struct ReservoirResultHandler {
|
|
257
411
|
* API for multiple results (called from 1 thread)
|
258
412
|
*/
|
259
413
|
|
260
|
-
size_t i0, i1;
|
261
|
-
|
262
414
|
std::vector<T> reservoir_dis;
|
263
415
|
std::vector<TI> reservoir_ids;
|
264
416
|
std::vector<ReservoirTopN<C>> reservoirs;
|
265
417
|
|
266
418
|
/// begin
|
267
|
-
void begin_multiple(size_t
|
268
|
-
this->i0 =
|
269
|
-
this->i1 =
|
419
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
420
|
+
this->i0 = i0_2;
|
421
|
+
this->i1 = i1_2;
|
270
422
|
reservoir_dis.resize((i1 - i0) * capacity);
|
271
423
|
reservoir_ids.resize((i1 - i0) * capacity);
|
272
424
|
reservoirs.clear();
|
273
|
-
for (size_t i =
|
425
|
+
for (size_t i = i0_2; i < i1_2; i++) {
|
274
426
|
reservoirs.emplace_back(
|
275
427
|
k,
|
276
428
|
capacity,
|
277
|
-
reservoir_dis.data() + (i -
|
278
|
-
reservoir_ids.data() + (i -
|
429
|
+
reservoir_dis.data() + (i - i0_2) * capacity,
|
430
|
+
reservoir_ids.data() + (i - i0_2) * capacity);
|
279
431
|
}
|
280
432
|
}
|
281
433
|
|
282
434
|
/// add results for query i0..i1 and j0..j1
|
283
435
|
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
284
|
-
// maybe parallel for
|
285
436
|
#pragma omp parallel for
|
286
437
|
for (int64_t i = i0; i < i1; i++) {
|
287
438
|
ReservoirTopN<C>& reservoir = reservoirs[i - i0];
|
288
439
|
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
289
440
|
for (size_t j = j0; j < j1; j++) {
|
290
441
|
T dis = dis_tab_i[j];
|
291
|
-
reservoir.
|
442
|
+
reservoir.add_result(dis, j);
|
292
443
|
}
|
293
444
|
}
|
294
445
|
}
|
295
446
|
|
296
447
|
/// series of results for queries i0..i1 is done
|
297
|
-
void end_multiple() {
|
448
|
+
void end_multiple() final {
|
298
449
|
// maybe parallel for
|
299
450
|
for (size_t i = i0; i < i1; i++) {
|
300
451
|
reservoirs[i - i0].to_result(
|
@@ -308,29 +459,33 @@ struct ReservoirResultHandler {
|
|
308
459
|
*****************************************************************/
|
309
460
|
|
310
461
|
template <class C>
|
311
|
-
struct
|
462
|
+
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
312
463
|
using T = typename C::T;
|
313
464
|
using TI = typename C::TI;
|
465
|
+
using BlockResultHandler<C>::i0;
|
466
|
+
using BlockResultHandler<C>::i1;
|
314
467
|
|
315
468
|
RangeSearchResult* res;
|
316
|
-
|
469
|
+
T radius;
|
317
470
|
|
318
|
-
|
319
|
-
: res(res), radius(radius) {}
|
471
|
+
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
|
472
|
+
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
|
320
473
|
|
321
474
|
/******************************************************
|
322
475
|
* API for 1 result at a time (each SingleResultHandler is
|
323
476
|
* called from 1 thread)
|
324
477
|
******************************************************/
|
325
478
|
|
326
|
-
struct SingleResultHandler {
|
479
|
+
struct SingleResultHandler : ResultHandler<C> {
|
327
480
|
// almost the same interface as RangeSearchResultHandler
|
481
|
+
using ResultHandler<C>::threshold;
|
328
482
|
RangeSearchPartialResult pres;
|
329
|
-
float radius;
|
330
483
|
RangeQueryResult* qr = nullptr;
|
331
484
|
|
332
|
-
SingleResultHandler(
|
333
|
-
: pres(rh.res)
|
485
|
+
explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
|
486
|
+
: pres(rh.res) {
|
487
|
+
threshold = rh.radius;
|
488
|
+
}
|
334
489
|
|
335
490
|
/// begin results for query # i
|
336
491
|
void begin(size_t i) {
|
@@ -338,10 +493,11 @@ struct RangeSearchResultHandler {
|
|
338
493
|
}
|
339
494
|
|
340
495
|
/// add one result for query i
|
341
|
-
|
342
|
-
if (C::cmp(
|
496
|
+
bool add_result(T dis, TI idx) final {
|
497
|
+
if (C::cmp(threshold, dis)) {
|
343
498
|
qr->add(dis, idx);
|
344
499
|
}
|
500
|
+
return false;
|
345
501
|
}
|
346
502
|
|
347
503
|
/// series of results for query i is done
|
@@ -356,16 +512,14 @@ struct RangeSearchResultHandler {
|
|
356
512
|
* API for multiple results (called from 1 thread)
|
357
513
|
******************************************************/
|
358
514
|
|
359
|
-
size_t i0, i1;
|
360
|
-
|
361
515
|
std::vector<RangeSearchPartialResult*> partial_results;
|
362
516
|
std::vector<size_t> j0s;
|
363
517
|
int pr = 0;
|
364
518
|
|
365
519
|
/// begin
|
366
|
-
void begin_multiple(size_t
|
367
|
-
this->i0 =
|
368
|
-
this->i1 =
|
520
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
521
|
+
this->i0 = i0_2;
|
522
|
+
this->i1 = i1_2;
|
369
523
|
}
|
370
524
|
|
371
525
|
/// add results for query i0..i1 and j0..j1
|
@@ -404,109 +558,11 @@ struct RangeSearchResultHandler {
|
|
404
558
|
}
|
405
559
|
}
|
406
560
|
|
407
|
-
|
408
|
-
|
409
|
-
~RangeSearchResultHandler() {
|
561
|
+
~RangeSearchBlockResultHandler() {
|
410
562
|
if (partial_results.size() > 0) {
|
411
563
|
RangeSearchPartialResult::merge(partial_results);
|
412
564
|
}
|
413
565
|
}
|
414
566
|
};
|
415
567
|
|
416
|
-
/*****************************************************************
|
417
|
-
* Single best result handler.
|
418
|
-
* Tracks the only best result, thus avoiding storing
|
419
|
-
* some temporary data in memory.
|
420
|
-
*****************************************************************/
|
421
|
-
|
422
|
-
template <class C>
|
423
|
-
struct SingleBestResultHandler {
|
424
|
-
using T = typename C::T;
|
425
|
-
using TI = typename C::TI;
|
426
|
-
|
427
|
-
int nq;
|
428
|
-
// contains exactly nq elements
|
429
|
-
T* dis_tab;
|
430
|
-
// contains exactly nq elements
|
431
|
-
TI* ids_tab;
|
432
|
-
|
433
|
-
SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
434
|
-
: nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
435
|
-
|
436
|
-
struct SingleResultHandler {
|
437
|
-
SingleBestResultHandler& hr;
|
438
|
-
|
439
|
-
T min_dis;
|
440
|
-
TI min_idx;
|
441
|
-
size_t current_idx = 0;
|
442
|
-
|
443
|
-
SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
|
444
|
-
|
445
|
-
/// begin results for query # i
|
446
|
-
void begin(const size_t current_idx) {
|
447
|
-
this->current_idx = current_idx;
|
448
|
-
min_dis = HUGE_VALF;
|
449
|
-
min_idx = 0;
|
450
|
-
}
|
451
|
-
|
452
|
-
/// add one result for query i
|
453
|
-
void add_result(T dis, TI idx) {
|
454
|
-
if (C::cmp(min_dis, dis)) {
|
455
|
-
min_dis = dis;
|
456
|
-
min_idx = idx;
|
457
|
-
}
|
458
|
-
}
|
459
|
-
|
460
|
-
/// series of results for query i is done
|
461
|
-
void end() {
|
462
|
-
hr.dis_tab[current_idx] = min_dis;
|
463
|
-
hr.ids_tab[current_idx] = min_idx;
|
464
|
-
}
|
465
|
-
};
|
466
|
-
|
467
|
-
size_t i0, i1;
|
468
|
-
|
469
|
-
/// begin
|
470
|
-
void begin_multiple(size_t i0, size_t i1) {
|
471
|
-
this->i0 = i0;
|
472
|
-
this->i1 = i1;
|
473
|
-
|
474
|
-
for (size_t i = i0; i < i1; i++) {
|
475
|
-
this->dis_tab[i] = HUGE_VALF;
|
476
|
-
}
|
477
|
-
}
|
478
|
-
|
479
|
-
/// add results for query i0..i1 and j0..j1
|
480
|
-
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
481
|
-
for (int64_t i = i0; i < i1; i++) {
|
482
|
-
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
483
|
-
|
484
|
-
auto& min_distance = this->dis_tab[i];
|
485
|
-
auto& min_index = this->ids_tab[i];
|
486
|
-
|
487
|
-
for (size_t j = j0; j < j1; j++) {
|
488
|
-
const T distance = dis_tab_i[j];
|
489
|
-
|
490
|
-
if (C::cmp(min_distance, distance)) {
|
491
|
-
min_distance = distance;
|
492
|
-
min_index = j;
|
493
|
-
}
|
494
|
-
}
|
495
|
-
}
|
496
|
-
}
|
497
|
-
|
498
|
-
void add_result(const size_t i, const T dis, const TI idx) {
|
499
|
-
auto& min_distance = this->dis_tab[i];
|
500
|
-
auto& min_index = this->ids_tab[i];
|
501
|
-
|
502
|
-
if (C::cmp(min_distance, dis)) {
|
503
|
-
min_distance = dis;
|
504
|
-
min_index = idx;
|
505
|
-
}
|
506
|
-
}
|
507
|
-
|
508
|
-
/// series of results for queries i0..i1 is done
|
509
|
-
void end_multiple() {}
|
510
|
-
};
|
511
|
-
|
512
568
|
} // namespace faiss
|