faiss 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
|
@@ -14,40 +14,86 @@
|
|
|
14
14
|
#include <faiss/utils/Heap.h>
|
|
15
15
|
#include <faiss/utils/simdlib.h>
|
|
16
16
|
|
|
17
|
+
#include <faiss/impl/FaissAssert.h>
|
|
18
|
+
#include <faiss/impl/ResultHandler.h>
|
|
17
19
|
#include <faiss/impl/platform_macros.h>
|
|
18
20
|
#include <faiss/utils/AlignedTable.h>
|
|
19
21
|
#include <faiss/utils/partitioning.h>
|
|
20
22
|
|
|
21
23
|
/** This file contains callbacks for kernels that compute distances.
|
|
22
|
-
*
|
|
23
|
-
* The SIMDResultHandler object is intended to be templated and inlined.
|
|
24
|
-
* Methods:
|
|
25
|
-
* - handle(): called when 32 distances are computed and provided in two
|
|
26
|
-
* simd16uint16. (q, b) indicate which entry it is in the block.
|
|
27
|
-
* - set_block_origin(): set the sub-matrix that is being computed
|
|
28
24
|
*/
|
|
29
25
|
|
|
30
26
|
namespace faiss {
|
|
31
27
|
|
|
28
|
+
struct SIMDResultHandler {
|
|
29
|
+
// used to dispatch templates
|
|
30
|
+
bool is_CMax = false;
|
|
31
|
+
uint8_t sizeof_ids = 0;
|
|
32
|
+
bool with_fields = false;
|
|
33
|
+
|
|
34
|
+
/** called when 32 distances are computed and provided in two
|
|
35
|
+
* simd16uint16. (q, b) indicate which entry it is in the block. */
|
|
36
|
+
virtual void handle(
|
|
37
|
+
size_t q,
|
|
38
|
+
size_t b,
|
|
39
|
+
simd16uint16 d0,
|
|
40
|
+
simd16uint16 d1) = 0;
|
|
41
|
+
|
|
42
|
+
/// set the sub-matrix that is being computed
|
|
43
|
+
virtual void set_block_origin(size_t i0, size_t j0) = 0;
|
|
44
|
+
|
|
45
|
+
virtual ~SIMDResultHandler() {}
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
/* Result handler that will return float resutls eventually */
|
|
49
|
+
struct SIMDResultHandlerToFloat : SIMDResultHandler {
|
|
50
|
+
size_t nq; // number of queries
|
|
51
|
+
size_t ntotal; // ignore excess elements after ntotal
|
|
52
|
+
|
|
53
|
+
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
|
54
|
+
const idx_t* id_map = nullptr; // map offset in invlist to vector id
|
|
55
|
+
const int* q_map = nullptr; // map q to global query
|
|
56
|
+
const uint16_t* dbias =
|
|
57
|
+
nullptr; // table of biases to add to each query (for IVF L2 search)
|
|
58
|
+
const float* normalizers = nullptr; // size 2 * nq, to convert
|
|
59
|
+
|
|
60
|
+
SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
|
|
61
|
+
: nq(nq), ntotal(ntotal) {}
|
|
62
|
+
|
|
63
|
+
virtual void begin(const float* norms) {
|
|
64
|
+
normalizers = norms;
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// called at end of search to convert int16 distances to float, before
|
|
68
|
+
// normalizers are deallocated
|
|
69
|
+
virtual void end() {
|
|
70
|
+
normalizers = nullptr;
|
|
71
|
+
}
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
FAISS_API extern bool simd_result_handlers_accept_virtual;
|
|
75
|
+
|
|
32
76
|
namespace simd_result_handlers {
|
|
33
77
|
|
|
34
|
-
/** Dummy structure that just computes a
|
|
78
|
+
/** Dummy structure that just computes a chqecksum on results
|
|
35
79
|
* (to avoid the computation to be optimized away) */
|
|
36
|
-
struct DummyResultHandler {
|
|
80
|
+
struct DummyResultHandler : SIMDResultHandler {
|
|
37
81
|
size_t cs = 0;
|
|
38
82
|
|
|
39
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
83
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
40
84
|
cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
|
|
41
85
|
}
|
|
42
86
|
|
|
43
|
-
void set_block_origin(size_t, size_t) {}
|
|
87
|
+
void set_block_origin(size_t, size_t) final {}
|
|
88
|
+
|
|
89
|
+
~DummyResultHandler() {}
|
|
44
90
|
};
|
|
45
91
|
|
|
46
92
|
/** memorize results in a nq-by-nb matrix.
|
|
47
93
|
*
|
|
48
94
|
* j0 is the current upper-left block of the matrix
|
|
49
95
|
*/
|
|
50
|
-
struct StoreResultHandler {
|
|
96
|
+
struct StoreResultHandler : SIMDResultHandler {
|
|
51
97
|
uint16_t* data;
|
|
52
98
|
size_t ld; // total number of columns
|
|
53
99
|
size_t i0 = 0;
|
|
@@ -55,32 +101,32 @@ struct StoreResultHandler {
|
|
|
55
101
|
|
|
56
102
|
StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
|
|
57
103
|
|
|
58
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
104
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
59
105
|
size_t ofs = (q + i0) * ld + j0 + b * 32;
|
|
60
106
|
d0.store(data + ofs);
|
|
61
107
|
d1.store(data + ofs + 16);
|
|
62
108
|
}
|
|
63
109
|
|
|
64
|
-
void set_block_origin(size_t
|
|
65
|
-
this->i0 =
|
|
66
|
-
this->j0 =
|
|
110
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
111
|
+
this->i0 = i0_in;
|
|
112
|
+
this->j0 = j0_in;
|
|
67
113
|
}
|
|
68
114
|
};
|
|
69
115
|
|
|
70
116
|
/** stores results in fixed-size matrix. */
|
|
71
117
|
template <int NQ, int BB>
|
|
72
|
-
struct FixedStorageHandler {
|
|
118
|
+
struct FixedStorageHandler : SIMDResultHandler {
|
|
73
119
|
simd16uint16 dis[NQ][BB];
|
|
74
120
|
int i0 = 0;
|
|
75
121
|
|
|
76
|
-
void handle(
|
|
122
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
77
123
|
dis[q + i0][2 * b] = d0;
|
|
78
124
|
dis[q + i0][2 * b + 1] = d1;
|
|
79
125
|
}
|
|
80
126
|
|
|
81
|
-
void set_block_origin(size_t
|
|
82
|
-
this->i0 =
|
|
83
|
-
assert(
|
|
127
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
128
|
+
this->i0 = i0_in;
|
|
129
|
+
assert(j0_in == 0);
|
|
84
130
|
}
|
|
85
131
|
|
|
86
132
|
template <class OtherResultHandler>
|
|
@@ -91,30 +137,29 @@ struct FixedStorageHandler {
|
|
|
91
137
|
}
|
|
92
138
|
}
|
|
93
139
|
}
|
|
140
|
+
virtual ~FixedStorageHandler() {}
|
|
94
141
|
};
|
|
95
142
|
|
|
96
|
-
/**
|
|
143
|
+
/** Result handler that compares distances to check if they need to be kept */
|
|
97
144
|
template <class C, bool with_id_map>
|
|
98
|
-
struct
|
|
145
|
+
struct ResultHandlerCompare : SIMDResultHandlerToFloat {
|
|
99
146
|
using TI = typename C::TI;
|
|
100
147
|
|
|
101
148
|
bool disable = false;
|
|
102
149
|
|
|
103
150
|
int64_t i0 = 0; // query origin
|
|
104
151
|
int64_t j0 = 0; // db origin
|
|
105
|
-
size_t ntotal; // ignore excess elements after ntotal
|
|
106
|
-
|
|
107
|
-
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
|
108
|
-
const TI* id_map; // map offset in invlist to vector id
|
|
109
|
-
const int* q_map; // map q to global query
|
|
110
|
-
const uint16_t* dbias; // table of biases to add to each query
|
|
111
152
|
|
|
112
|
-
|
|
113
|
-
:
|
|
153
|
+
ResultHandlerCompare(size_t nq, size_t ntotal)
|
|
154
|
+
: SIMDResultHandlerToFloat(nq, ntotal) {
|
|
155
|
+
this->is_CMax = C::is_max;
|
|
156
|
+
this->sizeof_ids = sizeof(typename C::TI);
|
|
157
|
+
this->with_fields = with_id_map;
|
|
158
|
+
}
|
|
114
159
|
|
|
115
|
-
void set_block_origin(size_t
|
|
116
|
-
this->i0 =
|
|
117
|
-
this->j0 =
|
|
160
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
161
|
+
this->i0 = i0_in;
|
|
162
|
+
this->j0 = j0_in;
|
|
118
163
|
}
|
|
119
164
|
|
|
120
165
|
// adjust handler data for IVF.
|
|
@@ -172,43 +217,37 @@ struct SIMDResultHandler {
|
|
|
172
217
|
return lt_mask;
|
|
173
218
|
}
|
|
174
219
|
|
|
175
|
-
virtual
|
|
176
|
-
float* distances,
|
|
177
|
-
int64_t* labels,
|
|
178
|
-
const float* normalizers = nullptr) = 0;
|
|
179
|
-
|
|
180
|
-
virtual ~SIMDResultHandler() {}
|
|
220
|
+
virtual ~ResultHandlerCompare() {}
|
|
181
221
|
};
|
|
182
222
|
|
|
183
223
|
/** Special version for k=1 */
|
|
184
224
|
template <class C, bool with_id_map = false>
|
|
185
|
-
struct SingleResultHandler :
|
|
225
|
+
struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
|
|
186
226
|
using T = typename C::T;
|
|
187
227
|
using TI = typename C::TI;
|
|
228
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
229
|
+
using RHC::normalizers;
|
|
188
230
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
};
|
|
193
|
-
std::vector<Result> results;
|
|
231
|
+
std::vector<int16_t> idis;
|
|
232
|
+
float* dis;
|
|
233
|
+
int64_t* ids;
|
|
194
234
|
|
|
195
|
-
SingleResultHandler(size_t nq, size_t ntotal)
|
|
196
|
-
:
|
|
235
|
+
SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
|
|
236
|
+
: RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
|
|
197
237
|
for (int i = 0; i < nq; i++) {
|
|
198
|
-
|
|
199
|
-
|
|
238
|
+
ids[i] = -1;
|
|
239
|
+
idis[i] = C::neutral();
|
|
200
240
|
}
|
|
201
241
|
}
|
|
202
242
|
|
|
203
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
243
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
204
244
|
if (this->disable) {
|
|
205
245
|
return;
|
|
206
246
|
}
|
|
207
247
|
|
|
208
248
|
this->adjust_with_origin(q, d0, d1);
|
|
209
249
|
|
|
210
|
-
|
|
211
|
-
uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
|
|
250
|
+
uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
|
|
212
251
|
if (!lt_mask) {
|
|
213
252
|
return;
|
|
214
253
|
}
|
|
@@ -221,70 +260,61 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
221
260
|
// find first non-zero
|
|
222
261
|
int j = __builtin_ctz(lt_mask);
|
|
223
262
|
lt_mask -= 1 << j;
|
|
224
|
-
T
|
|
225
|
-
if (C::cmp(
|
|
226
|
-
|
|
227
|
-
|
|
263
|
+
T d = d32tab[j];
|
|
264
|
+
if (C::cmp(idis[q], d)) {
|
|
265
|
+
idis[q] = d;
|
|
266
|
+
ids[q] = this->adjust_id(b, j);
|
|
228
267
|
}
|
|
229
268
|
}
|
|
230
269
|
}
|
|
231
270
|
|
|
232
|
-
void
|
|
233
|
-
|
|
234
|
-
int64_t* labels,
|
|
235
|
-
const float* normalizers = nullptr) override {
|
|
236
|
-
for (int q = 0; q < results.size(); q++) {
|
|
271
|
+
void end() {
|
|
272
|
+
for (int q = 0; q < this->nq; q++) {
|
|
237
273
|
if (!normalizers) {
|
|
238
|
-
|
|
274
|
+
dis[q] = idis[q];
|
|
239
275
|
} else {
|
|
240
276
|
float one_a = 1 / normalizers[2 * q];
|
|
241
277
|
float b = normalizers[2 * q + 1];
|
|
242
|
-
|
|
278
|
+
dis[q] = b + idis[q] * one_a;
|
|
243
279
|
}
|
|
244
|
-
labels[q] = results[q].id;
|
|
245
280
|
}
|
|
246
281
|
}
|
|
247
282
|
};
|
|
248
283
|
|
|
249
284
|
/** Structure that collects results in a min- or max-heap */
|
|
250
285
|
template <class C, bool with_id_map = false>
|
|
251
|
-
struct HeapHandler :
|
|
286
|
+
struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
252
287
|
using T = typename C::T;
|
|
253
288
|
using TI = typename C::TI;
|
|
289
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
290
|
+
using RHC::normalizers;
|
|
254
291
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
292
|
+
std::vector<uint16_t> idis;
|
|
293
|
+
std::vector<TI> iids;
|
|
294
|
+
float* dis;
|
|
295
|
+
int64_t* ids;
|
|
258
296
|
|
|
259
297
|
int64_t k; // number of results to keep
|
|
260
298
|
|
|
261
|
-
HeapHandler(
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
: SIMDResultHandler<C, with_id_map>(ntotal),
|
|
268
|
-
nq(nq),
|
|
269
|
-
heap_dis_tab(heap_dis_tab),
|
|
270
|
-
heap_ids_tab(heap_ids_tab),
|
|
299
|
+
HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
|
|
300
|
+
: RHC(nq, ntotal),
|
|
301
|
+
idis(nq * k),
|
|
302
|
+
iids(nq * k),
|
|
303
|
+
dis(dis),
|
|
304
|
+
ids(ids),
|
|
271
305
|
k(k) {
|
|
272
|
-
|
|
273
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
|
274
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
|
275
|
-
heap_heapify<C>(k, heap_dis_in, heap_ids_in);
|
|
276
|
-
}
|
|
306
|
+
heap_heapify<C>(k * nq, idis.data(), iids.data());
|
|
277
307
|
}
|
|
278
308
|
|
|
279
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
309
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
280
310
|
if (this->disable) {
|
|
281
311
|
return;
|
|
282
312
|
}
|
|
283
313
|
|
|
284
314
|
this->adjust_with_origin(q, d0, d1);
|
|
285
315
|
|
|
286
|
-
T* heap_dis =
|
|
287
|
-
TI* heap_ids =
|
|
316
|
+
T* heap_dis = idis.data() + q * k;
|
|
317
|
+
TI* heap_ids = iids.data() + q * k;
|
|
288
318
|
|
|
289
319
|
uint16_t cur_thresh =
|
|
290
320
|
heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
|
|
@@ -313,16 +343,13 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
313
343
|
}
|
|
314
344
|
}
|
|
315
345
|
|
|
316
|
-
void
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
for (int q = 0; q < nq; q++) {
|
|
321
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
|
322
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
|
346
|
+
void end() override {
|
|
347
|
+
for (int q = 0; q < this->nq; q++) {
|
|
348
|
+
T* heap_dis_in = idis.data() + q * k;
|
|
349
|
+
TI* heap_ids_in = iids.data() + q * k;
|
|
323
350
|
heap_reorder<C>(k, heap_dis_in, heap_ids_in);
|
|
324
|
-
|
|
325
|
-
|
|
351
|
+
float* heap_dis = dis + q * k;
|
|
352
|
+
int64_t* heap_ids = ids + q * k;
|
|
326
353
|
|
|
327
354
|
float one_a = 1.0, b = 0.0;
|
|
328
355
|
if (normalizers) {
|
|
@@ -330,8 +357,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
330
357
|
b = normalizers[2 * q + 1];
|
|
331
358
|
}
|
|
332
359
|
for (int j = 0; j < k; j++) {
|
|
333
|
-
heap_ids[j] = heap_ids_in[j];
|
|
334
360
|
heap_dis[j] = heap_dis_in[j] * one_a + b;
|
|
361
|
+
heap_ids[j] = heap_ids_in[j];
|
|
335
362
|
}
|
|
336
363
|
}
|
|
337
364
|
}
|
|
@@ -342,114 +369,45 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
342
369
|
* Results are stored when they are below the threshold until the capacity is
|
|
343
370
|
* reached. Then a partition sort is used to update the threshold. */
|
|
344
371
|
|
|
345
|
-
namespace {
|
|
346
|
-
|
|
347
|
-
uint64_t get_cy() {
|
|
348
|
-
#ifdef MICRO_BENCHMARK
|
|
349
|
-
uint32_t high, low;
|
|
350
|
-
asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
|
|
351
|
-
return ((uint64_t)high << 32) | (low);
|
|
352
|
-
#else
|
|
353
|
-
return 0;
|
|
354
|
-
#endif
|
|
355
|
-
}
|
|
356
|
-
|
|
357
|
-
} // anonymous namespace
|
|
358
|
-
|
|
359
|
-
template <class C>
|
|
360
|
-
struct ReservoirTopN {
|
|
361
|
-
using T = typename C::T;
|
|
362
|
-
using TI = typename C::TI;
|
|
363
|
-
|
|
364
|
-
T* vals;
|
|
365
|
-
TI* ids;
|
|
366
|
-
|
|
367
|
-
size_t i; // number of stored elements
|
|
368
|
-
size_t n; // number of requested elements
|
|
369
|
-
size_t capacity; // size of storage
|
|
370
|
-
size_t cycles = 0;
|
|
371
|
-
|
|
372
|
-
T threshold; // current threshold
|
|
373
|
-
|
|
374
|
-
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
|
375
|
-
: vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
|
|
376
|
-
assert(n < capacity);
|
|
377
|
-
threshold = C::neutral();
|
|
378
|
-
}
|
|
379
|
-
|
|
380
|
-
void add(T val, TI id) {
|
|
381
|
-
if (C::cmp(threshold, val)) {
|
|
382
|
-
if (i == capacity) {
|
|
383
|
-
shrink_fuzzy();
|
|
384
|
-
}
|
|
385
|
-
vals[i] = val;
|
|
386
|
-
ids[i] = id;
|
|
387
|
-
i++;
|
|
388
|
-
}
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
/// shrink number of stored elements to n
|
|
392
|
-
void shrink_xx() {
|
|
393
|
-
uint64_t t0 = get_cy();
|
|
394
|
-
qselect(vals, ids, i, n);
|
|
395
|
-
i = n; // forget all elements above i = n
|
|
396
|
-
threshold = C::Crev::neutral();
|
|
397
|
-
for (size_t j = 0; j < n; j++) {
|
|
398
|
-
if (C::cmp(vals[j], threshold)) {
|
|
399
|
-
threshold = vals[j];
|
|
400
|
-
}
|
|
401
|
-
}
|
|
402
|
-
cycles += get_cy() - t0;
|
|
403
|
-
}
|
|
404
|
-
|
|
405
|
-
void shrink() {
|
|
406
|
-
uint64_t t0 = get_cy();
|
|
407
|
-
threshold = partition<C>(vals, ids, i, n);
|
|
408
|
-
i = n;
|
|
409
|
-
cycles += get_cy() - t0;
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
void shrink_fuzzy() {
|
|
413
|
-
uint64_t t0 = get_cy();
|
|
414
|
-
assert(i == capacity);
|
|
415
|
-
threshold = partition_fuzzy<C>(
|
|
416
|
-
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
|
417
|
-
cycles += get_cy() - t0;
|
|
418
|
-
}
|
|
419
|
-
};
|
|
420
|
-
|
|
421
372
|
/** Handler built from several ReservoirTopN (one per query) */
|
|
422
373
|
template <class C, bool with_id_map = false>
|
|
423
|
-
struct ReservoirHandler :
|
|
374
|
+
struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
424
375
|
using T = typename C::T;
|
|
425
376
|
using TI = typename C::TI;
|
|
377
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
378
|
+
using RHC::normalizers;
|
|
426
379
|
|
|
427
380
|
size_t capacity; // rounded up to multiple of 16
|
|
381
|
+
|
|
382
|
+
// where the final results will be written
|
|
383
|
+
float* dis;
|
|
384
|
+
int64_t* ids;
|
|
385
|
+
|
|
428
386
|
std::vector<TI> all_ids;
|
|
429
387
|
AlignedTable<T> all_vals;
|
|
430
|
-
|
|
431
388
|
std::vector<ReservoirTopN<C>> reservoirs;
|
|
432
389
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
390
|
+
ReservoirHandler(
|
|
391
|
+
size_t nq,
|
|
392
|
+
size_t ntotal,
|
|
393
|
+
size_t k,
|
|
394
|
+
size_t cap,
|
|
395
|
+
float* dis,
|
|
396
|
+
int64_t* ids)
|
|
397
|
+
: RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
|
|
440
398
|
assert(capacity % 16 == 0);
|
|
441
|
-
|
|
399
|
+
all_ids.resize(nq * capacity);
|
|
400
|
+
all_vals.resize(nq * capacity);
|
|
401
|
+
for (size_t q = 0; q < nq; q++) {
|
|
442
402
|
reservoirs.emplace_back(
|
|
443
|
-
|
|
403
|
+
k,
|
|
444
404
|
capacity,
|
|
445
|
-
all_vals.get() +
|
|
446
|
-
all_ids.data() +
|
|
405
|
+
all_vals.get() + q * capacity,
|
|
406
|
+
all_ids.data() + q * capacity);
|
|
447
407
|
}
|
|
448
|
-
times[0] = times[1] = times[2] = times[3] = 0;
|
|
449
408
|
}
|
|
450
409
|
|
|
451
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
452
|
-
uint64_t t0 = get_cy();
|
|
410
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
453
411
|
if (this->disable) {
|
|
454
412
|
return;
|
|
455
413
|
}
|
|
@@ -457,8 +415,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
457
415
|
|
|
458
416
|
ReservoirTopN<C>& res = reservoirs[q];
|
|
459
417
|
uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
|
|
460
|
-
uint64_t t1 = get_cy();
|
|
461
|
-
times[0] += t1 - t0;
|
|
462
418
|
|
|
463
419
|
if (!lt_mask) {
|
|
464
420
|
return;
|
|
@@ -474,20 +430,14 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
474
430
|
T dis = d32tab[j];
|
|
475
431
|
res.add(dis, this->adjust_id(b, j));
|
|
476
432
|
}
|
|
477
|
-
times[1] += get_cy() - t1;
|
|
478
433
|
}
|
|
479
434
|
|
|
480
|
-
void
|
|
481
|
-
float* distances,
|
|
482
|
-
int64_t* labels,
|
|
483
|
-
const float* normalizers = nullptr) override {
|
|
435
|
+
void end() override {
|
|
484
436
|
using Cf = typename std::conditional<
|
|
485
437
|
C::is_max,
|
|
486
438
|
CMax<float, int64_t>,
|
|
487
439
|
CMin<float, int64_t>>::type;
|
|
488
440
|
|
|
489
|
-
uint64_t t0 = get_cy();
|
|
490
|
-
uint64_t t3 = 0;
|
|
491
441
|
std::vector<int> perm(reservoirs[0].n);
|
|
492
442
|
for (int q = 0; q < reservoirs.size(); q++) {
|
|
493
443
|
ReservoirTopN<C>& res = reservoirs[q];
|
|
@@ -496,8 +446,8 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
496
446
|
if (res.i > res.n) {
|
|
497
447
|
res.shrink();
|
|
498
448
|
}
|
|
499
|
-
int64_t* heap_ids =
|
|
500
|
-
float* heap_dis =
|
|
449
|
+
int64_t* heap_ids = ids + q * n;
|
|
450
|
+
float* heap_dis = dis + q * n;
|
|
501
451
|
|
|
502
452
|
float one_a = 1.0, b = 0.0;
|
|
503
453
|
if (normalizers) {
|
|
@@ -518,14 +468,236 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
518
468
|
|
|
519
469
|
// possibly add empty results
|
|
520
470
|
heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
|
|
471
|
+
}
|
|
472
|
+
}
|
|
473
|
+
};
|
|
474
|
+
|
|
475
|
+
/** Result hanlder for range search. The difficulty is that the range distances
|
|
476
|
+
* have to be scaled using the scaler.
|
|
477
|
+
*/
|
|
478
|
+
|
|
479
|
+
template <class C, bool with_id_map = false>
|
|
480
|
+
struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
|
481
|
+
using T = typename C::T;
|
|
482
|
+
using TI = typename C::TI;
|
|
483
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
484
|
+
using RHC::normalizers;
|
|
485
|
+
using RHC::nq;
|
|
486
|
+
|
|
487
|
+
RangeSearchResult& rres;
|
|
488
|
+
float radius;
|
|
489
|
+
std::vector<uint16_t> thresholds;
|
|
490
|
+
std::vector<size_t> n_per_query;
|
|
491
|
+
size_t q0 = 0;
|
|
492
|
+
|
|
493
|
+
// we cannot use the RangeSearchPartialResult interface because queries can
|
|
494
|
+
// be performed by batches
|
|
495
|
+
struct Triplet {
|
|
496
|
+
idx_t q;
|
|
497
|
+
idx_t b;
|
|
498
|
+
uint16_t dis;
|
|
499
|
+
};
|
|
500
|
+
std::vector<Triplet> triplets;
|
|
501
|
+
|
|
502
|
+
RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
|
|
503
|
+
: RHC(rres.nq, ntotal), rres(rres), radius(radius) {
|
|
504
|
+
thresholds.resize(nq);
|
|
505
|
+
n_per_query.resize(nq + 1);
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
virtual void begin(const float* norms) {
|
|
509
|
+
normalizers = norms;
|
|
510
|
+
for (int q = 0; q < nq; ++q) {
|
|
511
|
+
thresholds[q] =
|
|
512
|
+
normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
|
|
513
|
+
}
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
517
|
+
if (this->disable) {
|
|
518
|
+
return;
|
|
519
|
+
}
|
|
520
|
+
this->adjust_with_origin(q, d0, d1);
|
|
521
|
+
|
|
522
|
+
uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
|
|
523
|
+
|
|
524
|
+
if (!lt_mask) {
|
|
525
|
+
return;
|
|
526
|
+
}
|
|
527
|
+
ALIGNED(32) uint16_t d32tab[32];
|
|
528
|
+
d0.store(d32tab);
|
|
529
|
+
d1.store(d32tab + 16);
|
|
530
|
+
|
|
531
|
+
while (lt_mask) {
|
|
532
|
+
// find first non-zero
|
|
533
|
+
int j = __builtin_ctz(lt_mask);
|
|
534
|
+
lt_mask -= 1 << j;
|
|
535
|
+
T dis = d32tab[j];
|
|
536
|
+
n_per_query[q]++;
|
|
537
|
+
triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
|
|
538
|
+
}
|
|
539
|
+
}
|
|
521
540
|
|
|
522
|
-
|
|
541
|
+
void end() override {
|
|
542
|
+
memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
|
|
543
|
+
rres.do_allocation();
|
|
544
|
+
for (auto it = triplets.begin(); it != triplets.end(); ++it) {
|
|
545
|
+
size_t& l = rres.lims[it->q];
|
|
546
|
+
rres.distances[l] = it->dis;
|
|
547
|
+
rres.labels[l] = it->b;
|
|
548
|
+
l++;
|
|
549
|
+
}
|
|
550
|
+
memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
|
|
551
|
+
rres.lims[0] = 0;
|
|
552
|
+
|
|
553
|
+
for (int q = 0; q < nq; q++) {
|
|
554
|
+
float one_a = 1 / normalizers[2 * q];
|
|
555
|
+
float b = normalizers[2 * q + 1];
|
|
556
|
+
for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
|
|
557
|
+
rres.distances[i] = rres.distances[i] * one_a + b;
|
|
558
|
+
}
|
|
523
559
|
}
|
|
524
|
-
times[2] += get_cy() - t0;
|
|
525
|
-
times[3] += t3;
|
|
526
560
|
}
|
|
527
561
|
};
|
|
528
562
|
|
|
563
|
+
#ifndef SWIG
|
|
564
|
+
|
|
565
|
+
// handler for a subset of queries
|
|
566
|
+
template <class C, bool with_id_map = false>
|
|
567
|
+
struct PartialRangeHandler : RangeHandler<C, with_id_map> {
|
|
568
|
+
using T = typename C::T;
|
|
569
|
+
using TI = typename C::TI;
|
|
570
|
+
using RHC = RangeHandler<C, with_id_map>;
|
|
571
|
+
using RHC::normalizers;
|
|
572
|
+
using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
|
|
573
|
+
|
|
574
|
+
RangeSearchPartialResult& pres;
|
|
575
|
+
|
|
576
|
+
PartialRangeHandler(
|
|
577
|
+
RangeSearchPartialResult& pres,
|
|
578
|
+
float radius,
|
|
579
|
+
size_t ntotal,
|
|
580
|
+
size_t q0,
|
|
581
|
+
size_t q1)
|
|
582
|
+
: RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
|
|
583
|
+
pres(pres) {
|
|
584
|
+
nq = q1 - q0;
|
|
585
|
+
this->q0 = q0;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
// shift left n_per_query
|
|
589
|
+
void shift_n_per_query() {
|
|
590
|
+
memmove(n_per_query.data() + 1,
|
|
591
|
+
n_per_query.data(),
|
|
592
|
+
nq * sizeof(n_per_query[0]));
|
|
593
|
+
n_per_query[0] = 0;
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
// commit to partial result instead of full RangeResult
|
|
597
|
+
void end() override {
|
|
598
|
+
std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
|
|
599
|
+
for (int q = 0; q < nq; q++) {
|
|
600
|
+
n_per_query[q + 1] += n_per_query[q];
|
|
601
|
+
}
|
|
602
|
+
shift_n_per_query();
|
|
603
|
+
|
|
604
|
+
for (size_t i = 0; i < triplets.size(); i++) {
|
|
605
|
+
sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
|
|
606
|
+
}
|
|
607
|
+
shift_n_per_query();
|
|
608
|
+
|
|
609
|
+
size_t* lims = n_per_query.data();
|
|
610
|
+
|
|
611
|
+
for (int q = 0; q < nq; q++) {
|
|
612
|
+
float one_a = 1 / normalizers[2 * q];
|
|
613
|
+
float b = normalizers[2 * q + 1];
|
|
614
|
+
RangeQueryResult& qres = pres.new_result(q + q0);
|
|
615
|
+
for (size_t i = lims[q]; i < lims[q + 1]; i++) {
|
|
616
|
+
qres.add(
|
|
617
|
+
sorted_triplets[i].dis * one_a + b,
|
|
618
|
+
sorted_triplets[i].b);
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
}
|
|
622
|
+
};
|
|
623
|
+
|
|
624
|
+
#endif
|
|
625
|
+
|
|
626
|
+
/********************************************************************************
|
|
627
|
+
* Dynamic dispatching function. The consumer should have a templatized method f
|
|
628
|
+
* that will be replaced with the actual SIMDResultHandler that is determined
|
|
629
|
+
* dynamically.
|
|
630
|
+
*/
|
|
631
|
+
|
|
632
|
+
template <class C, bool W, class Consumer, class... Types>
|
|
633
|
+
void dispatch_SIMDResultHanlder_fixedCW(
|
|
634
|
+
SIMDResultHandler& res,
|
|
635
|
+
Consumer& consumer,
|
|
636
|
+
Types... args) {
|
|
637
|
+
if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
|
|
638
|
+
consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
|
|
639
|
+
} else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
|
|
640
|
+
consumer.template f<HeapHandler<C, W>>(*resh, args...);
|
|
641
|
+
} else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
|
|
642
|
+
consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
|
|
643
|
+
} else { // generic handler -- will not be inlined
|
|
644
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
645
|
+
simd_result_handlers_accept_virtual,
|
|
646
|
+
"Running vitrual handler for %s",
|
|
647
|
+
typeid(res).name());
|
|
648
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
|
649
|
+
}
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
template <class C, class Consumer, class... Types>
|
|
653
|
+
void dispatch_SIMDResultHanlder_fixedC(
|
|
654
|
+
SIMDResultHandler& res,
|
|
655
|
+
Consumer& consumer,
|
|
656
|
+
Types... args) {
|
|
657
|
+
if (res.with_fields) {
|
|
658
|
+
dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
|
|
659
|
+
} else {
|
|
660
|
+
dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
|
|
661
|
+
}
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
template <class Consumer, class... Types>
|
|
665
|
+
void dispatch_SIMDResultHanlder(
|
|
666
|
+
SIMDResultHandler& res,
|
|
667
|
+
Consumer& consumer,
|
|
668
|
+
Types... args) {
|
|
669
|
+
if (res.sizeof_ids == 0) {
|
|
670
|
+
if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
|
|
671
|
+
consumer.template f<StoreResultHandler>(*resh, args...);
|
|
672
|
+
} else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
|
|
673
|
+
consumer.template f<DummyResultHandler>(*resh, args...);
|
|
674
|
+
} else { // generic path
|
|
675
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
676
|
+
simd_result_handlers_accept_virtual,
|
|
677
|
+
"Running vitrual handler for %s",
|
|
678
|
+
typeid(res).name());
|
|
679
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
|
680
|
+
}
|
|
681
|
+
} else if (res.sizeof_ids == sizeof(int)) {
|
|
682
|
+
if (res.is_CMax) {
|
|
683
|
+
dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
|
|
684
|
+
res, consumer, args...);
|
|
685
|
+
} else {
|
|
686
|
+
dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
|
|
687
|
+
res, consumer, args...);
|
|
688
|
+
}
|
|
689
|
+
} else if (res.sizeof_ids == sizeof(int64_t)) {
|
|
690
|
+
if (res.is_CMax) {
|
|
691
|
+
dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
|
|
692
|
+
res, consumer, args...);
|
|
693
|
+
} else {
|
|
694
|
+
dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
|
|
695
|
+
res, consumer, args...);
|
|
696
|
+
}
|
|
697
|
+
} else {
|
|
698
|
+
FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
|
|
699
|
+
}
|
|
700
|
+
}
|
|
529
701
|
} // namespace simd_result_handlers
|
|
530
702
|
|
|
531
703
|
} // namespace faiss
|