faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -17,23 +17,170 @@
|
|
|
17
17
|
|
|
18
18
|
namespace faiss {
|
|
19
19
|
|
|
20
|
+
/*****************************************************************
|
|
21
|
+
* The classes below are intended to be used as template arguments
|
|
22
|
+
* they handle results for batches of queries (size nq).
|
|
23
|
+
* They can be called in two ways:
|
|
24
|
+
* - by instanciating a SingleResultHandler that tracks results for a single
|
|
25
|
+
* query
|
|
26
|
+
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
|
27
|
+
* resutls is submitted
|
|
28
|
+
* All classes are templated on C which to define wheter the min or the max of
|
|
29
|
+
* results is to be kept.
|
|
30
|
+
*****************************************************************/
|
|
31
|
+
|
|
32
|
+
template <class C>
|
|
33
|
+
struct BlockResultHandler {
|
|
34
|
+
size_t nq; // number of queries for which we search
|
|
35
|
+
|
|
36
|
+
explicit BlockResultHandler(size_t nq) : nq(nq) {}
|
|
37
|
+
|
|
38
|
+
// currently handled query range
|
|
39
|
+
size_t i0 = 0, i1 = 0;
|
|
40
|
+
|
|
41
|
+
// start collecting results for queries [i0, i1)
|
|
42
|
+
virtual void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
43
|
+
this->i0 = i0_2;
|
|
44
|
+
this->i1 = i1_2;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// add results for queries [i0, i1) and database [j0, j1)
|
|
48
|
+
virtual void add_results(size_t, size_t, const typename C::T*) {}
|
|
49
|
+
|
|
50
|
+
// series of results for queries i0..i1 is done
|
|
51
|
+
virtual void end_multiple() {}
|
|
52
|
+
|
|
53
|
+
virtual ~BlockResultHandler() {}
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
// handler for a single query
|
|
57
|
+
template <class C>
|
|
58
|
+
struct ResultHandler {
|
|
59
|
+
// if not better than threshold, then not necessary to call add_result
|
|
60
|
+
typename C::T threshold = 0;
|
|
61
|
+
|
|
62
|
+
// return whether threshold was updated
|
|
63
|
+
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
|
|
64
|
+
|
|
65
|
+
virtual ~ResultHandler() {}
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
/*****************************************************************
|
|
69
|
+
* Single best result handler.
|
|
70
|
+
* Tracks the only best result, thus avoiding storing
|
|
71
|
+
* some temporary data in memory.
|
|
72
|
+
*****************************************************************/
|
|
73
|
+
|
|
74
|
+
template <class C>
|
|
75
|
+
struct Top1BlockResultHandler : BlockResultHandler<C> {
|
|
76
|
+
using T = typename C::T;
|
|
77
|
+
using TI = typename C::TI;
|
|
78
|
+
using BlockResultHandler<C>::i0;
|
|
79
|
+
using BlockResultHandler<C>::i1;
|
|
80
|
+
|
|
81
|
+
// contains exactly nq elements
|
|
82
|
+
T* dis_tab;
|
|
83
|
+
// contains exactly nq elements
|
|
84
|
+
TI* ids_tab;
|
|
85
|
+
|
|
86
|
+
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
|
87
|
+
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
|
88
|
+
|
|
89
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
90
|
+
Top1BlockResultHandler& hr;
|
|
91
|
+
using ResultHandler<C>::threshold;
|
|
92
|
+
|
|
93
|
+
TI min_idx;
|
|
94
|
+
size_t current_idx = 0;
|
|
95
|
+
|
|
96
|
+
explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
|
|
97
|
+
|
|
98
|
+
/// begin results for query # i
|
|
99
|
+
void begin(const size_t current_idx_2) {
|
|
100
|
+
this->current_idx = current_idx_2;
|
|
101
|
+
threshold = C::neutral();
|
|
102
|
+
min_idx = -1;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
/// add one result for query i
|
|
106
|
+
bool add_result(T dis, TI idx) final {
|
|
107
|
+
if (C::cmp(this->threshold, dis)) {
|
|
108
|
+
threshold = dis;
|
|
109
|
+
min_idx = idx;
|
|
110
|
+
return true;
|
|
111
|
+
}
|
|
112
|
+
return false;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
/// series of results for query i is done
|
|
116
|
+
void end() {
|
|
117
|
+
hr.dis_tab[current_idx] = threshold;
|
|
118
|
+
hr.ids_tab[current_idx] = min_idx;
|
|
119
|
+
}
|
|
120
|
+
};
|
|
121
|
+
|
|
122
|
+
/// begin
|
|
123
|
+
void begin_multiple(size_t i0, size_t i1) final {
|
|
124
|
+
this->i0 = i0;
|
|
125
|
+
this->i1 = i1;
|
|
126
|
+
|
|
127
|
+
for (size_t i = i0; i < i1; i++) {
|
|
128
|
+
this->dis_tab[i] = C::neutral();
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
/// add results for query i0..i1 and j0..j1
|
|
133
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
|
|
134
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
135
|
+
const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
|
|
136
|
+
|
|
137
|
+
auto& min_distance = this->dis_tab[i];
|
|
138
|
+
auto& min_index = this->ids_tab[i];
|
|
139
|
+
|
|
140
|
+
for (size_t j = j0; j < j1; j++) {
|
|
141
|
+
const T distance = dis_tab_i[j];
|
|
142
|
+
|
|
143
|
+
if (C::cmp(min_distance, distance)) {
|
|
144
|
+
min_distance = distance;
|
|
145
|
+
min_index = j;
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
void add_result(const size_t i, const T dis, const TI idx) {
|
|
152
|
+
auto& min_distance = this->dis_tab[i];
|
|
153
|
+
auto& min_index = this->ids_tab[i];
|
|
154
|
+
|
|
155
|
+
if (C::cmp(min_distance, dis)) {
|
|
156
|
+
min_distance = dis;
|
|
157
|
+
min_index = idx;
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
};
|
|
161
|
+
|
|
20
162
|
/*****************************************************************
|
|
21
163
|
* Heap based result handler
|
|
22
164
|
*****************************************************************/
|
|
23
165
|
|
|
24
166
|
template <class C>
|
|
25
|
-
struct
|
|
167
|
+
struct HeapBlockResultHandler : BlockResultHandler<C> {
|
|
26
168
|
using T = typename C::T;
|
|
27
169
|
using TI = typename C::TI;
|
|
170
|
+
using BlockResultHandler<C>::i0;
|
|
171
|
+
using BlockResultHandler<C>::i1;
|
|
28
172
|
|
|
29
|
-
int nq;
|
|
30
173
|
T* heap_dis_tab;
|
|
31
174
|
TI* heap_ids_tab;
|
|
32
175
|
|
|
33
176
|
int64_t k; // number of results to keep
|
|
34
177
|
|
|
35
|
-
|
|
36
|
-
|
|
178
|
+
HeapBlockResultHandler(
|
|
179
|
+
size_t nq,
|
|
180
|
+
T* heap_dis_tab,
|
|
181
|
+
TI* heap_ids_tab,
|
|
182
|
+
size_t k)
|
|
183
|
+
: BlockResultHandler<C>(nq),
|
|
37
184
|
heap_dis_tab(heap_dis_tab),
|
|
38
185
|
heap_ids_tab(heap_ids_tab),
|
|
39
186
|
k(k) {}
|
|
@@ -43,30 +190,33 @@ struct HeapResultHandler {
|
|
|
43
190
|
* called from 1 thread)
|
|
44
191
|
*/
|
|
45
192
|
|
|
46
|
-
struct SingleResultHandler {
|
|
47
|
-
|
|
193
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
194
|
+
HeapBlockResultHandler& hr;
|
|
195
|
+
using ResultHandler<C>::threshold;
|
|
48
196
|
size_t k;
|
|
49
197
|
|
|
50
198
|
T* heap_dis;
|
|
51
199
|
TI* heap_ids;
|
|
52
|
-
T thresh;
|
|
53
200
|
|
|
54
|
-
SingleResultHandler(
|
|
201
|
+
explicit SingleResultHandler(HeapBlockResultHandler& hr)
|
|
202
|
+
: hr(hr), k(hr.k) {}
|
|
55
203
|
|
|
56
204
|
/// begin results for query # i
|
|
57
205
|
void begin(size_t i) {
|
|
58
206
|
heap_dis = hr.heap_dis_tab + i * k;
|
|
59
207
|
heap_ids = hr.heap_ids_tab + i * k;
|
|
60
208
|
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
61
|
-
|
|
209
|
+
threshold = heap_dis[0];
|
|
62
210
|
}
|
|
63
211
|
|
|
64
212
|
/// add one result for query i
|
|
65
|
-
|
|
66
|
-
if (C::cmp(
|
|
213
|
+
bool add_result(T dis, TI idx) final {
|
|
214
|
+
if (C::cmp(threshold, dis)) {
|
|
67
215
|
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
|
68
|
-
|
|
216
|
+
threshold = heap_dis[0];
|
|
217
|
+
return true;
|
|
69
218
|
}
|
|
219
|
+
return false;
|
|
70
220
|
}
|
|
71
221
|
|
|
72
222
|
/// series of results for query i is done
|
|
@@ -79,19 +229,17 @@ struct HeapResultHandler {
|
|
|
79
229
|
* API for multiple results (called from 1 thread)
|
|
80
230
|
*/
|
|
81
231
|
|
|
82
|
-
size_t i0, i1;
|
|
83
|
-
|
|
84
232
|
/// begin
|
|
85
|
-
void begin_multiple(size_t
|
|
86
|
-
this->i0 =
|
|
87
|
-
this->i1 =
|
|
233
|
+
void begin_multiple(size_t i0_2, size_t i1_2) final {
|
|
234
|
+
this->i0 = i0_2;
|
|
235
|
+
this->i1 = i1_2;
|
|
88
236
|
for (size_t i = i0; i < i1; i++) {
|
|
89
237
|
heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
90
238
|
}
|
|
91
239
|
}
|
|
92
240
|
|
|
93
241
|
/// add results for query i0..i1 and j0..j1
|
|
94
|
-
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
242
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) final {
|
|
95
243
|
#pragma omp parallel for
|
|
96
244
|
for (int64_t i = i0; i < i1; i++) {
|
|
97
245
|
T* heap_dis = heap_dis_tab + i * k;
|
|
@@ -109,7 +257,7 @@ struct HeapResultHandler {
|
|
|
109
257
|
}
|
|
110
258
|
|
|
111
259
|
/// series of results for queries i0..i1 is done
|
|
112
|
-
void end_multiple() {
|
|
260
|
+
void end_multiple() final {
|
|
113
261
|
// maybe parallel for
|
|
114
262
|
for (size_t i = i0; i < i1; i++) {
|
|
115
263
|
heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
@@ -128,9 +276,10 @@ struct HeapResultHandler {
|
|
|
128
276
|
|
|
129
277
|
/// Reservoir for a single query
|
|
130
278
|
template <class C>
|
|
131
|
-
struct ReservoirTopN {
|
|
279
|
+
struct ReservoirTopN : ResultHandler<C> {
|
|
132
280
|
using T = typename C::T;
|
|
133
281
|
using TI = typename C::TI;
|
|
282
|
+
using ResultHandler<C>::threshold;
|
|
134
283
|
|
|
135
284
|
T* vals;
|
|
136
285
|
TI* ids;
|
|
@@ -139,8 +288,6 @@ struct ReservoirTopN {
|
|
|
139
288
|
size_t n; // number of requested elements
|
|
140
289
|
size_t capacity; // size of storage
|
|
141
290
|
|
|
142
|
-
T threshold; // current threshold
|
|
143
|
-
|
|
144
291
|
ReservoirTopN() {}
|
|
145
292
|
|
|
146
293
|
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
|
@@ -149,15 +296,22 @@ struct ReservoirTopN {
|
|
|
149
296
|
threshold = C::neutral();
|
|
150
297
|
}
|
|
151
298
|
|
|
152
|
-
|
|
299
|
+
bool add_result(T val, TI id) final {
|
|
300
|
+
bool updated_threshold = false;
|
|
153
301
|
if (C::cmp(threshold, val)) {
|
|
154
302
|
if (i == capacity) {
|
|
155
303
|
shrink_fuzzy();
|
|
304
|
+
updated_threshold = true;
|
|
156
305
|
}
|
|
157
306
|
vals[i] = val;
|
|
158
307
|
ids[i] = id;
|
|
159
308
|
i++;
|
|
160
309
|
}
|
|
310
|
+
return updated_threshold;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
void add(T val, TI id) {
|
|
314
|
+
add_result(val, id);
|
|
161
315
|
}
|
|
162
316
|
|
|
163
317
|
// reduce storage from capacity to anything
|
|
@@ -169,6 +323,11 @@ struct ReservoirTopN {
|
|
|
169
323
|
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
|
170
324
|
}
|
|
171
325
|
|
|
326
|
+
void shrink() {
|
|
327
|
+
threshold = partition<C>(vals, ids, i, n);
|
|
328
|
+
i = n;
|
|
329
|
+
}
|
|
330
|
+
|
|
172
331
|
void to_result(T* heap_dis, TI* heap_ids) const {
|
|
173
332
|
for (int j = 0; j < std::min(i, n); j++) {
|
|
174
333
|
heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
|
|
@@ -187,23 +346,24 @@ struct ReservoirTopN {
|
|
|
187
346
|
};
|
|
188
347
|
|
|
189
348
|
template <class C>
|
|
190
|
-
struct
|
|
349
|
+
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|
191
350
|
using T = typename C::T;
|
|
192
351
|
using TI = typename C::TI;
|
|
352
|
+
using BlockResultHandler<C>::i0;
|
|
353
|
+
using BlockResultHandler<C>::i1;
|
|
193
354
|
|
|
194
|
-
int nq;
|
|
195
355
|
T* heap_dis_tab;
|
|
196
356
|
TI* heap_ids_tab;
|
|
197
357
|
|
|
198
358
|
int64_t k; // number of results to keep
|
|
199
359
|
size_t capacity; // capacity of the reservoirs
|
|
200
360
|
|
|
201
|
-
|
|
361
|
+
ReservoirBlockResultHandler(
|
|
202
362
|
size_t nq,
|
|
203
363
|
T* heap_dis_tab,
|
|
204
364
|
TI* heap_ids_tab,
|
|
205
365
|
size_t k)
|
|
206
|
-
:
|
|
366
|
+
: BlockResultHandler<C>(nq),
|
|
207
367
|
heap_dis_tab(heap_dis_tab),
|
|
208
368
|
heap_ids_tab(heap_ids_tab),
|
|
209
369
|
k(k) {
|
|
@@ -216,40 +376,34 @@ struct ReservoirResultHandler {
|
|
|
216
376
|
* called from 1 thread)
|
|
217
377
|
*/
|
|
218
378
|
|
|
219
|
-
struct SingleResultHandler {
|
|
220
|
-
|
|
379
|
+
struct SingleResultHandler : ReservoirTopN<C> {
|
|
380
|
+
ReservoirBlockResultHandler& hr;
|
|
221
381
|
|
|
222
382
|
std::vector<T> reservoir_dis;
|
|
223
383
|
std::vector<TI> reservoir_ids;
|
|
224
|
-
ReservoirTopN<C> res1;
|
|
225
384
|
|
|
226
|
-
SingleResultHandler(
|
|
227
|
-
:
|
|
228
|
-
|
|
229
|
-
reservoir_ids(hr.capacity) {}
|
|
385
|
+
explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
|
|
386
|
+
: ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
|
|
387
|
+
hr(hr) {}
|
|
230
388
|
|
|
231
|
-
size_t
|
|
389
|
+
size_t qno;
|
|
232
390
|
|
|
233
391
|
/// begin results for query # i
|
|
234
|
-
void begin(size_t
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
this->
|
|
392
|
+
void begin(size_t qno_2) {
|
|
393
|
+
reservoir_dis.resize(hr.capacity);
|
|
394
|
+
reservoir_ids.resize(hr.capacity);
|
|
395
|
+
this->vals = reservoir_dis.data();
|
|
396
|
+
this->ids = reservoir_ids.data();
|
|
397
|
+
this->i = 0; // size of reservoir
|
|
398
|
+
this->threshold = C::neutral();
|
|
399
|
+
this->qno = qno_2;
|
|
241
400
|
}
|
|
242
401
|
|
|
243
|
-
///
|
|
244
|
-
void add_result(T dis, TI idx) {
|
|
245
|
-
res1.add(dis, idx);
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
/// series of results for query i is done
|
|
402
|
+
/// series of results for query qno is done
|
|
249
403
|
void end() {
|
|
250
|
-
T* heap_dis = hr.heap_dis_tab +
|
|
251
|
-
TI* heap_ids = hr.heap_ids_tab +
|
|
252
|
-
|
|
404
|
+
T* heap_dis = hr.heap_dis_tab + qno * hr.k;
|
|
405
|
+
TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
|
|
406
|
+
this->to_result(heap_dis, heap_ids);
|
|
253
407
|
}
|
|
254
408
|
};
|
|
255
409
|
|
|
@@ -257,44 +411,41 @@ struct ReservoirResultHandler {
|
|
|
257
411
|
* API for multiple results (called from 1 thread)
|
|
258
412
|
*/
|
|
259
413
|
|
|
260
|
-
size_t i0, i1;
|
|
261
|
-
|
|
262
414
|
std::vector<T> reservoir_dis;
|
|
263
415
|
std::vector<TI> reservoir_ids;
|
|
264
416
|
std::vector<ReservoirTopN<C>> reservoirs;
|
|
265
417
|
|
|
266
418
|
/// begin
|
|
267
|
-
void begin_multiple(size_t
|
|
268
|
-
this->i0 =
|
|
269
|
-
this->i1 =
|
|
419
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
420
|
+
this->i0 = i0_2;
|
|
421
|
+
this->i1 = i1_2;
|
|
270
422
|
reservoir_dis.resize((i1 - i0) * capacity);
|
|
271
423
|
reservoir_ids.resize((i1 - i0) * capacity);
|
|
272
424
|
reservoirs.clear();
|
|
273
|
-
for (size_t i =
|
|
425
|
+
for (size_t i = i0_2; i < i1_2; i++) {
|
|
274
426
|
reservoirs.emplace_back(
|
|
275
427
|
k,
|
|
276
428
|
capacity,
|
|
277
|
-
reservoir_dis.data() + (i -
|
|
278
|
-
reservoir_ids.data() + (i -
|
|
429
|
+
reservoir_dis.data() + (i - i0_2) * capacity,
|
|
430
|
+
reservoir_ids.data() + (i - i0_2) * capacity);
|
|
279
431
|
}
|
|
280
432
|
}
|
|
281
433
|
|
|
282
434
|
/// add results for query i0..i1 and j0..j1
|
|
283
435
|
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
284
|
-
// maybe parallel for
|
|
285
436
|
#pragma omp parallel for
|
|
286
437
|
for (int64_t i = i0; i < i1; i++) {
|
|
287
438
|
ReservoirTopN<C>& reservoir = reservoirs[i - i0];
|
|
288
439
|
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
289
440
|
for (size_t j = j0; j < j1; j++) {
|
|
290
441
|
T dis = dis_tab_i[j];
|
|
291
|
-
reservoir.
|
|
442
|
+
reservoir.add_result(dis, j);
|
|
292
443
|
}
|
|
293
444
|
}
|
|
294
445
|
}
|
|
295
446
|
|
|
296
447
|
/// series of results for queries i0..i1 is done
|
|
297
|
-
void end_multiple() {
|
|
448
|
+
void end_multiple() final {
|
|
298
449
|
// maybe parallel for
|
|
299
450
|
for (size_t i = i0; i < i1; i++) {
|
|
300
451
|
reservoirs[i - i0].to_result(
|
|
@@ -308,29 +459,33 @@ struct ReservoirResultHandler {
|
|
|
308
459
|
*****************************************************************/
|
|
309
460
|
|
|
310
461
|
template <class C>
|
|
311
|
-
struct
|
|
462
|
+
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
|
312
463
|
using T = typename C::T;
|
|
313
464
|
using TI = typename C::TI;
|
|
465
|
+
using BlockResultHandler<C>::i0;
|
|
466
|
+
using BlockResultHandler<C>::i1;
|
|
314
467
|
|
|
315
468
|
RangeSearchResult* res;
|
|
316
|
-
|
|
469
|
+
T radius;
|
|
317
470
|
|
|
318
|
-
|
|
319
|
-
: res(res), radius(radius) {}
|
|
471
|
+
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
|
|
472
|
+
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
|
|
320
473
|
|
|
321
474
|
/******************************************************
|
|
322
475
|
* API for 1 result at a time (each SingleResultHandler is
|
|
323
476
|
* called from 1 thread)
|
|
324
477
|
******************************************************/
|
|
325
478
|
|
|
326
|
-
struct SingleResultHandler {
|
|
479
|
+
struct SingleResultHandler : ResultHandler<C> {
|
|
327
480
|
// almost the same interface as RangeSearchResultHandler
|
|
481
|
+
using ResultHandler<C>::threshold;
|
|
328
482
|
RangeSearchPartialResult pres;
|
|
329
|
-
float radius;
|
|
330
483
|
RangeQueryResult* qr = nullptr;
|
|
331
484
|
|
|
332
|
-
SingleResultHandler(
|
|
333
|
-
: pres(rh.res)
|
|
485
|
+
explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
|
|
486
|
+
: pres(rh.res) {
|
|
487
|
+
threshold = rh.radius;
|
|
488
|
+
}
|
|
334
489
|
|
|
335
490
|
/// begin results for query # i
|
|
336
491
|
void begin(size_t i) {
|
|
@@ -338,10 +493,11 @@ struct RangeSearchResultHandler {
|
|
|
338
493
|
}
|
|
339
494
|
|
|
340
495
|
/// add one result for query i
|
|
341
|
-
|
|
342
|
-
if (C::cmp(
|
|
496
|
+
bool add_result(T dis, TI idx) final {
|
|
497
|
+
if (C::cmp(threshold, dis)) {
|
|
343
498
|
qr->add(dis, idx);
|
|
344
499
|
}
|
|
500
|
+
return false;
|
|
345
501
|
}
|
|
346
502
|
|
|
347
503
|
/// series of results for query i is done
|
|
@@ -356,16 +512,14 @@ struct RangeSearchResultHandler {
|
|
|
356
512
|
* API for multiple results (called from 1 thread)
|
|
357
513
|
******************************************************/
|
|
358
514
|
|
|
359
|
-
size_t i0, i1;
|
|
360
|
-
|
|
361
515
|
std::vector<RangeSearchPartialResult*> partial_results;
|
|
362
516
|
std::vector<size_t> j0s;
|
|
363
517
|
int pr = 0;
|
|
364
518
|
|
|
365
519
|
/// begin
|
|
366
|
-
void begin_multiple(size_t
|
|
367
|
-
this->i0 =
|
|
368
|
-
this->i1 =
|
|
520
|
+
void begin_multiple(size_t i0_2, size_t i1_2) {
|
|
521
|
+
this->i0 = i0_2;
|
|
522
|
+
this->i1 = i1_2;
|
|
369
523
|
}
|
|
370
524
|
|
|
371
525
|
/// add results for query i0..i1 and j0..j1
|
|
@@ -404,109 +558,11 @@ struct RangeSearchResultHandler {
|
|
|
404
558
|
}
|
|
405
559
|
}
|
|
406
560
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
~RangeSearchResultHandler() {
|
|
561
|
+
~RangeSearchBlockResultHandler() {
|
|
410
562
|
if (partial_results.size() > 0) {
|
|
411
563
|
RangeSearchPartialResult::merge(partial_results);
|
|
412
564
|
}
|
|
413
565
|
}
|
|
414
566
|
};
|
|
415
567
|
|
|
416
|
-
/*****************************************************************
|
|
417
|
-
* Single best result handler.
|
|
418
|
-
* Tracks the only best result, thus avoiding storing
|
|
419
|
-
* some temporary data in memory.
|
|
420
|
-
*****************************************************************/
|
|
421
|
-
|
|
422
|
-
template <class C>
|
|
423
|
-
struct SingleBestResultHandler {
|
|
424
|
-
using T = typename C::T;
|
|
425
|
-
using TI = typename C::TI;
|
|
426
|
-
|
|
427
|
-
int nq;
|
|
428
|
-
// contains exactly nq elements
|
|
429
|
-
T* dis_tab;
|
|
430
|
-
// contains exactly nq elements
|
|
431
|
-
TI* ids_tab;
|
|
432
|
-
|
|
433
|
-
SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
|
|
434
|
-
: nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
|
|
435
|
-
|
|
436
|
-
struct SingleResultHandler {
|
|
437
|
-
SingleBestResultHandler& hr;
|
|
438
|
-
|
|
439
|
-
T min_dis;
|
|
440
|
-
TI min_idx;
|
|
441
|
-
size_t current_idx = 0;
|
|
442
|
-
|
|
443
|
-
SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
|
|
444
|
-
|
|
445
|
-
/// begin results for query # i
|
|
446
|
-
void begin(const size_t current_idx) {
|
|
447
|
-
this->current_idx = current_idx;
|
|
448
|
-
min_dis = HUGE_VALF;
|
|
449
|
-
min_idx = 0;
|
|
450
|
-
}
|
|
451
|
-
|
|
452
|
-
/// add one result for query i
|
|
453
|
-
void add_result(T dis, TI idx) {
|
|
454
|
-
if (C::cmp(min_dis, dis)) {
|
|
455
|
-
min_dis = dis;
|
|
456
|
-
min_idx = idx;
|
|
457
|
-
}
|
|
458
|
-
}
|
|
459
|
-
|
|
460
|
-
/// series of results for query i is done
|
|
461
|
-
void end() {
|
|
462
|
-
hr.dis_tab[current_idx] = min_dis;
|
|
463
|
-
hr.ids_tab[current_idx] = min_idx;
|
|
464
|
-
}
|
|
465
|
-
};
|
|
466
|
-
|
|
467
|
-
size_t i0, i1;
|
|
468
|
-
|
|
469
|
-
/// begin
|
|
470
|
-
void begin_multiple(size_t i0, size_t i1) {
|
|
471
|
-
this->i0 = i0;
|
|
472
|
-
this->i1 = i1;
|
|
473
|
-
|
|
474
|
-
for (size_t i = i0; i < i1; i++) {
|
|
475
|
-
this->dis_tab[i] = HUGE_VALF;
|
|
476
|
-
}
|
|
477
|
-
}
|
|
478
|
-
|
|
479
|
-
/// add results for query i0..i1 and j0..j1
|
|
480
|
-
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
481
|
-
for (int64_t i = i0; i < i1; i++) {
|
|
482
|
-
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
483
|
-
|
|
484
|
-
auto& min_distance = this->dis_tab[i];
|
|
485
|
-
auto& min_index = this->ids_tab[i];
|
|
486
|
-
|
|
487
|
-
for (size_t j = j0; j < j1; j++) {
|
|
488
|
-
const T distance = dis_tab_i[j];
|
|
489
|
-
|
|
490
|
-
if (C::cmp(min_distance, distance)) {
|
|
491
|
-
min_distance = distance;
|
|
492
|
-
min_index = j;
|
|
493
|
-
}
|
|
494
|
-
}
|
|
495
|
-
}
|
|
496
|
-
}
|
|
497
|
-
|
|
498
|
-
void add_result(const size_t i, const T dis, const TI idx) {
|
|
499
|
-
auto& min_distance = this->dis_tab[i];
|
|
500
|
-
auto& min_index = this->ids_tab[i];
|
|
501
|
-
|
|
502
|
-
if (C::cmp(min_distance, dis)) {
|
|
503
|
-
min_distance = dis;
|
|
504
|
-
min_index = idx;
|
|
505
|
-
}
|
|
506
|
-
}
|
|
507
|
-
|
|
508
|
-
/// series of results for queries i0..i1 is done
|
|
509
|
-
void end_multiple() {}
|
|
510
|
-
};
|
|
511
|
-
|
|
512
568
|
} // namespace faiss
|