faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -14,40 +14,86 @@
|
|
14
14
|
#include <faiss/utils/Heap.h>
|
15
15
|
#include <faiss/utils/simdlib.h>
|
16
16
|
|
17
|
+
#include <faiss/impl/FaissAssert.h>
|
18
|
+
#include <faiss/impl/ResultHandler.h>
|
17
19
|
#include <faiss/impl/platform_macros.h>
|
18
20
|
#include <faiss/utils/AlignedTable.h>
|
19
21
|
#include <faiss/utils/partitioning.h>
|
20
22
|
|
21
23
|
/** This file contains callbacks for kernels that compute distances.
|
22
|
-
*
|
23
|
-
* The SIMDResultHandler object is intended to be templated and inlined.
|
24
|
-
* Methods:
|
25
|
-
* - handle(): called when 32 distances are computed and provided in two
|
26
|
-
* simd16uint16. (q, b) indicate which entry it is in the block.
|
27
|
-
* - set_block_origin(): set the sub-matrix that is being computed
|
28
24
|
*/
|
29
25
|
|
30
26
|
namespace faiss {
|
31
27
|
|
28
|
+
struct SIMDResultHandler {
|
29
|
+
// used to dispatch templates
|
30
|
+
bool is_CMax = false;
|
31
|
+
uint8_t sizeof_ids = 0;
|
32
|
+
bool with_fields = false;
|
33
|
+
|
34
|
+
/** called when 32 distances are computed and provided in two
|
35
|
+
* simd16uint16. (q, b) indicate which entry it is in the block. */
|
36
|
+
virtual void handle(
|
37
|
+
size_t q,
|
38
|
+
size_t b,
|
39
|
+
simd16uint16 d0,
|
40
|
+
simd16uint16 d1) = 0;
|
41
|
+
|
42
|
+
/// set the sub-matrix that is being computed
|
43
|
+
virtual void set_block_origin(size_t i0, size_t j0) = 0;
|
44
|
+
|
45
|
+
virtual ~SIMDResultHandler() {}
|
46
|
+
};
|
47
|
+
|
48
|
+
/* Result handler that will return float resutls eventually */
|
49
|
+
struct SIMDResultHandlerToFloat : SIMDResultHandler {
|
50
|
+
size_t nq; // number of queries
|
51
|
+
size_t ntotal; // ignore excess elements after ntotal
|
52
|
+
|
53
|
+
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
54
|
+
const idx_t* id_map = nullptr; // map offset in invlist to vector id
|
55
|
+
const int* q_map = nullptr; // map q to global query
|
56
|
+
const uint16_t* dbias =
|
57
|
+
nullptr; // table of biases to add to each query (for IVF L2 search)
|
58
|
+
const float* normalizers = nullptr; // size 2 * nq, to convert
|
59
|
+
|
60
|
+
SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
|
61
|
+
: nq(nq), ntotal(ntotal) {}
|
62
|
+
|
63
|
+
virtual void begin(const float* norms) {
|
64
|
+
normalizers = norms;
|
65
|
+
}
|
66
|
+
|
67
|
+
// called at end of search to convert int16 distances to float, before
|
68
|
+
// normalizers are deallocated
|
69
|
+
virtual void end() {
|
70
|
+
normalizers = nullptr;
|
71
|
+
}
|
72
|
+
};
|
73
|
+
|
74
|
+
FAISS_API extern bool simd_result_handlers_accept_virtual;
|
75
|
+
|
32
76
|
namespace simd_result_handlers {
|
33
77
|
|
34
|
-
/** Dummy structure that just computes a
|
78
|
+
/** Dummy structure that just computes a chqecksum on results
|
35
79
|
* (to avoid the computation to be optimized away) */
|
36
|
-
struct DummyResultHandler {
|
80
|
+
struct DummyResultHandler : SIMDResultHandler {
|
37
81
|
size_t cs = 0;
|
38
82
|
|
39
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
83
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
40
84
|
cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
|
41
85
|
}
|
42
86
|
|
43
|
-
void set_block_origin(size_t, size_t) {}
|
87
|
+
void set_block_origin(size_t, size_t) final {}
|
88
|
+
|
89
|
+
~DummyResultHandler() {}
|
44
90
|
};
|
45
91
|
|
46
92
|
/** memorize results in a nq-by-nb matrix.
|
47
93
|
*
|
48
94
|
* j0 is the current upper-left block of the matrix
|
49
95
|
*/
|
50
|
-
struct StoreResultHandler {
|
96
|
+
struct StoreResultHandler : SIMDResultHandler {
|
51
97
|
uint16_t* data;
|
52
98
|
size_t ld; // total number of columns
|
53
99
|
size_t i0 = 0;
|
@@ -55,32 +101,32 @@ struct StoreResultHandler {
|
|
55
101
|
|
56
102
|
StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
|
57
103
|
|
58
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
104
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
59
105
|
size_t ofs = (q + i0) * ld + j0 + b * 32;
|
60
106
|
d0.store(data + ofs);
|
61
107
|
d1.store(data + ofs + 16);
|
62
108
|
}
|
63
109
|
|
64
|
-
void set_block_origin(size_t
|
65
|
-
this->i0 =
|
66
|
-
this->j0 =
|
110
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
111
|
+
this->i0 = i0_in;
|
112
|
+
this->j0 = j0_in;
|
67
113
|
}
|
68
114
|
};
|
69
115
|
|
70
116
|
/** stores results in fixed-size matrix. */
|
71
117
|
template <int NQ, int BB>
|
72
|
-
struct FixedStorageHandler {
|
118
|
+
struct FixedStorageHandler : SIMDResultHandler {
|
73
119
|
simd16uint16 dis[NQ][BB];
|
74
120
|
int i0 = 0;
|
75
121
|
|
76
|
-
void handle(
|
122
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
77
123
|
dis[q + i0][2 * b] = d0;
|
78
124
|
dis[q + i0][2 * b + 1] = d1;
|
79
125
|
}
|
80
126
|
|
81
|
-
void set_block_origin(size_t
|
82
|
-
this->i0 =
|
83
|
-
assert(
|
127
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
128
|
+
this->i0 = i0_in;
|
129
|
+
assert(j0_in == 0);
|
84
130
|
}
|
85
131
|
|
86
132
|
template <class OtherResultHandler>
|
@@ -91,30 +137,29 @@ struct FixedStorageHandler {
|
|
91
137
|
}
|
92
138
|
}
|
93
139
|
}
|
140
|
+
virtual ~FixedStorageHandler() {}
|
94
141
|
};
|
95
142
|
|
96
|
-
/**
|
143
|
+
/** Result handler that compares distances to check if they need to be kept */
|
97
144
|
template <class C, bool with_id_map>
|
98
|
-
struct
|
145
|
+
struct ResultHandlerCompare : SIMDResultHandlerToFloat {
|
99
146
|
using TI = typename C::TI;
|
100
147
|
|
101
148
|
bool disable = false;
|
102
149
|
|
103
150
|
int64_t i0 = 0; // query origin
|
104
151
|
int64_t j0 = 0; // db origin
|
105
|
-
size_t ntotal; // ignore excess elements after ntotal
|
106
|
-
|
107
|
-
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
108
|
-
const TI* id_map; // map offset in invlist to vector id
|
109
|
-
const int* q_map; // map q to global query
|
110
|
-
const uint16_t* dbias; // table of biases to add to each query
|
111
152
|
|
112
|
-
|
113
|
-
:
|
153
|
+
ResultHandlerCompare(size_t nq, size_t ntotal)
|
154
|
+
: SIMDResultHandlerToFloat(nq, ntotal) {
|
155
|
+
this->is_CMax = C::is_max;
|
156
|
+
this->sizeof_ids = sizeof(typename C::TI);
|
157
|
+
this->with_fields = with_id_map;
|
158
|
+
}
|
114
159
|
|
115
|
-
void set_block_origin(size_t
|
116
|
-
this->i0 =
|
117
|
-
this->j0 =
|
160
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
161
|
+
this->i0 = i0_in;
|
162
|
+
this->j0 = j0_in;
|
118
163
|
}
|
119
164
|
|
120
165
|
// adjust handler data for IVF.
|
@@ -172,43 +217,37 @@ struct SIMDResultHandler {
|
|
172
217
|
return lt_mask;
|
173
218
|
}
|
174
219
|
|
175
|
-
virtual
|
176
|
-
float* distances,
|
177
|
-
int64_t* labels,
|
178
|
-
const float* normalizers = nullptr) = 0;
|
179
|
-
|
180
|
-
virtual ~SIMDResultHandler() {}
|
220
|
+
virtual ~ResultHandlerCompare() {}
|
181
221
|
};
|
182
222
|
|
183
223
|
/** Special version for k=1 */
|
184
224
|
template <class C, bool with_id_map = false>
|
185
|
-
struct SingleResultHandler :
|
225
|
+
struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
|
186
226
|
using T = typename C::T;
|
187
227
|
using TI = typename C::TI;
|
228
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
229
|
+
using RHC::normalizers;
|
188
230
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
};
|
193
|
-
std::vector<Result> results;
|
231
|
+
std::vector<int16_t> idis;
|
232
|
+
float* dis;
|
233
|
+
int64_t* ids;
|
194
234
|
|
195
|
-
SingleResultHandler(size_t nq, size_t ntotal)
|
196
|
-
:
|
235
|
+
SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
|
236
|
+
: RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
|
197
237
|
for (int i = 0; i < nq; i++) {
|
198
|
-
|
199
|
-
|
238
|
+
ids[i] = -1;
|
239
|
+
idis[i] = C::neutral();
|
200
240
|
}
|
201
241
|
}
|
202
242
|
|
203
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
243
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
204
244
|
if (this->disable) {
|
205
245
|
return;
|
206
246
|
}
|
207
247
|
|
208
248
|
this->adjust_with_origin(q, d0, d1);
|
209
249
|
|
210
|
-
|
211
|
-
uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
|
250
|
+
uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
|
212
251
|
if (!lt_mask) {
|
213
252
|
return;
|
214
253
|
}
|
@@ -221,70 +260,61 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
|
|
221
260
|
// find first non-zero
|
222
261
|
int j = __builtin_ctz(lt_mask);
|
223
262
|
lt_mask -= 1 << j;
|
224
|
-
T
|
225
|
-
if (C::cmp(
|
226
|
-
|
227
|
-
|
263
|
+
T d = d32tab[j];
|
264
|
+
if (C::cmp(idis[q], d)) {
|
265
|
+
idis[q] = d;
|
266
|
+
ids[q] = this->adjust_id(b, j);
|
228
267
|
}
|
229
268
|
}
|
230
269
|
}
|
231
270
|
|
232
|
-
void
|
233
|
-
|
234
|
-
int64_t* labels,
|
235
|
-
const float* normalizers = nullptr) override {
|
236
|
-
for (int q = 0; q < results.size(); q++) {
|
271
|
+
void end() {
|
272
|
+
for (int q = 0; q < this->nq; q++) {
|
237
273
|
if (!normalizers) {
|
238
|
-
|
274
|
+
dis[q] = idis[q];
|
239
275
|
} else {
|
240
276
|
float one_a = 1 / normalizers[2 * q];
|
241
277
|
float b = normalizers[2 * q + 1];
|
242
|
-
|
278
|
+
dis[q] = b + idis[q] * one_a;
|
243
279
|
}
|
244
|
-
labels[q] = results[q].id;
|
245
280
|
}
|
246
281
|
}
|
247
282
|
};
|
248
283
|
|
249
284
|
/** Structure that collects results in a min- or max-heap */
|
250
285
|
template <class C, bool with_id_map = false>
|
251
|
-
struct HeapHandler :
|
286
|
+
struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
252
287
|
using T = typename C::T;
|
253
288
|
using TI = typename C::TI;
|
289
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
290
|
+
using RHC::normalizers;
|
254
291
|
|
255
|
-
|
256
|
-
|
257
|
-
|
292
|
+
std::vector<uint16_t> idis;
|
293
|
+
std::vector<TI> iids;
|
294
|
+
float* dis;
|
295
|
+
int64_t* ids;
|
258
296
|
|
259
297
|
int64_t k; // number of results to keep
|
260
298
|
|
261
|
-
HeapHandler(
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
: SIMDResultHandler<C, with_id_map>(ntotal),
|
268
|
-
nq(nq),
|
269
|
-
heap_dis_tab(heap_dis_tab),
|
270
|
-
heap_ids_tab(heap_ids_tab),
|
299
|
+
HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
|
300
|
+
: RHC(nq, ntotal),
|
301
|
+
idis(nq * k),
|
302
|
+
iids(nq * k),
|
303
|
+
dis(dis),
|
304
|
+
ids(ids),
|
271
305
|
k(k) {
|
272
|
-
|
273
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
274
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
275
|
-
heap_heapify<C>(k, heap_dis_in, heap_ids_in);
|
276
|
-
}
|
306
|
+
heap_heapify<C>(k * nq, idis.data(), iids.data());
|
277
307
|
}
|
278
308
|
|
279
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
309
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
280
310
|
if (this->disable) {
|
281
311
|
return;
|
282
312
|
}
|
283
313
|
|
284
314
|
this->adjust_with_origin(q, d0, d1);
|
285
315
|
|
286
|
-
T* heap_dis =
|
287
|
-
TI* heap_ids =
|
316
|
+
T* heap_dis = idis.data() + q * k;
|
317
|
+
TI* heap_ids = iids.data() + q * k;
|
288
318
|
|
289
319
|
uint16_t cur_thresh =
|
290
320
|
heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
|
@@ -313,16 +343,13 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
313
343
|
}
|
314
344
|
}
|
315
345
|
|
316
|
-
void
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
for (int q = 0; q < nq; q++) {
|
321
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
322
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
346
|
+
void end() override {
|
347
|
+
for (int q = 0; q < this->nq; q++) {
|
348
|
+
T* heap_dis_in = idis.data() + q * k;
|
349
|
+
TI* heap_ids_in = iids.data() + q * k;
|
323
350
|
heap_reorder<C>(k, heap_dis_in, heap_ids_in);
|
324
|
-
|
325
|
-
|
351
|
+
float* heap_dis = dis + q * k;
|
352
|
+
int64_t* heap_ids = ids + q * k;
|
326
353
|
|
327
354
|
float one_a = 1.0, b = 0.0;
|
328
355
|
if (normalizers) {
|
@@ -330,8 +357,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
330
357
|
b = normalizers[2 * q + 1];
|
331
358
|
}
|
332
359
|
for (int j = 0; j < k; j++) {
|
333
|
-
heap_ids[j] = heap_ids_in[j];
|
334
360
|
heap_dis[j] = heap_dis_in[j] * one_a + b;
|
361
|
+
heap_ids[j] = heap_ids_in[j];
|
335
362
|
}
|
336
363
|
}
|
337
364
|
}
|
@@ -342,114 +369,45 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
342
369
|
* Results are stored when they are below the threshold until the capacity is
|
343
370
|
* reached. Then a partition sort is used to update the threshold. */
|
344
371
|
|
345
|
-
namespace {
|
346
|
-
|
347
|
-
uint64_t get_cy() {
|
348
|
-
#ifdef MICRO_BENCHMARK
|
349
|
-
uint32_t high, low;
|
350
|
-
asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
|
351
|
-
return ((uint64_t)high << 32) | (low);
|
352
|
-
#else
|
353
|
-
return 0;
|
354
|
-
#endif
|
355
|
-
}
|
356
|
-
|
357
|
-
} // anonymous namespace
|
358
|
-
|
359
|
-
template <class C>
|
360
|
-
struct ReservoirTopN {
|
361
|
-
using T = typename C::T;
|
362
|
-
using TI = typename C::TI;
|
363
|
-
|
364
|
-
T* vals;
|
365
|
-
TI* ids;
|
366
|
-
|
367
|
-
size_t i; // number of stored elements
|
368
|
-
size_t n; // number of requested elements
|
369
|
-
size_t capacity; // size of storage
|
370
|
-
size_t cycles = 0;
|
371
|
-
|
372
|
-
T threshold; // current threshold
|
373
|
-
|
374
|
-
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
375
|
-
: vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
|
376
|
-
assert(n < capacity);
|
377
|
-
threshold = C::neutral();
|
378
|
-
}
|
379
|
-
|
380
|
-
void add(T val, TI id) {
|
381
|
-
if (C::cmp(threshold, val)) {
|
382
|
-
if (i == capacity) {
|
383
|
-
shrink_fuzzy();
|
384
|
-
}
|
385
|
-
vals[i] = val;
|
386
|
-
ids[i] = id;
|
387
|
-
i++;
|
388
|
-
}
|
389
|
-
}
|
390
|
-
|
391
|
-
/// shrink number of stored elements to n
|
392
|
-
void shrink_xx() {
|
393
|
-
uint64_t t0 = get_cy();
|
394
|
-
qselect(vals, ids, i, n);
|
395
|
-
i = n; // forget all elements above i = n
|
396
|
-
threshold = C::Crev::neutral();
|
397
|
-
for (size_t j = 0; j < n; j++) {
|
398
|
-
if (C::cmp(vals[j], threshold)) {
|
399
|
-
threshold = vals[j];
|
400
|
-
}
|
401
|
-
}
|
402
|
-
cycles += get_cy() - t0;
|
403
|
-
}
|
404
|
-
|
405
|
-
void shrink() {
|
406
|
-
uint64_t t0 = get_cy();
|
407
|
-
threshold = partition<C>(vals, ids, i, n);
|
408
|
-
i = n;
|
409
|
-
cycles += get_cy() - t0;
|
410
|
-
}
|
411
|
-
|
412
|
-
void shrink_fuzzy() {
|
413
|
-
uint64_t t0 = get_cy();
|
414
|
-
assert(i == capacity);
|
415
|
-
threshold = partition_fuzzy<C>(
|
416
|
-
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
417
|
-
cycles += get_cy() - t0;
|
418
|
-
}
|
419
|
-
};
|
420
|
-
|
421
372
|
/** Handler built from several ReservoirTopN (one per query) */
|
422
373
|
template <class C, bool with_id_map = false>
|
423
|
-
struct ReservoirHandler :
|
374
|
+
struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
424
375
|
using T = typename C::T;
|
425
376
|
using TI = typename C::TI;
|
377
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
378
|
+
using RHC::normalizers;
|
426
379
|
|
427
380
|
size_t capacity; // rounded up to multiple of 16
|
381
|
+
|
382
|
+
// where the final results will be written
|
383
|
+
float* dis;
|
384
|
+
int64_t* ids;
|
385
|
+
|
428
386
|
std::vector<TI> all_ids;
|
429
387
|
AlignedTable<T> all_vals;
|
430
|
-
|
431
388
|
std::vector<ReservoirTopN<C>> reservoirs;
|
432
389
|
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
390
|
+
ReservoirHandler(
|
391
|
+
size_t nq,
|
392
|
+
size_t ntotal,
|
393
|
+
size_t k,
|
394
|
+
size_t cap,
|
395
|
+
float* dis,
|
396
|
+
int64_t* ids)
|
397
|
+
: RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
|
440
398
|
assert(capacity % 16 == 0);
|
441
|
-
|
399
|
+
all_ids.resize(nq * capacity);
|
400
|
+
all_vals.resize(nq * capacity);
|
401
|
+
for (size_t q = 0; q < nq; q++) {
|
442
402
|
reservoirs.emplace_back(
|
443
|
-
|
403
|
+
k,
|
444
404
|
capacity,
|
445
|
-
all_vals.get() +
|
446
|
-
all_ids.data() +
|
405
|
+
all_vals.get() + q * capacity,
|
406
|
+
all_ids.data() + q * capacity);
|
447
407
|
}
|
448
|
-
times[0] = times[1] = times[2] = times[3] = 0;
|
449
408
|
}
|
450
409
|
|
451
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
452
|
-
uint64_t t0 = get_cy();
|
410
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
453
411
|
if (this->disable) {
|
454
412
|
return;
|
455
413
|
}
|
@@ -457,8 +415,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
457
415
|
|
458
416
|
ReservoirTopN<C>& res = reservoirs[q];
|
459
417
|
uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
|
460
|
-
uint64_t t1 = get_cy();
|
461
|
-
times[0] += t1 - t0;
|
462
418
|
|
463
419
|
if (!lt_mask) {
|
464
420
|
return;
|
@@ -474,20 +430,14 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
474
430
|
T dis = d32tab[j];
|
475
431
|
res.add(dis, this->adjust_id(b, j));
|
476
432
|
}
|
477
|
-
times[1] += get_cy() - t1;
|
478
433
|
}
|
479
434
|
|
480
|
-
void
|
481
|
-
float* distances,
|
482
|
-
int64_t* labels,
|
483
|
-
const float* normalizers = nullptr) override {
|
435
|
+
void end() override {
|
484
436
|
using Cf = typename std::conditional<
|
485
437
|
C::is_max,
|
486
438
|
CMax<float, int64_t>,
|
487
439
|
CMin<float, int64_t>>::type;
|
488
440
|
|
489
|
-
uint64_t t0 = get_cy();
|
490
|
-
uint64_t t3 = 0;
|
491
441
|
std::vector<int> perm(reservoirs[0].n);
|
492
442
|
for (int q = 0; q < reservoirs.size(); q++) {
|
493
443
|
ReservoirTopN<C>& res = reservoirs[q];
|
@@ -496,8 +446,8 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
496
446
|
if (res.i > res.n) {
|
497
447
|
res.shrink();
|
498
448
|
}
|
499
|
-
int64_t* heap_ids =
|
500
|
-
float* heap_dis =
|
449
|
+
int64_t* heap_ids = ids + q * n;
|
450
|
+
float* heap_dis = dis + q * n;
|
501
451
|
|
502
452
|
float one_a = 1.0, b = 0.0;
|
503
453
|
if (normalizers) {
|
@@ -518,14 +468,236 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
518
468
|
|
519
469
|
// possibly add empty results
|
520
470
|
heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
|
471
|
+
}
|
472
|
+
}
|
473
|
+
};
|
474
|
+
|
475
|
+
/** Result hanlder for range search. The difficulty is that the range distances
|
476
|
+
* have to be scaled using the scaler.
|
477
|
+
*/
|
478
|
+
|
479
|
+
template <class C, bool with_id_map = false>
|
480
|
+
struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
481
|
+
using T = typename C::T;
|
482
|
+
using TI = typename C::TI;
|
483
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
484
|
+
using RHC::normalizers;
|
485
|
+
using RHC::nq;
|
486
|
+
|
487
|
+
RangeSearchResult& rres;
|
488
|
+
float radius;
|
489
|
+
std::vector<uint16_t> thresholds;
|
490
|
+
std::vector<size_t> n_per_query;
|
491
|
+
size_t q0 = 0;
|
492
|
+
|
493
|
+
// we cannot use the RangeSearchPartialResult interface because queries can
|
494
|
+
// be performed by batches
|
495
|
+
struct Triplet {
|
496
|
+
idx_t q;
|
497
|
+
idx_t b;
|
498
|
+
uint16_t dis;
|
499
|
+
};
|
500
|
+
std::vector<Triplet> triplets;
|
501
|
+
|
502
|
+
RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
|
503
|
+
: RHC(rres.nq, ntotal), rres(rres), radius(radius) {
|
504
|
+
thresholds.resize(nq);
|
505
|
+
n_per_query.resize(nq + 1);
|
506
|
+
}
|
507
|
+
|
508
|
+
virtual void begin(const float* norms) {
|
509
|
+
normalizers = norms;
|
510
|
+
for (int q = 0; q < nq; ++q) {
|
511
|
+
thresholds[q] =
|
512
|
+
normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
|
513
|
+
}
|
514
|
+
}
|
515
|
+
|
516
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
517
|
+
if (this->disable) {
|
518
|
+
return;
|
519
|
+
}
|
520
|
+
this->adjust_with_origin(q, d0, d1);
|
521
|
+
|
522
|
+
uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
|
523
|
+
|
524
|
+
if (!lt_mask) {
|
525
|
+
return;
|
526
|
+
}
|
527
|
+
ALIGNED(32) uint16_t d32tab[32];
|
528
|
+
d0.store(d32tab);
|
529
|
+
d1.store(d32tab + 16);
|
530
|
+
|
531
|
+
while (lt_mask) {
|
532
|
+
// find first non-zero
|
533
|
+
int j = __builtin_ctz(lt_mask);
|
534
|
+
lt_mask -= 1 << j;
|
535
|
+
T dis = d32tab[j];
|
536
|
+
n_per_query[q]++;
|
537
|
+
triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
|
538
|
+
}
|
539
|
+
}
|
521
540
|
|
522
|
-
|
541
|
+
void end() override {
|
542
|
+
memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
|
543
|
+
rres.do_allocation();
|
544
|
+
for (auto it = triplets.begin(); it != triplets.end(); ++it) {
|
545
|
+
size_t& l = rres.lims[it->q];
|
546
|
+
rres.distances[l] = it->dis;
|
547
|
+
rres.labels[l] = it->b;
|
548
|
+
l++;
|
549
|
+
}
|
550
|
+
memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
|
551
|
+
rres.lims[0] = 0;
|
552
|
+
|
553
|
+
for (int q = 0; q < nq; q++) {
|
554
|
+
float one_a = 1 / normalizers[2 * q];
|
555
|
+
float b = normalizers[2 * q + 1];
|
556
|
+
for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
|
557
|
+
rres.distances[i] = rres.distances[i] * one_a + b;
|
558
|
+
}
|
523
559
|
}
|
524
|
-
times[2] += get_cy() - t0;
|
525
|
-
times[3] += t3;
|
526
560
|
}
|
527
561
|
};
|
528
562
|
|
563
|
+
#ifndef SWIG
|
564
|
+
|
565
|
+
// handler for a subset of queries
|
566
|
+
template <class C, bool with_id_map = false>
|
567
|
+
struct PartialRangeHandler : RangeHandler<C, with_id_map> {
|
568
|
+
using T = typename C::T;
|
569
|
+
using TI = typename C::TI;
|
570
|
+
using RHC = RangeHandler<C, with_id_map>;
|
571
|
+
using RHC::normalizers;
|
572
|
+
using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
|
573
|
+
|
574
|
+
RangeSearchPartialResult& pres;
|
575
|
+
|
576
|
+
PartialRangeHandler(
|
577
|
+
RangeSearchPartialResult& pres,
|
578
|
+
float radius,
|
579
|
+
size_t ntotal,
|
580
|
+
size_t q0,
|
581
|
+
size_t q1)
|
582
|
+
: RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
|
583
|
+
pres(pres) {
|
584
|
+
nq = q1 - q0;
|
585
|
+
this->q0 = q0;
|
586
|
+
}
|
587
|
+
|
588
|
+
// shift left n_per_query
|
589
|
+
void shift_n_per_query() {
|
590
|
+
memmove(n_per_query.data() + 1,
|
591
|
+
n_per_query.data(),
|
592
|
+
nq * sizeof(n_per_query[0]));
|
593
|
+
n_per_query[0] = 0;
|
594
|
+
}
|
595
|
+
|
596
|
+
// commit to partial result instead of full RangeResult
|
597
|
+
void end() override {
|
598
|
+
std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
|
599
|
+
for (int q = 0; q < nq; q++) {
|
600
|
+
n_per_query[q + 1] += n_per_query[q];
|
601
|
+
}
|
602
|
+
shift_n_per_query();
|
603
|
+
|
604
|
+
for (size_t i = 0; i < triplets.size(); i++) {
|
605
|
+
sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
|
606
|
+
}
|
607
|
+
shift_n_per_query();
|
608
|
+
|
609
|
+
size_t* lims = n_per_query.data();
|
610
|
+
|
611
|
+
for (int q = 0; q < nq; q++) {
|
612
|
+
float one_a = 1 / normalizers[2 * q];
|
613
|
+
float b = normalizers[2 * q + 1];
|
614
|
+
RangeQueryResult& qres = pres.new_result(q + q0);
|
615
|
+
for (size_t i = lims[q]; i < lims[q + 1]; i++) {
|
616
|
+
qres.add(
|
617
|
+
sorted_triplets[i].dis * one_a + b,
|
618
|
+
sorted_triplets[i].b);
|
619
|
+
}
|
620
|
+
}
|
621
|
+
}
|
622
|
+
};
|
623
|
+
|
624
|
+
#endif
|
625
|
+
|
626
|
+
/********************************************************************************
|
627
|
+
* Dynamic dispatching function. The consumer should have a templatized method f
|
628
|
+
* that will be replaced with the actual SIMDResultHandler that is determined
|
629
|
+
* dynamically.
|
630
|
+
*/
|
631
|
+
|
632
|
+
template <class C, bool W, class Consumer, class... Types>
|
633
|
+
void dispatch_SIMDResultHanlder_fixedCW(
|
634
|
+
SIMDResultHandler& res,
|
635
|
+
Consumer& consumer,
|
636
|
+
Types... args) {
|
637
|
+
if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
|
638
|
+
consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
|
639
|
+
} else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
|
640
|
+
consumer.template f<HeapHandler<C, W>>(*resh, args...);
|
641
|
+
} else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
|
642
|
+
consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
|
643
|
+
} else { // generic handler -- will not be inlined
|
644
|
+
FAISS_THROW_IF_NOT_FMT(
|
645
|
+
simd_result_handlers_accept_virtual,
|
646
|
+
"Running vitrual handler for %s",
|
647
|
+
typeid(res).name());
|
648
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
649
|
+
}
|
650
|
+
}
|
651
|
+
|
652
|
+
template <class C, class Consumer, class... Types>
|
653
|
+
void dispatch_SIMDResultHanlder_fixedC(
|
654
|
+
SIMDResultHandler& res,
|
655
|
+
Consumer& consumer,
|
656
|
+
Types... args) {
|
657
|
+
if (res.with_fields) {
|
658
|
+
dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
|
659
|
+
} else {
|
660
|
+
dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
|
661
|
+
}
|
662
|
+
}
|
663
|
+
|
664
|
+
template <class Consumer, class... Types>
|
665
|
+
void dispatch_SIMDResultHanlder(
|
666
|
+
SIMDResultHandler& res,
|
667
|
+
Consumer& consumer,
|
668
|
+
Types... args) {
|
669
|
+
if (res.sizeof_ids == 0) {
|
670
|
+
if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
|
671
|
+
consumer.template f<StoreResultHandler>(*resh, args...);
|
672
|
+
} else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
|
673
|
+
consumer.template f<DummyResultHandler>(*resh, args...);
|
674
|
+
} else { // generic path
|
675
|
+
FAISS_THROW_IF_NOT_FMT(
|
676
|
+
simd_result_handlers_accept_virtual,
|
677
|
+
"Running vitrual handler for %s",
|
678
|
+
typeid(res).name());
|
679
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
680
|
+
}
|
681
|
+
} else if (res.sizeof_ids == sizeof(int)) {
|
682
|
+
if (res.is_CMax) {
|
683
|
+
dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
|
684
|
+
res, consumer, args...);
|
685
|
+
} else {
|
686
|
+
dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
|
687
|
+
res, consumer, args...);
|
688
|
+
}
|
689
|
+
} else if (res.sizeof_ids == sizeof(int64_t)) {
|
690
|
+
if (res.is_CMax) {
|
691
|
+
dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
|
692
|
+
res, consumer, args...);
|
693
|
+
} else {
|
694
|
+
dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
|
695
|
+
res, consumer, args...);
|
696
|
+
}
|
697
|
+
} else {
|
698
|
+
FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
|
699
|
+
}
|
700
|
+
}
|
529
701
|
} // namespace simd_result_handlers
|
530
702
|
|
531
703
|
} // namespace faiss
|