faiss 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
@@ -5,9 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// quiet the noise
|
9
|
-
// clang-format off
|
10
|
-
|
11
8
|
#include <faiss/IndexAdditiveQuantizer.h>
|
12
9
|
|
13
10
|
#include <algorithm>
|
@@ -21,7 +18,6 @@
|
|
21
18
|
#include <faiss/utils/extra_distances.h>
|
22
19
|
#include <faiss/utils/utils.h>
|
23
20
|
|
24
|
-
|
25
21
|
namespace faiss {
|
26
22
|
|
27
23
|
/**************************************************************************************
|
@@ -29,15 +25,13 @@ namespace faiss {
|
|
29
25
|
**************************************************************************************/
|
30
26
|
|
31
27
|
IndexAdditiveQuantizer::IndexAdditiveQuantizer(
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
IndexFlatCodes(aq->code_size, d, metric), aq(aq)
|
36
|
-
{
|
28
|
+
idx_t d,
|
29
|
+
AdditiveQuantizer* aq,
|
30
|
+
MetricType metric)
|
31
|
+
: IndexFlatCodes(aq->code_size, d, metric), aq(aq) {
|
37
32
|
FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
|
38
33
|
}
|
39
34
|
|
40
|
-
|
41
35
|
namespace {
|
42
36
|
|
43
37
|
/************************************************************
|
@@ -45,21 +39,22 @@ namespace {
|
|
45
39
|
************************************************************/
|
46
40
|
|
47
41
|
template <class VectorDistance>
|
48
|
-
struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
|
42
|
+
struct AQDistanceComputerDecompress : FlatCodesDistanceComputer {
|
49
43
|
std::vector<float> tmp;
|
50
|
-
const AdditiveQuantizer
|
44
|
+
const AdditiveQuantizer& aq;
|
51
45
|
VectorDistance vd;
|
52
46
|
size_t d;
|
53
47
|
|
54
|
-
AQDistanceComputerDecompress(
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
48
|
+
AQDistanceComputerDecompress(
|
49
|
+
const IndexAdditiveQuantizer& iaq,
|
50
|
+
VectorDistance vd)
|
51
|
+
: FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
52
|
+
tmp(iaq.d * 2),
|
53
|
+
aq(*iaq.aq),
|
54
|
+
vd(vd),
|
55
|
+
d(iaq.d) {}
|
61
56
|
|
62
|
-
const float
|
57
|
+
const float* q;
|
63
58
|
void set_query(const float* x) final {
|
64
59
|
q = x;
|
65
60
|
}
|
@@ -70,27 +65,25 @@ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
|
|
70
65
|
return vd(tmp.data(), tmp.data() + d);
|
71
66
|
}
|
72
67
|
|
73
|
-
float distance_to_code(const uint8_t
|
68
|
+
float distance_to_code(const uint8_t* code) final {
|
74
69
|
aq.decode(code, tmp.data(), 1);
|
75
70
|
return vd(q, tmp.data());
|
76
71
|
}
|
77
72
|
|
78
|
-
virtual ~AQDistanceComputerDecompress()
|
73
|
+
virtual ~AQDistanceComputerDecompress() = default;
|
79
74
|
};
|
80
75
|
|
81
|
-
|
82
|
-
|
83
|
-
struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
|
76
|
+
template <bool is_IP, AdditiveQuantizer::Search_type_t st>
|
77
|
+
struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
|
84
78
|
std::vector<float> LUT;
|
85
|
-
const AdditiveQuantizer
|
79
|
+
const AdditiveQuantizer& aq;
|
86
80
|
size_t d;
|
87
81
|
|
88
|
-
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
{}
|
82
|
+
explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer& iaq)
|
83
|
+
: FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
|
84
|
+
LUT(iaq.aq->total_codebook_size + iaq.d * 2),
|
85
|
+
aq(*iaq.aq),
|
86
|
+
d(iaq.d) {}
|
94
87
|
|
95
88
|
float bias;
|
96
89
|
void set_query(const float* x) final {
|
@@ -104,40 +97,38 @@ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
|
|
104
97
|
}
|
105
98
|
|
106
99
|
float symmetric_dis(idx_t i, idx_t j) final {
|
107
|
-
float
|
100
|
+
float* tmp = LUT.data();
|
108
101
|
aq.decode(codes + i * d, tmp, 1);
|
109
102
|
aq.decode(codes + j * d, tmp + d, 1);
|
110
103
|
return fvec_L2sqr(tmp, tmp + d, d);
|
111
104
|
}
|
112
105
|
|
113
|
-
float distance_to_code(const uint8_t
|
106
|
+
float distance_to_code(const uint8_t* code) final {
|
114
107
|
return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
|
115
108
|
}
|
116
109
|
|
117
|
-
virtual ~AQDistanceComputerLUT()
|
110
|
+
virtual ~AQDistanceComputerLUT() = default;
|
118
111
|
};
|
119
112
|
|
120
|
-
|
121
|
-
|
122
113
|
/************************************************************
|
123
114
|
* scanning implementation for search
|
124
115
|
************************************************************/
|
125
116
|
|
126
|
-
|
127
|
-
template <class VectorDistance, class ResultHandler>
|
117
|
+
template <class VectorDistance, class BlockResultHandler>
|
128
118
|
void search_with_decompress(
|
129
119
|
const IndexAdditiveQuantizer& ir,
|
130
120
|
const float* xq,
|
131
121
|
VectorDistance& vd,
|
132
|
-
|
122
|
+
BlockResultHandler& res) {
|
133
123
|
const uint8_t* codes = ir.codes.data();
|
134
124
|
size_t ntotal = ir.ntotal;
|
135
125
|
size_t code_size = ir.code_size;
|
136
|
-
const AdditiveQuantizer
|
126
|
+
const AdditiveQuantizer* aq = ir.aq;
|
137
127
|
|
138
|
-
using SingleResultHandler =
|
128
|
+
using SingleResultHandler =
|
129
|
+
typename BlockResultHandler::SingleResultHandler;
|
139
130
|
|
140
|
-
#pragma omp parallel for if(res.nq > 100)
|
131
|
+
#pragma omp parallel for if (res.nq > 100)
|
141
132
|
for (int64_t q = 0; q < res.nq; q++) {
|
142
133
|
SingleResultHandler resi(res);
|
143
134
|
resi.begin(q);
|
@@ -152,52 +143,51 @@ void search_with_decompress(
|
|
152
143
|
}
|
153
144
|
}
|
154
145
|
|
155
|
-
template<
|
146
|
+
template <
|
147
|
+
bool is_IP,
|
148
|
+
AdditiveQuantizer::Search_type_t st,
|
149
|
+
class BlockResultHandler>
|
156
150
|
void search_with_LUT(
|
157
151
|
const IndexAdditiveQuantizer& ir,
|
158
152
|
const float* xq,
|
159
|
-
|
160
|
-
|
161
|
-
const AdditiveQuantizer & aq = *ir.aq;
|
153
|
+
BlockResultHandler& res) {
|
154
|
+
const AdditiveQuantizer& aq = *ir.aq;
|
162
155
|
const uint8_t* codes = ir.codes.data();
|
163
156
|
size_t ntotal = ir.ntotal;
|
164
157
|
size_t code_size = aq.code_size;
|
165
158
|
size_t nq = res.nq;
|
166
159
|
size_t d = ir.d;
|
167
160
|
|
168
|
-
using SingleResultHandler =
|
169
|
-
|
161
|
+
using SingleResultHandler =
|
162
|
+
typename BlockResultHandler::SingleResultHandler;
|
163
|
+
std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);
|
170
164
|
|
171
165
|
aq.compute_LUT(nq, xq, LUT.get());
|
172
166
|
|
173
|
-
#pragma omp parallel for if(nq > 100)
|
167
|
+
#pragma omp parallel for if (nq > 100)
|
174
168
|
for (int64_t q = 0; q < nq; q++) {
|
175
169
|
SingleResultHandler resi(res);
|
176
170
|
resi.begin(q);
|
177
171
|
std::vector<float> tmp(aq.d);
|
178
|
-
const float
|
172
|
+
const float* LUT_q = LUT.get() + aq.total_codebook_size * q;
|
179
173
|
float bias = 0;
|
180
|
-
if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
|
174
|
+
if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
|
175
|
+
// add ||x||^2
|
181
176
|
bias = fvec_norm_L2sqr(xq + q * d, d);
|
182
177
|
}
|
183
178
|
for (size_t i = 0; i < ntotal; i++) {
|
184
179
|
float dis = aq.compute_1_distance_LUT<is_IP, st>(
|
185
|
-
|
186
|
-
LUT_q
|
187
|
-
);
|
180
|
+
codes + i * code_size, LUT_q);
|
188
181
|
resi.add_result(dis + bias, i);
|
189
182
|
}
|
190
183
|
resi.end();
|
191
184
|
}
|
192
|
-
|
193
185
|
}
|
194
186
|
|
195
|
-
|
196
187
|
} // anonymous namespace
|
197
188
|
|
198
|
-
|
199
|
-
|
200
|
-
|
189
|
+
FlatCodesDistanceComputer* IndexAdditiveQuantizer::
|
190
|
+
get_FlatCodesDistanceComputer() const {
|
201
191
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
202
192
|
if (metric_type == METRIC_L2) {
|
203
193
|
using VD = VectorDistance<METRIC_L2>;
|
@@ -212,34 +202,36 @@ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceCompute
|
|
212
202
|
}
|
213
203
|
} else {
|
214
204
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
215
|
-
return new AQDistanceComputerLUT<
|
205
|
+
return new AQDistanceComputerLUT<
|
206
|
+
true,
|
207
|
+
AdditiveQuantizer::ST_LUT_nonorm>(*this);
|
216
208
|
} else {
|
217
|
-
switch(aq->search_type) {
|
218
|
-
#define DISPATCH(st)
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
209
|
+
switch (aq->search_type) {
|
210
|
+
#define DISPATCH(st) \
|
211
|
+
case AdditiveQuantizer::st: \
|
212
|
+
return new AQDistanceComputerLUT<false, AdditiveQuantizer::st>(*this); \
|
213
|
+
break;
|
214
|
+
DISPATCH(ST_norm_float)
|
215
|
+
DISPATCH(ST_LUT_nonorm)
|
216
|
+
DISPATCH(ST_norm_qint8)
|
217
|
+
DISPATCH(ST_norm_qint4)
|
218
|
+
DISPATCH(ST_norm_cqint4)
|
219
|
+
case AdditiveQuantizer::ST_norm_cqint8:
|
220
|
+
case AdditiveQuantizer::ST_norm_lsq2x4:
|
221
|
+
case AdditiveQuantizer::ST_norm_rq2x4:
|
222
|
+
return new AQDistanceComputerLUT<
|
223
|
+
false,
|
224
|
+
AdditiveQuantizer::ST_norm_cqint8>(*this);
|
225
|
+
break;
|
232
226
|
#undef DISPATCH
|
233
|
-
|
234
|
-
|
227
|
+
default:
|
228
|
+
FAISS_THROW_FMT(
|
229
|
+
"search type %d not supported", aq->search_type);
|
235
230
|
}
|
236
231
|
}
|
237
232
|
}
|
238
233
|
}
|
239
234
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
235
|
void IndexAdditiveQuantizer::search(
|
244
236
|
idx_t n,
|
245
237
|
const float* x,
|
@@ -247,62 +239,65 @@ void IndexAdditiveQuantizer::search(
|
|
247
239
|
float* distances,
|
248
240
|
idx_t* labels,
|
249
241
|
const SearchParameters* params) const {
|
250
|
-
|
251
|
-
|
242
|
+
FAISS_THROW_IF_NOT_MSG(
|
243
|
+
!params, "search params not supported for this index");
|
252
244
|
|
253
245
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
254
246
|
if (metric_type == METRIC_L2) {
|
255
247
|
using VD = VectorDistance<METRIC_L2>;
|
256
248
|
VD vd = {size_t(d), metric_arg};
|
257
|
-
|
249
|
+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
|
258
250
|
search_with_decompress(*this, x, vd, rh);
|
259
251
|
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
260
252
|
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
|
261
253
|
VD vd = {size_t(d), metric_arg};
|
262
|
-
|
254
|
+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
|
263
255
|
search_with_decompress(*this, x, vd, rh);
|
264
256
|
}
|
265
257
|
} else {
|
266
258
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
267
|
-
|
268
|
-
|
259
|
+
HeapBlockResultHandler<CMin<float, idx_t>> rh(
|
260
|
+
n, distances, labels, k);
|
261
|
+
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
|
262
|
+
*this, x, rh);
|
269
263
|
} else {
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
264
|
+
HeapBlockResultHandler<CMax<float, idx_t>> rh(
|
265
|
+
n, distances, labels, k);
|
266
|
+
switch (aq->search_type) {
|
267
|
+
#define DISPATCH(st) \
|
268
|
+
case AdditiveQuantizer::st: \
|
269
|
+
search_with_LUT<false, AdditiveQuantizer::st>(*this, x, rh); \
|
270
|
+
break;
|
271
|
+
DISPATCH(ST_norm_float)
|
272
|
+
DISPATCH(ST_LUT_nonorm)
|
273
|
+
DISPATCH(ST_norm_qint8)
|
274
|
+
DISPATCH(ST_norm_qint4)
|
275
|
+
DISPATCH(ST_norm_cqint4)
|
276
|
+
case AdditiveQuantizer::ST_norm_cqint8:
|
277
|
+
case AdditiveQuantizer::ST_norm_lsq2x4:
|
278
|
+
case AdditiveQuantizer::ST_norm_rq2x4:
|
279
|
+
search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
|
280
|
+
*this, x, rh);
|
281
|
+
break;
|
286
282
|
#undef DISPATCH
|
287
|
-
|
288
|
-
|
283
|
+
default:
|
284
|
+
FAISS_THROW_FMT(
|
285
|
+
"search type %d not supported", aq->search_type);
|
289
286
|
}
|
290
287
|
}
|
291
|
-
|
292
288
|
}
|
293
289
|
}
|
294
290
|
|
295
|
-
void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
291
|
+
void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
292
|
+
const {
|
296
293
|
return aq->compute_codes(x, bytes, n);
|
297
294
|
}
|
298
295
|
|
299
|
-
void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
296
|
+
void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
297
|
+
const {
|
300
298
|
return aq->decode(bytes, x, n);
|
301
299
|
}
|
302
300
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
301
|
/**************************************************************************************
|
307
302
|
* IndexResidualQuantizer
|
308
303
|
**************************************************************************************/
|
@@ -313,8 +308,11 @@ IndexResidualQuantizer::IndexResidualQuantizer(
|
|
313
308
|
size_t nbits, ///< number of bit per subvector index
|
314
309
|
MetricType metric,
|
315
310
|
Search_type_t search_type)
|
316
|
-
: IndexResidualQuantizer(
|
317
|
-
|
311
|
+
: IndexResidualQuantizer(
|
312
|
+
d,
|
313
|
+
std::vector<size_t>(M, nbits),
|
314
|
+
metric,
|
315
|
+
search_type) {}
|
318
316
|
|
319
317
|
IndexResidualQuantizer::IndexResidualQuantizer(
|
320
318
|
int d,
|
@@ -326,14 +324,14 @@ IndexResidualQuantizer::IndexResidualQuantizer(
|
|
326
324
|
is_trained = false;
|
327
325
|
}
|
328
326
|
|
329
|
-
IndexResidualQuantizer::IndexResidualQuantizer()
|
327
|
+
IndexResidualQuantizer::IndexResidualQuantizer()
|
328
|
+
: IndexResidualQuantizer(0, 0, 0) {}
|
330
329
|
|
331
330
|
void IndexResidualQuantizer::train(idx_t n, const float* x) {
|
332
331
|
rq.train(n, x);
|
333
332
|
is_trained = true;
|
334
333
|
}
|
335
334
|
|
336
|
-
|
337
335
|
/**************************************************************************************
|
338
336
|
* IndexLocalSearchQuantizer
|
339
337
|
**************************************************************************************/
|
@@ -344,31 +342,33 @@ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
|
|
344
342
|
size_t nbits, ///< number of bit per subvector index
|
345
343
|
MetricType metric,
|
346
344
|
Search_type_t search_type)
|
347
|
-
: IndexAdditiveQuantizer(d, &lsq, metric),
|
345
|
+
: IndexAdditiveQuantizer(d, &lsq, metric),
|
346
|
+
lsq(d, M, nbits, search_type) {
|
348
347
|
code_size = lsq.code_size;
|
349
348
|
is_trained = false;
|
350
349
|
}
|
351
350
|
|
352
|
-
IndexLocalSearchQuantizer::IndexLocalSearchQuantizer()
|
351
|
+
IndexLocalSearchQuantizer::IndexLocalSearchQuantizer()
|
352
|
+
: IndexLocalSearchQuantizer(0, 0, 0) {}
|
353
353
|
|
354
354
|
void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
|
355
355
|
lsq.train(n, x);
|
356
356
|
is_trained = true;
|
357
357
|
}
|
358
358
|
|
359
|
-
|
360
359
|
/**************************************************************************************
|
361
360
|
* IndexProductResidualQuantizer
|
362
361
|
**************************************************************************************/
|
363
362
|
|
364
363
|
IndexProductResidualQuantizer::IndexProductResidualQuantizer(
|
365
|
-
int d,
|
364
|
+
int d, ///< dimensionality of the input vectors
|
366
365
|
size_t nsplits, ///< number of residual quantizers
|
367
|
-
size_t Msub,
|
368
|
-
size_t nbits,
|
366
|
+
size_t Msub, ///< number of subquantizers per RQ
|
367
|
+
size_t nbits, ///< number of bit per subvector index
|
369
368
|
MetricType metric,
|
370
369
|
Search_type_t search_type)
|
371
|
-
: IndexAdditiveQuantizer(d, &prq, metric),
|
370
|
+
: IndexAdditiveQuantizer(d, &prq, metric),
|
371
|
+
prq(d, nsplits, Msub, nbits, search_type) {
|
372
372
|
code_size = prq.code_size;
|
373
373
|
is_trained = false;
|
374
374
|
}
|
@@ -381,19 +381,19 @@ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
|
|
381
381
|
is_trained = true;
|
382
382
|
}
|
383
383
|
|
384
|
-
|
385
384
|
/**************************************************************************************
|
386
385
|
* IndexProductLocalSearchQuantizer
|
387
386
|
**************************************************************************************/
|
388
387
|
|
389
388
|
IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
|
390
|
-
int d,
|
389
|
+
int d, ///< dimensionality of the input vectors
|
391
390
|
size_t nsplits, ///< number of local search quantizers
|
392
|
-
size_t Msub,
|
393
|
-
size_t nbits,
|
391
|
+
size_t Msub, ///< number of subquantizers per LSQ
|
392
|
+
size_t nbits, ///< number of bit per subvector index
|
394
393
|
MetricType metric,
|
395
394
|
Search_type_t search_type)
|
396
|
-
: IndexAdditiveQuantizer(d, &plsq, metric),
|
395
|
+
: IndexAdditiveQuantizer(d, &plsq, metric),
|
396
|
+
plsq(d, nsplits, Msub, nbits, search_type) {
|
397
397
|
code_size = plsq.code_size;
|
398
398
|
is_trained = false;
|
399
399
|
}
|
@@ -406,17 +406,15 @@ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
|
|
406
406
|
is_trained = true;
|
407
407
|
}
|
408
408
|
|
409
|
-
|
410
409
|
/**************************************************************************************
|
411
410
|
* AdditiveCoarseQuantizer
|
412
411
|
**************************************************************************************/
|
413
412
|
|
414
413
|
AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
Index(d, metric), aq(aq)
|
419
|
-
{}
|
414
|
+
idx_t d,
|
415
|
+
AdditiveQuantizer* aq,
|
416
|
+
MetricType metric)
|
417
|
+
: Index(d, metric), aq(aq) {}
|
420
418
|
|
421
419
|
void AdditiveCoarseQuantizer::add(idx_t, const float*) {
|
422
420
|
FAISS_THROW_MSG("not applicable");
|
@@ -430,17 +428,16 @@ void AdditiveCoarseQuantizer::reset() {
|
|
430
428
|
FAISS_THROW_MSG("not applicable");
|
431
429
|
}
|
432
430
|
|
433
|
-
|
434
431
|
void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
|
435
432
|
if (verbose) {
|
436
|
-
printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n",
|
433
|
+
printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n",
|
434
|
+
size_t(n));
|
437
435
|
}
|
438
436
|
size_t norms_size = sizeof(float) << aq->tot_bits;
|
439
437
|
|
440
|
-
FAISS_THROW_IF_NOT_MSG
|
441
|
-
|
442
|
-
|
443
|
-
);
|
438
|
+
FAISS_THROW_IF_NOT_MSG(
|
439
|
+
norms_size <= aq->max_mem_distances,
|
440
|
+
"the RCQ norms matrix will become too large, please reduce the number of quantization steps");
|
444
441
|
|
445
442
|
aq->train(n, x);
|
446
443
|
is_trained = true;
|
@@ -448,7 +445,8 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
|
|
448
445
|
|
449
446
|
if (metric_type == METRIC_L2) {
|
450
447
|
if (verbose) {
|
451
|
-
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
448
|
+
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
449
|
+
size_t(ntotal));
|
452
450
|
}
|
453
451
|
// this is not necessary for the residualcoarsequantizer when
|
454
452
|
// using beam search. We'll see if the memory overhead is too high
|
@@ -463,16 +461,15 @@ void AdditiveCoarseQuantizer::search(
|
|
463
461
|
idx_t k,
|
464
462
|
float* distances,
|
465
463
|
idx_t* labels,
|
466
|
-
const SearchParameters
|
467
|
-
|
468
|
-
|
464
|
+
const SearchParameters* params) const {
|
465
|
+
FAISS_THROW_IF_NOT_MSG(
|
466
|
+
!params, "search params not supported for this index");
|
469
467
|
|
470
468
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
471
469
|
aq->knn_centroids_inner_product(n, x, k, distances, labels);
|
472
470
|
} else if (metric_type == METRIC_L2) {
|
473
471
|
FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
|
474
|
-
aq->knn_centroids_L2(
|
475
|
-
n, x, k, distances, labels, centroid_norms.data());
|
472
|
+
aq->knn_centroids_L2(n, x, k, distances, labels, centroid_norms.data());
|
476
473
|
}
|
477
474
|
}
|
478
475
|
|
@@ -481,7 +478,7 @@ void AdditiveCoarseQuantizer::search(
|
|
481
478
|
**************************************************************************************/
|
482
479
|
|
483
480
|
ResidualCoarseQuantizer::ResidualCoarseQuantizer(
|
484
|
-
int d,
|
481
|
+
int d, ///< dimensionality of the input vectors
|
485
482
|
const std::vector<size_t>& nbits,
|
486
483
|
MetricType metric)
|
487
484
|
: AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits) {
|
@@ -496,21 +493,30 @@ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
|
|
496
493
|
MetricType metric)
|
497
494
|
: ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
|
498
495
|
|
499
|
-
ResidualCoarseQuantizer::ResidualCoarseQuantizer()
|
500
|
-
|
501
|
-
|
496
|
+
ResidualCoarseQuantizer::ResidualCoarseQuantizer()
|
497
|
+
: ResidualCoarseQuantizer(0, 0, 0) {}
|
502
498
|
|
503
499
|
void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
|
504
500
|
beam_factor = new_beam_factor;
|
505
501
|
if (new_beam_factor > 0) {
|
506
502
|
FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
|
503
|
+
if (rq.codebook_cross_products.size() == 0) {
|
504
|
+
rq.compute_codebook_tables();
|
505
|
+
}
|
507
506
|
return;
|
508
|
-
} else
|
509
|
-
|
510
|
-
|
507
|
+
} else {
|
508
|
+
// new_beam_factor = -1: exhaustive computation.
|
509
|
+
// Does not use the cross_products
|
510
|
+
rq.codebook_cross_products.resize(0);
|
511
|
+
// but the centroid norms are necessary!
|
512
|
+
if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
|
513
|
+
if (verbose) {
|
514
|
+
printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
|
515
|
+
size_t(ntotal));
|
516
|
+
}
|
517
|
+
centroid_norms.resize(ntotal);
|
518
|
+
aq->compute_centroid_norms(centroid_norms.data());
|
511
519
|
}
|
512
|
-
centroid_norms.resize(ntotal);
|
513
|
-
aq->compute_centroid_norms(centroid_norms.data());
|
514
520
|
}
|
515
521
|
}
|
516
522
|
|
@@ -520,13 +526,15 @@ void ResidualCoarseQuantizer::search(
|
|
520
526
|
idx_t k,
|
521
527
|
float* distances,
|
522
528
|
idx_t* labels,
|
523
|
-
const SearchParameters
|
524
|
-
) const {
|
525
|
-
|
529
|
+
const SearchParameters* params_in) const {
|
526
530
|
float beam_factor = this->beam_factor;
|
527
531
|
if (params_in) {
|
528
|
-
auto params =
|
529
|
-
|
532
|
+
auto params =
|
533
|
+
dynamic_cast<const SearchParametersResidualCoarseQuantizer*>(
|
534
|
+
params_in);
|
535
|
+
FAISS_THROW_IF_NOT_MSG(
|
536
|
+
params,
|
537
|
+
"need SearchParametersResidualCoarseQuantizer parameters");
|
530
538
|
beam_factor = params->beam_factor;
|
531
539
|
}
|
532
540
|
|
@@ -559,7 +567,12 @@ void ResidualCoarseQuantizer::search(
|
|
559
567
|
}
|
560
568
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
561
569
|
idx_t i1 = std::min(n, i0 + bs);
|
562
|
-
search(i1 - i0,
|
570
|
+
search(i1 - i0,
|
571
|
+
x + i0 * d,
|
572
|
+
k,
|
573
|
+
distances + i0 * k,
|
574
|
+
labels + i0 * k,
|
575
|
+
params_in);
|
563
576
|
InterruptCallback::check();
|
564
577
|
}
|
565
578
|
return;
|
@@ -571,6 +584,7 @@ void ResidualCoarseQuantizer::search(
|
|
571
584
|
rq.refine_beam(
|
572
585
|
n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
|
573
586
|
|
587
|
+
// pack int32 table
|
574
588
|
#pragma omp parallel for if (n > 4000)
|
575
589
|
for (idx_t i = 0; i < n; i++) {
|
576
590
|
memcpy(distances + i * k,
|
@@ -590,7 +604,8 @@ void ResidualCoarseQuantizer::search(
|
|
590
604
|
}
|
591
605
|
}
|
592
606
|
|
593
|
-
void ResidualCoarseQuantizer::initialize_from(
|
607
|
+
void ResidualCoarseQuantizer::initialize_from(
|
608
|
+
const ResidualCoarseQuantizer& other) {
|
594
609
|
FAISS_THROW_IF_NOT(rq.M <= other.rq.M);
|
595
610
|
rq.initialize_from(other.rq);
|
596
611
|
set_beam_factor(other.beam_factor);
|
@@ -598,7 +613,6 @@ void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &oth
|
|
598
613
|
ntotal = (idx_t)1 << aq->tot_bits;
|
599
614
|
}
|
600
615
|
|
601
|
-
|
602
616
|
/**************************************************************************************
|
603
617
|
* LocalSearchCoarseQuantizer
|
604
618
|
**************************************************************************************/
|
@@ -613,12 +627,8 @@ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
|
|
613
627
|
is_trained = false;
|
614
628
|
}
|
615
629
|
|
616
|
-
|
617
630
|
LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
|
618
631
|
aq = &lsq;
|
619
632
|
}
|
620
633
|
|
621
|
-
|
622
|
-
|
623
|
-
|
624
634
|
} // namespace faiss
|