faiss 0.2.3 → 0.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,19 +5,19 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
/*
|
11
|
-
* implementation of
|
9
|
+
* implementation of the index_factory function. Lots of regex parsing code.
|
12
10
|
*/
|
13
11
|
|
14
12
|
#include <faiss/index_factory.h>
|
15
|
-
|
16
|
-
#include
|
13
|
+
#include "faiss/MetricType.h"
|
14
|
+
#include "faiss/impl/FaissAssert.h"
|
17
15
|
|
18
16
|
#include <cinttypes>
|
19
17
|
#include <cmath>
|
20
18
|
|
19
|
+
#include <map>
|
20
|
+
|
21
21
|
#include <regex>
|
22
22
|
|
23
23
|
#include <faiss/impl/FaissAssert.h>
|
@@ -25,13 +25,16 @@
|
|
25
25
|
#include <faiss/utils/utils.h>
|
26
26
|
|
27
27
|
#include <faiss/Index2Layer.h>
|
28
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
28
29
|
#include <faiss/IndexFlat.h>
|
29
30
|
#include <faiss/IndexHNSW.h>
|
30
31
|
#include <faiss/IndexIVF.h>
|
32
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
31
33
|
#include <faiss/IndexIVFFlat.h>
|
32
34
|
#include <faiss/IndexIVFPQ.h>
|
33
35
|
#include <faiss/IndexIVFPQFastScan.h>
|
34
36
|
#include <faiss/IndexIVFPQR.h>
|
37
|
+
#include <faiss/IndexIVFSpectralHash.h>
|
35
38
|
#include <faiss/IndexLSH.h>
|
36
39
|
#include <faiss/IndexLattice.h>
|
37
40
|
#include <faiss/IndexNSG.h>
|
@@ -39,7 +42,6 @@
|
|
39
42
|
#include <faiss/IndexPQFastScan.h>
|
40
43
|
#include <faiss/IndexPreTransform.h>
|
41
44
|
#include <faiss/IndexRefine.h>
|
42
|
-
#include <faiss/IndexResidual.h>
|
43
45
|
#include <faiss/IndexScalarQuantizer.h>
|
44
46
|
#include <faiss/MetaIndexes.h>
|
45
47
|
#include <faiss/VectorTransform.h>
|
@@ -48,6 +50,7 @@
|
|
48
50
|
#include <faiss/IndexBinaryHNSW.h>
|
49
51
|
#include <faiss/IndexBinaryHash.h>
|
50
52
|
#include <faiss/IndexBinaryIVF.h>
|
53
|
+
#include <string>
|
51
54
|
|
52
55
|
namespace faiss {
|
53
56
|
|
@@ -55,16 +58,49 @@ namespace faiss {
|
|
55
58
|
* index_factory
|
56
59
|
***************************************************************/
|
57
60
|
|
61
|
+
int index_factory_verbose = 0;
|
62
|
+
|
58
63
|
namespace {
|
59
64
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
+
/***************************************************************
|
66
|
+
* Small functions
|
67
|
+
*/
|
68
|
+
|
69
|
+
bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
|
70
|
+
return std::regex_match(s, sm, std::regex(pat));
|
71
|
+
}
|
72
|
+
|
73
|
+
// find first pair of matching parentheses
|
74
|
+
void find_matching_parentheses(
|
75
|
+
const std::string& s,
|
76
|
+
int& i0,
|
77
|
+
int& i1,
|
78
|
+
int begin = 0) {
|
79
|
+
int st = 0;
|
80
|
+
i0 = i1 = 0;
|
81
|
+
for (int i = begin; i < s.length(); i++) {
|
82
|
+
if (s[i] == '(') {
|
83
|
+
if (st == 0) {
|
84
|
+
i0 = i;
|
85
|
+
}
|
86
|
+
st++;
|
87
|
+
}
|
88
|
+
|
89
|
+
if (s[i] == ')') {
|
90
|
+
st--;
|
91
|
+
if (st == 0) {
|
92
|
+
i1 = i;
|
93
|
+
return;
|
94
|
+
}
|
95
|
+
if (st < 0) {
|
96
|
+
FAISS_THROW_FMT(
|
97
|
+
"factory string %s: unbalanced parentheses", s.c_str());
|
98
|
+
}
|
65
99
|
}
|
66
100
|
}
|
67
|
-
|
101
|
+
FAISS_THROW_FMT(
|
102
|
+
"factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
|
103
|
+
}
|
68
104
|
|
69
105
|
/// what kind of training does this coarse quantizer require?
|
70
106
|
char get_trains_alone(const Index* coarse_quantizer) {
|
@@ -83,447 +119,582 @@ char get_trains_alone(const Index* coarse_quantizer) {
|
|
83
119
|
// kmeans index
|
84
120
|
}
|
85
121
|
|
86
|
-
|
87
|
-
|
122
|
+
// set the fields for factory-constructed IVF structures
|
123
|
+
IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
|
124
|
+
index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
|
125
|
+
index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
|
126
|
+
index_ivf->own_fields = true;
|
127
|
+
return index_ivf;
|
88
128
|
}
|
89
129
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
for (i = s.length() - 1; i >= 0; i--) {
|
94
|
-
if (!isdigit(s[i]))
|
95
|
-
break;
|
130
|
+
int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
|
131
|
+
if (mr.length() == 0) {
|
132
|
+
return deflt;
|
96
133
|
}
|
97
|
-
return
|
134
|
+
return std::stoi(mr.str().substr(begin));
|
98
135
|
}
|
99
136
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
}
|
137
|
+
std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
|
138
|
+
{"SQ8", ScalarQuantizer::QT_8bit},
|
139
|
+
{"SQ4", ScalarQuantizer::QT_4bit},
|
140
|
+
{"SQ6", ScalarQuantizer::QT_6bit},
|
141
|
+
{"SQfp16", ScalarQuantizer::QT_fp16},
|
142
|
+
};
|
143
|
+
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
|
144
|
+
|
145
|
+
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
|
146
|
+
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
|
147
|
+
{"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
|
148
|
+
{"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
|
149
|
+
{"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
|
150
|
+
{"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
|
151
|
+
{"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
|
152
|
+
};
|
110
153
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
}
|
154
|
+
const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
|
155
|
+
const std::string aq_norm_pattern =
|
156
|
+
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
|
157
|
+
|
158
|
+
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
159
|
+
std::string stok,
|
160
|
+
MetricType metric) {
|
161
|
+
if (stok == "") {
|
162
|
+
return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
|
163
|
+
: AdditiveQuantizer::ST_LUT_nonorm;
|
122
164
|
}
|
123
|
-
|
124
|
-
|
165
|
+
int pos = stok.rfind("_");
|
166
|
+
return aq_search_type[stok.substr(pos)];
|
125
167
|
}
|
126
168
|
|
127
|
-
|
169
|
+
std::vector<size_t> aq_parse_nbits(std::string stok) {
|
170
|
+
std::vector<size_t> nbits;
|
171
|
+
std::smatch sm;
|
172
|
+
while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
173
|
+
int M = std::stoi(sm[1].str());
|
174
|
+
int nbit = std::stoi(sm[2].str());
|
175
|
+
nbits.resize(nbits.size() + M, nbit);
|
176
|
+
stok = sm.suffix();
|
177
|
+
}
|
178
|
+
return nbits;
|
179
|
+
}
|
128
180
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
Index* coarse_quantizer = nullptr;
|
133
|
-
std::string parenthesis_ivf, parenthesis_refine;
|
134
|
-
Index* index = nullptr;
|
135
|
-
bool add_idmap = false;
|
136
|
-
int d_in = d;
|
181
|
+
/***************************************************************
|
182
|
+
* Parse VectorTransform
|
183
|
+
*/
|
137
184
|
|
138
|
-
|
185
|
+
VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
186
|
+
std::smatch sm;
|
187
|
+
auto match = [&sm, description](std::string pattern) {
|
188
|
+
return re_match(description, pattern, sm);
|
189
|
+
};
|
190
|
+
if (match("PCA(W?)(R?)([0-9]+)")) {
|
191
|
+
bool white = sm[1].length() > 0;
|
192
|
+
bool rot = sm[2].length() > 0;
|
193
|
+
return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
|
194
|
+
}
|
195
|
+
if (match("L2[nN]orm")) {
|
196
|
+
return new NormalizationTransform(d, 2.0);
|
197
|
+
}
|
198
|
+
if (match("RR([0-9]+)?")) {
|
199
|
+
return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
|
200
|
+
}
|
201
|
+
if (match("ITQ([0-9]+)?")) {
|
202
|
+
return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
|
203
|
+
}
|
204
|
+
if (match("OPQ([0-9]+)(_[0-9]+)?")) {
|
205
|
+
int M = std::stoi(sm[1].str());
|
206
|
+
int d_out = mres_to_int(sm[2], d, 1);
|
207
|
+
return new OPQMatrix(d, M, d_out);
|
208
|
+
}
|
209
|
+
if (match("Pad([0-9]+)")) {
|
210
|
+
int d_out = std::stoi(sm[1].str());
|
211
|
+
return new RemapDimensionsTransform(d, std::max(d_out, d), false);
|
212
|
+
}
|
213
|
+
return nullptr;
|
214
|
+
};
|
139
215
|
|
140
|
-
|
141
|
-
|
216
|
+
/***************************************************************
|
217
|
+
* Parse IndexIVF
|
218
|
+
*/
|
142
219
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
int
|
147
|
-
|
220
|
+
// parsing guard + function
|
221
|
+
Index* parse_coarse_quantizer(
|
222
|
+
const std::string& description,
|
223
|
+
int d,
|
224
|
+
MetricType mt,
|
225
|
+
std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
|
226
|
+
size_t& nlist,
|
227
|
+
bool& use_2layer) {
|
228
|
+
std::smatch sm;
|
229
|
+
auto match = [&sm, description](std::string pattern) {
|
230
|
+
return re_match(description, pattern, sm);
|
231
|
+
};
|
232
|
+
use_2layer = false;
|
233
|
+
|
234
|
+
if (match("IVF([0-9]+)")) {
|
235
|
+
nlist = std::stoi(sm[1].str());
|
236
|
+
return new IndexFlat(d, mt);
|
237
|
+
}
|
238
|
+
if (match("IMI2x([0-9]+)")) {
|
239
|
+
int nbit = std::stoi(sm[1].str());
|
240
|
+
FAISS_THROW_IF_NOT_MSG(
|
241
|
+
mt == METRIC_L2,
|
242
|
+
"MultiIndex not implemented for inner prod search");
|
243
|
+
nlist = (size_t)1 << (2 * nbit);
|
244
|
+
return new MultiIndexQuantizer(d, 2, nbit);
|
245
|
+
}
|
246
|
+
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
|
247
|
+
nlist = std::stoi(sm[1].str());
|
248
|
+
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
|
249
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
250
|
+
}
|
251
|
+
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
|
252
|
+
nlist = std::stoi(sm[1].str());
|
253
|
+
int R = std::stoi(sm[2]);
|
254
|
+
return new IndexNSGFlat(d, R, mt);
|
255
|
+
}
|
256
|
+
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
|
257
|
+
nlist = std::stoi(sm[1].str());
|
258
|
+
int no = std::stoi(sm[2].str());
|
259
|
+
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
|
260
|
+
return parenthesis_indexes[no].release();
|
261
|
+
}
|
148
262
|
|
149
|
-
|
263
|
+
// these two generate Index2Layer's not IndexIVF's
|
264
|
+
if (match("Residual([0-9]+)x([0-9]+)")) {
|
265
|
+
FAISS_THROW_IF_NOT_MSG(
|
266
|
+
mt == METRIC_L2,
|
267
|
+
"MultiIndex not implemented for inner prod search");
|
268
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
|
269
|
+
nlist = (size_t)1 << (M * nbit);
|
270
|
+
use_2layer = true;
|
271
|
+
return new MultiIndexQuantizer(d, M, nbit);
|
272
|
+
}
|
273
|
+
if (match("Residual([0-9]+)")) {
|
274
|
+
FAISS_THROW_IF_NOT_MSG(
|
275
|
+
mt == METRIC_L2,
|
276
|
+
"Residual not implemented for inner prod search");
|
277
|
+
use_2layer = true;
|
278
|
+
nlist = mres_to_int(sm[1]);
|
279
|
+
return new IndexFlatL2(d);
|
280
|
+
}
|
281
|
+
return nullptr;
|
282
|
+
}
|
150
283
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
284
|
+
// parse the code description part of the IVF description
|
285
|
+
|
286
|
+
IndexIVF* parse_IndexIVF(
|
287
|
+
const std::string& code_string,
|
288
|
+
std::unique_ptr<Index>& quantizer,
|
289
|
+
size_t nlist,
|
290
|
+
MetricType mt) {
|
291
|
+
std::smatch sm;
|
292
|
+
auto match = [&sm, &code_string](const std::string pattern) {
|
293
|
+
return re_match(code_string, pattern, sm);
|
294
|
+
};
|
295
|
+
auto get_q = [&quantizer] { return quantizer.release(); };
|
296
|
+
int d = quantizer->d;
|
297
|
+
|
298
|
+
if (match("Flat")) {
|
299
|
+
return new IndexIVFFlat(get_q(), d, nlist, mt);
|
300
|
+
}
|
301
|
+
if (match("FlatDedup")) {
|
302
|
+
return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
|
303
|
+
}
|
304
|
+
if (match(sq_pattern)) {
|
305
|
+
return new IndexIVFScalarQuantizer(
|
306
|
+
get_q(), d, nlist, sq_types[sm[1].str()], mt);
|
307
|
+
}
|
308
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
309
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
|
310
|
+
IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
|
311
|
+
index_ivf->do_polysemous_training = sm[3].str() != "np";
|
312
|
+
return index_ivf;
|
313
|
+
}
|
314
|
+
if (match("PQ([0-9]+)\\+([0-9]+)")) {
|
315
|
+
FAISS_THROW_IF_NOT_MSG(
|
316
|
+
mt == METRIC_L2,
|
317
|
+
"IVFPQR not implemented for inner product search");
|
318
|
+
int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
|
319
|
+
return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
|
320
|
+
}
|
321
|
+
if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
|
322
|
+
int M = mres_to_int(sm[1]);
|
323
|
+
int bbs = mres_to_int(sm[3], 32, 1);
|
324
|
+
IndexIVFPQFastScan* index_ivf =
|
325
|
+
new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
|
326
|
+
index_ivf->by_residual = sm[2].str() == "r";
|
327
|
+
return index_ivf;
|
328
|
+
}
|
329
|
+
if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
|
330
|
+
std::vector<size_t> nbits = aq_parse_nbits(sm.str());
|
331
|
+
AdditiveQuantizer::Search_type_t st =
|
332
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
333
|
+
IndexIVF* index_ivf;
|
334
|
+
if (sm[1].str() == "RQ") {
|
335
|
+
index_ivf = new IndexIVFResidualQuantizer(
|
336
|
+
get_q(), d, nlist, nbits, mt, st);
|
155
337
|
} else {
|
156
|
-
|
338
|
+
FAISS_THROW_IF_NOT(nbits.size() > 0);
|
339
|
+
index_ivf = new IndexIVFLocalSearchQuantizer(
|
340
|
+
get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
|
157
341
|
}
|
158
|
-
|
159
|
-
}
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
std::string stok(tok);
|
170
|
-
nbit = 8;
|
171
|
-
int bbs = -1;
|
172
|
-
char c;
|
173
|
-
|
174
|
-
// to avoid mem leaks with exceptions:
|
175
|
-
// do all tests before any instanciation
|
176
|
-
|
177
|
-
VectorTransform* vt_1 = nullptr;
|
178
|
-
Index* coarse_quantizer_1 = nullptr;
|
179
|
-
Index* index_1 = nullptr;
|
180
|
-
|
181
|
-
// VectorTransforms
|
182
|
-
if (sscanf(tok, "PCA%d", &d_out) == 1) {
|
183
|
-
vt_1 = new PCAMatrix(d, d_out);
|
184
|
-
d = d_out;
|
185
|
-
} else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
|
186
|
-
vt_1 = new PCAMatrix(d, d_out, 0, true);
|
187
|
-
d = d_out;
|
188
|
-
} else if (sscanf(tok, "RR%d", &d_out) == 1) {
|
189
|
-
vt_1 = new RandomRotationMatrix(d, d_out);
|
190
|
-
d = d_out;
|
191
|
-
} else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
|
192
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, false);
|
193
|
-
d = d_out;
|
194
|
-
} else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
|
195
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, true);
|
196
|
-
d = d_out;
|
197
|
-
} else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
|
198
|
-
vt_1 = new OPQMatrix(d, opq_M, d_out);
|
199
|
-
d = d_out;
|
200
|
-
} else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
|
201
|
-
vt_1 = new OPQMatrix(d, opq_M);
|
202
|
-
} else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
|
203
|
-
vt_1 = new ITQTransform(d, d_out, true);
|
204
|
-
d = d_out;
|
205
|
-
} else if (stok == "ITQ") {
|
206
|
-
vt_1 = new ITQTransform(d, d, false);
|
207
|
-
} else if (sscanf(tok, "Pad%d", &d_out) == 1) {
|
208
|
-
if (d_out > d) {
|
209
|
-
vt_1 = new RemapDimensionsTransform(d, d_out, false);
|
210
|
-
d = d_out;
|
211
|
-
}
|
212
|
-
} else if (stok == "L2norm") {
|
213
|
-
vt_1 = new NormalizationTransform(d, 2.0);
|
214
|
-
|
215
|
-
// coarse quantizers
|
216
|
-
} else if (
|
217
|
-
!coarse_quantizer &&
|
218
|
-
sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
|
219
|
-
coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
|
220
|
-
|
221
|
-
} else if (
|
222
|
-
!coarse_quantizer &&
|
223
|
-
sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
|
224
|
-
coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
|
225
|
-
|
226
|
-
} else if (
|
227
|
-
!coarse_quantizer &&
|
228
|
-
sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
|
229
|
-
if (!parenthesis_ivf.empty()) {
|
230
|
-
coarse_quantizer_1 =
|
231
|
-
index_factory(d, parenthesis_ivf.c_str(), metric);
|
232
|
-
} else if (metric == METRIC_L2) {
|
233
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
234
|
-
} else {
|
235
|
-
coarse_quantizer_1 = new IndexFlatIP(d);
|
236
|
-
}
|
237
|
-
|
238
|
-
} else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
|
239
|
-
FAISS_THROW_IF_NOT_MSG(
|
240
|
-
metric == METRIC_L2,
|
241
|
-
"MultiIndex not implemented for inner prod search");
|
242
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
|
243
|
-
ncentroids = 1 << (2 * nbit);
|
244
|
-
|
245
|
-
} else if (
|
246
|
-
!coarse_quantizer &&
|
247
|
-
sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
|
248
|
-
FAISS_THROW_IF_NOT_MSG(
|
249
|
-
metric == METRIC_L2,
|
250
|
-
"MultiIndex not implemented for inner prod search");
|
251
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
|
252
|
-
ncentroids = int64_t(1) << (M * nbit);
|
253
|
-
use_2layer = true;
|
254
|
-
|
255
|
-
} else if (std::regex_match(
|
256
|
-
stok,
|
257
|
-
std::regex(
|
258
|
-
"(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
|
259
|
-
std::vector<size_t> nbits;
|
260
|
-
std::smatch sm;
|
261
|
-
bool is_RCQ = stok.find("RCQ") == 0;
|
262
|
-
while (std::regex_search(
|
263
|
-
stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
264
|
-
int M = std::stoi(sm[1].str());
|
265
|
-
int nbit = std::stoi(sm[2].str());
|
266
|
-
nbits.resize(nbits.size() + M, nbit);
|
267
|
-
stok = sm.suffix();
|
268
|
-
}
|
269
|
-
if (!is_RCQ) {
|
270
|
-
index_1 = new IndexResidual(d, nbits, metric);
|
271
|
-
} else {
|
272
|
-
index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
|
273
|
-
}
|
274
|
-
} else if (
|
275
|
-
!coarse_quantizer &&
|
276
|
-
sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
|
277
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
278
|
-
use_2layer = true;
|
279
|
-
|
280
|
-
} else if (stok == "IDMap") {
|
281
|
-
add_idmap = true;
|
282
|
-
|
283
|
-
// IVFs
|
284
|
-
} else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
|
285
|
-
if (coarse_quantizer) {
|
286
|
-
// if there was an IVF in front, then it is an IVFFlat
|
287
|
-
IndexIVF* index_ivf = stok == "Flat"
|
288
|
-
? new IndexIVFFlat(
|
289
|
-
coarse_quantizer, d, ncentroids, metric)
|
290
|
-
: new IndexIVFFlatDedup(
|
291
|
-
coarse_quantizer, d, ncentroids, metric);
|
292
|
-
index_ivf->quantizer_trains_alone =
|
293
|
-
get_trains_alone(coarse_quantizer);
|
294
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
295
|
-
del_coarse_quantizer.release();
|
296
|
-
index_ivf->own_fields = true;
|
297
|
-
index_1 = index_ivf;
|
298
|
-
} else if (hnsw_M > 0) {
|
299
|
-
index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
|
300
|
-
} else if (nsg_R > 0) {
|
301
|
-
index_1 = new IndexNSGFlat(d, nsg_R, metric);
|
302
|
-
} else {
|
303
|
-
FAISS_THROW_IF_NOT_MSG(
|
304
|
-
stok != "FlatDedup",
|
305
|
-
"dedup supported only for IVFFlat");
|
306
|
-
index_1 = new IndexFlat(d, metric);
|
307
|
-
}
|
308
|
-
} else if (
|
309
|
-
!index &&
|
310
|
-
(stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
|
311
|
-
stok == "SQfp16")) {
|
312
|
-
ScalarQuantizer::QuantizerType qt = stok == "SQ8"
|
313
|
-
? ScalarQuantizer::QT_8bit
|
314
|
-
: stok == "SQ6" ? ScalarQuantizer::QT_6bit
|
315
|
-
: stok == "SQ4" ? ScalarQuantizer::QT_4bit
|
316
|
-
: stok == "SQfp16" ? ScalarQuantizer::QT_fp16
|
317
|
-
: ScalarQuantizer::QT_4bit;
|
318
|
-
if (coarse_quantizer) {
|
319
|
-
FAISS_THROW_IF_NOT(!use_2layer);
|
320
|
-
IndexIVFScalarQuantizer* index_ivf =
|
321
|
-
new IndexIVFScalarQuantizer(
|
322
|
-
coarse_quantizer, d, ncentroids, qt, metric);
|
323
|
-
index_ivf->quantizer_trains_alone =
|
324
|
-
get_trains_alone(coarse_quantizer);
|
325
|
-
del_coarse_quantizer.release();
|
326
|
-
index_ivf->own_fields = true;
|
327
|
-
index_1 = index_ivf;
|
328
|
-
} else if (hnsw_M > 0) {
|
329
|
-
index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
|
330
|
-
} else {
|
331
|
-
index_1 = new IndexScalarQuantizer(d, qt, metric);
|
332
|
-
}
|
333
|
-
} else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
|
334
|
-
FAISS_THROW_IF_NOT_MSG(
|
335
|
-
coarse_quantizer, "PQ with + works only with an IVF");
|
336
|
-
FAISS_THROW_IF_NOT_MSG(
|
337
|
-
metric == METRIC_L2,
|
338
|
-
"IVFPQR not implemented for inner product search");
|
339
|
-
IndexIVFPQR* index_ivf = new IndexIVFPQR(
|
340
|
-
coarse_quantizer, d, ncentroids, M, 8, M2, 8);
|
341
|
-
index_ivf->quantizer_trains_alone =
|
342
|
-
get_trains_alone(coarse_quantizer);
|
343
|
-
del_coarse_quantizer.release();
|
344
|
-
index_ivf->own_fields = true;
|
345
|
-
index_1 = index_ivf;
|
346
|
-
} else if (
|
347
|
-
!index &&
|
348
|
-
(sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
|
349
|
-
(sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
|
350
|
-
(sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
|
351
|
-
if (bbs == -1) {
|
352
|
-
bbs = 32;
|
353
|
-
}
|
354
|
-
bool by_residual = str_ends_with(stok, "fsr");
|
355
|
-
if (coarse_quantizer) {
|
356
|
-
IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
|
357
|
-
coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
|
358
|
-
index_ivf->quantizer_trains_alone =
|
359
|
-
get_trains_alone(coarse_quantizer);
|
360
|
-
index_ivf->metric_type = metric;
|
361
|
-
index_ivf->by_residual = by_residual;
|
362
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
363
|
-
del_coarse_quantizer.release();
|
364
|
-
index_ivf->own_fields = true;
|
365
|
-
index_1 = index_ivf;
|
366
|
-
} else {
|
367
|
-
IndexPQFastScan* index_pq =
|
368
|
-
new IndexPQFastScan(d, M, 4, metric, bbs);
|
369
|
-
index_1 = index_pq;
|
370
|
-
}
|
371
|
-
} else if (
|
372
|
-
!index &&
|
373
|
-
(sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
|
374
|
-
sscanf(tok, "PQ%d", &M) == 1 ||
|
375
|
-
sscanf(tok, "PQ%dnp", &M) == 1)) {
|
376
|
-
bool do_polysemous_training = stok.find("np") == std::string::npos;
|
377
|
-
if (coarse_quantizer) {
|
378
|
-
if (!use_2layer) {
|
379
|
-
IndexIVFPQ* index_ivf = new IndexIVFPQ(
|
380
|
-
coarse_quantizer, d, ncentroids, M, nbit);
|
381
|
-
index_ivf->quantizer_trains_alone =
|
382
|
-
get_trains_alone(coarse_quantizer);
|
383
|
-
index_ivf->metric_type = metric;
|
384
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
385
|
-
del_coarse_quantizer.release();
|
386
|
-
index_ivf->own_fields = true;
|
387
|
-
index_ivf->do_polysemous_training = do_polysemous_training;
|
388
|
-
index_1 = index_ivf;
|
389
|
-
} else {
|
390
|
-
Index2Layer* index_2l = new Index2Layer(
|
391
|
-
coarse_quantizer, ncentroids, M, nbit);
|
392
|
-
index_2l->q1.quantizer_trains_alone =
|
393
|
-
get_trains_alone(coarse_quantizer);
|
394
|
-
index_2l->q1.own_fields = true;
|
395
|
-
index_1 = index_2l;
|
396
|
-
}
|
397
|
-
} else if (hnsw_M > 0) {
|
398
|
-
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
399
|
-
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
400
|
-
do_polysemous_training;
|
401
|
-
index_1 = ipq;
|
402
|
-
} else {
|
403
|
-
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
404
|
-
index_pq->do_polysemous_training = do_polysemous_training;
|
405
|
-
index_1 = index_pq;
|
406
|
-
}
|
407
|
-
} else if (
|
408
|
-
!index &&
|
409
|
-
sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
|
410
|
-
Index* quant = new IndexFlatL2(d);
|
411
|
-
IndexHNSW2Level* hidx2l =
|
412
|
-
new IndexHNSW2Level(quant, ncent, pq_m, M);
|
413
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
414
|
-
idx2l->q1.own_fields = true;
|
415
|
-
index_1 = hidx2l;
|
416
|
-
} else if (
|
417
|
-
!index &&
|
418
|
-
sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
|
419
|
-
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
420
|
-
IndexHNSW2Level* hidx2l =
|
421
|
-
new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
|
422
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
423
|
-
idx2l->q1.own_fields = true;
|
424
|
-
idx2l->q1.quantizer_trains_alone = 1;
|
425
|
-
index_1 = hidx2l;
|
426
|
-
} else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
|
427
|
-
index_1 = new IndexHNSWPQ(d, pq_m, M);
|
428
|
-
} else if (
|
429
|
-
!index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
|
430
|
-
pq_m == 8) {
|
431
|
-
index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
|
432
|
-
} else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
|
433
|
-
hnsw_M = M;
|
434
|
-
// here it is unclear what we want: HNSW flat or HNSWx,Y ?
|
435
|
-
} else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
|
436
|
-
nsg_R = R;
|
437
|
-
} else if (
|
438
|
-
!index &&
|
439
|
-
(stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
|
440
|
-
stok == "LSHt")) {
|
441
|
-
bool rotate_data = strstr(tok, "r") != nullptr;
|
442
|
-
bool train_thresholds = strstr(tok, "t") != nullptr;
|
443
|
-
index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
|
444
|
-
} else if (
|
445
|
-
!index &&
|
446
|
-
sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
|
447
|
-
FAISS_THROW_IF_NOT(!coarse_quantizer);
|
448
|
-
index_1 = new IndexLattice(d, M, nbit, r2);
|
449
|
-
} else if (stok == "RFlat") {
|
450
|
-
parenthesis_refine = "Flat";
|
451
|
-
} else if (stok == "Refine") {
|
452
|
-
FAISS_THROW_IF_NOT_MSG(
|
453
|
-
!parenthesis_refine.empty(),
|
454
|
-
"Refine index should be provided in parentheses");
|
455
|
-
} else {
|
456
|
-
FAISS_THROW_FMT(
|
457
|
-
"could not parse token \"%s\" in %s\n",
|
458
|
-
tok,
|
459
|
-
description_in);
|
342
|
+
return index_ivf;
|
343
|
+
}
|
344
|
+
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
|
345
|
+
int outdim = mres_to_int(sm[2], d); // is also the number of bits
|
346
|
+
std::unique_ptr<VectorTransform> vt;
|
347
|
+
if (sm[1] == "ITQ") {
|
348
|
+
vt.reset(new ITQTransform(d, outdim, d != outdim));
|
349
|
+
} else if (sm[1] == "PCA") {
|
350
|
+
vt.reset(new PCAMatrix(d, outdim));
|
351
|
+
} else if (sm[1] == "PCAR") {
|
352
|
+
vt.reset(new PCAMatrix(d, outdim, 0, true));
|
460
353
|
}
|
354
|
+
// the rationale for -1e10 is that this corresponds to simple
|
355
|
+
// thresholding
|
356
|
+
float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
|
357
|
+
IndexIVFSpectralHash* index_ivf =
|
358
|
+
new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
|
359
|
+
index_ivf->replace_vt(vt.release(), true);
|
360
|
+
if (sm[4].length()) {
|
361
|
+
std::string s = sm[4].str();
|
362
|
+
index_ivf->threshold_type = s == "g"
|
363
|
+
? IndexIVFSpectralHash::Thresh_global
|
364
|
+
: s == "c"
|
365
|
+
? IndexIVFSpectralHash::Thresh_centroid
|
366
|
+
:
|
367
|
+
/* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
|
368
|
+
}
|
369
|
+
return index_ivf;
|
370
|
+
}
|
371
|
+
return nullptr;
|
372
|
+
}
|
373
|
+
|
374
|
+
/***************************************************************
|
375
|
+
* Parse IndexHNSW
|
376
|
+
*/
|
377
|
+
|
378
|
+
IndexHNSW* parse_IndexHNSW(
|
379
|
+
const std::string code_string,
|
380
|
+
int d,
|
381
|
+
MetricType mt,
|
382
|
+
int hnsw_M) {
|
383
|
+
std::smatch sm;
|
384
|
+
auto match = [&sm, &code_string](const std::string& pattern) {
|
385
|
+
return re_match(code_string, pattern, sm);
|
386
|
+
};
|
387
|
+
|
388
|
+
if (match("Flat|")) {
|
389
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
390
|
+
}
|
391
|
+
if (match("PQ([0-9]+)(np)?")) {
|
392
|
+
int M = std::stoi(sm[1].str());
|
393
|
+
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
394
|
+
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
395
|
+
sm[2].str() != "np";
|
396
|
+
return ipq;
|
397
|
+
}
|
398
|
+
if (match(sq_pattern)) {
|
399
|
+
return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
|
400
|
+
}
|
401
|
+
if (match("([0-9]+)\\+PQ([0-9]+)?")) {
|
402
|
+
int ncent = mres_to_int(sm[1]);
|
403
|
+
int pq_m = mres_to_int(sm[2]);
|
404
|
+
IndexHNSW2Level* hidx2l =
|
405
|
+
new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
|
406
|
+
dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
|
407
|
+
return hidx2l;
|
408
|
+
}
|
409
|
+
if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
|
410
|
+
int nbit = mres_to_int(sm[1]);
|
411
|
+
int pq_m = mres_to_int(sm[2]);
|
412
|
+
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
413
|
+
IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
|
414
|
+
quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
|
415
|
+
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
416
|
+
idx2l->q1.own_fields = true;
|
417
|
+
idx2l->q1.quantizer_trains_alone = 1;
|
418
|
+
return hidx2l;
|
419
|
+
}
|
461
420
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
421
|
+
return nullptr;
|
422
|
+
}
|
423
|
+
|
424
|
+
/***************************************************************
|
425
|
+
* Parse basic indexes
|
426
|
+
*/
|
427
|
+
|
428
|
+
Index* parse_other_indexes(
|
429
|
+
const std::string& description,
|
430
|
+
int d,
|
431
|
+
MetricType metric) {
|
432
|
+
std::smatch sm;
|
433
|
+
auto match = [&sm, description](const std::string& pattern) {
|
434
|
+
return re_match(description, pattern, sm);
|
435
|
+
};
|
436
|
+
|
437
|
+
// IndexFlat
|
438
|
+
if (description == "Flat") {
|
439
|
+
return new IndexFlat(d, metric);
|
440
|
+
}
|
441
|
+
|
442
|
+
// IndexLSH
|
443
|
+
if (match("LSH(r?)(t?)")) {
|
444
|
+
bool rotate_data = sm[1].length() > 0;
|
445
|
+
bool train_thresholds = sm[2].length() > 0;
|
446
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2);
|
447
|
+
return new IndexLSH(d, d, rotate_data, train_thresholds);
|
448
|
+
}
|
449
|
+
|
450
|
+
// IndexLattice
|
451
|
+
if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
|
452
|
+
int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
|
453
|
+
int nbit = std::stoi(sm[3].str());
|
454
|
+
return new IndexLattice(d, M, nbit, r2);
|
455
|
+
}
|
456
|
+
|
457
|
+
// IndexNSGFlat
|
458
|
+
if (match("NSG([0-9]+)(,Flat)?")) {
|
459
|
+
return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
|
460
|
+
}
|
461
|
+
|
462
|
+
// IndexScalarQuantizer
|
463
|
+
if (match(sq_pattern)) {
|
464
|
+
return new IndexScalarQuantizer(d, sq_types[description], metric);
|
465
|
+
}
|
466
|
+
|
467
|
+
// IndexPQ
|
468
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
469
|
+
int M = std::stoi(sm[1].str());
|
470
|
+
int nbit = mres_to_int(sm[2], 8, 1);
|
471
|
+
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
472
|
+
index_pq->do_polysemous_training = sm[3].str() != "np";
|
473
|
+
return index_pq;
|
474
|
+
}
|
475
|
+
|
476
|
+
// IndexPQFastScan
|
477
|
+
if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
|
478
|
+
int M = std::stoi(sm[1].str());
|
479
|
+
int bbs = mres_to_int(sm[2], 32, 1);
|
480
|
+
return new IndexPQFastScan(d, M, 4, metric, bbs);
|
481
|
+
}
|
482
|
+
|
483
|
+
// IndexResidualCoarseQuantizer and IndexResidualQuantizer
|
484
|
+
std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
|
485
|
+
if (match(pattern)) {
|
486
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
487
|
+
if (sm[1].str() == "RCQ") {
|
488
|
+
return new ResidualCoarseQuantizer(d, nbits, metric);
|
468
489
|
}
|
490
|
+
AdditiveQuantizer::Search_type_t st =
|
491
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
492
|
+
return new IndexResidualQuantizer(d, nbits, metric, st);
|
493
|
+
}
|
469
494
|
|
470
|
-
|
471
|
-
|
495
|
+
// LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
|
496
|
+
if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
|
497
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
498
|
+
int M = mres_to_int(sm[2]);
|
499
|
+
int nbit = mres_to_int(sm[3]);
|
500
|
+
if (sm[1].str() == "LSCQ") {
|
501
|
+
return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
|
472
502
|
}
|
503
|
+
AdditiveQuantizer::Search_type_t st =
|
504
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
505
|
+
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
|
506
|
+
}
|
507
|
+
|
508
|
+
return nullptr;
|
509
|
+
}
|
510
|
+
|
511
|
+
/***************************************************************
|
512
|
+
* Driver function
|
513
|
+
*/
|
514
|
+
std::unique_ptr<Index> index_factory_sub(
|
515
|
+
int d,
|
516
|
+
std::string description,
|
517
|
+
MetricType metric) {
|
518
|
+
// handle composite indexes
|
519
|
+
|
520
|
+
bool verbose = index_factory_verbose;
|
473
521
|
|
474
|
-
|
475
|
-
|
476
|
-
|
522
|
+
if (verbose) {
|
523
|
+
printf("begin parse VectorTransforms: %s \n", description.c_str());
|
524
|
+
}
|
525
|
+
|
526
|
+
// for the current match
|
527
|
+
std::smatch sm;
|
528
|
+
|
529
|
+
// handle refines
|
530
|
+
if (re_match(description, "(.+),RFlat", sm) ||
|
531
|
+
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
|
532
|
+
std::unique_ptr<Index> filter_index =
|
533
|
+
index_factory_sub(d, sm[1].str(), metric);
|
534
|
+
std::unique_ptr<Index> refine_index;
|
535
|
+
|
536
|
+
if (sm.size() == 3) { // Refine
|
537
|
+
refine_index = index_factory_sub(d, sm[2].str(), metric);
|
538
|
+
} else { // RFlat
|
539
|
+
refine_index.reset(new IndexFlat(d, metric));
|
477
540
|
}
|
541
|
+
IndexRefine* index_rf =
|
542
|
+
new IndexRefine(filter_index.get(), refine_index.get());
|
543
|
+
index_rf->own_fields = true;
|
544
|
+
filter_index.release();
|
545
|
+
refine_index.release();
|
546
|
+
index_rf->own_refine_index = true;
|
547
|
+
return std::unique_ptr<Index>(index_rf);
|
548
|
+
}
|
478
549
|
|
479
|
-
|
480
|
-
|
481
|
-
|
550
|
+
// IndexPreTransform
|
551
|
+
// should handle this first (even before parentheses) because it changes d
|
552
|
+
std::vector<std::unique_ptr<VectorTransform>> vts;
|
553
|
+
VectorTransform* vt = nullptr;
|
554
|
+
while (re_match(description, "([^,]+),(.*)", sm) &&
|
555
|
+
(vt = parse_VectorTransform(sm[1], d))) {
|
556
|
+
// reset loop
|
557
|
+
description = sm[sm.size() - 1];
|
558
|
+
vts.emplace_back(vt);
|
559
|
+
d = vts.back()->d_out;
|
560
|
+
}
|
561
|
+
|
562
|
+
if (vts.size() > 0) {
|
563
|
+
std::unique_ptr<Index> sub_index =
|
564
|
+
index_factory_sub(d, description, metric);
|
565
|
+
IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
|
566
|
+
std::unique_ptr<Index> ret(index_pt);
|
567
|
+
index_pt->own_fields = true;
|
568
|
+
sub_index.release();
|
569
|
+
while (vts.size() > 0) {
|
570
|
+
if (verbose) {
|
571
|
+
printf("prepend trans %d -> %d\n",
|
572
|
+
vts.back()->d_in,
|
573
|
+
vts.back()->d_out);
|
574
|
+
}
|
575
|
+
index_pt->prepend_transform(vts.back().release());
|
576
|
+
vts.pop_back();
|
482
577
|
}
|
578
|
+
return ret;
|
483
579
|
}
|
484
580
|
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
581
|
+
// what we got from the parentheses
|
582
|
+
std::vector<std::unique_ptr<Index>> parenthesis_indexes;
|
583
|
+
|
584
|
+
int begin = 0;
|
585
|
+
while (description.find('(', begin) != std::string::npos) {
|
586
|
+
// replace indexes in () with Index0, Index1, etc.
|
587
|
+
int i0, i1;
|
588
|
+
find_matching_parentheses(description, i0, i1, begin);
|
589
|
+
std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
|
590
|
+
int no = parenthesis_indexes.size();
|
591
|
+
parenthesis_indexes.push_back(
|
592
|
+
index_factory_sub(d, sub_description, metric));
|
593
|
+
description = description.substr(0, i0 + 1) + "Index" +
|
594
|
+
std::to_string(no) + description.substr(i1);
|
595
|
+
begin = i1 + 1;
|
491
596
|
}
|
492
597
|
|
493
|
-
|
494
|
-
|
598
|
+
if (verbose) {
|
599
|
+
printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
|
600
|
+
description.c_str(),
|
601
|
+
parenthesis_indexes.size(),
|
602
|
+
d);
|
603
|
+
}
|
495
604
|
|
496
|
-
//
|
497
|
-
|
498
|
-
|
605
|
+
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
606
|
+
// support both
|
607
|
+
if (re_match(description, "(.+),IDMap", sm) ||
|
608
|
+
re_match(description, "IDMap,(.+)", sm)) {
|
609
|
+
IndexIDMap* idmap = new IndexIDMap(
|
610
|
+
index_factory_sub(d, sm[1].str(), metric).release());
|
611
|
+
idmap->own_fields = true;
|
612
|
+
return std::unique_ptr<Index>(idmap);
|
613
|
+
}
|
499
614
|
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
615
|
+
{ // handle basic index types
|
616
|
+
Index* index = parse_other_indexes(description, d, metric);
|
617
|
+
if (index) {
|
618
|
+
return std::unique_ptr<Index>(index);
|
619
|
+
}
|
504
620
|
}
|
505
621
|
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
622
|
+
// HNSW variants (it was unclear in the old version that the separator was a
|
623
|
+
// "," so we support both "_" and ",")
|
624
|
+
if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
|
625
|
+
int hnsw_M = mres_to_int(sm[1], 32);
|
626
|
+
// We also accept empty code string (synonym of Flat)
|
627
|
+
std::string code_string =
|
628
|
+
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
629
|
+
if (verbose) {
|
630
|
+
printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
|
631
|
+
description.c_str(),
|
632
|
+
code_string.c_str(),
|
633
|
+
hnsw_M);
|
513
634
|
}
|
514
|
-
|
635
|
+
|
636
|
+
IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
|
637
|
+
FAISS_THROW_IF_NOT_FMT(
|
638
|
+
index,
|
639
|
+
"could not parse HNSW code description %s in %s",
|
640
|
+
code_string.c_str(),
|
641
|
+
description.c_str());
|
642
|
+
return std::unique_ptr<Index>(index);
|
515
643
|
}
|
516
644
|
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
645
|
+
// IndexIVF
|
646
|
+
{
|
647
|
+
size_t nlist;
|
648
|
+
bool use_2layer;
|
649
|
+
size_t comma = description.find(",");
|
650
|
+
std::string coarse_string = description.substr(0, comma);
|
651
|
+
// Match coarse quantizer part first
|
652
|
+
std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
|
653
|
+
description.substr(0, comma),
|
654
|
+
d,
|
655
|
+
metric,
|
656
|
+
parenthesis_indexes,
|
657
|
+
nlist,
|
658
|
+
use_2layer));
|
659
|
+
|
660
|
+
if (comma != std::string::npos && quantizer.get()) {
|
661
|
+
std::string code_description = description.substr(comma + 1);
|
662
|
+
if (use_2layer) {
|
663
|
+
bool ok =
|
664
|
+
re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
|
665
|
+
FAISS_THROW_IF_NOT_FMT(
|
666
|
+
ok,
|
667
|
+
"could not parse 2 layer code description %s in %s",
|
668
|
+
code_description.c_str(),
|
669
|
+
description.c_str());
|
670
|
+
int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
|
671
|
+
Index2Layer* index_2l =
|
672
|
+
new Index2Layer(quantizer.release(), nlist, M, nbit);
|
673
|
+
index_2l->q1.own_fields = true;
|
674
|
+
index_2l->q1.quantizer_trains_alone =
|
675
|
+
get_trains_alone(index_2l->q1.quantizer);
|
676
|
+
return std::unique_ptr<Index>(index_2l);
|
677
|
+
}
|
678
|
+
|
679
|
+
IndexIVF* index_ivf =
|
680
|
+
parse_IndexIVF(code_description, quantizer, nlist, metric);
|
681
|
+
|
682
|
+
FAISS_THROW_IF_NOT_FMT(
|
683
|
+
index_ivf,
|
684
|
+
"could not parse code description %s in %s",
|
685
|
+
code_description.c_str(),
|
686
|
+
description.c_str());
|
687
|
+
return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
|
688
|
+
}
|
524
689
|
}
|
690
|
+
FAISS_THROW_FMT("could not parse index string %s", description.c_str());
|
691
|
+
return nullptr;
|
692
|
+
}
|
525
693
|
|
526
|
-
|
694
|
+
} // anonymous namespace
|
695
|
+
|
696
|
+
Index* index_factory(int d, const char* description, MetricType metric) {
|
697
|
+
return index_factory_sub(d, description, metric).release();
|
527
698
|
}
|
528
699
|
|
529
700
|
IndexBinary* index_binary_factory(int d, const char* description) {
|