faiss 0.2.3 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,19 +5,19 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
/*
|
11
|
-
* implementation of
|
9
|
+
* implementation of the index_factory function. Lots of regex parsing code.
|
12
10
|
*/
|
13
11
|
|
14
12
|
#include <faiss/index_factory.h>
|
15
|
-
|
16
|
-
#include
|
13
|
+
#include "faiss/MetricType.h"
|
14
|
+
#include "faiss/impl/FaissAssert.h"
|
17
15
|
|
18
16
|
#include <cinttypes>
|
19
17
|
#include <cmath>
|
20
18
|
|
19
|
+
#include <map>
|
20
|
+
|
21
21
|
#include <regex>
|
22
22
|
|
23
23
|
#include <faiss/impl/FaissAssert.h>
|
@@ -25,13 +25,16 @@
|
|
25
25
|
#include <faiss/utils/utils.h>
|
26
26
|
|
27
27
|
#include <faiss/Index2Layer.h>
|
28
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
28
29
|
#include <faiss/IndexFlat.h>
|
29
30
|
#include <faiss/IndexHNSW.h>
|
30
31
|
#include <faiss/IndexIVF.h>
|
32
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
31
33
|
#include <faiss/IndexIVFFlat.h>
|
32
34
|
#include <faiss/IndexIVFPQ.h>
|
33
35
|
#include <faiss/IndexIVFPQFastScan.h>
|
34
36
|
#include <faiss/IndexIVFPQR.h>
|
37
|
+
#include <faiss/IndexIVFSpectralHash.h>
|
35
38
|
#include <faiss/IndexLSH.h>
|
36
39
|
#include <faiss/IndexLattice.h>
|
37
40
|
#include <faiss/IndexNSG.h>
|
@@ -39,7 +42,6 @@
|
|
39
42
|
#include <faiss/IndexPQFastScan.h>
|
40
43
|
#include <faiss/IndexPreTransform.h>
|
41
44
|
#include <faiss/IndexRefine.h>
|
42
|
-
#include <faiss/IndexResidual.h>
|
43
45
|
#include <faiss/IndexScalarQuantizer.h>
|
44
46
|
#include <faiss/MetaIndexes.h>
|
45
47
|
#include <faiss/VectorTransform.h>
|
@@ -48,6 +50,7 @@
|
|
48
50
|
#include <faiss/IndexBinaryHNSW.h>
|
49
51
|
#include <faiss/IndexBinaryHash.h>
|
50
52
|
#include <faiss/IndexBinaryIVF.h>
|
53
|
+
#include <string>
|
51
54
|
|
52
55
|
namespace faiss {
|
53
56
|
|
@@ -55,16 +58,49 @@ namespace faiss {
|
|
55
58
|
* index_factory
|
56
59
|
***************************************************************/
|
57
60
|
|
61
|
+
int index_factory_verbose = 0;
|
62
|
+
|
58
63
|
namespace {
|
59
64
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
+
/***************************************************************
|
66
|
+
* Small functions
|
67
|
+
*/
|
68
|
+
|
69
|
+
bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
|
70
|
+
return std::regex_match(s, sm, std::regex(pat));
|
71
|
+
}
|
72
|
+
|
73
|
+
// find first pair of matching parentheses
|
74
|
+
void find_matching_parentheses(
|
75
|
+
const std::string& s,
|
76
|
+
int& i0,
|
77
|
+
int& i1,
|
78
|
+
int begin = 0) {
|
79
|
+
int st = 0;
|
80
|
+
i0 = i1 = 0;
|
81
|
+
for (int i = begin; i < s.length(); i++) {
|
82
|
+
if (s[i] == '(') {
|
83
|
+
if (st == 0) {
|
84
|
+
i0 = i;
|
85
|
+
}
|
86
|
+
st++;
|
87
|
+
}
|
88
|
+
|
89
|
+
if (s[i] == ')') {
|
90
|
+
st--;
|
91
|
+
if (st == 0) {
|
92
|
+
i1 = i;
|
93
|
+
return;
|
94
|
+
}
|
95
|
+
if (st < 0) {
|
96
|
+
FAISS_THROW_FMT(
|
97
|
+
"factory string %s: unbalanced parentheses", s.c_str());
|
98
|
+
}
|
65
99
|
}
|
66
100
|
}
|
67
|
-
|
101
|
+
FAISS_THROW_FMT(
|
102
|
+
"factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
|
103
|
+
}
|
68
104
|
|
69
105
|
/// what kind of training does this coarse quantizer require?
|
70
106
|
char get_trains_alone(const Index* coarse_quantizer) {
|
@@ -83,447 +119,582 @@ char get_trains_alone(const Index* coarse_quantizer) {
|
|
83
119
|
// kmeans index
|
84
120
|
}
|
85
121
|
|
86
|
-
|
87
|
-
|
122
|
+
// set the fields for factory-constructed IVF structures
|
123
|
+
IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
|
124
|
+
index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
|
125
|
+
index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
|
126
|
+
index_ivf->own_fields = true;
|
127
|
+
return index_ivf;
|
88
128
|
}
|
89
129
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
for (i = s.length() - 1; i >= 0; i--) {
|
94
|
-
if (!isdigit(s[i]))
|
95
|
-
break;
|
130
|
+
int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
|
131
|
+
if (mr.length() == 0) {
|
132
|
+
return deflt;
|
96
133
|
}
|
97
|
-
return
|
134
|
+
return std::stoi(mr.str().substr(begin));
|
98
135
|
}
|
99
136
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
}
|
137
|
+
std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
|
138
|
+
{"SQ8", ScalarQuantizer::QT_8bit},
|
139
|
+
{"SQ4", ScalarQuantizer::QT_4bit},
|
140
|
+
{"SQ6", ScalarQuantizer::QT_6bit},
|
141
|
+
{"SQfp16", ScalarQuantizer::QT_fp16},
|
142
|
+
};
|
143
|
+
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
|
144
|
+
|
145
|
+
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
|
146
|
+
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
|
147
|
+
{"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
|
148
|
+
{"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
|
149
|
+
{"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
|
150
|
+
{"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
|
151
|
+
{"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
|
152
|
+
};
|
110
153
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
}
|
154
|
+
const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
|
155
|
+
const std::string aq_norm_pattern =
|
156
|
+
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
|
157
|
+
|
158
|
+
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
159
|
+
std::string stok,
|
160
|
+
MetricType metric) {
|
161
|
+
if (stok == "") {
|
162
|
+
return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
|
163
|
+
: AdditiveQuantizer::ST_LUT_nonorm;
|
122
164
|
}
|
123
|
-
|
124
|
-
|
165
|
+
int pos = stok.rfind("_");
|
166
|
+
return aq_search_type[stok.substr(pos)];
|
125
167
|
}
|
126
168
|
|
127
|
-
|
169
|
+
std::vector<size_t> aq_parse_nbits(std::string stok) {
|
170
|
+
std::vector<size_t> nbits;
|
171
|
+
std::smatch sm;
|
172
|
+
while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
173
|
+
int M = std::stoi(sm[1].str());
|
174
|
+
int nbit = std::stoi(sm[2].str());
|
175
|
+
nbits.resize(nbits.size() + M, nbit);
|
176
|
+
stok = sm.suffix();
|
177
|
+
}
|
178
|
+
return nbits;
|
179
|
+
}
|
128
180
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
Index* coarse_quantizer = nullptr;
|
133
|
-
std::string parenthesis_ivf, parenthesis_refine;
|
134
|
-
Index* index = nullptr;
|
135
|
-
bool add_idmap = false;
|
136
|
-
int d_in = d;
|
181
|
+
/***************************************************************
|
182
|
+
* Parse VectorTransform
|
183
|
+
*/
|
137
184
|
|
138
|
-
|
185
|
+
VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
186
|
+
std::smatch sm;
|
187
|
+
auto match = [&sm, description](std::string pattern) {
|
188
|
+
return re_match(description, pattern, sm);
|
189
|
+
};
|
190
|
+
if (match("PCA(W?)(R?)([0-9]+)")) {
|
191
|
+
bool white = sm[1].length() > 0;
|
192
|
+
bool rot = sm[2].length() > 0;
|
193
|
+
return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
|
194
|
+
}
|
195
|
+
if (match("L2[nN]orm")) {
|
196
|
+
return new NormalizationTransform(d, 2.0);
|
197
|
+
}
|
198
|
+
if (match("RR([0-9]+)?")) {
|
199
|
+
return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
|
200
|
+
}
|
201
|
+
if (match("ITQ([0-9]+)?")) {
|
202
|
+
return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
|
203
|
+
}
|
204
|
+
if (match("OPQ([0-9]+)(_[0-9]+)?")) {
|
205
|
+
int M = std::stoi(sm[1].str());
|
206
|
+
int d_out = mres_to_int(sm[2], d, 1);
|
207
|
+
return new OPQMatrix(d, M, d_out);
|
208
|
+
}
|
209
|
+
if (match("Pad([0-9]+)")) {
|
210
|
+
int d_out = std::stoi(sm[1].str());
|
211
|
+
return new RemapDimensionsTransform(d, std::max(d_out, d), false);
|
212
|
+
}
|
213
|
+
return nullptr;
|
214
|
+
};
|
139
215
|
|
140
|
-
|
141
|
-
|
216
|
+
/***************************************************************
|
217
|
+
* Parse IndexIVF
|
218
|
+
*/
|
142
219
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
int
|
147
|
-
|
220
|
+
// parsing guard + function
|
221
|
+
Index* parse_coarse_quantizer(
|
222
|
+
const std::string& description,
|
223
|
+
int d,
|
224
|
+
MetricType mt,
|
225
|
+
std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
|
226
|
+
size_t& nlist,
|
227
|
+
bool& use_2layer) {
|
228
|
+
std::smatch sm;
|
229
|
+
auto match = [&sm, description](std::string pattern) {
|
230
|
+
return re_match(description, pattern, sm);
|
231
|
+
};
|
232
|
+
use_2layer = false;
|
233
|
+
|
234
|
+
if (match("IVF([0-9]+)")) {
|
235
|
+
nlist = std::stoi(sm[1].str());
|
236
|
+
return new IndexFlat(d, mt);
|
237
|
+
}
|
238
|
+
if (match("IMI2x([0-9]+)")) {
|
239
|
+
int nbit = std::stoi(sm[1].str());
|
240
|
+
FAISS_THROW_IF_NOT_MSG(
|
241
|
+
mt == METRIC_L2,
|
242
|
+
"MultiIndex not implemented for inner prod search");
|
243
|
+
nlist = (size_t)1 << (2 * nbit);
|
244
|
+
return new MultiIndexQuantizer(d, 2, nbit);
|
245
|
+
}
|
246
|
+
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
|
247
|
+
nlist = std::stoi(sm[1].str());
|
248
|
+
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
|
249
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
250
|
+
}
|
251
|
+
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
|
252
|
+
nlist = std::stoi(sm[1].str());
|
253
|
+
int R = std::stoi(sm[2]);
|
254
|
+
return new IndexNSGFlat(d, R, mt);
|
255
|
+
}
|
256
|
+
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
|
257
|
+
nlist = std::stoi(sm[1].str());
|
258
|
+
int no = std::stoi(sm[2].str());
|
259
|
+
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
|
260
|
+
return parenthesis_indexes[no].release();
|
261
|
+
}
|
148
262
|
|
149
|
-
|
263
|
+
// these two generate Index2Layer's not IndexIVF's
|
264
|
+
if (match("Residual([0-9]+)x([0-9]+)")) {
|
265
|
+
FAISS_THROW_IF_NOT_MSG(
|
266
|
+
mt == METRIC_L2,
|
267
|
+
"MultiIndex not implemented for inner prod search");
|
268
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
|
269
|
+
nlist = (size_t)1 << (M * nbit);
|
270
|
+
use_2layer = true;
|
271
|
+
return new MultiIndexQuantizer(d, M, nbit);
|
272
|
+
}
|
273
|
+
if (match("Residual([0-9]+)")) {
|
274
|
+
FAISS_THROW_IF_NOT_MSG(
|
275
|
+
mt == METRIC_L2,
|
276
|
+
"Residual not implemented for inner prod search");
|
277
|
+
use_2layer = true;
|
278
|
+
nlist = mres_to_int(sm[1]);
|
279
|
+
return new IndexFlatL2(d);
|
280
|
+
}
|
281
|
+
return nullptr;
|
282
|
+
}
|
150
283
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
284
|
+
// parse the code description part of the IVF description
|
285
|
+
|
286
|
+
IndexIVF* parse_IndexIVF(
|
287
|
+
const std::string& code_string,
|
288
|
+
std::unique_ptr<Index>& quantizer,
|
289
|
+
size_t nlist,
|
290
|
+
MetricType mt) {
|
291
|
+
std::smatch sm;
|
292
|
+
auto match = [&sm, &code_string](const std::string pattern) {
|
293
|
+
return re_match(code_string, pattern, sm);
|
294
|
+
};
|
295
|
+
auto get_q = [&quantizer] { return quantizer.release(); };
|
296
|
+
int d = quantizer->d;
|
297
|
+
|
298
|
+
if (match("Flat")) {
|
299
|
+
return new IndexIVFFlat(get_q(), d, nlist, mt);
|
300
|
+
}
|
301
|
+
if (match("FlatDedup")) {
|
302
|
+
return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
|
303
|
+
}
|
304
|
+
if (match(sq_pattern)) {
|
305
|
+
return new IndexIVFScalarQuantizer(
|
306
|
+
get_q(), d, nlist, sq_types[sm[1].str()], mt);
|
307
|
+
}
|
308
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
309
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
|
310
|
+
IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
|
311
|
+
index_ivf->do_polysemous_training = sm[3].str() != "np";
|
312
|
+
return index_ivf;
|
313
|
+
}
|
314
|
+
if (match("PQ([0-9]+)\\+([0-9]+)")) {
|
315
|
+
FAISS_THROW_IF_NOT_MSG(
|
316
|
+
mt == METRIC_L2,
|
317
|
+
"IVFPQR not implemented for inner product search");
|
318
|
+
int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
|
319
|
+
return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
|
320
|
+
}
|
321
|
+
if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
|
322
|
+
int M = mres_to_int(sm[1]);
|
323
|
+
int bbs = mres_to_int(sm[3], 32, 1);
|
324
|
+
IndexIVFPQFastScan* index_ivf =
|
325
|
+
new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
|
326
|
+
index_ivf->by_residual = sm[2].str() == "r";
|
327
|
+
return index_ivf;
|
328
|
+
}
|
329
|
+
if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
|
330
|
+
std::vector<size_t> nbits = aq_parse_nbits(sm.str());
|
331
|
+
AdditiveQuantizer::Search_type_t st =
|
332
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
333
|
+
IndexIVF* index_ivf;
|
334
|
+
if (sm[1].str() == "RQ") {
|
335
|
+
index_ivf = new IndexIVFResidualQuantizer(
|
336
|
+
get_q(), d, nlist, nbits, mt, st);
|
155
337
|
} else {
|
156
|
-
|
338
|
+
FAISS_THROW_IF_NOT(nbits.size() > 0);
|
339
|
+
index_ivf = new IndexIVFLocalSearchQuantizer(
|
340
|
+
get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
|
157
341
|
}
|
158
|
-
|
159
|
-
}
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
std::string stok(tok);
|
170
|
-
nbit = 8;
|
171
|
-
int bbs = -1;
|
172
|
-
char c;
|
173
|
-
|
174
|
-
// to avoid mem leaks with exceptions:
|
175
|
-
// do all tests before any instanciation
|
176
|
-
|
177
|
-
VectorTransform* vt_1 = nullptr;
|
178
|
-
Index* coarse_quantizer_1 = nullptr;
|
179
|
-
Index* index_1 = nullptr;
|
180
|
-
|
181
|
-
// VectorTransforms
|
182
|
-
if (sscanf(tok, "PCA%d", &d_out) == 1) {
|
183
|
-
vt_1 = new PCAMatrix(d, d_out);
|
184
|
-
d = d_out;
|
185
|
-
} else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
|
186
|
-
vt_1 = new PCAMatrix(d, d_out, 0, true);
|
187
|
-
d = d_out;
|
188
|
-
} else if (sscanf(tok, "RR%d", &d_out) == 1) {
|
189
|
-
vt_1 = new RandomRotationMatrix(d, d_out);
|
190
|
-
d = d_out;
|
191
|
-
} else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
|
192
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, false);
|
193
|
-
d = d_out;
|
194
|
-
} else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
|
195
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, true);
|
196
|
-
d = d_out;
|
197
|
-
} else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
|
198
|
-
vt_1 = new OPQMatrix(d, opq_M, d_out);
|
199
|
-
d = d_out;
|
200
|
-
} else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
|
201
|
-
vt_1 = new OPQMatrix(d, opq_M);
|
202
|
-
} else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
|
203
|
-
vt_1 = new ITQTransform(d, d_out, true);
|
204
|
-
d = d_out;
|
205
|
-
} else if (stok == "ITQ") {
|
206
|
-
vt_1 = new ITQTransform(d, d, false);
|
207
|
-
} else if (sscanf(tok, "Pad%d", &d_out) == 1) {
|
208
|
-
if (d_out > d) {
|
209
|
-
vt_1 = new RemapDimensionsTransform(d, d_out, false);
|
210
|
-
d = d_out;
|
211
|
-
}
|
212
|
-
} else if (stok == "L2norm") {
|
213
|
-
vt_1 = new NormalizationTransform(d, 2.0);
|
214
|
-
|
215
|
-
// coarse quantizers
|
216
|
-
} else if (
|
217
|
-
!coarse_quantizer &&
|
218
|
-
sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
|
219
|
-
coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
|
220
|
-
|
221
|
-
} else if (
|
222
|
-
!coarse_quantizer &&
|
223
|
-
sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
|
224
|
-
coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
|
225
|
-
|
226
|
-
} else if (
|
227
|
-
!coarse_quantizer &&
|
228
|
-
sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
|
229
|
-
if (!parenthesis_ivf.empty()) {
|
230
|
-
coarse_quantizer_1 =
|
231
|
-
index_factory(d, parenthesis_ivf.c_str(), metric);
|
232
|
-
} else if (metric == METRIC_L2) {
|
233
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
234
|
-
} else {
|
235
|
-
coarse_quantizer_1 = new IndexFlatIP(d);
|
236
|
-
}
|
237
|
-
|
238
|
-
} else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
|
239
|
-
FAISS_THROW_IF_NOT_MSG(
|
240
|
-
metric == METRIC_L2,
|
241
|
-
"MultiIndex not implemented for inner prod search");
|
242
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
|
243
|
-
ncentroids = 1 << (2 * nbit);
|
244
|
-
|
245
|
-
} else if (
|
246
|
-
!coarse_quantizer &&
|
247
|
-
sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
|
248
|
-
FAISS_THROW_IF_NOT_MSG(
|
249
|
-
metric == METRIC_L2,
|
250
|
-
"MultiIndex not implemented for inner prod search");
|
251
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
|
252
|
-
ncentroids = int64_t(1) << (M * nbit);
|
253
|
-
use_2layer = true;
|
254
|
-
|
255
|
-
} else if (std::regex_match(
|
256
|
-
stok,
|
257
|
-
std::regex(
|
258
|
-
"(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
|
259
|
-
std::vector<size_t> nbits;
|
260
|
-
std::smatch sm;
|
261
|
-
bool is_RCQ = stok.find("RCQ") == 0;
|
262
|
-
while (std::regex_search(
|
263
|
-
stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
264
|
-
int M = std::stoi(sm[1].str());
|
265
|
-
int nbit = std::stoi(sm[2].str());
|
266
|
-
nbits.resize(nbits.size() + M, nbit);
|
267
|
-
stok = sm.suffix();
|
268
|
-
}
|
269
|
-
if (!is_RCQ) {
|
270
|
-
index_1 = new IndexResidual(d, nbits, metric);
|
271
|
-
} else {
|
272
|
-
index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
|
273
|
-
}
|
274
|
-
} else if (
|
275
|
-
!coarse_quantizer &&
|
276
|
-
sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
|
277
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
278
|
-
use_2layer = true;
|
279
|
-
|
280
|
-
} else if (stok == "IDMap") {
|
281
|
-
add_idmap = true;
|
282
|
-
|
283
|
-
// IVFs
|
284
|
-
} else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
|
285
|
-
if (coarse_quantizer) {
|
286
|
-
// if there was an IVF in front, then it is an IVFFlat
|
287
|
-
IndexIVF* index_ivf = stok == "Flat"
|
288
|
-
? new IndexIVFFlat(
|
289
|
-
coarse_quantizer, d, ncentroids, metric)
|
290
|
-
: new IndexIVFFlatDedup(
|
291
|
-
coarse_quantizer, d, ncentroids, metric);
|
292
|
-
index_ivf->quantizer_trains_alone =
|
293
|
-
get_trains_alone(coarse_quantizer);
|
294
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
295
|
-
del_coarse_quantizer.release();
|
296
|
-
index_ivf->own_fields = true;
|
297
|
-
index_1 = index_ivf;
|
298
|
-
} else if (hnsw_M > 0) {
|
299
|
-
index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
|
300
|
-
} else if (nsg_R > 0) {
|
301
|
-
index_1 = new IndexNSGFlat(d, nsg_R, metric);
|
302
|
-
} else {
|
303
|
-
FAISS_THROW_IF_NOT_MSG(
|
304
|
-
stok != "FlatDedup",
|
305
|
-
"dedup supported only for IVFFlat");
|
306
|
-
index_1 = new IndexFlat(d, metric);
|
307
|
-
}
|
308
|
-
} else if (
|
309
|
-
!index &&
|
310
|
-
(stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
|
311
|
-
stok == "SQfp16")) {
|
312
|
-
ScalarQuantizer::QuantizerType qt = stok == "SQ8"
|
313
|
-
? ScalarQuantizer::QT_8bit
|
314
|
-
: stok == "SQ6" ? ScalarQuantizer::QT_6bit
|
315
|
-
: stok == "SQ4" ? ScalarQuantizer::QT_4bit
|
316
|
-
: stok == "SQfp16" ? ScalarQuantizer::QT_fp16
|
317
|
-
: ScalarQuantizer::QT_4bit;
|
318
|
-
if (coarse_quantizer) {
|
319
|
-
FAISS_THROW_IF_NOT(!use_2layer);
|
320
|
-
IndexIVFScalarQuantizer* index_ivf =
|
321
|
-
new IndexIVFScalarQuantizer(
|
322
|
-
coarse_quantizer, d, ncentroids, qt, metric);
|
323
|
-
index_ivf->quantizer_trains_alone =
|
324
|
-
get_trains_alone(coarse_quantizer);
|
325
|
-
del_coarse_quantizer.release();
|
326
|
-
index_ivf->own_fields = true;
|
327
|
-
index_1 = index_ivf;
|
328
|
-
} else if (hnsw_M > 0) {
|
329
|
-
index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
|
330
|
-
} else {
|
331
|
-
index_1 = new IndexScalarQuantizer(d, qt, metric);
|
332
|
-
}
|
333
|
-
} else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
|
334
|
-
FAISS_THROW_IF_NOT_MSG(
|
335
|
-
coarse_quantizer, "PQ with + works only with an IVF");
|
336
|
-
FAISS_THROW_IF_NOT_MSG(
|
337
|
-
metric == METRIC_L2,
|
338
|
-
"IVFPQR not implemented for inner product search");
|
339
|
-
IndexIVFPQR* index_ivf = new IndexIVFPQR(
|
340
|
-
coarse_quantizer, d, ncentroids, M, 8, M2, 8);
|
341
|
-
index_ivf->quantizer_trains_alone =
|
342
|
-
get_trains_alone(coarse_quantizer);
|
343
|
-
del_coarse_quantizer.release();
|
344
|
-
index_ivf->own_fields = true;
|
345
|
-
index_1 = index_ivf;
|
346
|
-
} else if (
|
347
|
-
!index &&
|
348
|
-
(sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
|
349
|
-
(sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
|
350
|
-
(sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
|
351
|
-
if (bbs == -1) {
|
352
|
-
bbs = 32;
|
353
|
-
}
|
354
|
-
bool by_residual = str_ends_with(stok, "fsr");
|
355
|
-
if (coarse_quantizer) {
|
356
|
-
IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
|
357
|
-
coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
|
358
|
-
index_ivf->quantizer_trains_alone =
|
359
|
-
get_trains_alone(coarse_quantizer);
|
360
|
-
index_ivf->metric_type = metric;
|
361
|
-
index_ivf->by_residual = by_residual;
|
362
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
363
|
-
del_coarse_quantizer.release();
|
364
|
-
index_ivf->own_fields = true;
|
365
|
-
index_1 = index_ivf;
|
366
|
-
} else {
|
367
|
-
IndexPQFastScan* index_pq =
|
368
|
-
new IndexPQFastScan(d, M, 4, metric, bbs);
|
369
|
-
index_1 = index_pq;
|
370
|
-
}
|
371
|
-
} else if (
|
372
|
-
!index &&
|
373
|
-
(sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
|
374
|
-
sscanf(tok, "PQ%d", &M) == 1 ||
|
375
|
-
sscanf(tok, "PQ%dnp", &M) == 1)) {
|
376
|
-
bool do_polysemous_training = stok.find("np") == std::string::npos;
|
377
|
-
if (coarse_quantizer) {
|
378
|
-
if (!use_2layer) {
|
379
|
-
IndexIVFPQ* index_ivf = new IndexIVFPQ(
|
380
|
-
coarse_quantizer, d, ncentroids, M, nbit);
|
381
|
-
index_ivf->quantizer_trains_alone =
|
382
|
-
get_trains_alone(coarse_quantizer);
|
383
|
-
index_ivf->metric_type = metric;
|
384
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
385
|
-
del_coarse_quantizer.release();
|
386
|
-
index_ivf->own_fields = true;
|
387
|
-
index_ivf->do_polysemous_training = do_polysemous_training;
|
388
|
-
index_1 = index_ivf;
|
389
|
-
} else {
|
390
|
-
Index2Layer* index_2l = new Index2Layer(
|
391
|
-
coarse_quantizer, ncentroids, M, nbit);
|
392
|
-
index_2l->q1.quantizer_trains_alone =
|
393
|
-
get_trains_alone(coarse_quantizer);
|
394
|
-
index_2l->q1.own_fields = true;
|
395
|
-
index_1 = index_2l;
|
396
|
-
}
|
397
|
-
} else if (hnsw_M > 0) {
|
398
|
-
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
399
|
-
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
400
|
-
do_polysemous_training;
|
401
|
-
index_1 = ipq;
|
402
|
-
} else {
|
403
|
-
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
404
|
-
index_pq->do_polysemous_training = do_polysemous_training;
|
405
|
-
index_1 = index_pq;
|
406
|
-
}
|
407
|
-
} else if (
|
408
|
-
!index &&
|
409
|
-
sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
|
410
|
-
Index* quant = new IndexFlatL2(d);
|
411
|
-
IndexHNSW2Level* hidx2l =
|
412
|
-
new IndexHNSW2Level(quant, ncent, pq_m, M);
|
413
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
414
|
-
idx2l->q1.own_fields = true;
|
415
|
-
index_1 = hidx2l;
|
416
|
-
} else if (
|
417
|
-
!index &&
|
418
|
-
sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
|
419
|
-
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
420
|
-
IndexHNSW2Level* hidx2l =
|
421
|
-
new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
|
422
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
423
|
-
idx2l->q1.own_fields = true;
|
424
|
-
idx2l->q1.quantizer_trains_alone = 1;
|
425
|
-
index_1 = hidx2l;
|
426
|
-
} else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
|
427
|
-
index_1 = new IndexHNSWPQ(d, pq_m, M);
|
428
|
-
} else if (
|
429
|
-
!index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
|
430
|
-
pq_m == 8) {
|
431
|
-
index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
|
432
|
-
} else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
|
433
|
-
hnsw_M = M;
|
434
|
-
// here it is unclear what we want: HNSW flat or HNSWx,Y ?
|
435
|
-
} else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
|
436
|
-
nsg_R = R;
|
437
|
-
} else if (
|
438
|
-
!index &&
|
439
|
-
(stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
|
440
|
-
stok == "LSHt")) {
|
441
|
-
bool rotate_data = strstr(tok, "r") != nullptr;
|
442
|
-
bool train_thresholds = strstr(tok, "t") != nullptr;
|
443
|
-
index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
|
444
|
-
} else if (
|
445
|
-
!index &&
|
446
|
-
sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
|
447
|
-
FAISS_THROW_IF_NOT(!coarse_quantizer);
|
448
|
-
index_1 = new IndexLattice(d, M, nbit, r2);
|
449
|
-
} else if (stok == "RFlat") {
|
450
|
-
parenthesis_refine = "Flat";
|
451
|
-
} else if (stok == "Refine") {
|
452
|
-
FAISS_THROW_IF_NOT_MSG(
|
453
|
-
!parenthesis_refine.empty(),
|
454
|
-
"Refine index should be provided in parentheses");
|
455
|
-
} else {
|
456
|
-
FAISS_THROW_FMT(
|
457
|
-
"could not parse token \"%s\" in %s\n",
|
458
|
-
tok,
|
459
|
-
description_in);
|
342
|
+
return index_ivf;
|
343
|
+
}
|
344
|
+
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
|
345
|
+
int outdim = mres_to_int(sm[2], d); // is also the number of bits
|
346
|
+
std::unique_ptr<VectorTransform> vt;
|
347
|
+
if (sm[1] == "ITQ") {
|
348
|
+
vt.reset(new ITQTransform(d, outdim, d != outdim));
|
349
|
+
} else if (sm[1] == "PCA") {
|
350
|
+
vt.reset(new PCAMatrix(d, outdim));
|
351
|
+
} else if (sm[1] == "PCAR") {
|
352
|
+
vt.reset(new PCAMatrix(d, outdim, 0, true));
|
460
353
|
}
|
354
|
+
// the rationale for -1e10 is that this corresponds to simple
|
355
|
+
// thresholding
|
356
|
+
float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
|
357
|
+
IndexIVFSpectralHash* index_ivf =
|
358
|
+
new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
|
359
|
+
index_ivf->replace_vt(vt.release(), true);
|
360
|
+
if (sm[4].length()) {
|
361
|
+
std::string s = sm[4].str();
|
362
|
+
index_ivf->threshold_type = s == "g"
|
363
|
+
? IndexIVFSpectralHash::Thresh_global
|
364
|
+
: s == "c"
|
365
|
+
? IndexIVFSpectralHash::Thresh_centroid
|
366
|
+
:
|
367
|
+
/* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
|
368
|
+
}
|
369
|
+
return index_ivf;
|
370
|
+
}
|
371
|
+
return nullptr;
|
372
|
+
}
|
373
|
+
|
374
|
+
/***************************************************************
|
375
|
+
* Parse IndexHNSW
|
376
|
+
*/
|
377
|
+
|
378
|
+
IndexHNSW* parse_IndexHNSW(
|
379
|
+
const std::string code_string,
|
380
|
+
int d,
|
381
|
+
MetricType mt,
|
382
|
+
int hnsw_M) {
|
383
|
+
std::smatch sm;
|
384
|
+
auto match = [&sm, &code_string](const std::string& pattern) {
|
385
|
+
return re_match(code_string, pattern, sm);
|
386
|
+
};
|
387
|
+
|
388
|
+
if (match("Flat|")) {
|
389
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
390
|
+
}
|
391
|
+
if (match("PQ([0-9]+)(np)?")) {
|
392
|
+
int M = std::stoi(sm[1].str());
|
393
|
+
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
394
|
+
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
395
|
+
sm[2].str() != "np";
|
396
|
+
return ipq;
|
397
|
+
}
|
398
|
+
if (match(sq_pattern)) {
|
399
|
+
return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
|
400
|
+
}
|
401
|
+
if (match("([0-9]+)\\+PQ([0-9]+)?")) {
|
402
|
+
int ncent = mres_to_int(sm[1]);
|
403
|
+
int pq_m = mres_to_int(sm[2]);
|
404
|
+
IndexHNSW2Level* hidx2l =
|
405
|
+
new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
|
406
|
+
dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
|
407
|
+
return hidx2l;
|
408
|
+
}
|
409
|
+
if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
|
410
|
+
int nbit = mres_to_int(sm[1]);
|
411
|
+
int pq_m = mres_to_int(sm[2]);
|
412
|
+
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
413
|
+
IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
|
414
|
+
quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
|
415
|
+
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
416
|
+
idx2l->q1.own_fields = true;
|
417
|
+
idx2l->q1.quantizer_trains_alone = 1;
|
418
|
+
return hidx2l;
|
419
|
+
}
|
461
420
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
421
|
+
return nullptr;
|
422
|
+
}
|
423
|
+
|
424
|
+
/***************************************************************
|
425
|
+
* Parse basic indexes
|
426
|
+
*/
|
427
|
+
|
428
|
+
Index* parse_other_indexes(
|
429
|
+
const std::string& description,
|
430
|
+
int d,
|
431
|
+
MetricType metric) {
|
432
|
+
std::smatch sm;
|
433
|
+
auto match = [&sm, description](const std::string& pattern) {
|
434
|
+
return re_match(description, pattern, sm);
|
435
|
+
};
|
436
|
+
|
437
|
+
// IndexFlat
|
438
|
+
if (description == "Flat") {
|
439
|
+
return new IndexFlat(d, metric);
|
440
|
+
}
|
441
|
+
|
442
|
+
// IndexLSH
|
443
|
+
if (match("LSH(r?)(t?)")) {
|
444
|
+
bool rotate_data = sm[1].length() > 0;
|
445
|
+
bool train_thresholds = sm[2].length() > 0;
|
446
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2);
|
447
|
+
return new IndexLSH(d, d, rotate_data, train_thresholds);
|
448
|
+
}
|
449
|
+
|
450
|
+
// IndexLattice
|
451
|
+
if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
|
452
|
+
int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
|
453
|
+
int nbit = std::stoi(sm[3].str());
|
454
|
+
return new IndexLattice(d, M, nbit, r2);
|
455
|
+
}
|
456
|
+
|
457
|
+
// IndexNSGFlat
|
458
|
+
if (match("NSG([0-9]+)(,Flat)?")) {
|
459
|
+
return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
|
460
|
+
}
|
461
|
+
|
462
|
+
// IndexScalarQuantizer
|
463
|
+
if (match(sq_pattern)) {
|
464
|
+
return new IndexScalarQuantizer(d, sq_types[description], metric);
|
465
|
+
}
|
466
|
+
|
467
|
+
// IndexPQ
|
468
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
469
|
+
int M = std::stoi(sm[1].str());
|
470
|
+
int nbit = mres_to_int(sm[2], 8, 1);
|
471
|
+
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
472
|
+
index_pq->do_polysemous_training = sm[3].str() != "np";
|
473
|
+
return index_pq;
|
474
|
+
}
|
475
|
+
|
476
|
+
// IndexPQFastScan
|
477
|
+
if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
|
478
|
+
int M = std::stoi(sm[1].str());
|
479
|
+
int bbs = mres_to_int(sm[2], 32, 1);
|
480
|
+
return new IndexPQFastScan(d, M, 4, metric, bbs);
|
481
|
+
}
|
482
|
+
|
483
|
+
// IndexResidualCoarseQuantizer and IndexResidualQuantizer
|
484
|
+
std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
|
485
|
+
if (match(pattern)) {
|
486
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
487
|
+
if (sm[1].str() == "RCQ") {
|
488
|
+
return new ResidualCoarseQuantizer(d, nbits, metric);
|
468
489
|
}
|
490
|
+
AdditiveQuantizer::Search_type_t st =
|
491
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
492
|
+
return new IndexResidualQuantizer(d, nbits, metric, st);
|
493
|
+
}
|
469
494
|
|
470
|
-
|
471
|
-
|
495
|
+
// LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
|
496
|
+
if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
|
497
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
498
|
+
int M = mres_to_int(sm[2]);
|
499
|
+
int nbit = mres_to_int(sm[3]);
|
500
|
+
if (sm[1].str() == "LSCQ") {
|
501
|
+
return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
|
472
502
|
}
|
503
|
+
AdditiveQuantizer::Search_type_t st =
|
504
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
505
|
+
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
|
506
|
+
}
|
507
|
+
|
508
|
+
return nullptr;
|
509
|
+
}
|
510
|
+
|
511
|
+
/***************************************************************
|
512
|
+
* Driver function
|
513
|
+
*/
|
514
|
+
std::unique_ptr<Index> index_factory_sub(
|
515
|
+
int d,
|
516
|
+
std::string description,
|
517
|
+
MetricType metric) {
|
518
|
+
// handle composite indexes
|
519
|
+
|
520
|
+
bool verbose = index_factory_verbose;
|
473
521
|
|
474
|
-
|
475
|
-
|
476
|
-
|
522
|
+
if (verbose) {
|
523
|
+
printf("begin parse VectorTransforms: %s \n", description.c_str());
|
524
|
+
}
|
525
|
+
|
526
|
+
// for the current match
|
527
|
+
std::smatch sm;
|
528
|
+
|
529
|
+
// handle refines
|
530
|
+
if (re_match(description, "(.+),RFlat", sm) ||
|
531
|
+
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
|
532
|
+
std::unique_ptr<Index> filter_index =
|
533
|
+
index_factory_sub(d, sm[1].str(), metric);
|
534
|
+
std::unique_ptr<Index> refine_index;
|
535
|
+
|
536
|
+
if (sm.size() == 3) { // Refine
|
537
|
+
refine_index = index_factory_sub(d, sm[2].str(), metric);
|
538
|
+
} else { // RFlat
|
539
|
+
refine_index.reset(new IndexFlat(d, metric));
|
477
540
|
}
|
541
|
+
IndexRefine* index_rf =
|
542
|
+
new IndexRefine(filter_index.get(), refine_index.get());
|
543
|
+
index_rf->own_fields = true;
|
544
|
+
filter_index.release();
|
545
|
+
refine_index.release();
|
546
|
+
index_rf->own_refine_index = true;
|
547
|
+
return std::unique_ptr<Index>(index_rf);
|
548
|
+
}
|
478
549
|
|
479
|
-
|
480
|
-
|
481
|
-
|
550
|
+
// IndexPreTransform
|
551
|
+
// should handle this first (even before parentheses) because it changes d
|
552
|
+
std::vector<std::unique_ptr<VectorTransform>> vts;
|
553
|
+
VectorTransform* vt = nullptr;
|
554
|
+
while (re_match(description, "([^,]+),(.*)", sm) &&
|
555
|
+
(vt = parse_VectorTransform(sm[1], d))) {
|
556
|
+
// reset loop
|
557
|
+
description = sm[sm.size() - 1];
|
558
|
+
vts.emplace_back(vt);
|
559
|
+
d = vts.back()->d_out;
|
560
|
+
}
|
561
|
+
|
562
|
+
if (vts.size() > 0) {
|
563
|
+
std::unique_ptr<Index> sub_index =
|
564
|
+
index_factory_sub(d, description, metric);
|
565
|
+
IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
|
566
|
+
std::unique_ptr<Index> ret(index_pt);
|
567
|
+
index_pt->own_fields = true;
|
568
|
+
sub_index.release();
|
569
|
+
while (vts.size() > 0) {
|
570
|
+
if (verbose) {
|
571
|
+
printf("prepend trans %d -> %d\n",
|
572
|
+
vts.back()->d_in,
|
573
|
+
vts.back()->d_out);
|
574
|
+
}
|
575
|
+
index_pt->prepend_transform(vts.back().release());
|
576
|
+
vts.pop_back();
|
482
577
|
}
|
578
|
+
return ret;
|
483
579
|
}
|
484
580
|
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
581
|
+
// what we got from the parentheses
|
582
|
+
std::vector<std::unique_ptr<Index>> parenthesis_indexes;
|
583
|
+
|
584
|
+
int begin = 0;
|
585
|
+
while (description.find('(', begin) != std::string::npos) {
|
586
|
+
// replace indexes in () with Index0, Index1, etc.
|
587
|
+
int i0, i1;
|
588
|
+
find_matching_parentheses(description, i0, i1, begin);
|
589
|
+
std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
|
590
|
+
int no = parenthesis_indexes.size();
|
591
|
+
parenthesis_indexes.push_back(
|
592
|
+
index_factory_sub(d, sub_description, metric));
|
593
|
+
description = description.substr(0, i0 + 1) + "Index" +
|
594
|
+
std::to_string(no) + description.substr(i1);
|
595
|
+
begin = i1 + 1;
|
491
596
|
}
|
492
597
|
|
493
|
-
|
494
|
-
|
598
|
+
if (verbose) {
|
599
|
+
printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
|
600
|
+
description.c_str(),
|
601
|
+
parenthesis_indexes.size(),
|
602
|
+
d);
|
603
|
+
}
|
495
604
|
|
496
|
-
//
|
497
|
-
|
498
|
-
|
605
|
+
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
606
|
+
// support both
|
607
|
+
if (re_match(description, "(.+),IDMap", sm) ||
|
608
|
+
re_match(description, "IDMap,(.+)", sm)) {
|
609
|
+
IndexIDMap* idmap = new IndexIDMap(
|
610
|
+
index_factory_sub(d, sm[1].str(), metric).release());
|
611
|
+
idmap->own_fields = true;
|
612
|
+
return std::unique_ptr<Index>(idmap);
|
613
|
+
}
|
499
614
|
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
615
|
+
{ // handle basic index types
|
616
|
+
Index* index = parse_other_indexes(description, d, metric);
|
617
|
+
if (index) {
|
618
|
+
return std::unique_ptr<Index>(index);
|
619
|
+
}
|
504
620
|
}
|
505
621
|
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
622
|
+
// HNSW variants (it was unclear in the old version that the separator was a
|
623
|
+
// "," so we support both "_" and ",")
|
624
|
+
if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
|
625
|
+
int hnsw_M = mres_to_int(sm[1], 32);
|
626
|
+
// We also accept empty code string (synonym of Flat)
|
627
|
+
std::string code_string =
|
628
|
+
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
629
|
+
if (verbose) {
|
630
|
+
printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
|
631
|
+
description.c_str(),
|
632
|
+
code_string.c_str(),
|
633
|
+
hnsw_M);
|
513
634
|
}
|
514
|
-
|
635
|
+
|
636
|
+
IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
|
637
|
+
FAISS_THROW_IF_NOT_FMT(
|
638
|
+
index,
|
639
|
+
"could not parse HNSW code description %s in %s",
|
640
|
+
code_string.c_str(),
|
641
|
+
description.c_str());
|
642
|
+
return std::unique_ptr<Index>(index);
|
515
643
|
}
|
516
644
|
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
645
|
+
// IndexIVF
|
646
|
+
{
|
647
|
+
size_t nlist;
|
648
|
+
bool use_2layer;
|
649
|
+
size_t comma = description.find(",");
|
650
|
+
std::string coarse_string = description.substr(0, comma);
|
651
|
+
// Match coarse quantizer part first
|
652
|
+
std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
|
653
|
+
description.substr(0, comma),
|
654
|
+
d,
|
655
|
+
metric,
|
656
|
+
parenthesis_indexes,
|
657
|
+
nlist,
|
658
|
+
use_2layer));
|
659
|
+
|
660
|
+
if (comma != std::string::npos && quantizer.get()) {
|
661
|
+
std::string code_description = description.substr(comma + 1);
|
662
|
+
if (use_2layer) {
|
663
|
+
bool ok =
|
664
|
+
re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
|
665
|
+
FAISS_THROW_IF_NOT_FMT(
|
666
|
+
ok,
|
667
|
+
"could not parse 2 layer code description %s in %s",
|
668
|
+
code_description.c_str(),
|
669
|
+
description.c_str());
|
670
|
+
int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
|
671
|
+
Index2Layer* index_2l =
|
672
|
+
new Index2Layer(quantizer.release(), nlist, M, nbit);
|
673
|
+
index_2l->q1.own_fields = true;
|
674
|
+
index_2l->q1.quantizer_trains_alone =
|
675
|
+
get_trains_alone(index_2l->q1.quantizer);
|
676
|
+
return std::unique_ptr<Index>(index_2l);
|
677
|
+
}
|
678
|
+
|
679
|
+
IndexIVF* index_ivf =
|
680
|
+
parse_IndexIVF(code_description, quantizer, nlist, metric);
|
681
|
+
|
682
|
+
FAISS_THROW_IF_NOT_FMT(
|
683
|
+
index_ivf,
|
684
|
+
"could not parse code description %s in %s",
|
685
|
+
code_description.c_str(),
|
686
|
+
description.c_str());
|
687
|
+
return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
|
688
|
+
}
|
524
689
|
}
|
690
|
+
FAISS_THROW_FMT("could not parse index string %s", description.c_str());
|
691
|
+
return nullptr;
|
692
|
+
}
|
525
693
|
|
526
|
-
|
694
|
+
} // anonymous namespace
|
695
|
+
|
696
|
+
Index* index_factory(int d, const char* description, MetricType metric) {
|
697
|
+
return index_factory_sub(d, description, metric).release();
|
527
698
|
}
|
528
699
|
|
529
700
|
IndexBinary* index_binary_factory(int d, const char* description) {
|