faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,19 +5,19 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /*
11
- * implementation of Hyper-parameter auto-tuning
9
+ * implementation of the index_factory function. Lots of regex parsing code.
12
10
  */
13
11
 
14
12
  #include <faiss/index_factory.h>
15
-
16
- #include <faiss/AutoTune.h>
13
+ #include "faiss/MetricType.h"
14
+ #include "faiss/impl/FaissAssert.h"
17
15
 
18
16
  #include <cinttypes>
19
17
  #include <cmath>
20
18
 
19
+ #include <map>
20
+
21
21
  #include <regex>
22
22
 
23
23
  #include <faiss/impl/FaissAssert.h>
@@ -25,13 +25,16 @@
25
25
  #include <faiss/utils/utils.h>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
+ #include <faiss/IndexAdditiveQuantizer.h>
28
29
  #include <faiss/IndexFlat.h>
29
30
  #include <faiss/IndexHNSW.h>
30
31
  #include <faiss/IndexIVF.h>
32
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
31
33
  #include <faiss/IndexIVFFlat.h>
32
34
  #include <faiss/IndexIVFPQ.h>
33
35
  #include <faiss/IndexIVFPQFastScan.h>
34
36
  #include <faiss/IndexIVFPQR.h>
37
+ #include <faiss/IndexIVFSpectralHash.h>
35
38
  #include <faiss/IndexLSH.h>
36
39
  #include <faiss/IndexLattice.h>
37
40
  #include <faiss/IndexNSG.h>
@@ -39,7 +42,6 @@
39
42
  #include <faiss/IndexPQFastScan.h>
40
43
  #include <faiss/IndexPreTransform.h>
41
44
  #include <faiss/IndexRefine.h>
42
- #include <faiss/IndexResidual.h>
43
45
  #include <faiss/IndexScalarQuantizer.h>
44
46
  #include <faiss/MetaIndexes.h>
45
47
  #include <faiss/VectorTransform.h>
@@ -48,6 +50,7 @@
48
50
  #include <faiss/IndexBinaryHNSW.h>
49
51
  #include <faiss/IndexBinaryHash.h>
50
52
  #include <faiss/IndexBinaryIVF.h>
53
+ #include <string>
51
54
 
