faiss 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -5,19 +5,19 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /*
11
- * implementation of Hyper-parameter auto-tuning
9
+ * implementation of the index_factory function. Lots of regex parsing code.
12
10
  */
13
11
 
14
12
  #include <faiss/index_factory.h>
15
-
16
- #include <faiss/AutoTune.h>
13
+ #include "faiss/MetricType.h"
14
+ #include "faiss/impl/FaissAssert.h"
17
15
 
18
16
  #include <cinttypes>
19
17
  #include <cmath>
20
18
 
19
+ #include <map>
20
+
21
21
  #include <regex>
22
22
 
23
23
  #include <faiss/impl/FaissAssert.h>
@@ -25,13 +25,16 @@
25
25
  #include <faiss/utils/utils.h>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
+ #include <faiss/IndexAdditiveQuantizer.h>
28
29
  #include <faiss/IndexFlat.h>
29
30
  #include <faiss/IndexHNSW.h>
30
31
  #include <faiss/IndexIVF.h>
32
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
31
33
  #include <faiss/IndexIVFFlat.h>
32
34
  #include <faiss/IndexIVFPQ.h>
33
35
  #include <faiss/IndexIVFPQFastScan.h>
34
36
  #include <faiss/IndexIVFPQR.h>
37
+ #include <faiss/IndexIVFSpectralHash.h>
35
38
  #include <faiss/IndexLSH.h>
36
39
  #include <faiss/IndexLattice.h>
37
40
  #include <faiss/IndexNSG.h>
@@ -39,7 +42,6 @@
39
42
  #include <faiss/IndexPQFastScan.h>
40
43
  #include <faiss/IndexPreTransform.h>
41
44
  #include <faiss/IndexRefine.h>
42
- #include <faiss/IndexResidual.h>
43
45
  #include <faiss/IndexScalarQuantizer.h>
44
46
  #include <faiss/MetaIndexes.h>
45
47
  #include <faiss/VectorTransform.h>
@@ -48,6 +50,7 @@
48
50
  #include <faiss/IndexBinaryHNSW.h>
49
51
  #include <faiss/IndexBinaryHash.h>
50
52
  #include <faiss/IndexBinaryIVF.h>
53
+ #include <string>
51
54
 
