faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -5,9 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// quiet the noise
|
|
9
|
-
// clang-format off
|
|
10
|
-
|
|
11
8
|
#include <faiss/IndexAdditiveQuantizer.h>
|
|
12
9
|
|
|
13
10
|
#include <algorithm>
|
|
@@ -21,7 +18,6 @@
|
|
|
21
18
|
#include <faiss/utils/extra_distances.h>
|
|
22
19
|
#include <faiss/utils/utils.h>
|
|
23
20
|
|
|
24
|
-
|
|
25
21
|
namespace faiss {
|
|
26
22
|
|
|
27
23
|
/**************************************************************************************
|
|
@@ -29,15 +25,13 @@ namespace faiss {
|
|
|
29
25
|
**************************************************************************************/
|
|
30
26
|
|
|
31
27
|
IndexAdditiveQuantizer::IndexAdditiveQuantizer(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
IndexFlatCodes(aq->code_size, d, metric), aq(aq)
|
|
36
|
-
{
|
|
28
|
+
idx_t d,
|
|
29
|
+
AdditiveQuantizer* aq,
|
|
30
|
+
MetricType metric)
|
|
31
|
+
: IndexFlatCodes(aq->code_size, d, metric), aq(aq) {
|
|
37
32
|
FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
|
|
38
33
|
}
|
|
39
34
|
|
|
40
|
-
|
|
41
35
|
namespace {
|
|
42
36
|
|
|
43
37
|
/************************************************************
|
|
@@ -45,21 +39,22 @@ namespace {
|
|
|
45
39
|
************************************************************/
|
|
46
40
|
|
|
47
41
|
template <class VectorDistance>
|
|
48
|
-
struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
|
|
42
|
+
struct AQDistanceComputerDecompress : FlatCodesDistanceComputer {
|
|
49
43
|
std::vector<float> tmp;
|
|
50
|
-
const AdditiveQuantizer
|
|
44
|
+
const AdditiveQuantizer& aq;
|
|
51
45
|
VectorDistance vd;
|
|
52
46
|
size_t d;
|
|
53
47
|
|
|
54
|
-
AQDistanceComputerDecompress(
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
48
|
+
AQDistanceComputerDecompress(
|
|
49
|
+
const IndexAdditiveQuantizer& iaq,
|
|
50
|
+
VectorDistance vd)
|
|
51
|
+
: FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
|
52
|
+
tmp(iaq.d * 2),
|
|
53
|
+
aq(*iaq.aq),
|
|
54
|
+
vd(vd),
|
|
55
|
+
d(iaq.d) {}
|
|
61
56
|
|
|
62
|
-
const float
|
|
57
|
+
const float* q;
|
|
63
58
|
void set_query(const float* x) final {
|
|
64
59
|
q = x;
|
|
65
60
|
}
|
|
@@ -70,27 +65,25 @@ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
|
|
|
70
65
|
return vd(tmp.data(), tmp.data() + d);
|
|
71
66
|
}
|
|
72
67
|
|
|
73
|
-
float distance_to_code(const uint8_t
|
|
68
|
+
float distance_to_code(const uint8_t* code) final {
|
|
74
69
|
aq.decode(code, tmp.data(), 1);
|
|
75
70
|
return vd(q, tmp.data());
|
|
76
71
|
}
|
|
77
72
|
|
|
78
|
-
virtual ~AQDistanceComputerDecompress()
|
|
73
|
+
virtual ~AQDistanceComputerDecompress() = default;
|
|
79
74
|
};
|
|
80
75
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
|
|
76
|
+
template <bool is_IP, AdditiveQuantizer::Search_type_t st>
|
|
77
|
+
struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
|
|
84
78
|
std::vector<float> LUT;
|
|
85
|
-
const AdditiveQuantizer
|
|
79
|
+
const AdditiveQuantizer& aq;
|
|
86
80
|
size_t d;
|
|
87
81
|
|
|
88
|
-
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
{}
|
|
82
|
+
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer& iaq)
|
|
83
|
+
: FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
|
84
|
+
LUT(iaq.aq->total_codebook_size + iaq.d * 2),
|
|
85
|
+
aq(*iaq.aq),
|
|
86
|
+
d(iaq.d) {}
|
|
94
87
|
|
|
95
88
|
float bias;
|
|
96
89
|
void set_query(const float* x) final {
|
|
@@ -104,40 +97,38 @@ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
|
|
|
104
97
|
}
|
|
105
98
|
|
|
106
99
|
float symmetric_dis(idx_t i, idx_t j) final {
|
|
107
|
-
float
|
|
100
|
+
float* tmp = LUT.data();
|
|
108
101
|
aq.decode(codes + i * d, tmp, 1);
|
|
109
102
|
aq.decode(codes + j * d, tmp + d, 1);
|
|
110
103
|
return fvec_L2sqr(tmp, tmp + d, d);
|
|
111
104
|
}
|
|
112
105
|
|
|
113
|
-
float distance_to_code(const uint8_t
|
|
106
|
+
float distance_to_code(const uint8_t* code) final {
|
|
114
107
|
return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
|
|
115
108
|
}
|
|
116
109
|
|
|
117
|
-
virtual ~AQDistanceComputerLUT()
|
|
110
|
+
virtual ~AQDistanceComputerLUT() = default;
|
|
118
111
|
};
|
|
119
112
|
|
|
120
|
-
|
|
121
|
-
|
|
122
113
|
/************************************************************
|
|
123
114
|
* scanning implementation for search
|
|
124
115
|
************************************************************/
|
|
125
116
|
|
|
126
|
-
|
|
127
|
-
template <class VectorDistance, class ResultHandler>
|
|
117
|
+
template <class VectorDistance, class BlockResultHandler>
|
|
128
118
|
void search_with_decompress(
|
|
129
119
|
const IndexAdditiveQuantizer& ir,
|
|
130
120
|
const float* xq,
|
|
131
121
|
VectorDistance& vd,
|
|
132
|
-
|
|
122
|
+
BlockResultHandler& res) {
|
|
133
123
|
const uint8_t* codes = ir.codes.data();
|
|
134
124
|
size_t ntotal = ir.ntotal;
|
|
135
125
|
size_t code_size = ir.code_size;
|
|
136
|
-
const AdditiveQuantizer
|
|
126
|
+
const AdditiveQuantizer* aq = ir.aq;
|
|
137
127
|
|
|
138
|
-
using SingleResultHandler =
|
|
128
|
+
using SingleResultHandler =
|
|
129
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
139
130
|
|
|
140
|
-
#pragma omp parallel for if(res.nq > 100)
|
|
131
|
+
#pragma omp parallel for if (res.nq > 100)
|
|
141
132
|
for (int64_t q = 0; q < res.nq; q++) {
|
|
142
133
|
SingleResultHandler resi(res);
|
|
143
134
|
resi.begin(q);
|
|
@@ -152,52 +143,51 @@ void search_with_decompress(
|
|
|
152
143
|
}
|
|
153
144
|
}
|
|
154
145
|
|
|
155
|
-
template<
|
|
146
|
+
template <
|
|
147
|
+
bool is_IP,
|
|
148
|
+
AdditiveQuantizer::Search_type_t st,
|
|
149
|
+
class BlockResultHandler>
|
|
156
150
|
void search_with_LUT(
|
|
157
151
|
const IndexAdditiveQuantizer& ir,
|
|
158
152
|
const float* xq,
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
const AdditiveQuantizer & aq = *ir.aq;
|
|
153
|
+
BlockResultHandler& res) {
|
|
154
|
+
const AdditiveQuantizer& aq = *ir.aq;
|
|
162
155
|
const uint8_t* codes = ir.codes.data();
|
|
163
156
|
size_t ntotal = ir.ntotal;
|
|
164
157
|
size_t code_size = aq.code_size;
|
|
165
158
|
size_t nq = res.nq;
|
|
166
159
|
size_t d = ir.d;
|
|
167
160
|
|
|
168
|
-
using SingleResultHandler =
|
|
169
|
-
|
|
161
|
+
using SingleResultHandler =
|
|
162
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
163
|
+
std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);
|
|
170
164
|
|
|
171
165
|
aq.compute_LUT(nq, xq, LUT.get());
|
|
172
166
|
|
|
173
|
-
#pragma omp parallel for if(nq > 100)
|
|
167
|
+
#pragma omp parallel for if (nq > 100)
|
|
174
168
|
for (int64_t q = 0; q < nq; q++) {
|
|
175
169
|
SingleResultHandler resi(res);
|
|
176
170
|
resi.begin(q);
|
|
177
171
|
std::vector<float> tmp(aq.d);
|
|
178
|
-
const float
|
|
172
|
+
const float* LUT_q = LUT.get() + aq.total_codebook_size * q;
|
|
179
173
|
float bias = 0;
|
|
180
|
-
if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
|
|
174
|
+
if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
|
|
175
|
+
// add ||x||^2
|
|
181
176
|
bias = fvec_norm_L2sqr(xq + q * d, d);
|
|
182
177
|
}
|
|
183
178
|
for (size_t i = 0; i < ntotal; i++) {
|
|
184
179
|
float dis = aq.compute_1_distance_LUT<is_IP, st>(
|
|
185
|
-
|
|
186
|
-
LUT_q
|
|
187
|
-
);
|
|
180
|
+
codes + i * code_size, LUT_q);
|
|
188
181
|
resi.add_result(dis + bias, i);
|
|
189
182
|
}
|
|
190
183
|
resi.end();
|
|
191
184
|
}
|
|
192
|
-
|
|
193
185
|
}
|
|
194
186
|
|
|
195
|
-
|
|
196
187
|
} // anonymous namespace
|
|
197
188
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
189
|
+
FlatCodesDistanceComputer* IndexAdditiveQuantizer::
|
|
190
|
+
get_FlatCodesDistanceComputer() const {
|
|
201
191
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
|
202
192
|
if (metric_type == METRIC_L2) {
|
|
203
193
|
using VD = VectorDistance<METRIC_L2>;
|
|
@@ -212,34 +202,36 @@ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceCompute
|
|
|
212
202
|
}
|
|
213
203
|
} else {
|
|
214
204
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
215
|
-
return new AQDistanceComputerLUT<
|
|
205
|
+
return new AQDistanceComputerLUT<
|
|
206
|
+
true,
|
|
207
|
+
AdditiveQuantizer::ST_LUT_nonorm>(*this);
|
|
216
208
|
} else {
|
|
217
|
-
switch(aq->search_type) {
|
|
218
|
-
#define DISPATCH(st)
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
209
|
+
switch (aq->search_type) {
|
|
210
|
+
#define DISPATCH(st) \
|
|
211
|
+
case AdditiveQuantizer::st: \
|
|
212
|
+
return new AQDistanceComputerLUT<false, AdditiveQuantizer::st>(*this); \
|
|
213
|
+
break;
|
|
214
|
+
DISPATCH(ST_norm_float)
|
|
215
|
+
DISPATCH(ST_LUT_nonorm)
|
|
216
|
+
DISPATCH(ST_norm_qint8)
|
|
217
|
+
DISPATCH(ST_norm_qint4)
|
|
218
|
+
DISPATCH(ST_norm_cqint4)
|
|
219
|
+
case AdditiveQuantizer::ST_norm_cqint8:
|
|
220
|
+
case AdditiveQuantizer::ST_norm_lsq2x4:
|
|
221
|
+
case AdditiveQuantizer::ST_norm_rq2x4:
|
|
222
|
+
return new AQDistanceComputerLUT<
|
|
223
|
+
false,
|
|
224
|
+
AdditiveQuantizer::ST_norm_cqint8>(*this);
|
|
225
|
+
break;
|
|
232
226
|
#undef DISPATCH
|
|
233
|
-
|
|
234
|
-
|
|
227
|
+
default:
|
|
228
|
+
FAISS_THROW_FMT(
|
|
229
|
+
"search type %d not supported", aq->search_type);
|
|
235
230
|
}
|
|
236
231
|
}
|
|
237
232
|
}
|
|
238
233
|
}
|
|
239
234
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
235
|
void IndexAdditiveQuantizer::search(
|
|
244
236
|
idx_t n,
|
|
245
237
|
const float* x,
|
|
@@ -247,62 +239,65 @@ void IndexAdditiveQuantizer::search(
|
|
|
247
239
|
float* distances,
|
|
248
240
|
idx_t* labels,
|
|
249
241
|
const SearchParameters* params) const {
|
|
250
|
-
|
|
251
|
-
|
|
242
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
243
|
+
!params, "search params not supported for this index");
|
|
252
244
|
|
|
253
245
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
|
254
246
|
if (metric_type == METRIC_L2) {
|
|
255
247
|
using VD = VectorDistance<METRIC_L2>;
|
|
256
248
|
VD vd = {size_t(d), metric_arg};
|
|
257
|
-
|
|
249
|
+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
|
|
258
250
|
search_with_decompress(*this, x, vd, rh);
|
|
259
251
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
260
252
|
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
|
|
261
253
|
VD vd = {size_t(d), metric_arg};
|
|
262
|
-
|
|
254
|
+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
|
|
263
255
|
search_with_decompress(*this, x, vd, rh);
|
|
264
256
|
}
|
|
265
257
|
} else {
|
|
266
258
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
267
|
-
|
|
268
|
-
|
|
259
|
+
HeapBlockResultHandler<CMin<float, idx_t>> rh(
|
|
260
|
+
n, distances, labels, k);
|
|
261
|
+
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
|
|
262
|
+
*this, x, rh);
|
|
269
263
|
} else {
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
264
|
+
HeapBlockResultHandler<CMax<float, idx_t>> rh(
|
|
265
|
+
n, distances, labels, k);
|
|
266
|
+
switch (aq->search_type) {
|
|
267
|
+
#define DISPATCH(st) \
|
|
268
|
+
case AdditiveQuantizer::st: \
|
|
269
|
+
search_with_LUT<false, AdditiveQuantizer::st>(*this, x, rh); \
|
|
270
|
+
break;
|
|
271
|
+
DISPATCH(ST_norm_float)
|
|
272
|
+
DISPATCH(ST_LUT_nonorm)
|
|
273
|
+
DISPATCH(ST_norm_qint8)
|
|
274
|
+
DISPATCH(ST_norm_qint4)
|
|
275
|
+
DISPATCH(ST_norm_cqint4)
|
|
276
|
+
case AdditiveQuantizer::ST_norm_cqint8:
|
|
277
|
+
case AdditiveQuantizer::ST_norm_lsq2x4:
|
|
278
|
+
case AdditiveQuantizer::ST_norm_rq2x4:
|
|
279
|
+
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
|
|
280
|
+
*this, x, rh);
|
|
281
|
+
break;
|
|
286
282
|
#undef DISPATCH
|
|
287
|
-
|
|
288
|
-
|
|
283
|
+
default:
|
|
284
|
+
FAISS_THROW_FMT(
|
|
285
|
+
"search type %d not supported", aq->search_type);
|
|
289
286
|
}
|
|
290
287
|
}
|
|
291
|
-
|
|
292
288
|
}
|
|
293
289
|
}
|
|
294
290
|
|
|
295
|
-
void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
|
291
|
+
void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
|
292
|
+
const {
|
|
296
293
|
return aq->compute_codes(x, bytes, n);
|
|
297
294
|
}
|
|
298
295
|
|
|
299
|
-
void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
296
|
+
void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
297
|
+
const {
|
|
300
298
|
return aq->decode(bytes, x, n);
|
|
301
299
|
}
|
|
302
300
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
301
|
/**************************************************************************************
|
|
307
302
|
* IndexResidualQuantizer
|
|
308
303
|
**************************************************************************************/
|
|
@@ -313,8 +308,11 @@ IndexResidualQuantizer::IndexResidualQuantizer(
|
|
|
313
308
|
size_t nbits, ///< number of bit per subvector index
|
|
314
309
|
MetricType metric,
|
|
315
310
|
Search_type_t search_type)
|
|
316
|
-
: IndexResidualQuantizer(
|
|
317
|
-
|
|
311
|
+
: IndexResidualQuantizer(
|
|
312
|
+
d,
|
|
313
|
+
std::vector<size_t>(M, nbits),
|
|
314
|
+
metric,
|
|
315
|
+
search_type) {}
|
|
318
316
|
|
|
319
317
|
IndexResidualQuantizer::IndexResidualQuantizer(
|
|
320
318
|
int d,
|
|
@@ -326,14 +324,14 @@ IndexResidualQuantizer::IndexResidualQuantizer(
|
|
|
326
324
|
is_trained = false;
|
|
327
325
|
}
|
|
328
326
|
|
|
329
|
-
IndexResidualQuantizer::IndexResidualQuantizer()
|
|
327
|
+
IndexResidualQuantizer::IndexResidualQuantizer()
|
|
328
|
+
: IndexResidualQuantizer(0, 0, 0) {}
|
|
330
329
|
|
|
331
330
|
void IndexResidualQuantizer::train(idx_t n, const float* x) {
|
|
332
331
|
rq.train(n, x);
|
|
333
332
|
is_trained = true;
|
|
334
333
|
}
|
|
335
334
|
|
|
336
|
-
|
|
337
335
|
/**************************************************************************************
|
|
338
336
|
* IndexLocalSearchQuantizer
|
|
339
337
|
**************************************************************************************/
|
|
@@ -344,31 +342,33 @@ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
|
|
|
344
342
|
size_t nbits, ///< number of bit per subvector index
|
|
345
343
|
MetricType metric,
|
|
346
344
|
Search_type_t search_type)
|
|
347
|
-
: IndexAdditiveQuantizer(d, &lsq, metric),
|
|
345
|
+
: IndexAdditiveQuantizer(d, &lsq, metric),
|
|
346
|
+
lsq(d, M, nbits, search_type) {
|
|
348
347
|
code_size = lsq.code_size;
|
|
349
348
|
is_trained = false;
|
|
350
349
|
}
|
|
351
350
|
|
|
352
|
-
IndexLocalSearchQuantizer::IndexLocalSearchQuantizer()
|
|
351
|
+
IndexLocalSearchQuantizer::IndexLocalSearchQuantizer()
|
|
352
|
+
: IndexLocalSearchQuantizer(0, 0, 0) {}
|
|
353
353
|
|
|
354
354
|
void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
|
|
355
355
|
lsq.train(n, x);
|
|
356
356
|
is_trained = true;
|
|
357
357
|
}
|
|
358
358
|
|
|
359
|
-
|
|
360
359
|
/**************************************************************************************
|
|
361
360
|
* IndexProductResidualQuantizer
|
|
362
361
|
**************************************************************************************/
|
|
363
362
|
|
|
364
363
|
IndexProductResidualQuantizer::IndexProductResidualQuantizer(
|
|
365
|
-
int d,
|
|
364
|
+
int d, ///< dimensionality of the input vectors
|
|
366
365
|
size_t nsplits, ///< number of residual quantizers
|
|
367
|
-
size_t Msub,
|
|
368
|
-
size_t nbits,
|
|
366
|
+
size_t Msub, ///< number of subquantizers per RQ
|
|
367
|
+
size_t nbits, ///< number of bit per subvector index
|
|
369
368
|
MetricType metric,
|
|
370
369
|
Search_type_t search_type)
|
|
371
|
-
: IndexAdditiveQuantizer(d, &prq, metric),
|
|
370
|
+
: IndexAdditiveQuantizer(d, &prq, metric),
|
|
371
|
+
prq(d, nsplits, Msub, nbits, search_type) {
|
|
372
372
|
code_size = prq.code_size;
|
|
373
373
|
is_trained = false;
|
|
374
374
|
}
|
|
@@ -381,19 +381,19 @@ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
|
|
|
381
381
|
is_trained = true;
|
|
382
382
|
}
|
|
383
383
|
|
|
384
|
-
|
|
385
384
|
/**************************************************************************************
|
|
386
385
|
* IndexProductLocalSearchQuantizer
|
|
387
386
|
**************************************************************************************/
|
|
388
387
|
|
|
389
388
|
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
|
|
390
|
-
int d,
|
|
389
|
+
int d, ///< dimensionality of the input vectors
|
|
391
390
|
size_t nsplits, ///< number of local search quantizers
|
|
392
|
-
size_t Msub,
|
|
393
|
-
size_t nbits,
|
|
391
|
+
size_t Msub, ///< number of subquantizers per LSQ
|
|
392
|
+
size_t nbits, ///< number of bit per subvector index
|
|
394
393
|
MetricType metric,
|
|
395
394
|
Search_type_t search_type)
|
|
396
|
-
: IndexAdditiveQuantizer(d, &plsq, metric),
|
|
395
|
+
: IndexAdditiveQuantizer(d, &plsq, metric),
|
|
396
|
+
plsq(d, nsplits, Msub, nbits, search_type) {
|
|
397
397
|
code_size = plsq.code_size;
|
|
398
398
|
is_trained = false;
|
|
399
399
|
}
|
|
@@ -406,17 +406,15 @@ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
|
|
|
406
406
|
is_trained = true;
|
|
407
407
|
}
|
|
408
408
|
|
|
409
|
-
|
|
410
409
|
/**************************************************************************************
|
|
411
410
|
* AdditiveCoarseQuantizer
|
|
412
411
|
**************************************************************************************/
|
|
413
412
|
|
|
414
413
|
AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
Index(d, metric), aq(aq)
|
|
419
|
-
{}
|
|
414
|
+
idx_t d,
|
|
415
|
+
AdditiveQuantizer* aq,
|
|
416
|
+
MetricType metric)
|
|
417
|
+
: Index(d, metric), aq(aq) {}
|
|
420
418
|
|
|
421
419
|
void AdditiveCoarseQuantizer::add(idx_t, const float*) {
|
|
422
420
|
FAISS_THROW_MSG("not applicable");
|
|
@@ -430,17 +428,16 @@ void AdditiveCoarseQuantizer::reset() {
|
|
|
430
428
|
FAISS_THROW_MSG("not applicable");
|
|
431
429
|
}
|
|
432
430
|
|
|
433
|
-
|
|
434
431
|
void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
|
|
435
432
|
if (verbose) {
|
|
436
|
-
printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n",
|
|
433
|
+
printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n",
|
|
434
|
+
size_t(n));
|
|
437
435
|
}
|
|
438
436
|
size_t norms_size = sizeof(float) << aq->tot_bits;
|
|
439
437
|
|
|
440
|
-
FAISS_THROW_IF_NOT_MSG
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
);
|
|
438
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
439
|
+
norms_size <= aq->max_mem_distances,
|
|
440
|
+
"the RCQ norms matrix will become too large, please reduce the number of quantization steps");
|
|
444
441
|
|
|
445
442
|
aq->train(n, x);
|
|
446
443
|
is_trained = true;
|
|
@@ -448,7 +445,8 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
|
|
|
448
445
|
|
|
449
446
|
if (metric_type == METRIC_L2) {
|
|
450
447
|
if (verbose) {
|
|
451
|
-
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
|
448
|
+
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
|
449
|
+
size_t(ntotal));
|
|
452
450
|
}
|
|
453
451
|
// this is not necessary for the residualcoarsequantizer when
|
|
454
452
|
// using beam search. We'll see if the memory overhead is too high
|
|
@@ -463,16 +461,15 @@ void AdditiveCoarseQuantizer::search(
|
|
|
463
461
|
idx_t k,
|
|
464
462
|
float* distances,
|
|
465
463
|
idx_t* labels,
|
|
466
|
-
const SearchParameters
|
|
467
|
-
|
|
468
|
-
|
|
464
|
+
const SearchParameters* params) const {
|
|
465
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
466
|
+
!params, "search params not supported for this index");
|
|
469
467
|
|
|
470
468
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
471
469
|
aq->knn_centroids_inner_product(n, x, k, distances, labels);
|
|
472
470
|
} else if (metric_type == METRIC_L2) {
|
|
473
471
|
FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
|
|
474
|
-
aq->knn_centroids_L2(
|
|
475
|
-
n, x, k, distances, labels, centroid_norms.data());
|
|
472
|
+
aq->knn_centroids_L2(n, x, k, distances, labels, centroid_norms.data());
|
|
476
473
|
}
|
|
477
474
|
}
|
|
478
475
|
|
|
@@ -481,7 +478,7 @@ void AdditiveCoarseQuantizer::search(
|
|
|
481
478
|
**************************************************************************************/
|
|
482
479
|
|
|
483
480
|
ResidualCoarseQuantizer::ResidualCoarseQuantizer(
|
|
484
|
-
int d,
|
|
481
|
+
int d, ///< dimensionality of the input vectors
|
|
485
482
|
const std::vector<size_t>& nbits,
|
|
486
483
|
MetricType metric)
|
|
487
484
|
: AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits) {
|
|
@@ -496,21 +493,30 @@ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
|
|
|
496
493
|
MetricType metric)
|
|
497
494
|
: ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
|
|
498
495
|
|
|
499
|
-
ResidualCoarseQuantizer::ResidualCoarseQuantizer()
|
|
500
|
-
|
|
501
|
-
|
|
496
|
+
ResidualCoarseQuantizer::ResidualCoarseQuantizer()
|
|
497
|
+
: ResidualCoarseQuantizer(0, 0, 0) {}
|
|
502
498
|
|
|
503
499
|
void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
|
|
504
500
|
beam_factor = new_beam_factor;
|
|
505
501
|
if (new_beam_factor > 0) {
|
|
506
502
|
FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
|
|
503
|
+
if (rq.codebook_cross_products.size() == 0) {
|
|
504
|
+
rq.compute_codebook_tables();
|
|
505
|
+
}
|
|
507
506
|
return;
|
|
508
|
-
} else
|
|
509
|
-
|
|
510
|
-
|
|
507
|
+
} else {
|
|
508
|
+
// new_beam_factor = -1: exhaustive computation.
|
|
509
|
+
// Does not use the cross_products
|
|
510
|
+
rq.codebook_cross_products.resize(0);
|
|
511
|
+
// but the centroid norms are necessary!
|
|
512
|
+
if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
|
|
513
|
+
if (verbose) {
|
|
514
|
+
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
|
515
|
+
size_t(ntotal));
|
|
516
|
+
}
|
|
517
|
+
centroid_norms.resize(ntotal);
|
|
518
|
+
aq->compute_centroid_norms(centroid_norms.data());
|
|
511
519
|
}
|
|
512
|
-
centroid_norms.resize(ntotal);
|
|
513
|
-
aq->compute_centroid_norms(centroid_norms.data());
|
|
514
520
|
}
|
|
515
521
|
}
|
|
516
522
|
|
|
@@ -520,13 +526,15 @@ void ResidualCoarseQuantizer::search(
|
|
|
520
526
|
idx_t k,
|
|
521
527
|
float* distances,
|
|
522
528
|
idx_t* labels,
|
|
523
|
-
const SearchParameters
|
|
524
|
-
) const {
|
|
525
|
-
|
|
529
|
+
const SearchParameters* params_in) const {
|
|
526
530
|
float beam_factor = this->beam_factor;
|
|
527
531
|
if (params_in) {
|
|
528
|
-
auto params =
|
|
529
|
-
|
|
532
|
+
auto params =
|
|
533
|
+
dynamic_cast<const SearchParametersResidualCoarseQuantizer*>(
|
|
534
|
+
params_in);
|
|
535
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
536
|
+
params,
|
|
537
|
+
"need SearchParametersResidualCoarseQuantizer parameters");
|
|
530
538
|
beam_factor = params->beam_factor;
|
|
531
539
|
}
|
|
532
540
|
|
|
@@ -559,7 +567,12 @@ void ResidualCoarseQuantizer::search(
|
|
|
559
567
|
}
|
|
560
568
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
561
569
|
idx_t i1 = std::min(n, i0 + bs);
|
|
562
|
-
search(i1 - i0,
|
|
570
|
+
search(i1 - i0,
|
|
571
|
+
x + i0 * d,
|
|
572
|
+
k,
|
|
573
|
+
distances + i0 * k,
|
|
574
|
+
labels + i0 * k,
|
|
575
|
+
params_in);
|
|
563
576
|
InterruptCallback::check();
|
|
564
577
|
}
|
|
565
578
|
return;
|
|
@@ -571,6 +584,7 @@ void ResidualCoarseQuantizer::search(
|
|
|
571
584
|
rq.refine_beam(
|
|
572
585
|
n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
|
|
573
586
|
|
|
587
|
+
// pack int32 table
|
|
574
588
|
#pragma omp parallel for if (n > 4000)
|
|
575
589
|
for (idx_t i = 0; i < n; i++) {
|
|
576
590
|
memcpy(distances + i * k,
|
|
@@ -590,7 +604,8 @@ void ResidualCoarseQuantizer::search(
|
|
|
590
604
|
}
|
|
591
605
|
}
|
|
592
606
|
|
|
593
|
-
void ResidualCoarseQuantizer::initialize_from(
|
|
607
|
+
void ResidualCoarseQuantizer::initialize_from(
|
|
608
|
+
const ResidualCoarseQuantizer& other) {
|
|
594
609
|
FAISS_THROW_IF_NOT(rq.M <= other.rq.M);
|
|
595
610
|
rq.initialize_from(other.rq);
|
|
596
611
|
set_beam_factor(other.beam_factor);
|
|
@@ -598,7 +613,6 @@ void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &oth
|
|
|
598
613
|
ntotal = (idx_t)1 << aq->tot_bits;
|
|
599
614
|
}
|
|
600
615
|
|
|
601
|
-
|
|
602
616
|
/**************************************************************************************
|
|
603
617
|
* LocalSearchCoarseQuantizer
|
|
604
618
|
**************************************************************************************/
|
|
@@ -613,12 +627,8 @@ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
|
|
|
613
627
|
is_trained = false;
|
|
614
628
|
}
|
|
615
629
|
|
|
616
|
-
|
|
617
630
|
LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
|
|
618
631
|
aq = &lsq;
|
|
619
632
|
}
|
|
620
633
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
634
|
} // namespace faiss
|