52
55
  namespace faiss {
53
56
 
@@ -55,16 +58,49 @@ namespace faiss {
55
58
  * index_factory
56
59
  ***************************************************************/
57
60
 
61
+ int index_factory_verbose = 0;
62
+
58
63
  namespace {
59
64
 
60
- struct VTChain {
61
- std::vector<VectorTransform*> chain;
62
- ~VTChain() {
63
- for (int i = 0; i < chain.size(); i++) {
64
- delete chain[i];
65
+ /***************************************************************
66
+ * Small functions
67
+ */
68
+
69
+ bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
70
+ return std::regex_match(s, sm, std::regex(pat));
71
+ }
72
+
73
+ // find first pair of matching parentheses
74
+ void find_matching_parentheses(
75
+ const std::string& s,
76
+ int& i0,
77
+ int& i1,
78
+ int begin = 0) {
79
+ int st = 0;
80
+ i0 = i1 = 0;
81
+ for (int i = begin; i < s.length(); i++) {
82
+ if (s[i] == '(') {
83
+ if (st == 0) {
84
+ i0 = i;
85
+ }
86
+ st++;
87
+ }
88
+
89
+ if (s[i] == ')') {
90
+ st--;
91
+ if (st == 0) {
92
+ i1 = i;
93
+ return;
94
+ }
95
+ if (st < 0) {
96
+ FAISS_THROW_FMT(
97
+ "factory string %s: unbalanced parentheses", s.c_str());
98
+ }
65
99
  }
66
100
  }
67
- };
101
+ FAISS_THROW_FMT(
102
+ "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
103
+ }
68
104
 
69
105
  /// what kind of training does this coarse quantizer require?
70
106
  char get_trains_alone(const Index* coarse_quantizer) {
@@ -83,447 +119,582 @@ char get_trains_alone(const Index* coarse_quantizer) {
83
119
  // kmeans index
84
120
  }
85
121
 
86
- bool str_ends_with(const std::string& s, const std::string& suffix) {
87
- return s.rfind(suffix) == std::abs(int(s.size() - suffix.size()));
122
+ // set the fields for factory-constructed IVF structures
123
+ IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
124
+ index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
125
+ index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
126
+ index_ivf->own_fields = true;
127
+ return index_ivf;
88
128
  }
89
129
 
90
- // check if ends with suffix followed by digits
91
- bool str_ends_with_digits(const std::string& s, const std::string& suffix) {
92
- int i;
93
- for (i = s.length() - 1; i >= 0; i--) {
94
- if (!isdigit(s[i]))
95
- break;
130
+ int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
131
+ if (mr.length() == 0) {
132
+ return deflt;
96
133
  }
97
- return str_ends_with(s.substr(0, i + 1), suffix);
134
+ return std::stoi(mr.str().substr(begin));
98
135
  }
99
136
 
100
- void find_matching_parentheses(const std::string& s, int& i0, int& i1) {
101
- int st = 0;
102
- i0 = i1 = 0;
103
- for (int i = 0; i < s.length(); i++) {
104
- if (s[i] == '(') {
105
- if (st == 0) {
106
- i0 = i;
107
- }
108
- st++;
109
- }
137
+ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
138
+ {"SQ8", ScalarQuantizer::QT_8bit},
139
+ {"SQ4", ScalarQuantizer::QT_4bit},
140
+ {"SQ6", ScalarQuantizer::QT_6bit},
141
+ {"SQfp16", ScalarQuantizer::QT_fp16},
142
+ };
143
+ const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
144
+
145
+ std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
146
+ {"_Nfloat", AdditiveQuantizer::ST_norm_float},
147
+ {"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
148
+ {"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
149
+ {"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
150
+ {"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
151
+ {"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
152
+ };
110
153
 
111
- if (s[i] == ')') {
112
- st--;
113
- if (st == 0) {
114
- i1 = i;
115
- return;
116
- }
117
- if (st < 0) {
118
- FAISS_THROW_FMT(
119
- "factory string %s: unbalanced parentheses", s.c_str());
120
- }
121
- }
154
+ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
155
+ const std::string aq_norm_pattern =
156
+ "(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
157
+
158
+ AdditiveQuantizer::Search_type_t aq_parse_search_type(
159
+ std::string stok,
160
+ MetricType metric) {
161
+ if (stok == "") {
162
+ return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
163
+ : AdditiveQuantizer::ST_LUT_nonorm;
122
164
  }
123
- FAISS_THROW_FMT(
124
- "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
165
+ int pos = stok.rfind("_");
166
+ return aq_search_type[stok.substr(pos)];
125
167
  }
126
168
 
127
- } // anonymous namespace
169
+ std::vector<size_t> aq_parse_nbits(std::string stok) {
170
+ std::vector<size_t> nbits;
171
+ std::smatch sm;
172
+ while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
173
+ int M = std::stoi(sm[1].str());
174
+ int nbit = std::stoi(sm[2].str());
175
+ nbits.resize(nbits.size() + M, nbit);
176
+ stok = sm.suffix();
177
+ }
178
+ return nbits;
179
+ }
128
180
 
129
- Index* index_factory(int d, const char* description_in, MetricType metric) {
130
- FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
131
- VTChain vts;
132
- Index* coarse_quantizer = nullptr;
133
- std::string parenthesis_ivf, parenthesis_refine;
134
- Index* index = nullptr;
135
- bool add_idmap = false;
136
- int d_in = d;
181
+ /***************************************************************
182
+ * Parse VectorTransform
183
+ */
137
184
 
138
- ScopeDeleter1<Index> del_coarse_quantizer, del_index;
185
+ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
186
+ std::smatch sm;
187
+ auto match = [&sm, description](std::string pattern) {
188
+ return re_match(description, pattern, sm);
189
+ };
190
+ if (match("PCA(W?)(R?)([0-9]+)")) {
191
+ bool white = sm[1].length() > 0;
192
+ bool rot = sm[2].length() > 0;
193
+ return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
194
+ }
195
+ if (match("L2[nN]orm")) {
196
+ return new NormalizationTransform(d, 2.0);
197
+ }
198
+ if (match("RR([0-9]+)?")) {
199
+ return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
200
+ }
201
+ if (match("ITQ([0-9]+)?")) {
202
+ return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
203
+ }
204
+ if (match("OPQ([0-9]+)(_[0-9]+)?")) {
205
+ int M = std::stoi(sm[1].str());
206
+ int d_out = mres_to_int(sm[2], d, 1);
207
+ return new OPQMatrix(d, M, d_out);
208
+ }
209
+ if (match("Pad([0-9]+)")) {
210
+ int d_out = std::stoi(sm[1].str());
211
+ return new RemapDimensionsTransform(d, std::max(d_out, d), false);
212
+ }
213
+ return nullptr;
214
+ };
139
215
 
140
- std::string description(description_in);
141
- char* ptr;
216
+ /***************************************************************
217
+ * Parse IndexIVF
218
+ */
142
219
 
143
- // handle indexes in parentheses
144
- while (description.find('(') != std::string::npos) {
145
- // then we make a sub-index and remove the () from the description
146
- int i0, i1;
147
- find_matching_parentheses(description, i0, i1);
220
+ // parsing guard + function
221
+ Index* parse_coarse_quantizer(
222
+ const std::string& description,
223
+ int d,
224
+ MetricType mt,
225
+ std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
226
+ size_t& nlist,
227
+ bool& use_2layer) {
228
+ std::smatch sm;
229
+ auto match = [&sm, description](std::string pattern) {
230
+ return re_match(description, pattern, sm);
231
+ };
232
+ use_2layer = false;
233
+
234
+ if (match("IVF([0-9]+)")) {
235
+ nlist = std::stoi(sm[1].str());
236
+ return new IndexFlat(d, mt);
237
+ }
238
+ if (match("IMI2x([0-9]+)")) {
239
+ int nbit = std::stoi(sm[1].str());
240
+ FAISS_THROW_IF_NOT_MSG(
241
+ mt == METRIC_L2,
242
+ "MultiIndex not implemented for inner prod search");
243
+ nlist = (size_t)1 << (2 * nbit);
244
+ return new MultiIndexQuantizer(d, 2, nbit);
245
+ }
246
+ if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
247
+ nlist = std::stoi(sm[1].str());
248
+ int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
249
+ return new IndexHNSWFlat(d, hnsw_M, mt);
250
+ }
251
+ if (match("IVF([0-9]+)_NSG([0-9]+)")) {
252
+ nlist = std::stoi(sm[1].str());
253
+ int R = std::stoi(sm[2]);
254
+ return new IndexNSGFlat(d, R, mt);
255
+ }
256
+ if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
257
+ nlist = std::stoi(sm[1].str());
258
+ int no = std::stoi(sm[2].str());
259
+ FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
260
+ return parenthesis_indexes[no].release();
261
+ }
148
262
 
149
- std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
263
+ // these two generate Index2Layer's not IndexIVF's
264
+ if (match("Residual([0-9]+)x([0-9]+)")) {
265
+ FAISS_THROW_IF_NOT_MSG(
266
+ mt == METRIC_L2,
267
+ "MultiIndex not implemented for inner prod search");
268
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
269
+ nlist = (size_t)1 << (M * nbit);
270
+ use_2layer = true;
271
+ return new MultiIndexQuantizer(d, M, nbit);
272
+ }
273
+ if (match("Residual([0-9]+)")) {
274
+ FAISS_THROW_IF_NOT_MSG(
275
+ mt == METRIC_L2,
276
+ "Residual not implemented for inner prod search");
277
+ use_2layer = true;
278
+ nlist = mres_to_int(sm[1]);
279
+ return new IndexFlatL2(d);
280
+ }
281
+ return nullptr;
282
+ }
150
283
 
151
- if (str_ends_with_digits(description.substr(0, i0), "IVF")) {
152
- parenthesis_ivf = sub_description;
153
- } else if (str_ends_with(description.substr(0, i0), "Refine")) {
154
- parenthesis_refine = sub_description;
284
+ // parse the code description part of the IVF description
285
+
286
+ IndexIVF* parse_IndexIVF(
287
+ const std::string& code_string,
288
+ std::unique_ptr<Index>& quantizer,
289
+ size_t nlist,
290
+ MetricType mt) {
291
+ std::smatch sm;
292
+ auto match = [&sm, &code_string](const std::string pattern) {
293
+ return re_match(code_string, pattern, sm);
294
+ };
295
+ auto get_q = [&quantizer] { return quantizer.release(); };
296
+ int d = quantizer->d;
297
+
298
+ if (match("Flat")) {
299
+ return new IndexIVFFlat(get_q(), d, nlist, mt);
300
+ }
301
+ if (match("FlatDedup")) {
302
+ return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
303
+ }
304
+ if (match(sq_pattern)) {
305
+ return new IndexIVFScalarQuantizer(
306
+ get_q(), d, nlist, sq_types[sm[1].str()], mt);
307
+ }
308
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
309
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
310
+ IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
311
+ index_ivf->do_polysemous_training = sm[3].str() != "np";
312
+ return index_ivf;
313
+ }
314
+ if (match("PQ([0-9]+)\\+([0-9]+)")) {
315
+ FAISS_THROW_IF_NOT_MSG(
316
+ mt == METRIC_L2,
317
+ "IVFPQR not implemented for inner product search");
318
+ int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
319
+ return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
320
+ }
321
+ if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
322
+ int M = mres_to_int(sm[1]);
323
+ int bbs = mres_to_int(sm[3], 32, 1);
324
+ IndexIVFPQFastScan* index_ivf =
325
+ new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
326
+ index_ivf->by_residual = sm[2].str() == "r";
327
+ return index_ivf;
328
+ }
329
+ if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
330
+ std::vector<size_t> nbits = aq_parse_nbits(sm.str());
331
+ AdditiveQuantizer::Search_type_t st =
332
+ aq_parse_search_type(sm[sm.size() - 1].str(), mt);
333
+ IndexIVF* index_ivf;
334
+ if (sm[1].str() == "RQ") {
335
+ index_ivf = new IndexIVFResidualQuantizer(
336
+ get_q(), d, nlist, nbits, mt, st);
155
337
  } else {
156
- FAISS_THROW_MSG("don't know what to do with parenthesis index");
338
+ FAISS_THROW_IF_NOT(nbits.size() > 0);
339
+ index_ivf = new IndexIVFLocalSearchQuantizer(
340
+ get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
157
341
  }
158
- description = description.erase(i0, i1 - i0 + 1);
159
- }
160
-
161
- int64_t ncentroids = -1;
162
- bool use_2layer = false;
163
- int hnsw_M = -1;
164
- int nsg_R = -1;
165
-
166
- for (char* tok = strtok_r(&description[0], " ,", &ptr); tok;
167
- tok = strtok_r(nullptr, " ,", &ptr)) {
168
- int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2, R;
169
- std::string stok(tok);
170
- nbit = 8;
171
- int bbs = -1;
172
- char c;
173
-
174
- // to avoid mem leaks with exceptions:
175
- // do all tests before any instanciation
176
-
177
- VectorTransform* vt_1 = nullptr;
178
- Index* coarse_quantizer_1 = nullptr;
179
- Index* index_1 = nullptr;
180
-
181
- // VectorTransforms
182
- if (sscanf(tok, "PCA%d", &d_out) == 1) {
183
- vt_1 = new PCAMatrix(d, d_out);
184
- d = d_out;
185
- } else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
186
- vt_1 = new PCAMatrix(d, d_out, 0, true);
187
- d = d_out;
188
- } else if (sscanf(tok, "RR%d", &d_out) == 1) {
189
- vt_1 = new RandomRotationMatrix(d, d_out);
190
- d = d_out;
191
- } else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
192
- vt_1 = new PCAMatrix(d, d_out, -0.5, false);
193
- d = d_out;
194
- } else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
195
- vt_1 = new PCAMatrix(d, d_out, -0.5, true);
196
- d = d_out;
197
- } else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
198
- vt_1 = new OPQMatrix(d, opq_M, d_out);
199
- d = d_out;
200
- } else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
201
- vt_1 = new OPQMatrix(d, opq_M);
202
- } else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
203
- vt_1 = new ITQTransform(d, d_out, true);
204
- d = d_out;
205
- } else if (stok == "ITQ") {
206
- vt_1 = new ITQTransform(d, d, false);
207
- } else if (sscanf(tok, "Pad%d", &d_out) == 1) {
208
- if (d_out > d) {
209
- vt_1 = new RemapDimensionsTransform(d, d_out, false);
210
- d = d_out;
211
- }
212
- } else if (stok == "L2norm") {
213
- vt_1 = new NormalizationTransform(d, 2.0);
214
-
215
- // coarse quantizers
216
- } else if (
217
- !coarse_quantizer &&
218
- sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
219
- coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
220
-
221
- } else if (
222
- !coarse_quantizer &&
223
- sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
224
- coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
225
-
226
- } else if (
227
- !coarse_quantizer &&
228
- sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
229
- if (!parenthesis_ivf.empty()) {
230
- coarse_quantizer_1 =
231
- index_factory(d, parenthesis_ivf.c_str(), metric);
232
- } else if (metric == METRIC_L2) {
233
- coarse_quantizer_1 = new IndexFlatL2(d);
234
- } else {
235
- coarse_quantizer_1 = new IndexFlatIP(d);
236
- }
237
-
238
- } else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
239
- FAISS_THROW_IF_NOT_MSG(
240
- metric == METRIC_L2,
241
- "MultiIndex not implemented for inner prod search");
242
- coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
243
- ncentroids = 1 << (2 * nbit);
244
-
245
- } else if (
246
- !coarse_quantizer &&
247
- sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
248
- FAISS_THROW_IF_NOT_MSG(
249
- metric == METRIC_L2,
250
- "MultiIndex not implemented for inner prod search");
251
- coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
252
- ncentroids = int64_t(1) << (M * nbit);
253
- use_2layer = true;
254
-
255
- } else if (std::regex_match(
256
- stok,
257
- std::regex(
258
- "(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
259
- std::vector<size_t> nbits;
260
- std::smatch sm;
261
- bool is_RCQ = stok.find("RCQ") == 0;
262
- while (std::regex_search(
263
- stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
264
- int M = std::stoi(sm[1].str());
265
- int nbit = std::stoi(sm[2].str());
266
- nbits.resize(nbits.size() + M, nbit);
267
- stok = sm.suffix();
268
- }
269
- if (!is_RCQ) {
270
- index_1 = new IndexResidual(d, nbits, metric);
271
- } else {
272
- index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
273
- }
274
- } else if (
275
- !coarse_quantizer &&
276
- sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
277
- coarse_quantizer_1 = new IndexFlatL2(d);
278
- use_2layer = true;
279
-
280
- } else if (stok == "IDMap") {
281
- add_idmap = true;
282
-
283
- // IVFs
284
- } else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
285
- if (coarse_quantizer) {
286
- // if there was an IVF in front, then it is an IVFFlat
287
- IndexIVF* index_ivf = stok == "Flat"
288
- ? new IndexIVFFlat(
289
- coarse_quantizer, d, ncentroids, metric)
290
- : new IndexIVFFlatDedup(
291
- coarse_quantizer, d, ncentroids, metric);
292
- index_ivf->quantizer_trains_alone =
293
- get_trains_alone(coarse_quantizer);
294
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
295
- del_coarse_quantizer.release();
296
- index_ivf->own_fields = true;
297
- index_1 = index_ivf;
298
- } else if (hnsw_M > 0) {
299
- index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
300
- } else if (nsg_R > 0) {
301
- index_1 = new IndexNSGFlat(d, nsg_R, metric);
302
- } else {
303
- FAISS_THROW_IF_NOT_MSG(
304
- stok != "FlatDedup",
305
- "dedup supported only for IVFFlat");
306
- index_1 = new IndexFlat(d, metric);
307
- }
308
- } else if (
309
- !index &&
310
- (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
311
- stok == "SQfp16")) {
312
- ScalarQuantizer::QuantizerType qt = stok == "SQ8"
313
- ? ScalarQuantizer::QT_8bit
314
- : stok == "SQ6" ? ScalarQuantizer::QT_6bit
315
- : stok == "SQ4" ? ScalarQuantizer::QT_4bit
316
- : stok == "SQfp16" ? ScalarQuantizer::QT_fp16
317
- : ScalarQuantizer::QT_4bit;
318
- if (coarse_quantizer) {
319
- FAISS_THROW_IF_NOT(!use_2layer);
320
- IndexIVFScalarQuantizer* index_ivf =
321
- new IndexIVFScalarQuantizer(
322
- coarse_quantizer, d, ncentroids, qt, metric);
323
- index_ivf->quantizer_trains_alone =
324
- get_trains_alone(coarse_quantizer);
325
- del_coarse_quantizer.release();
326
- index_ivf->own_fields = true;
327
- index_1 = index_ivf;
328
- } else if (hnsw_M > 0) {
329
- index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
330
- } else {
331
- index_1 = new IndexScalarQuantizer(d, qt, metric);
332
- }
333
- } else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
334
- FAISS_THROW_IF_NOT_MSG(
335
- coarse_quantizer, "PQ with + works only with an IVF");
336
- FAISS_THROW_IF_NOT_MSG(
337
- metric == METRIC_L2,
338
- "IVFPQR not implemented for inner product search");
339
- IndexIVFPQR* index_ivf = new IndexIVFPQR(
340
- coarse_quantizer, d, ncentroids, M, 8, M2, 8);
341
- index_ivf->quantizer_trains_alone =
342
- get_trains_alone(coarse_quantizer);
343
- del_coarse_quantizer.release();
344
- index_ivf->own_fields = true;
345
- index_1 = index_ivf;
346
- } else if (
347
- !index &&
348
- (sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
349
- (sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
350
- (sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
351
- if (bbs == -1) {
352
- bbs = 32;
353
- }
354
- bool by_residual = str_ends_with(stok, "fsr");
355
- if (coarse_quantizer) {
356
- IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
357
- coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
358
- index_ivf->quantizer_trains_alone =
359
- get_trains_alone(coarse_quantizer);
360
- index_ivf->metric_type = metric;
361
- index_ivf->by_residual = by_residual;
362
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
363
- del_coarse_quantizer.release();
364
- index_ivf->own_fields = true;
365
- index_1 = index_ivf;
366
- } else {
367
- IndexPQFastScan* index_pq =
368
- new IndexPQFastScan(d, M, 4, metric, bbs);
369
- index_1 = index_pq;
370
- }
371
- } else if (
372
- !index &&
373
- (sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
374
- sscanf(tok, "PQ%d", &M) == 1 ||
375
- sscanf(tok, "PQ%dnp", &M) == 1)) {
376
- bool do_polysemous_training = stok.find("np") == std::string::npos;
377
- if (coarse_quantizer) {
378
- if (!use_2layer) {
379
- IndexIVFPQ* index_ivf = new IndexIVFPQ(
380
- coarse_quantizer, d, ncentroids, M, nbit);
381
- index_ivf->quantizer_trains_alone =
382
- get_trains_alone(coarse_quantizer);
383
- index_ivf->metric_type = metric;
384
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
385
- del_coarse_quantizer.release();
386
- index_ivf->own_fields = true;
387
- index_ivf->do_polysemous_training = do_polysemous_training;
388
- index_1 = index_ivf;
389
- } else {
390
- Index2Layer* index_2l = new Index2Layer(
391
- coarse_quantizer, ncentroids, M, nbit);
392
- index_2l->q1.quantizer_trains_alone =
393
- get_trains_alone(coarse_quantizer);
394
- index_2l->q1.own_fields = true;
395
- index_1 = index_2l;
396
- }
397
- } else if (hnsw_M > 0) {
398
- IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
399
- dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
400
- do_polysemous_training;
401
- index_1 = ipq;
402
- } else {
403
- IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
404
- index_pq->do_polysemous_training = do_polysemous_training;
405
- index_1 = index_pq;
406
- }
407
- } else if (
408
- !index &&
409
- sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
410
- Index* quant = new IndexFlatL2(d);
411
- IndexHNSW2Level* hidx2l =
412
- new IndexHNSW2Level(quant, ncent, pq_m, M);
413
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
414
- idx2l->q1.own_fields = true;
415
- index_1 = hidx2l;
416
- } else if (
417
- !index &&
418
- sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
419
- Index* quant = new MultiIndexQuantizer(d, 2, nbit);
420
- IndexHNSW2Level* hidx2l =
421
- new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
422
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
423
- idx2l->q1.own_fields = true;
424
- idx2l->q1.quantizer_trains_alone = 1;
425
- index_1 = hidx2l;
426
- } else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
427
- index_1 = new IndexHNSWPQ(d, pq_m, M);
428
- } else if (
429
- !index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
430
- pq_m == 8) {
431
- index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
432
- } else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
433
- hnsw_M = M;
434
- // here it is unclear what we want: HNSW flat or HNSWx,Y ?
435
- } else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
436
- nsg_R = R;
437
- } else if (
438
- !index &&
439
- (stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
440
- stok == "LSHt")) {
441
- bool rotate_data = strstr(tok, "r") != nullptr;
442
- bool train_thresholds = strstr(tok, "t") != nullptr;
443
- index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
444
- } else if (
445
- !index &&
446
- sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
447
- FAISS_THROW_IF_NOT(!coarse_quantizer);
448
- index_1 = new IndexLattice(d, M, nbit, r2);
449
- } else if (stok == "RFlat") {
450
- parenthesis_refine = "Flat";
451
- } else if (stok == "Refine") {
452
- FAISS_THROW_IF_NOT_MSG(
453
- !parenthesis_refine.empty(),
454
- "Refine index should be provided in parentheses");
455
- } else {
456
- FAISS_THROW_FMT(
457
- "could not parse token \"%s\" in %s\n",
458
- tok,
459
- description_in);
342
+ return index_ivf;
343
+ }
344
+ if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
345
+ int outdim = mres_to_int(sm[2], d); // is also the number of bits
346
+ std::unique_ptr<VectorTransform> vt;
347
+ if (sm[1] == "ITQ") {
348
+ vt.reset(new ITQTransform(d, outdim, d != outdim));
349
+ } else if (sm[1] == "PCA") {
350
+ vt.reset(new PCAMatrix(d, outdim));
351
+ } else if (sm[1] == "PCAR") {
352
+ vt.reset(new PCAMatrix(d, outdim, 0, true));
460
353
  }
354
+ // the rationale for -1e10 is that this corresponds to simple
355
+ // thresholding
356
+ float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
357
+ IndexIVFSpectralHash* index_ivf =
358
+ new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
359
+ index_ivf->replace_vt(vt.release(), true);
360
+ if (sm[4].length()) {
361
+ std::string s = sm[4].str();
362
+ index_ivf->threshold_type = s == "g"
363
+ ? IndexIVFSpectralHash::Thresh_global
364
+ : s == "c"
365
+ ? IndexIVFSpectralHash::Thresh_centroid
366
+ :
367
+ /* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
368
+ }
369
+ return index_ivf;
370
+ }
371
+ return nullptr;
372
+ }
373
+
374
+ /***************************************************************
375
+ * Parse IndexHNSW
376
+ */
377
+
378
+ IndexHNSW* parse_IndexHNSW(
379
+ const std::string code_string,
380
+ int d,
381
+ MetricType mt,
382
+ int hnsw_M) {
383
+ std::smatch sm;
384
+ auto match = [&sm, &code_string](const std::string& pattern) {
385
+ return re_match(code_string, pattern, sm);
386
+ };
387
+
388
+ if (match("Flat|")) {
389
+ return new IndexHNSWFlat(d, hnsw_M, mt);
390
+ }
391
+ if (match("PQ([0-9]+)(np)?")) {
392
+ int M = std::stoi(sm[1].str());
393
+ IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
394
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
395
+ sm[2].str() != "np";
396
+ return ipq;
397
+ }
398
+ if (match(sq_pattern)) {
399
+ return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
400
+ }
401
+ if (match("([0-9]+)\\+PQ([0-9]+)?")) {
402
+ int ncent = mres_to_int(sm[1]);
403
+ int pq_m = mres_to_int(sm[2]);
404
+ IndexHNSW2Level* hidx2l =
405
+ new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
406
+ dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
407
+ return hidx2l;
408
+ }
409
+ if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
410
+ int nbit = mres_to_int(sm[1]);
411
+ int pq_m = mres_to_int(sm[2]);
412
+ Index* quant = new MultiIndexQuantizer(d, 2, nbit);
413
+ IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
414
+ quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
415
+ Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
416
+ idx2l->q1.own_fields = true;
417
+ idx2l->q1.quantizer_trains_alone = 1;
418
+ return hidx2l;
419
+ }
461
420
 
462
- if (index_1 && add_idmap) {
463
- IndexIDMap* idmap = new IndexIDMap(index_1);
464
- del_index.set(idmap);
465
- idmap->own_fields = true;
466
- index_1 = idmap;
467
- add_idmap = false;
421
+ return nullptr;
422
+ }
423
+
424
+ /***************************************************************
425
+ * Parse basic indexes
426
+ */
427
+
428
+ Index* parse_other_indexes(
429
+ const std::string& description,
430
+ int d,
431
+ MetricType metric) {
432
+ std::smatch sm;
433
+ auto match = [&sm, description](const std::string& pattern) {
434
+ return re_match(description, pattern, sm);
435
+ };
436
+
437
+ // IndexFlat
438
+ if (description == "Flat") {
439
+ return new IndexFlat(d, metric);
440
+ }
441
+
442
+ // IndexLSH
443
+ if (match("LSH(r?)(t?)")) {
444
+ bool rotate_data = sm[1].length() > 0;
445
+ bool train_thresholds = sm[2].length() > 0;
446
+ FAISS_THROW_IF_NOT(metric == METRIC_L2);
447
+ return new IndexLSH(d, d, rotate_data, train_thresholds);
448
+ }
449
+
450
+ // IndexLattice
451
+ if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
452
+ int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
453
+ int nbit = std::stoi(sm[3].str());
454
+ return new IndexLattice(d, M, nbit, r2);
455
+ }
456
+
457
+ // IndexNSGFlat
458
+ if (match("NSG([0-9]+)(,Flat)?")) {
459
+ return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
460
+ }
461
+
462
+ // IndexScalarQuantizer
463
+ if (match(sq_pattern)) {
464
+ return new IndexScalarQuantizer(d, sq_types[description], metric);
465
+ }
466
+
467
+ // IndexPQ
468
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
469
+ int M = std::stoi(sm[1].str());
470
+ int nbit = mres_to_int(sm[2], 8, 1);
471
+ IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
472
+ index_pq->do_polysemous_training = sm[3].str() != "np";
473
+ return index_pq;
474
+ }
475
+
476
+ // IndexPQFastScan
477
+ if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
478
+ int M = std::stoi(sm[1].str());
479
+ int bbs = mres_to_int(sm[2], 32, 1);
480
+ return new IndexPQFastScan(d, M, 4, metric, bbs);
481
+ }
482
+
483
+ // IndexResidualCoarseQuantizer and IndexResidualQuantizer
484
+ std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
485
+ if (match(pattern)) {
486
+ std::vector<size_t> nbits = aq_parse_nbits(description);
487
+ if (sm[1].str() == "RCQ") {
488
+ return new ResidualCoarseQuantizer(d, nbits, metric);
468
489
  }
490
+ AdditiveQuantizer::Search_type_t st =
491
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
492
+ return new IndexResidualQuantizer(d, nbits, metric, st);
493
+ }
469
494
 
470
- if (vt_1) {
471
- vts.chain.push_back(vt_1);
495
+ // LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
496
+ if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
497
+ std::vector<size_t> nbits = aq_parse_nbits(description);
498
+ int M = mres_to_int(sm[2]);
499
+ int nbit = mres_to_int(sm[3]);
500
+ if (sm[1].str() == "LSCQ") {
501
+ return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
472
502
  }
503
+ AdditiveQuantizer::Search_type_t st =
504
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
505
+ return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
506
+ }
507
+
508
+ return nullptr;
509
+ }
510
+
511
+ /***************************************************************
512
+ * Driver function
513
+ */
514
+ std::unique_ptr<Index> index_factory_sub(
515
+ int d,
516
+ std::string description,
517
+ MetricType metric) {
518
+ // handle composite indexes
519
+
520
+ bool verbose = index_factory_verbose;
473
521
 
474
- if (coarse_quantizer_1) {
475
- coarse_quantizer = coarse_quantizer_1;
476
- del_coarse_quantizer.set(coarse_quantizer);
522
+ if (verbose) {
523
+ printf("begin parse VectorTransforms: %s \n", description.c_str());
524
+ }
525
+
526
+ // for the current match
527
+ std::smatch sm;
528
+
529
+ // handle refines
530
+ if (re_match(description, "(.+),RFlat", sm) ||
531
+ re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
532
+ std::unique_ptr<Index> filter_index =
533
+ index_factory_sub(d, sm[1].str(), metric);
534
+ std::unique_ptr<Index> refine_index;
535
+
536
+ if (sm.size() == 3) { // Refine
537
+ refine_index = index_factory_sub(d, sm[2].str(), metric);
538
+ } else { // RFlat
539
+ refine_index.reset(new IndexFlat(d, metric));
477
540
  }
541
+ IndexRefine* index_rf =
542
+ new IndexRefine(filter_index.get(), refine_index.get());
543
+ index_rf->own_fields = true;
544
+ filter_index.release();
545
+ refine_index.release();
546
+ index_rf->own_refine_index = true;
547
+ return std::unique_ptr<Index>(index_rf);
548
+ }
478
549
 
479
- if (index_1) {
480
- index = index_1;
481
- del_index.set(index);
550
+ // IndexPreTransform
551
+ // should handle this first (even before parentheses) because it changes d
552
+ std::vector<std::unique_ptr<VectorTransform>> vts;
553
+ VectorTransform* vt = nullptr;
554
+ while (re_match(description, "([^,]+),(.*)", sm) &&
555
+ (vt = parse_VectorTransform(sm[1], d))) {
556
+ // reset loop
557
+ description = sm[sm.size() - 1];
558
+ vts.emplace_back(vt);
559
+ d = vts.back()->d_out;
560
+ }
561
+
562
+ if (vts.size() > 0) {
563
+ std::unique_ptr<Index> sub_index =
564
+ index_factory_sub(d, description, metric);
565
+ IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
566
+ std::unique_ptr<Index> ret(index_pt);
567
+ index_pt->own_fields = true;
568
+ sub_index.release();
569
+ while (vts.size() > 0) {
570
+ if (verbose) {
571
+ printf("prepend trans %d -> %d\n",
572
+ vts.back()->d_in,
573
+ vts.back()->d_out);
574
+ }
575
+ index_pt->prepend_transform(vts.back().release());
576
+ vts.pop_back();
482
577
  }
578
+ return ret;
483
579
  }
484
580
 
485
- if (!index && hnsw_M > 0) {
486
- index = new IndexHNSWFlat(d, hnsw_M, metric);
487
- del_index.set(index);
488
- } else if (!index && nsg_R > 0) {
489
- index = new IndexNSGFlat(d, nsg_R, metric);
490
- del_index.set(index);
581
+ // what we got from the parentheses
582
+ std::vector<std::unique_ptr<Index>> parenthesis_indexes;
583
+
584
+ int begin = 0;
585
+ while (description.find('(', begin) != std::string::npos) {
586
+ // replace indexes in () with Index0, Index1, etc.
587
+ int i0, i1;
588
+ find_matching_parentheses(description, i0, i1, begin);
589
+ std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
590
+ int no = parenthesis_indexes.size();
591
+ parenthesis_indexes.push_back(
592
+ index_factory_sub(d, sub_description, metric));
593
+ description = description.substr(0, i0 + 1) + "Index" +
594
+ std::to_string(no) + description.substr(i1);
595
+ begin = i1 + 1;
491
596
  }
492
597
 
493
- FAISS_THROW_IF_NOT_FMT(
494
- index, "description %s did not generate an index", description_in);
598
+ if (verbose) {
599
+ printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
600
+ description.c_str(),
601
+ parenthesis_indexes.size(),
602
+ d);
603
+ }
495
604
 
496
- // nothing can go wrong now
497
- del_index.release();
498
- del_coarse_quantizer.release();
605
+ // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
606
+ // support both
607
+ if (re_match(description, "(.+),IDMap", sm) ||
608
+ re_match(description, "IDMap,(.+)", sm)) {
609
+ IndexIDMap* idmap = new IndexIDMap(
610
+ index_factory_sub(d, sm[1].str(), metric).release());
611
+ idmap->own_fields = true;
612
+ return std::unique_ptr<Index>(idmap);
613
+ }
499
614
 
500
- if (add_idmap) {
501
- fprintf(stderr,
502
- "index_factory: WARNING: "
503
- "IDMap option not used\n");
615
+ { // handle basic index types
616
+ Index* index = parse_other_indexes(description, d, metric);
617
+ if (index) {
618
+ return std::unique_ptr<Index>(index);
619
+ }
504
620
  }
505
621
 
506
- if (vts.chain.size() > 0) {
507
- IndexPreTransform* index_pt = new IndexPreTransform(index);
508
- index_pt->own_fields = true;
509
- // add from back
510
- while (vts.chain.size() > 0) {
511
- index_pt->prepend_transform(vts.chain.back());
512
- vts.chain.pop_back();
622
+ // HNSW variants (it was unclear in the old version that the separator was a
623
+ // "," so we support both "_" and ",")
624
+ if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
625
+ int hnsw_M = mres_to_int(sm[1], 32);
626
+ // We also accept empty code string (synonym of Flat)
627
+ std::string code_string =
628
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
629
+ if (verbose) {
630
+ printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
631
+ description.c_str(),
632
+ code_string.c_str(),
633
+ hnsw_M);
513
634
  }
514
- index = index_pt;
635
+
636
+ IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
637
+ FAISS_THROW_IF_NOT_FMT(
638
+ index,
639
+ "could not parse HNSW code description %s in %s",
640
+ code_string.c_str(),
641
+ description.c_str());
642
+ return std::unique_ptr<Index>(index);
515
643
  }
516
644
 
517
- if (!parenthesis_refine.empty()) {
518
- Index* refine_index =
519
- index_factory(d_in, parenthesis_refine.c_str(), metric);
520
- IndexRefine* index_rf = new IndexRefine(index, refine_index);
521
- index_rf->own_refine_index = true;
522
- index_rf->own_fields = true;
523
- index = index_rf;
645
+ // IndexIVF
646
+ {
647
+ size_t nlist;
648
+ bool use_2layer;
649
+ size_t comma = description.find(",");
650
+ std::string coarse_string = description.substr(0, comma);
651
+ // Match coarse quantizer part first
652
+ std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
653
+ description.substr(0, comma),
654
+ d,
655
+ metric,
656
+ parenthesis_indexes,
657
+ nlist,
658
+ use_2layer));
659
+
660
+ if (comma != std::string::npos && quantizer.get()) {
661
+ std::string code_description = description.substr(comma + 1);
662
+ if (use_2layer) {
663
+ bool ok =
664
+ re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
665
+ FAISS_THROW_IF_NOT_FMT(
666
+ ok,
667
+ "could not parse 2 layer code description %s in %s",
668
+ code_description.c_str(),
669
+ description.c_str());
670
+ int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
671
+ Index2Layer* index_2l =
672
+ new Index2Layer(quantizer.release(), nlist, M, nbit);
673
+ index_2l->q1.own_fields = true;
674
+ index_2l->q1.quantizer_trains_alone =
675
+ get_trains_alone(index_2l->q1.quantizer);
676
+ return std::unique_ptr<Index>(index_2l);
677
+ }
678
+
679
+ IndexIVF* index_ivf =
680
+ parse_IndexIVF(code_description, quantizer, nlist, metric);
681
+
682
+ FAISS_THROW_IF_NOT_FMT(
683
+ index_ivf,
684
+ "could not parse code description %s in %s",
685
+ code_description.c_str(),
686
+ description.c_str());
687
+ return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
688
+ }
524
689
  }
690
+ FAISS_THROW_FMT("could not parse index string %s", description.c_str());
691
+ return nullptr;
692
+ }
525
693
 
526
- return index;
694
+ } // anonymous namespace
695
+
696
+ Index* index_factory(int d, const char* description, MetricType metric) {
697
+ return index_factory_sub(d, description, metric).release();
527
698
  }
528
699
 
529
700
  IndexBinary* index_binary_factory(int d, const char* description) {