52
55
  namespace faiss {
53
56
 
@@ -55,16 +58,49 @@ namespace faiss {
55
58
  * index_factory
56
59
  ***************************************************************/
57
60
 
61
+ int index_factory_verbose = 0;
62
+
58
63
  namespace {
59
64
 
60
- struct VTChain {
61
- std::vector<VectorTransform*> chain;
62
- ~VTChain() {
63
- for (int i = 0; i < chain.size(); i++) {
64
- delete chain[i];
65
+ /***************************************************************
66
+ * Small functions
67
+ */
68
+
69
+ bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
70
+ return std::regex_match(s, sm, std::regex(pat));
71
+ }
72
+
73
+ // find first pair of matching parentheses
74
+ void find_matching_parentheses(
75
+ const std::string& s,
76
+ int& i0,
77
+ int& i1,
78
+ int begin = 0) {
79
+ int st = 0;
80
+ i0 = i1 = 0;
81
+ for (int i = begin; i < s.length(); i++) {
82
+ if (s[i] == '(') {
83
+ if (st == 0) {
84
+ i0 = i;
85
+ }
86
+ st++;
87
+ }
88
+
89
+ if (s[i] == ')') {
90
+ st--;
91
+ if (st == 0) {
92
+ i1 = i;
93
+ return;
94
+ }
95
+ if (st < 0) {
96
+ FAISS_THROW_FMT(
97
+ "factory string %s: unbalanced parentheses", s.c_str());
98
+ }
65
99
  }
66
100
  }
67
- };
101
+ FAISS_THROW_FMT(
102
+ "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
103
+ }
68
104
 
69
105
  /// what kind of training does this coarse quantizer require?
70
106
  char get_trains_alone(const Index* coarse_quantizer) {
@@ -83,447 +119,582 @@ char get_trains_alone(const Index* coarse_quantizer) {
83
119
  // kmeans index
84
120
  }
85
121
 
86
- bool str_ends_with(const std::string& s, const std::string& suffix) {
87
- return s.rfind(suffix) == std::abs(int(s.size() - suffix.size()));
122
+ // set the fields for factory-constructed IVF structures
123
+ IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
124
+ index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
125
+ index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
126
+ index_ivf->own_fields = true;
127
+ return index_ivf;
88
128
  }
89
129
 
90
- // check if ends with suffix followed by digits
91
- bool str_ends_with_digits(const std::string& s, const std::string& suffix) {
92
- int i;
93
- for (i = s.length() - 1; i >= 0; i--) {
94
- if (!isdigit(s[i]))
95
- break;
130
+ int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
131
+ if (mr.length() == 0) {
132
+ return deflt;
96
133
  }
97
- return str_ends_with(s.substr(0, i + 1), suffix);
134
+ return std::stoi(mr.str().substr(begin));
98
135
  }
99
136
 
100
- void find_matching_parentheses(const std::string& s, int& i0, int& i1) {
101
- int st = 0;
102
- i0 = i1 = 0;
103
- for (int i = 0; i < s.length(); i++) {
104
- if (s[i] == '(') {
105
- if (st == 0) {
106
- i0 = i;
107
- }
108
- st++;
109
- }
137
+ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
138
+ {"SQ8", ScalarQuantizer::QT_8bit},
139
+ {"SQ4", ScalarQuantizer::QT_4bit},
140
+ {"SQ6", ScalarQuantizer::QT_6bit},
141
+ {"SQfp16", ScalarQuantizer::QT_fp16},
142
+ };
143
+ const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
144
+
145
+ std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
146
+ {"_Nfloat", AdditiveQuantizer::ST_norm_float},
147
+ {"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
148
+ {"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
149
+ {"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
150
+ {"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
151
+ {"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
152
+ };
110
153
 
111
- if (s[i] == ')') {
112
- st--;
113
- if (st == 0) {
114
- i1 = i;
115
- return;
116
- }
117
- if (st < 0) {
118
- FAISS_THROW_FMT(
119
- "factory string %s: unbalanced parentheses", s.c_str());
120
- }
121
- }
154
+ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
155
+ const std::string aq_norm_pattern =
156
+ "(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
157
+
158
+ AdditiveQuantizer::Search_type_t aq_parse_search_type(
159
+ std::string stok,
160
+ MetricType metric) {
161
+ if (stok == "") {
162
+ return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
163
+ : AdditiveQuantizer::ST_LUT_nonorm;
122
164
  }
123
- FAISS_THROW_FMT(
124
- "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
165
+ int pos = stok.rfind("_");
166
+ return aq_search_type[stok.substr(pos)];
125
167
  }
126
168
 
127
- } // anonymous namespace
169
+ std::vector<size_t> aq_parse_nbits(std::string stok) {
170
+ std::vector<size_t> nbits;
171
+ std::smatch sm;
172
+ while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
173
+ int M = std::stoi(sm[1].str());
174
+ int nbit = std::stoi(sm[2].str());
175
+ nbits.resize(nbits.size() + M, nbit);
176
+ stok = sm.suffix();
177
+ }
178
+ return nbits;
179
+ }
128
180
 
129
- Index* index_factory(int d, const char* description_in, MetricType metric) {
130
- FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
131
- VTChain vts;
132
- Index* coarse_quantizer = nullptr;
133
- std::string parenthesis_ivf, parenthesis_refine;
134
- Index* index = nullptr;
135
- bool add_idmap = false;
136
- int d_in = d;
181
+ /***************************************************************
182
+ * Parse VectorTransform
183
+ */
137
184
 
138
- ScopeDeleter1<Index> del_coarse_quantizer, del_index;
185
+ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
186
+ std::smatch sm;
187
+ auto match = [&sm, description](std::string pattern) {
188
+ return re_match(description, pattern, sm);
189
+ };
190
+ if (match("PCA(W?)(R?)([0-9]+)")) {
191
+ bool white = sm[1].length() > 0;
192
+ bool rot = sm[2].length() > 0;
193
+ return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
194
+ }
195
+ if (match("L2[nN]orm")) {
196
+ return new NormalizationTransform(d, 2.0);
197
+ }
198
+ if (match("RR([0-9]+)?")) {
199
+ return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
200
+ }
201
+ if (match("ITQ([0-9]+)?")) {
202
+ return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
203
+ }
204
+ if (match("OPQ([0-9]+)(_[0-9]+)?")) {
205
+ int M = std::stoi(sm[1].str());
206
+ int d_out = mres_to_int(sm[2], d, 1);
207
+ return new OPQMatrix(d, M, d_out);
208
+ }
209
+ if (match("Pad([0-9]+)")) {
210
+ int d_out = std::stoi(sm[1].str());
211
+ return new RemapDimensionsTransform(d, std::max(d_out, d), false);
212
+ }
213
+ return nullptr;
214
+ };
139
215
 
140
- std::string description(description_in);
141
- char* ptr;
216
+ /***************************************************************
217
+ * Parse IndexIVF
218
+ */
142
219
 
143
- // handle indexes in parentheses
144
- while (description.find('(') != std::string::npos) {
145
- // then we make a sub-index and remove the () from the description
146
- int i0, i1;
147
- find_matching_parentheses(description, i0, i1);
220
+ // parsing guard + function
221
+ Index* parse_coarse_quantizer(
222
+ const std::string& description,
223
+ int d,
224
+ MetricType mt,
225
+ std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
226
+ size_t& nlist,
227
+ bool& use_2layer) {
228
+ std::smatch sm;
229
+ auto match = [&sm, description](std::string pattern) {
230
+ return re_match(description, pattern, sm);
231
+ };
232
+ use_2layer = false;
233
+
234
+ if (match("IVF([0-9]+)")) {
235
+ nlist = std::stoi(sm[1].str());
236
+ return new IndexFlat(d, mt);
237
+ }
238
+ if (match("IMI2x([0-9]+)")) {
239
+ int nbit = std::stoi(sm[1].str());
240
+ FAISS_THROW_IF_NOT_MSG(
241
+ mt == METRIC_L2,
242
+ "MultiIndex not implemented for inner prod search");
243
+ nlist = (size_t)1 << (2 * nbit);
244
+ return new MultiIndexQuantizer(d, 2, nbit);
245
+ }
246
+ if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
247
+ nlist = std::stoi(sm[1].str());
248
+ int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
249
+ return new IndexHNSWFlat(d, hnsw_M, mt);
250
+ }
251
+ if (match("IVF([0-9]+)_NSG([0-9]+)")) {
252
+ nlist = std::stoi(sm[1].str());
253
+ int R = std::stoi(sm[2]);
254
+ return new IndexNSGFlat(d, R, mt);
255
+ }
256
+ if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
257
+ nlist = std::stoi(sm[1].str());
258
+ int no = std::stoi(sm[2].str());
259
+ FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
260
+ return parenthesis_indexes[no].release();
261
+ }
148
262
 
149
- std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
263
+ // these two generate Index2Layer's not IndexIVF's
264
+ if (match("Residual([0-9]+)x([0-9]+)")) {
265
+ FAISS_THROW_IF_NOT_MSG(
266
+ mt == METRIC_L2,
267
+ "MultiIndex not implemented for inner prod search");
268
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
269
+ nlist = (size_t)1 << (M * nbit);
270
+ use_2layer = true;
271
+ return new MultiIndexQuantizer(d, M, nbit);
272
+ }
273
+ if (match("Residual([0-9]+)")) {
274
+ FAISS_THROW_IF_NOT_MSG(
275
+ mt == METRIC_L2,
276
+ "Residual not implemented for inner prod search");
277
+ use_2layer = true;
278
+ nlist = mres_to_int(sm[1]);
279
+ return new IndexFlatL2(d);
280
+ }
281
+ return nullptr;
282
+ }
150
283
 
151
- if (str_ends_with_digits(description.substr(0, i0), "IVF")) {
152
- parenthesis_ivf = sub_description;
153
- } else if (str_ends_with(description.substr(0, i0), "Refine")) {
154
- parenthesis_refine = sub_description;
284
+ // parse the code description part of the IVF description
285
+
286
+ IndexIVF* parse_IndexIVF(
287
+ const std::string& code_string,
288
+ std::unique_ptr<Index>& quantizer,
289
+ size_t nlist,
290
+ MetricType mt) {
291
+ std::smatch sm;
292
+ auto match = [&sm, &code_string](const std::string pattern) {
293
+ return re_match(code_string, pattern, sm);
294
+ };
295
+ auto get_q = [&quantizer] { return quantizer.release(); };
296
+ int d = quantizer->d;
297
+
298
+ if (match("Flat")) {
299
+ return new IndexIVFFlat(get_q(), d, nlist, mt);
300
+ }
301
+ if (match("FlatDedup")) {
302
+ return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
303
+ }
304
+ if (match(sq_pattern)) {
305
+ return new IndexIVFScalarQuantizer(
306
+ get_q(), d, nlist, sq_types[sm[1].str()], mt);
307
+ }
308
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
309
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
310
+ IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
311
+ index_ivf->do_polysemous_training = sm[3].str() != "np";
312
+ return index_ivf;
313
+ }
314
+ if (match("PQ([0-9]+)\\+([0-9]+)")) {
315
+ FAISS_THROW_IF_NOT_MSG(
316
+ mt == METRIC_L2,
317
+ "IVFPQR not implemented for inner product search");
318
+ int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
319
+ return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
320
+ }
321
+ if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
322
+ int M = mres_to_int(sm[1]);
323
+ int bbs = mres_to_int(sm[3], 32, 1);
324
+ IndexIVFPQFastScan* index_ivf =
325
+ new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
326
+ index_ivf->by_residual = sm[2].str() == "r";
327
+ return index_ivf;
328
+ }
329
+ if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
330
+ std::vector<size_t> nbits = aq_parse_nbits(sm.str());
331
+ AdditiveQuantizer::Search_type_t st =
332
+ aq_parse_search_type(sm[sm.size() - 1].str(), mt);
333
+ IndexIVF* index_ivf;
334
+ if (sm[1].str() == "RQ") {
335
+ index_ivf = new IndexIVFResidualQuantizer(
336
+ get_q(), d, nlist, nbits, mt, st);
155
337
  } else {
156
- FAISS_THROW_MSG("don't know what to do with parenthesis index");
338
+ FAISS_THROW_IF_NOT(nbits.size() > 0);
339
+ index_ivf = new IndexIVFLocalSearchQuantizer(
340
+ get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
157
341
  }
158
- description = description.erase(i0, i1 - i0 + 1);
159
- }
160
-
161
- int64_t ncentroids = -1;
162
- bool use_2layer = false;
163
- int hnsw_M = -1;
164
- int nsg_R = -1;
165
-
166
- for (char* tok = strtok_r(&description[0], " ,", &ptr); tok;
167
- tok = strtok_r(nullptr, " ,", &ptr)) {
168
- int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2, R;
169
- std::string stok(tok);
170
- nbit = 8;
171
- int bbs = -1;
172
- char c;
173
-
174
- // to avoid mem leaks with exceptions:
175
- // do all tests before any instanciation
176
-
177
- VectorTransform* vt_1 = nullptr;
178
- Index* coarse_quantizer_1 = nullptr;
179
- Index* index_1 = nullptr;
180
-
181
- // VectorTransforms
182
- if (sscanf(tok, "PCA%d", &d_out) == 1) {
183
- vt_1 = new PCAMatrix(d, d_out);
184
- d = d_out;
185
- } else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
186
- vt_1 = new PCAMatrix(d, d_out, 0, true);
187
- d = d_out;
188
- } else if (sscanf(tok, "RR%d", &d_out) == 1) {
189
- vt_1 = new RandomRotationMatrix(d, d_out);
190
- d = d_out;
191
- } else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
192
- vt_1 = new PCAMatrix(d, d_out, -0.5, false);
193
- d = d_out;
194
- } else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
195
- vt_1 = new PCAMatrix(d, d_out, -0.5, true);
196
- d = d_out;
197
- } else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
198
- vt_1 = new OPQMatrix(d, opq_M, d_out);
199
- d = d_out;
200
- } else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
201
- vt_1 = new OPQMatrix(d, opq_M);
202
- } else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
203
- vt_1 = new ITQTransform(d, d_out, true);
204
- d = d_out;
205
- } else if (stok == "ITQ") {
206
- vt_1 = new ITQTransform(d, d, false);
207
- } else if (sscanf(tok, "Pad%d", &d_out) == 1) {
208
- if (d_out > d) {
209
- vt_1 = new RemapDimensionsTransform(d, d_out, false);
210
- d = d_out;
211
- }
212
- } else if (stok == "L2norm") {
213
- vt_1 = new NormalizationTransform(d, 2.0);
214
-
215
- // coarse quantizers
216
- } else if (
217
- !coarse_quantizer &&
218
- sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
219
- coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
220
-
221
- } else if (
222
- !coarse_quantizer &&
223
- sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
224
- coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
225
-
226
- } else if (
227
- !coarse_quantizer &&
228
- sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
229
- if (!parenthesis_ivf.empty()) {
230
- coarse_quantizer_1 =
231
- index_factory(d, parenthesis_ivf.c_str(), metric);
232
- } else if (metric == METRIC_L2) {
233
- coarse_quantizer_1 = new IndexFlatL2(d);
234
- } else {
235
- coarse_quantizer_1 = new IndexFlatIP(d);
236
- }
237
-
238
- } else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
239
- FAISS_THROW_IF_NOT_MSG(
240
- metric == METRIC_L2,
241
- "MultiIndex not implemented for inner prod search");
242
- coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
243
- ncentroids = 1 << (2 * nbit);
244
-
245
- } else if (
246
- !coarse_quantizer &&
247
- sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
248
- FAISS_THROW_IF_NOT_MSG(
249
- metric == METRIC_L2,
250
- "MultiIndex not implemented for inner prod search");
251
- coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
252
- ncentroids = int64_t(1) << (M * nbit);
253
- use_2layer = true;
254
-
255
- } else if (std::regex_match(
256
- stok,
257
- std::regex(
258
- "(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
259
- std::vector<size_t> nbits;
260
- std::smatch sm;
261
- bool is_RCQ = stok.find("RCQ") == 0;
262
- while (std::regex_search(
263
- stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
264
- int M = std::stoi(sm[1].str());
265
- int nbit = std::stoi(sm[2].str());
266
- nbits.resize(nbits.size() + M, nbit);
267
- stok = sm.suffix();
268
- }
269
- if (!is_RCQ) {
270
- index_1 = new IndexResidual(d, nbits, metric);
271
- } else {
272
- index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
273
- }
274
- } else if (
275
- !coarse_quantizer &&
276
- sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
277
- coarse_quantizer_1 = new IndexFlatL2(d);
278
- use_2layer = true;
279
-
280
- } else if (stok == "IDMap") {
281
- add_idmap = true;
282
-
283
- // IVFs
284
- } else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
285
- if (coarse_quantizer) {
286
- // if there was an IVF in front, then it is an IVFFlat
287
- IndexIVF* index_ivf = stok == "Flat"
288
- ? new IndexIVFFlat(
289
- coarse_quantizer, d, ncentroids, metric)
290
- : new IndexIVFFlatDedup(
291
- coarse_quantizer, d, ncentroids, metric);
292
- index_ivf->quantizer_trains_alone =
293
- get_trains_alone(coarse_quantizer);
294
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
295
- del_coarse_quantizer.release();
296
- index_ivf->own_fields = true;
297
- index_1 = index_ivf;
298
- } else if (hnsw_M > 0) {
299
- index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
300
- } else if (nsg_R > 0) {
301
- index_1 = new IndexNSGFlat(d, nsg_R, metric);
302
- } else {
303
- FAISS_THROW_IF_NOT_MSG(
304
- stok != "FlatDedup",
305
- "dedup supported only for IVFFlat");
306
- index_1 = new IndexFlat(d, metric);
307
- }
308
- } else if (
309
- !index &&
310
- (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
311
- stok == "SQfp16")) {
312
- ScalarQuantizer::QuantizerType qt = stok == "SQ8"
313
- ? ScalarQuantizer::QT_8bit
314
- : stok == "SQ6" ? ScalarQuantizer::QT_6bit
315
- : stok == "SQ4" ? ScalarQuantizer::QT_4bit
316
- : stok == "SQfp16" ? ScalarQuantizer::QT_fp16
317
- : ScalarQuantizer::QT_4bit;
318
- if (coarse_quantizer) {
319
- FAISS_THROW_IF_NOT(!use_2layer);
320
- IndexIVFScalarQuantizer* index_ivf =
321
- new IndexIVFScalarQuantizer(
322
- coarse_quantizer, d, ncentroids, qt, metric);
323
- index_ivf->quantizer_trains_alone =
324
- get_trains_alone(coarse_quantizer);
325
- del_coarse_quantizer.release();
326
- index_ivf->own_fields = true;
327
- index_1 = index_ivf;
328
- } else if (hnsw_M > 0) {
329
- index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
330
- } else {
331
- index_1 = new IndexScalarQuantizer(d, qt, metric);
332
- }
333
- } else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
334
- FAISS_THROW_IF_NOT_MSG(
335
- coarse_quantizer, "PQ with + works only with an IVF");
336
- FAISS_THROW_IF_NOT_MSG(
337
- metric == METRIC_L2,
338
- "IVFPQR not implemented for inner product search");
339
- IndexIVFPQR* index_ivf = new IndexIVFPQR(
340
- coarse_quantizer, d, ncentroids, M, 8, M2, 8);
341
- index_ivf->quantizer_trains_alone =
342
- get_trains_alone(coarse_quantizer);
343
- del_coarse_quantizer.release();
344
- index_ivf->own_fields = true;
345
- index_1 = index_ivf;
346
- } else if (
347
- !index &&
348
- (sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
349
- (sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
350
- (sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
351
- if (bbs == -1) {
352
- bbs = 32;
353
- }
354
- bool by_residual = str_ends_with(stok, "fsr");
355
- if (coarse_quantizer) {
356
- IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
357
- coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
358
- index_ivf->quantizer_trains_alone =
359
- get_trains_alone(coarse_quantizer);
360
- index_ivf->metric_type = metric;
361
- index_ivf->by_residual = by_residual;
362
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
363
- del_coarse_quantizer.release();
364
- index_ivf->own_fields = true;
365
- index_1 = index_ivf;
366
- } else {
367
- IndexPQFastScan* index_pq =
368
- new IndexPQFastScan(d, M, 4, metric, bbs);
369
- index_1 = index_pq;
370
- }
371
- } else if (
372
- !index &&
373
- (sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
374
- sscanf(tok, "PQ%d", &M) == 1 ||
375
- sscanf(tok, "PQ%dnp", &M) == 1)) {
376
- bool do_polysemous_training = stok.find("np") == std::string::npos;
377
- if (coarse_quantizer) {
378
- if (!use_2layer) {
379
- IndexIVFPQ* index_ivf = new IndexIVFPQ(
380
- coarse_quantizer, d, ncentroids, M, nbit);
381
- index_ivf->quantizer_trains_alone =
382
- get_trains_alone(coarse_quantizer);
383
- index_ivf->metric_type = metric;
384
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
385
- del_coarse_quantizer.release();
386
- index_ivf->own_fields = true;
387
- index_ivf->do_polysemous_training = do_polysemous_training;
388
- index_1 = index_ivf;
389
- } else {
390
- Index2Layer* index_2l = new Index2Layer(
391
- coarse_quantizer, ncentroids, M, nbit);
392
- index_2l->q1.quantizer_trains_alone =
393
- get_trains_alone(coarse_quantizer);
394
- index_2l->q1.own_fields = true;
395
- index_1 = index_2l;
396
- }
397
- } else if (hnsw_M > 0) {
398
- IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
399
- dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
400
- do_polysemous_training;
401
- index_1 = ipq;
402
- } else {
403
- IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
404
- index_pq->do_polysemous_training = do_polysemous_training;
405
- index_1 = index_pq;
406
- }
407
- } else if (
408
- !index &&
409
- sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
410
- Index* quant = new IndexFlatL2(d);
411
- IndexHNSW2Level* hidx2l =
412
- new IndexHNSW2Level(quant, ncent, pq_m, M);
413
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
414
- idx2l->q1.own_fields = true;
415
- index_1 = hidx2l;
416
- } else if (
417
- !index &&
418
- sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
419
- Index* quant = new MultiIndexQuantizer(d, 2, nbit);
420
- IndexHNSW2Level* hidx2l =
421
- new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
422
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
423
- idx2l->q1.own_fields = true;
424
- idx2l->q1.quantizer_trains_alone = 1;
425
- index_1 = hidx2l;
426
- } else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
427
- index_1 = new IndexHNSWPQ(d, pq_m, M);
428
- } else if (
429
- !index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
430
- pq_m == 8) {
431
- index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
432
- } else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
433
- hnsw_M = M;
434
- // here it is unclear what we want: HNSW flat or HNSWx,Y ?
435
- } else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
436
- nsg_R = R;
437
- } else if (
438
- !index &&
439
- (stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
440
- stok == "LSHt")) {
441
- bool rotate_data = strstr(tok, "r") != nullptr;
442
- bool train_thresholds = strstr(tok, "t") != nullptr;
443
- index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
444
- } else if (
445
- !index &&
446
- sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
447
- FAISS_THROW_IF_NOT(!coarse_quantizer);
448
- index_1 = new IndexLattice(d, M, nbit, r2);
449
- } else if (stok == "RFlat") {
450
- parenthesis_refine = "Flat";
451
- } else if (stok == "Refine") {
452
- FAISS_THROW_IF_NOT_MSG(
453
- !parenthesis_refine.empty(),
454
- "Refine index should be provided in parentheses");
455
- } else {
456
- FAISS_THROW_FMT(
457
- "could not parse token \"%s\" in %s\n",
458
- tok,
459
- description_in);
342
+ return index_ivf;
343
+ }
344
+ if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
345
+ int outdim = mres_to_int(sm[2], d); // is also the number of bits
346
+ std::unique_ptr<VectorTransform> vt;
347
+ if (sm[1] == "ITQ") {
348
+ vt.reset(new ITQTransform(d, outdim, d != outdim));
349
+ } else if (sm[1] == "PCA") {
350
+ vt.reset(new PCAMatrix(d, outdim));
351
+ } else if (sm[1] == "PCAR") {
352
+ vt.reset(new PCAMatrix(d, outdim, 0, true));
460
353
  }
354
+ // the rationale for -1e10 is that this corresponds to simple
355
+ // thresholding
356
+ float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
357
+ IndexIVFSpectralHash* index_ivf =
358
+ new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
359
+ index_ivf->replace_vt(vt.release(), true);
360
+ if (sm[4].length()) {
361
+ std::string s = sm[4].str();
362
+ index_ivf->threshold_type = s == "g"
363
+ ? IndexIVFSpectralHash::Thresh_global
364
+ : s == "c"
365
+ ? IndexIVFSpectralHash::Thresh_centroid
366
+ :
367
+ /* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
368
+ }
369
+ return index_ivf;
370
+ }
371
+ return nullptr;
372
+ }
373
+
374
+ /***************************************************************
375
+ * Parse IndexHNSW
376
+ */
377
+
378
+ IndexHNSW* parse_IndexHNSW(
379
+ const std::string code_string,
380
+ int d,
381
+ MetricType mt,
382
+ int hnsw_M) {
383
+ std::smatch sm;
384
+ auto match = [&sm, &code_string](const std::string& pattern) {
385
+ return re_match(code_string, pattern, sm);
386
+ };
387
+
388
+ if (match("Flat|")) {
389
+ return new IndexHNSWFlat(d, hnsw_M, mt);
390
+ }
391
+ if (match("PQ([0-9]+)(np)?")) {
392
+ int M = std::stoi(sm[1].str());
393
+ IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
394
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
395
+ sm[2].str() != "np";
396
+ return ipq;
397
+ }
398
+ if (match(sq_pattern)) {
399
+ return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
400
+ }
401
+ if (match("([0-9]+)\\+PQ([0-9]+)?")) {
402
+ int ncent = mres_to_int(sm[1]);
403
+ int pq_m = mres_to_int(sm[2]);
404
+ IndexHNSW2Level* hidx2l =
405
+ new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
406
+ dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
407
+ return hidx2l;
408
+ }
409
+ if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
410
+ int nbit = mres_to_int(sm[1]);
411
+ int pq_m = mres_to_int(sm[2]);
412
+ Index* quant = new MultiIndexQuantizer(d, 2, nbit);
413
+ IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
414
+ quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
415
+ Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
416
+ idx2l->q1.own_fields = true;
417
+ idx2l->q1.quantizer_trains_alone = 1;
418
+ return hidx2l;
419
+ }
461
420
 
462
- if (index_1 && add_idmap) {
463
- IndexIDMap* idmap = new IndexIDMap(index_1);
464
- del_index.set(idmap);
465
- idmap->own_fields = true;
466
- index_1 = idmap;
467
- add_idmap = false;
421
+ return nullptr;
422
+ }
423
+
424
+ /***************************************************************
425
+ * Parse basic indexes
426
+ */
427
+
428
+ Index* parse_other_indexes(
429
+ const std::string& description,
430
+ int d,
431
+ MetricType metric) {
432
+ std::smatch sm;
433
+ auto match = [&sm, description](const std::string& pattern) {
434
+ return re_match(description, pattern, sm);
435
+ };
436
+
437
+ // IndexFlat
438
+ if (description == "Flat") {
439
+ return new IndexFlat(d, metric);
440
+ }
441
+
442
+ // IndexLSH
443
+ if (match("LSH(r?)(t?)")) {
444
+ bool rotate_data = sm[1].length() > 0;
445
+ bool train_thresholds = sm[2].length() > 0;
446
+ FAISS_THROW_IF_NOT(metric == METRIC_L2);
447
+ return new IndexLSH(d, d, rotate_data, train_thresholds);
448
+ }
449
+
450
+ // IndexLattice
451
+ if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
452
+ int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
453
+ int nbit = std::stoi(sm[3].str());
454
+ return new IndexLattice(d, M, nbit, r2);
455
+ }
456
+
457
+ // IndexNSGFlat
458
+ if (match("NSG([0-9]+)(,Flat)?")) {
459
+ return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
460
+ }
461
+
462
+ // IndexScalarQuantizer
463
+ if (match(sq_pattern)) {
464
+ return new IndexScalarQuantizer(d, sq_types[description], metric);
465
+ }
466
+
467
+ // IndexPQ
468
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
469
+ int M = std::stoi(sm[1].str());
470
+ int nbit = mres_to_int(sm[2], 8, 1);
471
+ IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
472
+ index_pq->do_polysemous_training = sm[3].str() != "np";
473
+ return index_pq;
474
+ }
475
+
476
+ // IndexPQFastScan
477
+ if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
478
+ int M = std::stoi(sm[1].str());
479
+ int bbs = mres_to_int(sm[2], 32, 1);
480
+ return new IndexPQFastScan(d, M, 4, metric, bbs);
481
+ }
482
+
483
+ // IndexResidualCoarseQuantizer and IndexResidualQuantizer
484
+ std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
485
+ if (match(pattern)) {
486
+ std::vector<size_t> nbits = aq_parse_nbits(description);
487
+ if (sm[1].str() == "RCQ") {
488
+ return new ResidualCoarseQuantizer(d, nbits, metric);
468
489
  }
490
+ AdditiveQuantizer::Search_type_t st =
491
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
492
+ return new IndexResidualQuantizer(d, nbits, metric, st);
493
+ }
469
494
 
470
- if (vt_1) {
471
- vts.chain.push_back(vt_1);
495
+ // LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
496
+ if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
497
+ std::vector<size_t> nbits = aq_parse_nbits(description);
498
+ int M = mres_to_int(sm[2]);
499
+ int nbit = mres_to_int(sm[3]);
500
+ if (sm[1].str() == "LSCQ") {
501
+ return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
472
502
  }
503
+ AdditiveQuantizer::Search_type_t st =
504
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
505
+ return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
506
+ }
507
+
508
+ return nullptr;
509
+ }
510
+
511
+ /***************************************************************
512
+ * Driver function
513
+ */
514
+ std::unique_ptr<Index> index_factory_sub(
515
+ int d,
516
+ std::string description,
517
+ MetricType metric) {
518
+ // handle composite indexes
519
+
520
+ bool verbose = index_factory_verbose;
473
521
 
474
- if (coarse_quantizer_1) {
475
- coarse_quantizer = coarse_quantizer_1;
476
- del_coarse_quantizer.set(coarse_quantizer);
522
+ if (verbose) {
523
+ printf("begin parse VectorTransforms: %s \n", description.c_str());
524
+ }
525
+
526
+ // for the current match
527
+ std::smatch sm;
528
+
529
+ // handle refines
530
+ if (re_match(description, "(.+),RFlat", sm) ||
531
+ re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
532
+ std::unique_ptr<Index> filter_index =
533
+ index_factory_sub(d, sm[1].str(), metric);
534
+ std::unique_ptr<Index> refine_index;
535
+
536
+ if (sm.size() == 3) { // Refine
537
+ refine_index = index_factory_sub(d, sm[2].str(), metric);
538
+ } else { // RFlat
539
+ refine_index.reset(new IndexFlat(d, metric));
477
540
  }
541
+ IndexRefine* index_rf =
542
+ new IndexRefine(filter_index.get(), refine_index.get());
543
+ index_rf->own_fields = true;
544
+ filter_index.release();
545
+ refine_index.release();
546
+ index_rf->own_refine_index = true;
547
+ return std::unique_ptr<Index>(index_rf);
548
+ }
478
549
 
479
- if (index_1) {
480
- index = index_1;
481
- del_index.set(index);
550
+ // IndexPreTransform
551
+ // should handle this first (even before parentheses) because it changes d
552
+ std::vector<std::unique_ptr<VectorTransform>> vts;
553
+ VectorTransform* vt = nullptr;
554
+ while (re_match(description, "([^,]+),(.*)", sm) &&
555
+ (vt = parse_VectorTransform(sm[1], d))) {
556
+ // reset loop
557
+ description = sm[sm.size() - 1];
558
+ vts.emplace_back(vt);
559
+ d = vts.back()->d_out;
560
+ }
561
+
562
+ if (vts.size() > 0) {
563
+ std::unique_ptr<Index> sub_index =
564
+ index_factory_sub(d, description, metric);
565
+ IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
566
+ std::unique_ptr<Index> ret(index_pt);
567
+ index_pt->own_fields = true;
568
+ sub_index.release();
569
+ while (vts.size() > 0) {
570
+ if (verbose) {
571
+ printf("prepend trans %d -> %d\n",
572
+ vts.back()->d_in,
573
+ vts.back()->d_out);
574
+ }
575
+ index_pt->prepend_transform(vts.back().release());
576
+ vts.pop_back();
482
577
  }
578
+ return ret;
483
579
  }
484
580
 
485
- if (!index && hnsw_M > 0) {
486
- index = new IndexHNSWFlat(d, hnsw_M, metric);
487
- del_index.set(index);
488
- } else if (!index && nsg_R > 0) {
489
- index = new IndexNSGFlat(d, nsg_R, metric);
490
- del_index.set(index);
581
+ // what we got from the parentheses
582
+ std::vector<std::unique_ptr<Index>> parenthesis_indexes;
583
+
584
+ int begin = 0;
585
+ while (description.find('(', begin) != std::string::npos) {
586
+ // replace indexes in () with Index0, Index1, etc.
587
+ int i0, i1;
588
+ find_matching_parentheses(description, i0, i1, begin);
589
+ std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
590
+ int no = parenthesis_indexes.size();
591
+ parenthesis_indexes.push_back(
592
+ index_factory_sub(d, sub_description, metric));
593
+ description = description.substr(0, i0 + 1) + "Index" +
594
+ std::to_string(no) + description.substr(i1);
595
+ begin = i1 + 1;
491
596
  }
492
597
 
493
- FAISS_THROW_IF_NOT_FMT(
494
- index, "description %s did not generate an index", description_in);
598
+ if (verbose) {
599
+ printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
600
+ description.c_str(),
601
+ parenthesis_indexes.size(),
602
+ d);
603
+ }
495
604
 
496
- // nothing can go wrong now
497
- del_index.release();
498
- del_coarse_quantizer.release();
605
+ // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
606
+ // support both
607
+ if (re_match(description, "(.+),IDMap", sm) ||
608
+ re_match(description, "IDMap,(.+)", sm)) {
609
+ IndexIDMap* idmap = new IndexIDMap(
610
+ index_factory_sub(d, sm[1].str(), metric).release());
611
+ idmap->own_fields = true;
612
+ return std::unique_ptr<Index>(idmap);
613
+ }
499
614
 
500
- if (add_idmap) {
501
- fprintf(stderr,
502
- "index_factory: WARNING: "
503
- "IDMap option not used\n");
615
+ { // handle basic index types
616
+ Index* index = parse_other_indexes(description, d, metric);
617
+ if (index) {
618
+ return std::unique_ptr<Index>(index);
619
+ }
504
620
  }
505
621
 
506
- if (vts.chain.size() > 0) {
507
- IndexPreTransform* index_pt = new IndexPreTransform(index);
508
- index_pt->own_fields = true;
509
- // add from back
510
- while (vts.chain.size() > 0) {
511
- index_pt->prepend_transform(vts.chain.back());
512
- vts.chain.pop_back();
622
+ // HNSW variants (it was unclear in the old version that the separator was a
623
+ // "," so we support both "_" and ",")
624
+ if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
625
+ int hnsw_M = mres_to_int(sm[1], 32);
626
+ // We also accept empty code string (synonym of Flat)
627
+ std::string code_string =
628
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
629
+ if (verbose) {
630
+ printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
631
+ description.c_str(),
632
+ code_string.c_str(),
633
+ hnsw_M);
513
634
  }
514
- index = index_pt;
635
+
636
+ IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
637
+ FAISS_THROW_IF_NOT_FMT(
638
+ index,
639
+ "could not parse HNSW code description %s in %s",
640
+ code_string.c_str(),
641
+ description.c_str());
642
+ return std::unique_ptr<Index>(index);
515
643
  }
516
644
 
517
- if (!parenthesis_refine.empty()) {
518
- Index* refine_index =
519
- index_factory(d_in, parenthesis_refine.c_str(), metric);
520
- IndexRefine* index_rf = new IndexRefine(index, refine_index);
521
- index_rf->own_refine_index = true;
522
- index_rf->own_fields = true;
523
- index = index_rf;
645
+ // IndexIVF
646
+ {
647
+ size_t nlist;
648
+ bool use_2layer;
649
+ size_t comma = description.find(",");
650
+ std::string coarse_string = description.substr(0, comma);
651
+ // Match coarse quantizer part first
652
+ std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
653
+ description.substr(0, comma),
654
+ d,
655
+ metric,
656
+ parenthesis_indexes,
657
+ nlist,
658
+ use_2layer));
659
+
660
+ if (comma != std::string::npos && quantizer.get()) {
661
+ std::string code_description = description.substr(comma + 1);
662
+ if (use_2layer) {
663
+ bool ok =
664
+ re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
665
+ FAISS_THROW_IF_NOT_FMT(
666
+ ok,
667
+ "could not parse 2 layer code description %s in %s",
668
+ code_description.c_str(),
669
+ description.c_str());
670
+ int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
671
+ Index2Layer* index_2l =
672
+ new Index2Layer(quantizer.release(), nlist, M, nbit);
673
+ index_2l->q1.own_fields = true;
674
+ index_2l->q1.quantizer_trains_alone =
675
+ get_trains_alone(index_2l->q1.quantizer);
676
+ return std::unique_ptr<Index>(index_2l);
677
+ }
678
+
679
+ IndexIVF* index_ivf =
680
+ parse_IndexIVF(code_description, quantizer, nlist, metric);
681
+
682
+ FAISS_THROW_IF_NOT_FMT(
683
+ index_ivf,
684
+ "could not parse code description %s in %s",
685
+ code_description.c_str(),
686
+ description.c_str());
687
+ return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
688
+ }
524
689
  }
690
+ FAISS_THROW_FMT("could not parse index string %s", description.c_str());
691
+ return nullptr;
692
+ }
525
693
 
526
- return index;
694
+ } // anonymous namespace
695
+
696
+ Index* index_factory(int d, const char* description, MetricType metric) {
697
+ return index_factory_sub(d, description, metric).release();
527
698
  }
528
699
 
529
700
  IndexBinary* index_binary_factory(int d, const char* description) {