faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -14,40 +14,87 @@
|
|
|
14
14
|
#include <faiss/utils/Heap.h>
|
|
15
15
|
#include <faiss/utils/simdlib.h>
|
|
16
16
|
|
|
17
|
+
#include <faiss/impl/FaissAssert.h>
|
|
18
|
+
#include <faiss/impl/IDSelector.h>
|
|
19
|
+
#include <faiss/impl/ResultHandler.h>
|
|
17
20
|
#include <faiss/impl/platform_macros.h>
|
|
18
21
|
#include <faiss/utils/AlignedTable.h>
|
|
19
22
|
#include <faiss/utils/partitioning.h>
|
|
20
23
|
|
|
21
24
|
/** This file contains callbacks for kernels that compute distances.
|
|
22
|
-
*
|
|
23
|
-
* The SIMDResultHandler object is intended to be templated and inlined.
|
|
24
|
-
* Methods:
|
|
25
|
-
* - handle(): called when 32 distances are computed and provided in two
|
|
26
|
-
* simd16uint16. (q, b) indicate which entry it is in the block.
|
|
27
|
-
* - set_block_origin(): set the sub-matrix that is being computed
|
|
28
25
|
*/
|
|
29
26
|
|
|
30
27
|
namespace faiss {
|
|
31
28
|
|
|
29
|
+
struct SIMDResultHandler {
|
|
30
|
+
// used to dispatch templates
|
|
31
|
+
bool is_CMax = false;
|
|
32
|
+
uint8_t sizeof_ids = 0;
|
|
33
|
+
bool with_fields = false;
|
|
34
|
+
|
|
35
|
+
/** called when 32 distances are computed and provided in two
|
|
36
|
+
* simd16uint16. (q, b) indicate which entry it is in the block. */
|
|
37
|
+
virtual void handle(
|
|
38
|
+
size_t q,
|
|
39
|
+
size_t b,
|
|
40
|
+
simd16uint16 d0,
|
|
41
|
+
simd16uint16 d1) = 0;
|
|
42
|
+
|
|
43
|
+
/// set the sub-matrix that is being computed
|
|
44
|
+
virtual void set_block_origin(size_t i0, size_t j0) = 0;
|
|
45
|
+
|
|
46
|
+
virtual ~SIMDResultHandler() {}
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
/* Result handler that will return float resutls eventually */
|
|
50
|
+
struct SIMDResultHandlerToFloat : SIMDResultHandler {
|
|
51
|
+
size_t nq; // number of queries
|
|
52
|
+
size_t ntotal; // ignore excess elements after ntotal
|
|
53
|
+
|
|
54
|
+
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
|
55
|
+
const idx_t* id_map = nullptr; // map offset in invlist to vector id
|
|
56
|
+
const int* q_map = nullptr; // map q to global query
|
|
57
|
+
const uint16_t* dbias =
|
|
58
|
+
nullptr; // table of biases to add to each query (for IVF L2 search)
|
|
59
|
+
const float* normalizers = nullptr; // size 2 * nq, to convert
|
|
60
|
+
|
|
61
|
+
SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
|
|
62
|
+
: nq(nq), ntotal(ntotal) {}
|
|
63
|
+
|
|
64
|
+
virtual void begin(const float* norms) {
|
|
65
|
+
normalizers = norms;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
// called at end of search to convert int16 distances to float, before
|
|
69
|
+
// normalizers are deallocated
|
|
70
|
+
virtual void end() {
|
|
71
|
+
normalizers = nullptr;
|
|
72
|
+
}
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
FAISS_API extern bool simd_result_handlers_accept_virtual;
|
|
76
|
+
|
|
32
77
|
namespace simd_result_handlers {
|
|
33
78
|
|
|
34
|
-
/** Dummy structure that just computes a
|
|
79
|
+
/** Dummy structure that just computes a chqecksum on results
|
|
35
80
|
* (to avoid the computation to be optimized away) */
|
|
36
|
-
struct DummyResultHandler {
|
|
81
|
+
struct DummyResultHandler : SIMDResultHandler {
|
|
37
82
|
size_t cs = 0;
|
|
38
83
|
|
|
39
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
84
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
40
85
|
cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
|
|
41
86
|
}
|
|
42
87
|
|
|
43
|
-
void set_block_origin(size_t, size_t) {}
|
|
88
|
+
void set_block_origin(size_t, size_t) final {}
|
|
89
|
+
|
|
90
|
+
~DummyResultHandler() {}
|
|
44
91
|
};
|
|
45
92
|
|
|
46
93
|
/** memorize results in a nq-by-nb matrix.
|
|
47
94
|
*
|
|
48
95
|
* j0 is the current upper-left block of the matrix
|
|
49
96
|
*/
|
|
50
|
-
struct StoreResultHandler {
|
|
97
|
+
struct StoreResultHandler : SIMDResultHandler {
|
|
51
98
|
uint16_t* data;
|
|
52
99
|
size_t ld; // total number of columns
|
|
53
100
|
size_t i0 = 0;
|
|
@@ -55,32 +102,32 @@ struct StoreResultHandler {
|
|
|
55
102
|
|
|
56
103
|
StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
|
|
57
104
|
|
|
58
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
105
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
59
106
|
size_t ofs = (q + i0) * ld + j0 + b * 32;
|
|
60
107
|
d0.store(data + ofs);
|
|
61
108
|
d1.store(data + ofs + 16);
|
|
62
109
|
}
|
|
63
110
|
|
|
64
|
-
void set_block_origin(size_t
|
|
65
|
-
this->i0 =
|
|
66
|
-
this->j0 =
|
|
111
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
112
|
+
this->i0 = i0_in;
|
|
113
|
+
this->j0 = j0_in;
|
|
67
114
|
}
|
|
68
115
|
};
|
|
69
116
|
|
|
70
117
|
/** stores results in fixed-size matrix. */
|
|
71
118
|
template <int NQ, int BB>
|
|
72
|
-
struct FixedStorageHandler {
|
|
119
|
+
struct FixedStorageHandler : SIMDResultHandler {
|
|
73
120
|
simd16uint16 dis[NQ][BB];
|
|
74
121
|
int i0 = 0;
|
|
75
122
|
|
|
76
|
-
void handle(
|
|
123
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
77
124
|
dis[q + i0][2 * b] = d0;
|
|
78
125
|
dis[q + i0][2 * b + 1] = d1;
|
|
79
126
|
}
|
|
80
127
|
|
|
81
|
-
void set_block_origin(size_t
|
|
82
|
-
this->i0 =
|
|
83
|
-
assert(
|
|
128
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
129
|
+
this->i0 = i0_in;
|
|
130
|
+
assert(j0_in == 0);
|
|
84
131
|
}
|
|
85
132
|
|
|
86
133
|
template <class OtherResultHandler>
|
|
@@ -91,30 +138,32 @@ struct FixedStorageHandler {
|
|
|
91
138
|
}
|
|
92
139
|
}
|
|
93
140
|
}
|
|
141
|
+
|
|
142
|
+
virtual ~FixedStorageHandler() {}
|
|
94
143
|
};
|
|
95
144
|
|
|
96
|
-
/**
|
|
145
|
+
/** Result handler that compares distances to check if they need to be kept */
|
|
97
146
|
template <class C, bool with_id_map>
|
|
98
|
-
struct
|
|
147
|
+
struct ResultHandlerCompare : SIMDResultHandlerToFloat {
|
|
99
148
|
using TI = typename C::TI;
|
|
100
149
|
|
|
101
150
|
bool disable = false;
|
|
102
151
|
|
|
103
152
|
int64_t i0 = 0; // query origin
|
|
104
153
|
int64_t j0 = 0; // db origin
|
|
105
|
-
size_t ntotal; // ignore excess elements after ntotal
|
|
106
154
|
|
|
107
|
-
|
|
108
|
-
const TI* id_map; // map offset in invlist to vector id
|
|
109
|
-
const int* q_map; // map q to global query
|
|
110
|
-
const uint16_t* dbias; // table of biases to add to each query
|
|
155
|
+
const IDSelector* sel;
|
|
111
156
|
|
|
112
|
-
|
|
113
|
-
:
|
|
157
|
+
ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in)
|
|
158
|
+
: SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} {
|
|
159
|
+
this->is_CMax = C::is_max;
|
|
160
|
+
this->sizeof_ids = sizeof(typename C::TI);
|
|
161
|
+
this->with_fields = with_id_map;
|
|
162
|
+
}
|
|
114
163
|
|
|
115
|
-
void set_block_origin(size_t
|
|
116
|
-
this->i0 =
|
|
117
|
-
this->j0 =
|
|
164
|
+
void set_block_origin(size_t i0_in, size_t j0_in) final {
|
|
165
|
+
this->i0 = i0_in;
|
|
166
|
+
this->j0 = j0_in;
|
|
118
167
|
}
|
|
119
168
|
|
|
120
169
|
// adjust handler data for IVF.
|
|
@@ -172,43 +221,42 @@ struct SIMDResultHandler {
|
|
|
172
221
|
return lt_mask;
|
|
173
222
|
}
|
|
174
223
|
|
|
175
|
-
virtual
|
|
176
|
-
float* distances,
|
|
177
|
-
int64_t* labels,
|
|
178
|
-
const float* normalizers = nullptr) = 0;
|
|
179
|
-
|
|
180
|
-
virtual ~SIMDResultHandler() {}
|
|
224
|
+
virtual ~ResultHandlerCompare() {}
|
|
181
225
|
};
|
|
182
226
|
|
|
183
227
|
/** Special version for k=1 */
|
|
184
228
|
template <class C, bool with_id_map = false>
|
|
185
|
-
struct SingleResultHandler :
|
|
229
|
+
struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
|
|
186
230
|
using T = typename C::T;
|
|
187
231
|
using TI = typename C::TI;
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
SingleResultHandler(
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
232
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
233
|
+
using RHC::normalizers;
|
|
234
|
+
|
|
235
|
+
std::vector<int16_t> idis;
|
|
236
|
+
float* dis;
|
|
237
|
+
int64_t* ids;
|
|
238
|
+
|
|
239
|
+
SingleResultHandler(
|
|
240
|
+
size_t nq,
|
|
241
|
+
size_t ntotal,
|
|
242
|
+
float* dis,
|
|
243
|
+
int64_t* ids,
|
|
244
|
+
const IDSelector* sel_in)
|
|
245
|
+
: RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) {
|
|
246
|
+
for (size_t i = 0; i < nq; i++) {
|
|
247
|
+
ids[i] = -1;
|
|
248
|
+
idis[i] = C::neutral();
|
|
200
249
|
}
|
|
201
250
|
}
|
|
202
251
|
|
|
203
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
252
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
204
253
|
if (this->disable) {
|
|
205
254
|
return;
|
|
206
255
|
}
|
|
207
256
|
|
|
208
257
|
this->adjust_with_origin(q, d0, d1);
|
|
209
258
|
|
|
210
|
-
|
|
211
|
-
uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
|
|
259
|
+
uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
|
|
212
260
|
if (!lt_mask) {
|
|
213
261
|
return;
|
|
214
262
|
}
|
|
@@ -217,74 +265,87 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
217
265
|
d0.store(d32tab);
|
|
218
266
|
d1.store(d32tab + 16);
|
|
219
267
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
268
|
+
if (this->sel != nullptr) {
|
|
269
|
+
while (lt_mask) {
|
|
270
|
+
// find first non-zero
|
|
271
|
+
int j = __builtin_ctz(lt_mask);
|
|
272
|
+
auto real_idx = this->adjust_id(b, j);
|
|
273
|
+
lt_mask -= 1 << j;
|
|
274
|
+
if (this->sel->is_member(real_idx)) {
|
|
275
|
+
T d = d32tab[j];
|
|
276
|
+
if (C::cmp(idis[q], d)) {
|
|
277
|
+
idis[q] = d;
|
|
278
|
+
ids[q] = real_idx;
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
} else {
|
|
283
|
+
while (lt_mask) {
|
|
284
|
+
// find first non-zero
|
|
285
|
+
int j = __builtin_ctz(lt_mask);
|
|
286
|
+
lt_mask -= 1 << j;
|
|
287
|
+
T d = d32tab[j];
|
|
288
|
+
if (C::cmp(idis[q], d)) {
|
|
289
|
+
idis[q] = d;
|
|
290
|
+
ids[q] = this->adjust_id(b, j);
|
|
291
|
+
}
|
|
228
292
|
}
|
|
229
293
|
}
|
|
230
294
|
}
|
|
231
295
|
|
|
232
|
-
void
|
|
233
|
-
|
|
234
|
-
int64_t* labels,
|
|
235
|
-
const float* normalizers = nullptr) override {
|
|
236
|
-
for (int q = 0; q < results.size(); q++) {
|
|
296
|
+
void end() {
|
|
297
|
+
for (size_t q = 0; q < this->nq; q++) {
|
|
237
298
|
if (!normalizers) {
|
|
238
|
-
|
|
299
|
+
dis[q] = idis[q];
|
|
239
300
|
} else {
|
|
240
301
|
float one_a = 1 / normalizers[2 * q];
|
|
241
302
|
float b = normalizers[2 * q + 1];
|
|
242
|
-
|
|
303
|
+
dis[q] = b + idis[q] * one_a;
|
|
243
304
|
}
|
|
244
|
-
labels[q] = results[q].id;
|
|
245
305
|
}
|
|
246
306
|
}
|
|
247
307
|
};
|
|
248
308
|
|
|
249
309
|
/** Structure that collects results in a min- or max-heap */
|
|
250
310
|
template <class C, bool with_id_map = false>
|
|
251
|
-
struct HeapHandler :
|
|
311
|
+
struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
|
|
252
312
|
using T = typename C::T;
|
|
253
313
|
using TI = typename C::TI;
|
|
314
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
315
|
+
using RHC::normalizers;
|
|
254
316
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
317
|
+
std::vector<uint16_t> idis;
|
|
318
|
+
std::vector<TI> iids;
|
|
319
|
+
float* dis;
|
|
320
|
+
int64_t* ids;
|
|
258
321
|
|
|
259
322
|
int64_t k; // number of results to keep
|
|
260
323
|
|
|
261
324
|
HeapHandler(
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
325
|
+
size_t nq,
|
|
326
|
+
size_t ntotal,
|
|
327
|
+
int64_t k,
|
|
328
|
+
float* dis,
|
|
329
|
+
int64_t* ids,
|
|
330
|
+
const IDSelector* sel_in)
|
|
331
|
+
: RHC(nq, ntotal, sel_in),
|
|
332
|
+
idis(nq * k),
|
|
333
|
+
iids(nq * k),
|
|
334
|
+
dis(dis),
|
|
335
|
+
ids(ids),
|
|
271
336
|
k(k) {
|
|
272
|
-
|
|
273
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
|
274
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
|
275
|
-
heap_heapify<C>(k, heap_dis_in, heap_ids_in);
|
|
276
|
-
}
|
|
337
|
+
heap_heapify<C>(k * nq, idis.data(), iids.data());
|
|
277
338
|
}
|
|
278
339
|
|
|
279
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
340
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
280
341
|
if (this->disable) {
|
|
281
342
|
return;
|
|
282
343
|
}
|
|
283
344
|
|
|
284
345
|
this->adjust_with_origin(q, d0, d1);
|
|
285
346
|
|
|
286
|
-
T* heap_dis =
|
|
287
|
-
TI* heap_ids =
|
|
347
|
+
T* heap_dis = idis.data() + q * k;
|
|
348
|
+
TI* heap_ids = iids.data() + q * k;
|
|
288
349
|
|
|
289
350
|
uint16_t cur_thresh =
|
|
290
351
|
heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
|
|
@@ -300,29 +361,41 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
300
361
|
d0.store(d32tab);
|
|
301
362
|
d1.store(d32tab + 16);
|
|
302
363
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
364
|
+
if (this->sel != nullptr) {
|
|
365
|
+
while (lt_mask) {
|
|
366
|
+
// find first non-zero
|
|
367
|
+
int j = __builtin_ctz(lt_mask);
|
|
368
|
+
auto real_idx = this->adjust_id(b, j);
|
|
369
|
+
lt_mask -= 1 << j;
|
|
370
|
+
if (this->sel->is_member(real_idx)) {
|
|
371
|
+
T dis = d32tab[j];
|
|
372
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
373
|
+
heap_replace_top<C>(
|
|
374
|
+
k, heap_dis, heap_ids, dis, real_idx);
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
} else {
|
|
379
|
+
while (lt_mask) {
|
|
380
|
+
// find first non-zero
|
|
381
|
+
int j = __builtin_ctz(lt_mask);
|
|
382
|
+
lt_mask -= 1 << j;
|
|
383
|
+
T dis = d32tab[j];
|
|
384
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
385
|
+
int64_t idx = this->adjust_id(b, j);
|
|
386
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
|
|
387
|
+
}
|
|
312
388
|
}
|
|
313
389
|
}
|
|
314
390
|
}
|
|
315
391
|
|
|
316
|
-
void
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
for (int q = 0; q < nq; q++) {
|
|
321
|
-
T* heap_dis_in = heap_dis_tab + q * k;
|
|
322
|
-
TI* heap_ids_in = heap_ids_tab + q * k;
|
|
392
|
+
void end() override {
|
|
393
|
+
for (size_t q = 0; q < this->nq; q++) {
|
|
394
|
+
T* heap_dis_in = idis.data() + q * k;
|
|
395
|
+
TI* heap_ids_in = iids.data() + q * k;
|
|
323
396
|
heap_reorder<C>(k, heap_dis_in, heap_ids_in);
|
|
324
|
-
|
|
325
|
-
|
|
397
|
+
float* heap_dis = dis + q * k;
|
|
398
|
+
int64_t* heap_ids = ids + q * k;
|
|
326
399
|
|
|
327
400
|
float one_a = 1.0, b = 0.0;
|
|
328
401
|
if (normalizers) {
|
|
@@ -330,8 +403,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
330
403
|
b = normalizers[2 * q + 1];
|
|
331
404
|
}
|
|
332
405
|
for (int j = 0; j < k; j++) {
|
|
333
|
-
heap_ids[j] = heap_ids_in[j];
|
|
334
406
|
heap_dis[j] = heap_dis_in[j] * one_a + b;
|
|
407
|
+
heap_ids[j] = heap_ids_in[j];
|
|
335
408
|
}
|
|
336
409
|
}
|
|
337
410
|
}
|
|
@@ -342,114 +415,49 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
342
415
|
* Results are stored when they are below the threshold until the capacity is
|
|
343
416
|
* reached. Then a partition sort is used to update the threshold. */
|
|
344
417
|
|
|
345
|
-
namespace {
|
|
346
|
-
|
|
347
|
-
uint64_t get_cy() {
|
|
348
|
-
#ifdef MICRO_BENCHMARK
|
|
349
|
-
uint32_t high, low;
|
|
350
|
-
asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
|
|
351
|
-
return ((uint64_t)high << 32) | (low);
|
|
352
|
-
#else
|
|
353
|
-
return 0;
|
|
354
|
-
#endif
|
|
355
|
-
}
|
|
356
|
-
|
|
357
|
-
} // anonymous namespace
|
|
358
|
-
|
|
359
|
-
template <class C>
|
|
360
|
-
struct ReservoirTopN {
|
|
361
|
-
using T = typename C::T;
|
|
362
|
-
using TI = typename C::TI;
|
|
363
|
-
|
|
364
|
-
T* vals;
|
|
365
|
-
TI* ids;
|
|
366
|
-
|
|
367
|
-
size_t i; // number of stored elements
|
|
368
|
-
size_t n; // number of requested elements
|
|
369
|
-
size_t capacity; // size of storage
|
|
370
|
-
size_t cycles = 0;
|
|
371
|
-
|
|
372
|
-
T threshold; // current threshold
|
|
373
|
-
|
|
374
|
-
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
|
375
|
-
: vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
|
|
376
|
-
assert(n < capacity);
|
|
377
|
-
threshold = C::neutral();
|
|
378
|
-
}
|
|
379
|
-
|
|
380
|
-
void add(T val, TI id) {
|
|
381
|
-
if (C::cmp(threshold, val)) {
|
|
382
|
-
if (i == capacity) {
|
|
383
|
-
shrink_fuzzy();
|
|
384
|
-
}
|
|
385
|
-
vals[i] = val;
|
|
386
|
-
ids[i] = id;
|
|
387
|
-
i++;
|
|
388
|
-
}
|
|
389
|
-
}
|
|
390
|
-
|
|
391
|
-
/// shrink number of stored elements to n
|
|
392
|
-
void shrink_xx() {
|
|
393
|
-
uint64_t t0 = get_cy();
|
|
394
|
-
qselect(vals, ids, i, n);
|
|
395
|
-
i = n; // forget all elements above i = n
|
|
396
|
-
threshold = C::Crev::neutral();
|
|
397
|
-
for (size_t j = 0; j < n; j++) {
|
|
398
|
-
if (C::cmp(vals[j], threshold)) {
|
|
399
|
-
threshold = vals[j];
|
|
400
|
-
}
|
|
401
|
-
}
|
|
402
|
-
cycles += get_cy() - t0;
|
|
403
|
-
}
|
|
404
|
-
|
|
405
|
-
void shrink() {
|
|
406
|
-
uint64_t t0 = get_cy();
|
|
407
|
-
threshold = partition<C>(vals, ids, i, n);
|
|
408
|
-
i = n;
|
|
409
|
-
cycles += get_cy() - t0;
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
void shrink_fuzzy() {
|
|
413
|
-
uint64_t t0 = get_cy();
|
|
414
|
-
assert(i == capacity);
|
|
415
|
-
threshold = partition_fuzzy<C>(
|
|
416
|
-
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
|
417
|
-
cycles += get_cy() - t0;
|
|
418
|
-
}
|
|
419
|
-
};
|
|
420
|
-
|
|
421
418
|
/** Handler built from several ReservoirTopN (one per query) */
|
|
422
419
|
template <class C, bool with_id_map = false>
|
|
423
|
-
struct ReservoirHandler :
|
|
420
|
+
struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
|
|
424
421
|
using T = typename C::T;
|
|
425
422
|
using TI = typename C::TI;
|
|
423
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
424
|
+
using RHC::normalizers;
|
|
426
425
|
|
|
427
426
|
size_t capacity; // rounded up to multiple of 16
|
|
427
|
+
|
|
428
|
+
// where the final results will be written
|
|
429
|
+
float* dis;
|
|
430
|
+
int64_t* ids;
|
|
431
|
+
|
|
428
432
|
std::vector<TI> all_ids;
|
|
429
433
|
AlignedTable<T> all_vals;
|
|
430
|
-
|
|
431
434
|
std::vector<ReservoirTopN<C>> reservoirs;
|
|
432
435
|
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
436
|
+
ReservoirHandler(
|
|
437
|
+
size_t nq,
|
|
438
|
+
size_t ntotal,
|
|
439
|
+
size_t k,
|
|
440
|
+
size_t cap,
|
|
441
|
+
float* dis,
|
|
442
|
+
int64_t* ids,
|
|
443
|
+
const IDSelector* sel_in)
|
|
444
|
+
: RHC(nq, ntotal, sel_in),
|
|
445
|
+
capacity((cap + 15) & ~15),
|
|
446
|
+
dis(dis),
|
|
447
|
+
ids(ids) {
|
|
440
448
|
assert(capacity % 16 == 0);
|
|
441
|
-
|
|
449
|
+
all_ids.resize(nq * capacity);
|
|
450
|
+
all_vals.resize(nq * capacity);
|
|
451
|
+
for (size_t q = 0; q < nq; q++) {
|
|
442
452
|
reservoirs.emplace_back(
|
|
443
|
-
|
|
453
|
+
k,
|
|
444
454
|
capacity,
|
|
445
|
-
all_vals.get() +
|
|
446
|
-
all_ids.data() +
|
|
455
|
+
all_vals.get() + q * capacity,
|
|
456
|
+
all_ids.data() + q * capacity);
|
|
447
457
|
}
|
|
448
|
-
times[0] = times[1] = times[2] = times[3] = 0;
|
|
449
458
|
}
|
|
450
459
|
|
|
451
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
|
452
|
-
uint64_t t0 = get_cy();
|
|
460
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
453
461
|
if (this->disable) {
|
|
454
462
|
return;
|
|
455
463
|
}
|
|
@@ -457,8 +465,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
457
465
|
|
|
458
466
|
ReservoirTopN<C>& res = reservoirs[q];
|
|
459
467
|
uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
|
|
460
|
-
uint64_t t1 = get_cy();
|
|
461
|
-
times[0] += t1 - t0;
|
|
462
468
|
|
|
463
469
|
if (!lt_mask) {
|
|
464
470
|
return;
|
|
@@ -467,65 +473,315 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
|
|
|
467
473
|
d0.store(d32tab);
|
|
468
474
|
d1.store(d32tab + 16);
|
|
469
475
|
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
+
if (this->sel != nullptr) {
|
|
477
|
+
while (lt_mask) {
|
|
478
|
+
// find first non-zero
|
|
479
|
+
int j = __builtin_ctz(lt_mask);
|
|
480
|
+
auto real_idx = this->adjust_id(b, j);
|
|
481
|
+
lt_mask -= 1 << j;
|
|
482
|
+
if (this->sel->is_member(real_idx)) {
|
|
483
|
+
T dis = d32tab[j];
|
|
484
|
+
res.add(dis, real_idx);
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
} else {
|
|
488
|
+
while (lt_mask) {
|
|
489
|
+
// find first non-zero
|
|
490
|
+
int j = __builtin_ctz(lt_mask);
|
|
491
|
+
lt_mask -= 1 << j;
|
|
492
|
+
T dis = d32tab[j];
|
|
493
|
+
res.add(dis, this->adjust_id(b, j));
|
|
494
|
+
}
|
|
476
495
|
}
|
|
477
|
-
times[1] += get_cy() - t1;
|
|
478
496
|
}
|
|
479
497
|
|
|
480
|
-
void
|
|
481
|
-
float* distances,
|
|
482
|
-
int64_t* labels,
|
|
483
|
-
const float* normalizers = nullptr) override {
|
|
498
|
+
void end() override {
|
|
484
499
|
using Cf = typename std::conditional<
|
|
485
500
|
C::is_max,
|
|
486
501
|
CMax<float, int64_t>,
|
|
487
502
|
CMin<float, int64_t>>::type;
|
|
488
503
|
|
|
489
|
-
uint64_t t0 = get_cy();
|
|
490
|
-
uint64_t t3 = 0;
|
|
491
504
|
std::vector<int> perm(reservoirs[0].n);
|
|
492
|
-
for (
|
|
505
|
+
for (size_t q = 0; q < reservoirs.size(); q++) {
|
|
493
506
|
ReservoirTopN<C>& res = reservoirs[q];
|
|
494
507
|
size_t n = res.n;
|
|
495
508
|
|
|
496
509
|
if (res.i > res.n) {
|
|
497
510
|
res.shrink();
|
|
498
511
|
}
|
|
499
|
-
int64_t* heap_ids =
|
|
500
|
-
float* heap_dis =
|
|
512
|
+
int64_t* heap_ids = ids + q * n;
|
|
513
|
+
float* heap_dis = dis + q * n;
|
|
501
514
|
|
|
502
515
|
float one_a = 1.0, b = 0.0;
|
|
503
516
|
if (normalizers) {
|
|
504
517
|
one_a = 1 / normalizers[2 * q];
|
|
505
518
|
b = normalizers[2 * q + 1];
|
|
506
519
|
}
|
|
507
|
-
for (
|
|
520
|
+
for (size_t i = 0; i < res.i; i++) {
|
|
508
521
|
perm[i] = i;
|
|
509
522
|
}
|
|
510
523
|
// indirect sort of result arrays
|
|
511
524
|
std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
|
|
512
525
|
return C::cmp(res.vals[j], res.vals[i]);
|
|
513
526
|
});
|
|
514
|
-
for (
|
|
527
|
+
for (size_t i = 0; i < res.i; i++) {
|
|
515
528
|
heap_dis[i] = res.vals[perm[i]] * one_a + b;
|
|
516
529
|
heap_ids[i] = res.ids[perm[i]];
|
|
517
530
|
}
|
|
518
531
|
|
|
519
532
|
// possibly add empty results
|
|
520
533
|
heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
|
|
534
|
+
}
|
|
535
|
+
}
|
|
536
|
+
};
|
|
537
|
+
|
|
538
|
+
/** Result handler for range search. The difficulty is that the range distances
|
|
539
|
+
* have to be scaled using the scaler.
|
|
540
|
+
*/
|
|
541
|
+
|
|
542
|
+
template <class C, bool with_id_map = false>
|
|
543
|
+
struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
|
|
544
|
+
using T = typename C::T;
|
|
545
|
+
using TI = typename C::TI;
|
|
546
|
+
using RHC = ResultHandlerCompare<C, with_id_map>;
|
|
547
|
+
using RHC::normalizers;
|
|
548
|
+
using RHC::nq;
|
|
549
|
+
|
|
550
|
+
RangeSearchResult& rres;
|
|
551
|
+
float radius;
|
|
552
|
+
std::vector<uint16_t> thresholds;
|
|
553
|
+
std::vector<size_t> n_per_query;
|
|
554
|
+
size_t q0 = 0;
|
|
555
|
+
|
|
556
|
+
// we cannot use the RangeSearchPartialResult interface because queries can
|
|
557
|
+
// be performed by batches
|
|
558
|
+
struct Triplet {
|
|
559
|
+
idx_t q;
|
|
560
|
+
idx_t b;
|
|
561
|
+
uint16_t dis;
|
|
562
|
+
};
|
|
563
|
+
std::vector<Triplet> triplets;
|
|
564
|
+
|
|
565
|
+
RangeHandler(
|
|
566
|
+
RangeSearchResult& rres,
|
|
567
|
+
float radius,
|
|
568
|
+
size_t ntotal,
|
|
569
|
+
const IDSelector* sel_in)
|
|
570
|
+
: RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) {
|
|
571
|
+
thresholds.resize(nq);
|
|
572
|
+
n_per_query.resize(nq + 1);
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
virtual void begin(const float* norms) override {
|
|
576
|
+
normalizers = norms;
|
|
577
|
+
for (int q = 0; q < nq; ++q) {
|
|
578
|
+
thresholds[q] =
|
|
579
|
+
normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
|
|
580
|
+
}
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
|
|
584
|
+
if (this->disable) {
|
|
585
|
+
return;
|
|
586
|
+
}
|
|
587
|
+
this->adjust_with_origin(q, d0, d1);
|
|
521
588
|
|
|
522
|
-
|
|
589
|
+
uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
|
|
590
|
+
|
|
591
|
+
if (!lt_mask) {
|
|
592
|
+
return;
|
|
593
|
+
}
|
|
594
|
+
ALIGNED(32) uint16_t d32tab[32];
|
|
595
|
+
d0.store(d32tab);
|
|
596
|
+
d1.store(d32tab + 16);
|
|
597
|
+
|
|
598
|
+
if (this->sel != nullptr) {
|
|
599
|
+
while (lt_mask) {
|
|
600
|
+
// find first non-zero
|
|
601
|
+
int j = __builtin_ctz(lt_mask);
|
|
602
|
+
lt_mask -= 1 << j;
|
|
603
|
+
|
|
604
|
+
auto real_idx = this->adjust_id(b, j);
|
|
605
|
+
if (this->sel->is_member(real_idx)) {
|
|
606
|
+
T dis = d32tab[j];
|
|
607
|
+
n_per_query[q]++;
|
|
608
|
+
triplets.push_back({idx_t(q + q0), real_idx, dis});
|
|
609
|
+
}
|
|
610
|
+
}
|
|
611
|
+
} else {
|
|
612
|
+
while (lt_mask) {
|
|
613
|
+
// find first non-zero
|
|
614
|
+
int j = __builtin_ctz(lt_mask);
|
|
615
|
+
lt_mask -= 1 << j;
|
|
616
|
+
T dis = d32tab[j];
|
|
617
|
+
n_per_query[q]++;
|
|
618
|
+
triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
|
|
619
|
+
}
|
|
620
|
+
}
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
void end() override {
|
|
624
|
+
memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
|
|
625
|
+
rres.do_allocation();
|
|
626
|
+
for (auto it = triplets.begin(); it != triplets.end(); ++it) {
|
|
627
|
+
size_t& l = rres.lims[it->q];
|
|
628
|
+
rres.distances[l] = it->dis;
|
|
629
|
+
rres.labels[l] = it->b;
|
|
630
|
+
l++;
|
|
631
|
+
}
|
|
632
|
+
memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
|
|
633
|
+
rres.lims[0] = 0;
|
|
634
|
+
|
|
635
|
+
for (int q = 0; q < nq; q++) {
|
|
636
|
+
float one_a = 1 / normalizers[2 * q];
|
|
637
|
+
float b = normalizers[2 * q + 1];
|
|
638
|
+
for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
|
|
639
|
+
rres.distances[i] = rres.distances[i] * one_a + b;
|
|
640
|
+
}
|
|
523
641
|
}
|
|
524
|
-
times[2] += get_cy() - t0;
|
|
525
|
-
times[3] += t3;
|
|
526
642
|
}
|
|
527
643
|
};
|
|
528
644
|
|
|
645
|
+
#ifndef SWIG
|
|
646
|
+
|
|
647
|
+
// handler for a subset of queries
|
|
648
|
+
template <class C, bool with_id_map = false>
|
|
649
|
+
struct PartialRangeHandler : RangeHandler<C, with_id_map> {
|
|
650
|
+
using T = typename C::T;
|
|
651
|
+
using TI = typename C::TI;
|
|
652
|
+
using RHC = RangeHandler<C, with_id_map>;
|
|
653
|
+
using RHC::normalizers;
|
|
654
|
+
using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
|
|
655
|
+
|
|
656
|
+
RangeSearchPartialResult& pres;
|
|
657
|
+
|
|
658
|
+
PartialRangeHandler(
|
|
659
|
+
RangeSearchPartialResult& pres,
|
|
660
|
+
float radius,
|
|
661
|
+
size_t ntotal,
|
|
662
|
+
size_t q0,
|
|
663
|
+
size_t q1,
|
|
664
|
+
const IDSelector* sel_in)
|
|
665
|
+
: RangeHandler<C, with_id_map>(*pres.res, radius, ntotal, sel_in),
|
|
666
|
+
pres(pres) {
|
|
667
|
+
nq = q1 - q0;
|
|
668
|
+
this->q0 = q0;
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
// shift left n_per_query
|
|
672
|
+
void shift_n_per_query() {
|
|
673
|
+
memmove(n_per_query.data() + 1,
|
|
674
|
+
n_per_query.data(),
|
|
675
|
+
nq * sizeof(n_per_query[0]));
|
|
676
|
+
n_per_query[0] = 0;
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
// commit to partial result instead of full RangeResult
|
|
680
|
+
void end() override {
|
|
681
|
+
std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
|
|
682
|
+
for (int q = 0; q < nq; q++) {
|
|
683
|
+
n_per_query[q + 1] += n_per_query[q];
|
|
684
|
+
}
|
|
685
|
+
shift_n_per_query();
|
|
686
|
+
|
|
687
|
+
for (size_t i = 0; i < triplets.size(); i++) {
|
|
688
|
+
sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
|
|
689
|
+
}
|
|
690
|
+
shift_n_per_query();
|
|
691
|
+
|
|
692
|
+
size_t* lims = n_per_query.data();
|
|
693
|
+
|
|
694
|
+
for (int q = 0; q < nq; q++) {
|
|
695
|
+
float one_a = 1 / normalizers[2 * q];
|
|
696
|
+
float b = normalizers[2 * q + 1];
|
|
697
|
+
RangeQueryResult& qres = pres.new_result(q + q0);
|
|
698
|
+
for (size_t i = lims[q]; i < lims[q + 1]; i++) {
|
|
699
|
+
qres.add(
|
|
700
|
+
sorted_triplets[i].dis * one_a + b,
|
|
701
|
+
sorted_triplets[i].b);
|
|
702
|
+
}
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
};
|
|
706
|
+
|
|
707
|
+
#endif
|
|
708
|
+
|
|
709
|
+
/********************************************************************************
|
|
710
|
+
* Dynamic dispatching function. The consumer should have a templatized method f
|
|
711
|
+
* that will be replaced with the actual SIMDResultHandler that is determined
|
|
712
|
+
* dynamically.
|
|
713
|
+
*/
|
|
714
|
+
|
|
715
|
+
template <class C, bool W, class Consumer, class... Types>
|
|
716
|
+
void dispatch_SIMDResultHandler_fixedCW(
|
|
717
|
+
SIMDResultHandler& res,
|
|
718
|
+
Consumer& consumer,
|
|
719
|
+
Types... args) {
|
|
720
|
+
if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
|
|
721
|
+
consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
|
|
722
|
+
} else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
|
|
723
|
+
consumer.template f<HeapHandler<C, W>>(*resh, args...);
|
|
724
|
+
} else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
|
|
725
|
+
consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
|
|
726
|
+
} else { // generic handler -- will not be inlined
|
|
727
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
728
|
+
simd_result_handlers_accept_virtual,
|
|
729
|
+
"Running vitrual handler for %s",
|
|
730
|
+
typeid(res).name());
|
|
731
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
|
732
|
+
}
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
template <class C, class Consumer, class... Types>
|
|
736
|
+
void dispatch_SIMDResultHandler_fixedC(
|
|
737
|
+
SIMDResultHandler& res,
|
|
738
|
+
Consumer& consumer,
|
|
739
|
+
Types... args) {
|
|
740
|
+
if (res.with_fields) {
|
|
741
|
+
dispatch_SIMDResultHandler_fixedCW<C, true>(res, consumer, args...);
|
|
742
|
+
} else {
|
|
743
|
+
dispatch_SIMDResultHandler_fixedCW<C, false>(res, consumer, args...);
|
|
744
|
+
}
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
template <class Consumer, class... Types>
|
|
748
|
+
void dispatch_SIMDResultHandler(
|
|
749
|
+
SIMDResultHandler& res,
|
|
750
|
+
Consumer& consumer,
|
|
751
|
+
Types... args) {
|
|
752
|
+
if (res.sizeof_ids == 0) {
|
|
753
|
+
if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
|
|
754
|
+
consumer.template f<StoreResultHandler>(*resh, args...);
|
|
755
|
+
} else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
|
|
756
|
+
consumer.template f<DummyResultHandler>(*resh, args...);
|
|
757
|
+
} else { // generic path
|
|
758
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
759
|
+
simd_result_handlers_accept_virtual,
|
|
760
|
+
"Running vitrual handler for %s",
|
|
761
|
+
typeid(res).name());
|
|
762
|
+
consumer.template f<SIMDResultHandler>(res, args...);
|
|
763
|
+
}
|
|
764
|
+
} else if (res.sizeof_ids == sizeof(int)) {
|
|
765
|
+
if (res.is_CMax) {
|
|
766
|
+
dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int>>(
|
|
767
|
+
res, consumer, args...);
|
|
768
|
+
} else {
|
|
769
|
+
dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int>>(
|
|
770
|
+
res, consumer, args...);
|
|
771
|
+
}
|
|
772
|
+
} else if (res.sizeof_ids == sizeof(int64_t)) {
|
|
773
|
+
if (res.is_CMax) {
|
|
774
|
+
dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int64_t>>(
|
|
775
|
+
res, consumer, args...);
|
|
776
|
+
} else {
|
|
777
|
+
dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int64_t>>(
|
|
778
|
+
res, consumer, args...);
|
|
779
|
+
}
|
|
780
|
+
} else {
|
|
781
|
+
FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
|
|
782
|
+
}
|
|
783
|
+
}
|
|
784
|
+
|
|
529
785
|
} // namespace simd_result_handlers
|
|
530
786
|
|
|
531
787
|
} // namespace faiss
|