faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,19 +5,19 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /*
11
- * implementation of Hyper-parameter auto-tuning
9
+ * implementation of the index_factory function. Lots of regex parsing code.
12
10
  */
13
11
 
14
12
  #include <faiss/index_factory.h>
15
-
16
- #include <faiss/AutoTune.h>
13
+ #include "faiss/MetricType.h"
14
+ #include "faiss/impl/FaissAssert.h"
17
15
 
18
16
  #include <cinttypes>
19
17
  #include <cmath>
20
18
 
19
+ #include <map>
20
+
21
21
  #include <regex>
22
22
 
23
23
  #include <faiss/impl/FaissAssert.h>
@@ -25,13 +25,18 @@
25
25
  #include <faiss/utils/utils.h>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
+ #include <faiss/IndexAdditiveQuantizer.h>
29
+ #include <faiss/IndexAdditiveQuantizerFastScan.h>
28
30
  #include <faiss/IndexFlat.h>
29
31
  #include <faiss/IndexHNSW.h>
30
32
  #include <faiss/IndexIVF.h>
33
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
34
+ #include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
31
35
  #include <faiss/IndexIVFFlat.h>
32
36
  #include <faiss/IndexIVFPQ.h>
33
37
  #include <faiss/IndexIVFPQFastScan.h>
34
38
  #include <faiss/IndexIVFPQR.h>
39
+ #include <faiss/IndexIVFSpectralHash.h>
35
40
  #include <faiss/IndexLSH.h>
36
41
  #include <faiss/IndexLattice.h>
37
42
  #include <faiss/IndexNSG.h>
@@ -39,7 +44,7 @@
39
44
  #include <faiss/IndexPQFastScan.h>
40
45
  #include <faiss/IndexPreTransform.h>
41
46
  #include <faiss/IndexRefine.h>
42
- #include <faiss/IndexResidual.h>
47
+ #include <faiss/IndexRowwiseMinMax.h>
43
48
  #include <faiss/IndexScalarQuantizer.h>
44
49
  #include <faiss/MetaIndexes.h>
45
50
  #include <faiss/VectorTransform.h>
@@ -48,6 +53,7 @@
48
53
  #include <faiss/IndexBinaryHNSW.h>
49
54
  #include <faiss/IndexBinaryHash.h>
50
55
  #include <faiss/IndexBinaryIVF.h>
56
+ #include <string>
51
57
 
52
58
  namespace faiss {
53
59
 
@@ -55,16 +61,49 @@ namespace faiss {
55
61
  * index_factory
56
62
  ***************************************************************/
57
63
 
64
+ int index_factory_verbose = 0;
65
+
58
66
  namespace {
59
67
 
60
- struct VTChain {
61
- std::vector<VectorTransform*> chain;
62
- ~VTChain() {
63
- for (int i = 0; i < chain.size(); i++) {
64
- delete chain[i];
68
+ /***************************************************************
69
+ * Small functions
70
+ */
71
+
72
+ bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
73
+ return std::regex_match(s, sm, std::regex(pat));
74
+ }
75
+
76
+ // find first pair of matching parentheses
77
+ void find_matching_parentheses(
78
+ const std::string& s,
79
+ int& i0,
80
+ int& i1,
81
+ int begin = 0) {
82
+ int st = 0;
83
+ i0 = i1 = 0;
84
+ for (int i = begin; i < s.length(); i++) {
85
+ if (s[i] == '(') {
86
+ if (st == 0) {
87
+ i0 = i;
88
+ }
89
+ st++;
90
+ }
91
+
92
+ if (s[i] == ')') {
93
+ st--;
94
+ if (st == 0) {
95
+ i1 = i;
96
+ return;
97
+ }
98
+ if (st < 0) {
99
+ FAISS_THROW_FMT(
100
+ "factory string %s: unbalanced parentheses", s.c_str());
101
+ }
65
102
  }
66
103
  }
67
- };
104
+ FAISS_THROW_FMT(
105
+ "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
106
+ }
68
107
 
69
108
  /// what kind of training does this coarse quantizer require?
70
109
  char get_trains_alone(const Index* coarse_quantizer) {
@@ -83,447 +122,768 @@ char get_trains_alone(const Index* coarse_quantizer) {
83
122
  // kmeans index
84
123
  }
85
124
 
86
- bool str_ends_with(const std::string& s, const std::string& suffix) {
87
- return s.rfind(suffix) == std::abs(int(s.size() - suffix.size()));
125
+ // set the fields for factory-constructed IVF structures
126
+ IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
127
+ index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
128
+ index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
129
+ index_ivf->own_fields = true;
130
+ return index_ivf;
88
131
  }
89
132
 
90
- // check if ends with suffix followed by digits
91
- bool str_ends_with_digits(const std::string& s, const std::string& suffix) {
92
- int i;
93
- for (i = s.length() - 1; i >= 0; i--) {
94
- if (!isdigit(s[i]))
95
- break;
133
+ int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
134
+ if (mr.length() == 0) {
135
+ return deflt;
96
136
  }
97
- return str_ends_with(s.substr(0, i + 1), suffix);
137
+ return std::stoi(mr.str().substr(begin));
98
138
  }
99
139
 
100
- void find_matching_parentheses(const std::string& s, int& i0, int& i1) {
101
- int st = 0;
102
- i0 = i1 = 0;
103
- for (int i = 0; i < s.length(); i++) {
104
- if (s[i] == '(') {
105
- if (st == 0) {
106
- i0 = i;
107
- }
108
- st++;
109
- }
140
+ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
141
+ {"SQ8", ScalarQuantizer::QT_8bit},
142
+ {"SQ4", ScalarQuantizer::QT_4bit},
143
+ {"SQ6", ScalarQuantizer::QT_6bit},
144
+ {"SQfp16", ScalarQuantizer::QT_fp16},
145
+ };
146
+ const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
147
+
148
+ std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
149
+ {"_Nfloat", AdditiveQuantizer::ST_norm_float},
150
+ {"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
151
+ {"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
152
+ {"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
153
+ {"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
154
+ {"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
155
+ {"_Nlsq2x4", AdditiveQuantizer::ST_norm_lsq2x4},
156
+ {"_Nrq2x4", AdditiveQuantizer::ST_norm_rq2x4},
157
+ };
110
158
 
111
- if (s[i] == ')') {
112
- st--;
113
- if (st == 0) {
114
- i1 = i;
115
- return;
116
- }
117
- if (st < 0) {
118
- FAISS_THROW_FMT(
119
- "factory string %s: unbalanced parentheses", s.c_str());
120
- }
121
- }
159
+ const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
160
+ const std::string aq_norm_pattern =
161
+ "(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4|_Nlsq2x4|_Nrq2x4)";
162
+
163
+ const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
164
+
165
+ AdditiveQuantizer::Search_type_t aq_parse_search_type(
166
+ std::string stok,
167
+ MetricType metric) {
168
+ if (stok == "") {
169
+ return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
170
+ : AdditiveQuantizer::ST_LUT_nonorm;
122
171
  }
123
- FAISS_THROW_FMT(
124
- "factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
172
+ int pos = stok.rfind("_");
173
+ return aq_search_type[stok.substr(pos)];
125
174
  }
126
175
 
127
- } // anonymous namespace
176
+ std::vector<size_t> aq_parse_nbits(std::string stok) {
177
+ std::vector<size_t> nbits;
178
+ std::smatch sm;
179
+ while (std::regex_search(stok, sm, std::regex("[^q]([0-9]+)x([0-9]+)"))) {
180
+ int M = std::stoi(sm[1].str());
181
+ int nbit = std::stoi(sm[2].str());
182
+ nbits.resize(nbits.size() + M, nbit);
183
+ stok = sm.suffix();
184
+ }
185
+ return nbits;
186
+ }
128
187
 
129
- Index* index_factory(int d, const char* description_in, MetricType metric) {
130
- FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
131
- VTChain vts;
132
- Index* coarse_quantizer = nullptr;
133
- std::string parenthesis_ivf, parenthesis_refine;
134
- Index* index = nullptr;
135
- bool add_idmap = false;
136
- int d_in = d;
188
+ /***************************************************************
189
+ * Parse VectorTransform
190
+ */
137
191
 
138
- ScopeDeleter1<Index> del_coarse_quantizer, del_index;
192
+ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
193
+ std::smatch sm;
194
+ auto match = [&sm, description](std::string pattern) {
195
+ return re_match(description, pattern, sm);
196
+ };
197
+ if (match("PCA(W?)(R?)([0-9]+)")) {
198
+ bool white = sm[1].length() > 0;
199
+ bool rot = sm[2].length() > 0;
200
+ return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
201
+ }
202
+ if (match("L2[nN]orm")) {
203
+ return new NormalizationTransform(d, 2.0);
204
+ }
205
+ if (match("RR([0-9]+)?")) {
206
+ return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
207
+ }
208
+ if (match("ITQ([0-9]+)?")) {
209
+ return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
210
+ }
211
+ if (match("OPQ([0-9]+)(_[0-9]+)?")) {
212
+ int M = std::stoi(sm[1].str());
213
+ int d_out = mres_to_int(sm[2], d, 1);
214
+ return new OPQMatrix(d, M, d_out);
215
+ }
216
+ if (match("Pad([0-9]+)")) {
217
+ int d_out = std::stoi(sm[1].str());
218
+ return new RemapDimensionsTransform(d, std::max(d_out, d), false);
219
+ }
220
+ return nullptr;
221
+ };
139
222
 
140
- std::string description(description_in);
141
- char* ptr;
223
+ /***************************************************************
224
+ * Parse IndexIVF
225
+ */
142
226
 
143
- // handle indexes in parentheses
144
- while (description.find('(') != std::string::npos) {
145
- // then we make a sub-index and remove the () from the description
146
- int i0, i1;
147
- find_matching_parentheses(description, i0, i1);
227
+ // parsing guard + function
228
+ Index* parse_coarse_quantizer(
229
+ const std::string& description,
230
+ int d,
231
+ MetricType mt,
232
+ std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
233
+ size_t& nlist,
234
+ bool& use_2layer) {
235
+ std::smatch sm;
236
+ auto match = [&sm, description](std::string pattern) {
237
+ return re_match(description, pattern, sm);
238
+ };
239
+ use_2layer = false;
240
+
241
+ if (match("IVF([0-9]+)")) {
242
+ nlist = std::stoi(sm[1].str());
243
+ return new IndexFlat(d, mt);
244
+ }
245
+ if (match("IMI2x([0-9]+)")) {
246
+ int nbit = std::stoi(sm[1].str());
247
+ FAISS_THROW_IF_NOT_MSG(
248
+ mt == METRIC_L2,
249
+ "MultiIndex not implemented for inner prod search");
250
+ nlist = (size_t)1 << (2 * nbit);
251
+ return new MultiIndexQuantizer(d, 2, nbit);
252
+ }
253
+ if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
254
+ nlist = std::stoi(sm[1].str());
255
+ int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
256
+ return new IndexHNSWFlat(d, hnsw_M, mt);
257
+ }
258
+ if (match("IVF([0-9]+)_NSG([0-9]+)")) {
259
+ nlist = std::stoi(sm[1].str());
260
+ int R = std::stoi(sm[2]);
261
+ return new IndexNSGFlat(d, R, mt);
262
+ }
263
+ if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
264
+ nlist = std::stoi(sm[1].str());
265
+ int no = std::stoi(sm[2].str());
266
+ FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
267
+ return parenthesis_indexes[no].release();
268
+ }
148
269
 
149
- std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
270
+ // these two generate Index2Layer's not IndexIVF's
271
+ if (match("Residual([0-9]+)x([0-9]+)")) {
272
+ FAISS_THROW_IF_NOT_MSG(
273
+ mt == METRIC_L2,
274
+ "MultiIndex not implemented for inner prod search");
275
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
276
+ nlist = (size_t)1 << (M * nbit);
277
+ use_2layer = true;
278
+ return new MultiIndexQuantizer(d, M, nbit);
279
+ }
280
+ if (match("Residual([0-9]+)")) {
281
+ FAISS_THROW_IF_NOT_MSG(
282
+ mt == METRIC_L2,
283
+ "Residual not implemented for inner prod search");
284
+ use_2layer = true;
285
+ nlist = mres_to_int(sm[1]);
286
+ return new IndexFlatL2(d);
287
+ }
288
+ return nullptr;
289
+ }
150
290
 
151
- if (str_ends_with_digits(description.substr(0, i0), "IVF")) {
152
- parenthesis_ivf = sub_description;
153
- } else if (str_ends_with(description.substr(0, i0), "Refine")) {
154
- parenthesis_refine = sub_description;
291
+ // parse the code description part of the IVF description
292
+
293
+ IndexIVF* parse_IndexIVF(
294
+ const std::string& code_string,
295
+ std::unique_ptr<Index>& quantizer,
296
+ size_t nlist,
297
+ MetricType mt) {
298
+ std::smatch sm;
299
+ auto match = [&sm, &code_string](const std::string pattern) {
300
+ return re_match(code_string, pattern, sm);
301
+ };
302
+ auto get_q = [&quantizer] { return quantizer.release(); };
303
+ int d = quantizer->d;
304
+
305
+ if (match("Flat")) {
306
+ return new IndexIVFFlat(get_q(), d, nlist, mt);
307
+ }
308
+ if (match("FlatDedup")) {
309
+ return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
310
+ }
311
+ if (match(sq_pattern)) {
312
+ return new IndexIVFScalarQuantizer(
313
+ get_q(), d, nlist, sq_types[sm[1].str()], mt);
314
+ }
315
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
316
+ int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
317
+ IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
318
+ index_ivf->do_polysemous_training = sm[3].str() != "np";
319
+ return index_ivf;
320
+ }
321
+ if (match("PQ([0-9]+)\\+([0-9]+)")) {
322
+ FAISS_THROW_IF_NOT_MSG(
323
+ mt == METRIC_L2,
324
+ "IVFPQR not implemented for inner product search");
325
+ int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
326
+ return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
327
+ }
328
+ if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
329
+ int M = mres_to_int(sm[1]);
330
+ int bbs = mres_to_int(sm[3], 32, 1);
331
+ IndexIVFPQFastScan* index_ivf =
332
+ new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
333
+ index_ivf->by_residual = sm[2].str() == "r";
334
+ return index_ivf;
335
+ }
336
+ if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
337
+ std::vector<size_t> nbits = aq_parse_nbits(sm.str());
338
+ AdditiveQuantizer::Search_type_t st =
339
+ aq_parse_search_type(sm[sm.size() - 1].str(), mt);
340
+ IndexIVF* index_ivf;
341
+ if (sm[1].str() == "RQ") {
342
+ index_ivf = new IndexIVFResidualQuantizer(
343
+ get_q(), d, nlist, nbits, mt, st);
155
344
  } else {
156
- FAISS_THROW_MSG("don't know what to do with parenthesis index");
345
+ FAISS_THROW_IF_NOT(nbits.size() > 0);
346
+ index_ivf = new IndexIVFLocalSearchQuantizer(
347
+ get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
157
348
  }
158
- description = description.erase(i0, i1 - i0 + 1);
159
- }
160
-
161
- int64_t ncentroids = -1;
162
- bool use_2layer = false;
163
- int hnsw_M = -1;
164
- int nsg_R = -1;
165
-
166
- for (char* tok = strtok_r(&description[0], " ,", &ptr); tok;
167
- tok = strtok_r(nullptr, " ,", &ptr)) {
168
- int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2, R;
169
- std::string stok(tok);
170
- nbit = 8;
171
- int bbs = -1;
172
- char c;
173
-
174
- // to avoid mem leaks with exceptions:
175
- // do all tests before any instanciation
176
-
177
- VectorTransform* vt_1 = nullptr;
178
- Index* coarse_quantizer_1 = nullptr;
179
- Index* index_1 = nullptr;
180
-
181
- // VectorTransforms
182
- if (sscanf(tok, "PCA%d", &d_out) == 1) {
183
- vt_1 = new PCAMatrix(d, d_out);
184
- d = d_out;
185
- } else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
186
- vt_1 = new PCAMatrix(d, d_out, 0, true);
187
- d = d_out;
188
- } else if (sscanf(tok, "RR%d", &d_out) == 1) {
189
- vt_1 = new RandomRotationMatrix(d, d_out);
190
- d = d_out;
191
- } else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
192
- vt_1 = new PCAMatrix(d, d_out, -0.5, false);
193
- d = d_out;
194
- } else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
195
- vt_1 = new PCAMatrix(d, d_out, -0.5, true);
196
- d = d_out;
197
- } else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
198
- vt_1 = new OPQMatrix(d, opq_M, d_out);
199
- d = d_out;
200
- } else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
201
- vt_1 = new OPQMatrix(d, opq_M);
202
- } else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
203
- vt_1 = new ITQTransform(d, d_out, true);
204
- d = d_out;
205
- } else if (stok == "ITQ") {
206
- vt_1 = new ITQTransform(d, d, false);
207
- } else if (sscanf(tok, "Pad%d", &d_out) == 1) {
208
- if (d_out > d) {
209
- vt_1 = new RemapDimensionsTransform(d, d_out, false);
210
- d = d_out;
211
- }
212
- } else if (stok == "L2norm") {
213
- vt_1 = new NormalizationTransform(d, 2.0);
214
-
215
- // coarse quantizers
216
- } else if (
217
- !coarse_quantizer &&
218
- sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
219
- coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
220
-
221
- } else if (
222
- !coarse_quantizer &&
223
- sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
224
- coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
225
-
226
- } else if (
227
- !coarse_quantizer &&
228
- sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
229
- if (!parenthesis_ivf.empty()) {
230
- coarse_quantizer_1 =
231
- index_factory(d, parenthesis_ivf.c_str(), metric);
232
- } else if (metric == METRIC_L2) {
233
- coarse_quantizer_1 = new IndexFlatL2(d);
234
- } else {
235
- coarse_quantizer_1 = new IndexFlatIP(d);
236
- }
237
-
238
- } else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
239
- FAISS_THROW_IF_NOT_MSG(
240
- metric == METRIC_L2,
241
- "MultiIndex not implemented for inner prod search");
242
- coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
243
- ncentroids = 1 << (2 * nbit);
244
-
245
- } else if (
246
- !coarse_quantizer &&
247
- sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
248
- FAISS_THROW_IF_NOT_MSG(
249
- metric == METRIC_L2,
250
- "MultiIndex not implemented for inner prod search");
251
- coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
252
- ncentroids = int64_t(1) << (M * nbit);
253
- use_2layer = true;
254
-
255
- } else if (std::regex_match(
256
- stok,
257
- std::regex(
258
- "(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
259
- std::vector<size_t> nbits;
260
- std::smatch sm;
261
- bool is_RCQ = stok.find("RCQ") == 0;
262
- while (std::regex_search(
263
- stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
264
- int M = std::stoi(sm[1].str());
265
- int nbit = std::stoi(sm[2].str());
266
- nbits.resize(nbits.size() + M, nbit);
267
- stok = sm.suffix();
268
- }
269
- if (!is_RCQ) {
270
- index_1 = new IndexResidual(d, nbits, metric);
271
- } else {
272
- index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
273
- }
274
- } else if (
275
- !coarse_quantizer &&
276
- sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
277
- coarse_quantizer_1 = new IndexFlatL2(d);
278
- use_2layer = true;
279
-
280
- } else if (stok == "IDMap") {
281
- add_idmap = true;
282
-
283
- // IVFs
284
- } else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
285
- if (coarse_quantizer) {
286
- // if there was an IVF in front, then it is an IVFFlat
287
- IndexIVF* index_ivf = stok == "Flat"
288
- ? new IndexIVFFlat(
289
- coarse_quantizer, d, ncentroids, metric)
290
- : new IndexIVFFlatDedup(
291
- coarse_quantizer, d, ncentroids, metric);
292
- index_ivf->quantizer_trains_alone =
293
- get_trains_alone(coarse_quantizer);
294
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
295
- del_coarse_quantizer.release();
296
- index_ivf->own_fields = true;
297
- index_1 = index_ivf;
298
- } else if (hnsw_M > 0) {
299
- index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
300
- } else if (nsg_R > 0) {
301
- index_1 = new IndexNSGFlat(d, nsg_R, metric);
302
- } else {
303
- FAISS_THROW_IF_NOT_MSG(
304
- stok != "FlatDedup",
305
- "dedup supported only for IVFFlat");
306
- index_1 = new IndexFlat(d, metric);
307
- }
308
- } else if (
309
- !index &&
310
- (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
311
- stok == "SQfp16")) {
312
- ScalarQuantizer::QuantizerType qt = stok == "SQ8"
313
- ? ScalarQuantizer::QT_8bit
314
- : stok == "SQ6" ? ScalarQuantizer::QT_6bit
315
- : stok == "SQ4" ? ScalarQuantizer::QT_4bit
316
- : stok == "SQfp16" ? ScalarQuantizer::QT_fp16
317
- : ScalarQuantizer::QT_4bit;
318
- if (coarse_quantizer) {
319
- FAISS_THROW_IF_NOT(!use_2layer);
320
- IndexIVFScalarQuantizer* index_ivf =
321
- new IndexIVFScalarQuantizer(
322
- coarse_quantizer, d, ncentroids, qt, metric);
323
- index_ivf->quantizer_trains_alone =
324
- get_trains_alone(coarse_quantizer);
325
- del_coarse_quantizer.release();
326
- index_ivf->own_fields = true;
327
- index_1 = index_ivf;
328
- } else if (hnsw_M > 0) {
329
- index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
330
- } else {
331
- index_1 = new IndexScalarQuantizer(d, qt, metric);
332
- }
333
- } else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
334
- FAISS_THROW_IF_NOT_MSG(
335
- coarse_quantizer, "PQ with + works only with an IVF");
336
- FAISS_THROW_IF_NOT_MSG(
337
- metric == METRIC_L2,
338
- "IVFPQR not implemented for inner product search");
339
- IndexIVFPQR* index_ivf = new IndexIVFPQR(
340
- coarse_quantizer, d, ncentroids, M, 8, M2, 8);
341
- index_ivf->quantizer_trains_alone =
342
- get_trains_alone(coarse_quantizer);
343
- del_coarse_quantizer.release();
344
- index_ivf->own_fields = true;
345
- index_1 = index_ivf;
346
- } else if (
347
- !index &&
348
- (sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
349
- (sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
350
- (sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
351
- if (bbs == -1) {
352
- bbs = 32;
353
- }
354
- bool by_residual = str_ends_with(stok, "fsr");
355
- if (coarse_quantizer) {
356
- IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
357
- coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
358
- index_ivf->quantizer_trains_alone =
359
- get_trains_alone(coarse_quantizer);
360
- index_ivf->metric_type = metric;
361
- index_ivf->by_residual = by_residual;
362
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
363
- del_coarse_quantizer.release();
364
- index_ivf->own_fields = true;
365
- index_1 = index_ivf;
366
- } else {
367
- IndexPQFastScan* index_pq =
368
- new IndexPQFastScan(d, M, 4, metric, bbs);
369
- index_1 = index_pq;
370
- }
371
- } else if (
372
- !index &&
373
- (sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
374
- sscanf(tok, "PQ%d", &M) == 1 ||
375
- sscanf(tok, "PQ%dnp", &M) == 1)) {
376
- bool do_polysemous_training = stok.find("np") == std::string::npos;
377
- if (coarse_quantizer) {
378
- if (!use_2layer) {
379
- IndexIVFPQ* index_ivf = new IndexIVFPQ(
380
- coarse_quantizer, d, ncentroids, M, nbit);
381
- index_ivf->quantizer_trains_alone =
382
- get_trains_alone(coarse_quantizer);
383
- index_ivf->metric_type = metric;
384
- index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
385
- del_coarse_quantizer.release();
386
- index_ivf->own_fields = true;
387
- index_ivf->do_polysemous_training = do_polysemous_training;
388
- index_1 = index_ivf;
389
- } else {
390
- Index2Layer* index_2l = new Index2Layer(
391
- coarse_quantizer, ncentroids, M, nbit);
392
- index_2l->q1.quantizer_trains_alone =
393
- get_trains_alone(coarse_quantizer);
394
- index_2l->q1.own_fields = true;
395
- index_1 = index_2l;
396
- }
397
- } else if (hnsw_M > 0) {
398
- IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
399
- dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
400
- do_polysemous_training;
401
- index_1 = ipq;
402
- } else {
403
- IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
404
- index_pq->do_polysemous_training = do_polysemous_training;
405
- index_1 = index_pq;
406
- }
407
- } else if (
408
- !index &&
409
- sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
410
- Index* quant = new IndexFlatL2(d);
411
- IndexHNSW2Level* hidx2l =
412
- new IndexHNSW2Level(quant, ncent, pq_m, M);
413
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
414
- idx2l->q1.own_fields = true;
415
- index_1 = hidx2l;
416
- } else if (
417
- !index &&
418
- sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
419
- Index* quant = new MultiIndexQuantizer(d, 2, nbit);
420
- IndexHNSW2Level* hidx2l =
421
- new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
422
- Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
423
- idx2l->q1.own_fields = true;
424
- idx2l->q1.quantizer_trains_alone = 1;
425
- index_1 = hidx2l;
426
- } else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
427
- index_1 = new IndexHNSWPQ(d, pq_m, M);
428
- } else if (
429
- !index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
430
- pq_m == 8) {
431
- index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
432
- } else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
433
- hnsw_M = M;
434
- // here it is unclear what we want: HNSW flat or HNSWx,Y ?
435
- } else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
436
- nsg_R = R;
437
- } else if (
438
- !index &&
439
- (stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
440
- stok == "LSHt")) {
441
- bool rotate_data = strstr(tok, "r") != nullptr;
442
- bool train_thresholds = strstr(tok, "t") != nullptr;
443
- index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
444
- } else if (
445
- !index &&
446
- sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
447
- FAISS_THROW_IF_NOT(!coarse_quantizer);
448
- index_1 = new IndexLattice(d, M, nbit, r2);
449
- } else if (stok == "RFlat") {
450
- parenthesis_refine = "Flat";
451
- } else if (stok == "Refine") {
452
- FAISS_THROW_IF_NOT_MSG(
453
- !parenthesis_refine.empty(),
454
- "Refine index should be provided in parentheses");
349
+ return index_ivf;
350
+ }
351
+ if (match("(PRQ|PLSQ)" + paq_def_pattern + aq_norm_pattern)) {
352
+ int nsplits = mres_to_int(sm[2]);
353
+ int Msub = mres_to_int(sm[3]);
354
+ int nbit = mres_to_int(sm[4]);
355
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
356
+ IndexIVF* index_ivf;
357
+ if (sm[1].str() == "PRQ") {
358
+ index_ivf = new IndexIVFProductResidualQuantizer(
359
+ get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
360
+ } else {
361
+ index_ivf = new IndexIVFProductLocalSearchQuantizer(
362
+ get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
363
+ }
364
+ return index_ivf;
365
+ }
366
+ if (match("(RQ|LSQ)([0-9]+)x4fs(r?)(_[0-9]+)?" + aq_norm_pattern)) {
367
+ int M = std::stoi(sm[2].str());
368
+ int bbs = mres_to_int(sm[4], 32, 1);
369
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
370
+ IndexIVFAdditiveQuantizerFastScan* index_ivf;
371
+ if (sm[1].str() == "RQ") {
372
+ index_ivf = new IndexIVFResidualQuantizerFastScan(
373
+ get_q(), d, nlist, M, 4, mt, st, bbs);
455
374
  } else {
456
- FAISS_THROW_FMT(
457
- "could not parse token \"%s\" in %s\n",
458
- tok,
459
- description_in);
375
+ index_ivf = new IndexIVFLocalSearchQuantizerFastScan(
376
+ get_q(), d, nlist, M, 4, mt, st, bbs);
460
377
  }
378
+ index_ivf->by_residual = (sm[3].str() == "r");
379
+ return index_ivf;
380
+ }
381
+ if (match("(PRQ|PLSQ)([0-9]+)x([0-9]+)x4fs(r?)(_[0-9]+)?" +
382
+ aq_norm_pattern)) {
383
+ int nsplits = std::stoi(sm[2].str());
384
+ int Msub = std::stoi(sm[3].str());
385
+ int bbs = mres_to_int(sm[5], 32, 1);
386
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
387
+ IndexIVFAdditiveQuantizerFastScan* index_ivf;
388
+ if (sm[1].str() == "PRQ") {
389
+ index_ivf = new IndexIVFProductResidualQuantizerFastScan(
390
+ get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
391
+ } else {
392
+ index_ivf = new IndexIVFProductLocalSearchQuantizerFastScan(
393
+ get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
394
+ }
395
+ index_ivf->by_residual = (sm[4].str() == "r");
396
+ return index_ivf;
397
+ }
398
+ if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
399
+ int outdim = mres_to_int(sm[2], d); // is also the number of bits
400
+ std::unique_ptr<VectorTransform> vt;
401
+ if (sm[1] == "ITQ") {
402
+ vt.reset(new ITQTransform(d, outdim, d != outdim));
403
+ } else if (sm[1] == "PCA") {
404
+ vt.reset(new PCAMatrix(d, outdim));
405
+ } else if (sm[1] == "PCAR") {
406
+ vt.reset(new PCAMatrix(d, outdim, 0, true));
407
+ }
408
+ // the rationale for -1e10 is that this corresponds to simple
409
+ // thresholding
410
+ float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
411
+ IndexIVFSpectralHash* index_ivf =
412
+ new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
413
+ index_ivf->replace_vt(vt.release(), true);
414
+ if (sm[4].length()) {
415
+ std::string s = sm[4].str();
416
+ index_ivf->threshold_type = s == "g"
417
+ ? IndexIVFSpectralHash::Thresh_global
418
+ : s == "c"
419
+ ? IndexIVFSpectralHash::Thresh_centroid
420
+ :
421
+ /* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
422
+ }
423
+ return index_ivf;
424
+ }
425
+ return nullptr;
426
+ }
427
+
428
+ /***************************************************************
429
+ * Parse IndexHNSW
430
+ */
431
+
432
+ IndexHNSW* parse_IndexHNSW(
433
+ const std::string code_string,
434
+ int d,
435
+ MetricType mt,
436
+ int hnsw_M) {
437
+ std::smatch sm;
438
+ auto match = [&sm, &code_string](const std::string& pattern) {
439
+ return re_match(code_string, pattern, sm);
440
+ };
441
+
442
+ if (match("Flat|")) {
443
+ return new IndexHNSWFlat(d, hnsw_M, mt);
444
+ }
445
+ if (match("PQ([0-9]+)(np)?")) {
446
+ int M = std::stoi(sm[1].str());
447
+ IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
448
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
449
+ sm[2].str() != "np";
450
+ return ipq;
451
+ }
452
+ if (match(sq_pattern)) {
453
+ return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
454
+ }
455
+ if (match("([0-9]+)\\+PQ([0-9]+)?")) {
456
+ int ncent = mres_to_int(sm[1]);
457
+ int pq_m = mres_to_int(sm[2]);
458
+ IndexHNSW2Level* hidx2l =
459
+ new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
460
+ dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
461
+ return hidx2l;
462
+ }
463
+ if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
464
+ int nbit = mres_to_int(sm[1]);
465
+ int pq_m = mres_to_int(sm[2]);
466
+ Index* quant = new MultiIndexQuantizer(d, 2, nbit);
467
+ IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
468
+ quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
469
+ Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
470
+ idx2l->q1.own_fields = true;
471
+ idx2l->q1.quantizer_trains_alone = 1;
472
+ return hidx2l;
473
+ }
474
+
475
+ return nullptr;
476
+ }
477
+
478
+ /***************************************************************
479
+ * Parse IndexNSG
480
+ */
481
+
482
+ IndexNSG* parse_IndexNSG(
483
+ const std::string code_string,
484
+ int d,
485
+ MetricType mt,
486
+ int nsg_R) {
487
+ std::smatch sm;
488
+ auto match = [&sm, &code_string](const std::string& pattern) {
489
+ return re_match(code_string, pattern, sm);
490
+ };
491
+
492
+ if (match("Flat|")) {
493
+ return new IndexNSGFlat(d, nsg_R, mt);
494
+ }
495
+ if (match("PQ([0-9]+)(np)?")) {
496
+ int M = std::stoi(sm[1].str());
497
+ IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R);
498
+ dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
499
+ sm[2].str() != "np";
500
+ return ipq;
501
+ }
502
+ if (match(sq_pattern)) {
503
+ return new IndexNSGSQ(d, sq_types[sm[1].str()], nsg_R, mt);
504
+ }
505
+
506
+ return nullptr;
507
+ }
461
508
 
462
- if (index_1 && add_idmap) {
463
- IndexIDMap* idmap = new IndexIDMap(index_1);
464
- del_index.set(idmap);
465
- idmap->own_fields = true;
466
- index_1 = idmap;
467
- add_idmap = false;
509
+ /***************************************************************
510
+ * Parse basic indexes
511
+ */
512
+
513
+ Index* parse_other_indexes(
514
+ const std::string& description,
515
+ int d,
516
+ MetricType metric) {
517
+ std::smatch sm;
518
+ auto match = [&sm, description](const std::string& pattern) {
519
+ return re_match(description, pattern, sm);
520
+ };
521
+
522
+ // IndexFlat
523
+ if (description == "Flat") {
524
+ return new IndexFlat(d, metric);
525
+ }
526
+
527
+ // IndexLSH
528
+ if (match("LSH(r?)(t?)")) {
529
+ bool rotate_data = sm[1].length() > 0;
530
+ bool train_thresholds = sm[2].length() > 0;
531
+ FAISS_THROW_IF_NOT(metric == METRIC_L2);
532
+ return new IndexLSH(d, d, rotate_data, train_thresholds);
533
+ }
534
+
535
+ // IndexLattice
536
+ if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
537
+ int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
538
+ int nbit = std::stoi(sm[3].str());
539
+ return new IndexLattice(d, M, nbit, r2);
540
+ }
541
+
542
+ // IndexScalarQuantizer
543
+ if (match(sq_pattern)) {
544
+ return new IndexScalarQuantizer(d, sq_types[description], metric);
545
+ }
546
+
547
+ // IndexPQ
548
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
549
+ int M = std::stoi(sm[1].str());
550
+ int nbit = mres_to_int(sm[2], 8, 1);
551
+ IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
552
+ index_pq->do_polysemous_training = sm[3].str() != "np";
553
+ return index_pq;
554
+ }
555
+
556
+ // IndexPQFastScan
557
+ if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
558
+ int M = std::stoi(sm[1].str());
559
+ int bbs = mres_to_int(sm[2], 32, 1);
560
+ return new IndexPQFastScan(d, M, 4, metric, bbs);
561
+ }
562
+
563
+ // IndexResidualCoarseQuantizer and IndexResidualQuantizer
564
+ std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
565
+ if (match(pattern)) {
566
+ std::vector<size_t> nbits = aq_parse_nbits(description);
567
+ if (sm[1].str() == "RCQ") {
568
+ return new ResidualCoarseQuantizer(d, nbits, metric);
468
569
  }
570
+ AdditiveQuantizer::Search_type_t st =
571
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
572
+ return new IndexResidualQuantizer(d, nbits, metric, st);
573
+ }
469
574
 
470
- if (vt_1) {
471
- vts.chain.push_back(vt_1);
575
+ // LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
576
+ if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
577
+ std::vector<size_t> nbits = aq_parse_nbits(description);
578
+ int M = mres_to_int(sm[2]);
579
+ int nbit = mres_to_int(sm[3]);
580
+ if (sm[1].str() == "LSCQ") {
581
+ return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
472
582
  }
583
+ AdditiveQuantizer::Search_type_t st =
584
+ aq_parse_search_type(sm[sm.size() - 1].str(), metric);
585
+ return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
586
+ }
587
+
588
+ // IndexProductResidualQuantizer
589
+ if (match("PRQ" + paq_def_pattern + aq_norm_pattern)) {
590
+ int nsplits = mres_to_int(sm[1]);
591
+ int Msub = mres_to_int(sm[2]);
592
+ int nbit = mres_to_int(sm[3]);
593
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
594
+ return new IndexProductResidualQuantizer(
595
+ d, nsplits, Msub, nbit, metric, st);
596
+ }
473
597
 
474
- if (coarse_quantizer_1) {
475
- coarse_quantizer = coarse_quantizer_1;
476
- del_coarse_quantizer.set(coarse_quantizer);
598
+ // IndexProductLocalSearchQuantizer
599
+ if (match("PLSQ" + paq_def_pattern + aq_norm_pattern)) {
600
+ int nsplits = mres_to_int(sm[1]);
601
+ int Msub = mres_to_int(sm[2]);
602
+ int nbit = mres_to_int(sm[3]);
603
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
604
+ return new IndexProductLocalSearchQuantizer(
605
+ d, nsplits, Msub, nbit, metric, st);
606
+ }
607
+
608
+ // IndexAdditiveQuantizerFastScan
609
+ // RQ{M}x4fs_{bbs}_{search_type}
610
+ pattern = "(LSQ|RQ)([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
611
+ if (match(pattern)) {
612
+ int M = std::stoi(sm[2].str());
613
+ int bbs = mres_to_int(sm[3], 32, 1);
614
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
615
+
616
+ if (sm[1].str() == "RQ") {
617
+ return new IndexResidualQuantizerFastScan(d, M, 4, metric, st, bbs);
618
+ } else if (sm[1].str() == "LSQ") {
619
+ return new IndexLocalSearchQuantizerFastScan(
620
+ d, M, 4, metric, st, bbs);
477
621
  }
622
+ }
478
623
 
479
- if (index_1) {
480
- index = index_1;
481
- del_index.set(index);
624
+ // IndexProductAdditiveQuantizerFastScan
625
+ // PRQ{nsplits}x{Msub}x4fs_{bbs}_{search_type}
626
+ pattern = "(PLSQ|PRQ)([0-9]+)x([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
627
+ if (match(pattern)) {
628
+ int nsplits = std::stoi(sm[2].str());
629
+ int Msub = std::stoi(sm[3].str());
630
+ int bbs = mres_to_int(sm[4], 32, 1);
631
+ auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
632
+
633
+ if (sm[1].str() == "PRQ") {
634
+ return new IndexProductResidualQuantizerFastScan(
635
+ d, nsplits, Msub, 4, metric, st, bbs);
636
+ } else if (sm[1].str() == "PLSQ") {
637
+ return new IndexProductLocalSearchQuantizerFastScan(
638
+ d, nsplits, Msub, 4, metric, st, bbs);
482
639
  }
483
640
  }
484
641
 
485
- if (!index && hnsw_M > 0) {
486
- index = new IndexHNSWFlat(d, hnsw_M, metric);
487
- del_index.set(index);
488
- } else if (!index && nsg_R > 0) {
489
- index = new IndexNSGFlat(d, nsg_R, metric);
490
- del_index.set(index);
642
+ return nullptr;
643
+ }
644
+
645
+ /***************************************************************
646
+ * Driver function
647
+ */
648
+ std::unique_ptr<Index> index_factory_sub(
649
+ int d,
650
+ std::string description,
651
+ MetricType metric) {
652
+ // handle composite indexes
653
+
654
+ bool verbose = index_factory_verbose;
655
+
656
+ if (verbose) {
657
+ printf("begin parse VectorTransforms: %s \n", description.c_str());
491
658
  }
492
659
 
493
- FAISS_THROW_IF_NOT_FMT(
494
- index, "description %s did not generate an index", description_in);
660
+ // for the current match
661
+ std::smatch sm;
495
662
 
496
- // nothing can go wrong now
497
- del_index.release();
498
- del_coarse_quantizer.release();
663
+ // handle refines
664
+ if (re_match(description, "(.+),RFlat", sm) ||
665
+ re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
666
+ std::unique_ptr<Index> filter_index =
667
+ index_factory_sub(d, sm[1].str(), metric);
668
+ std::unique_ptr<Index> refine_index;
499
669
 
500
- if (add_idmap) {
501
- fprintf(stderr,
502
- "index_factory: WARNING: "
503
- "IDMap option not used\n");
670
+ if (sm.size() == 3) { // Refine
671
+ refine_index = index_factory_sub(d, sm[2].str(), metric);
672
+ } else { // RFlat
673
+ refine_index.reset(new IndexFlat(d, metric));
674
+ }
675
+ IndexRefine* index_rf =
676
+ new IndexRefine(filter_index.get(), refine_index.get());
677
+ index_rf->own_fields = true;
678
+ filter_index.release();
679
+ refine_index.release();
680
+ index_rf->own_refine_index = true;
681
+ return std::unique_ptr<Index>(index_rf);
504
682
  }
505
683
 
506
- if (vts.chain.size() > 0) {
507
- IndexPreTransform* index_pt = new IndexPreTransform(index);
684
+ // IndexPreTransform
685
+ // should handle this first (even before parentheses) because it changes d
686
+ std::vector<std::unique_ptr<VectorTransform>> vts;
687
+ VectorTransform* vt = nullptr;
688
+ while (re_match(description, "([^,]+),(.*)", sm) &&
689
+ (vt = parse_VectorTransform(sm[1], d))) {
690
+ // reset loop
691
+ description = sm[sm.size() - 1];
692
+ vts.emplace_back(vt);
693
+ d = vts.back()->d_out;
694
+ }
695
+
696
+ if (vts.size() > 0) {
697
+ std::unique_ptr<Index> sub_index =
698
+ index_factory_sub(d, description, metric);
699
+ IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
700
+ std::unique_ptr<Index> ret(index_pt);
508
701
  index_pt->own_fields = true;
509
- // add from back
510
- while (vts.chain.size() > 0) {
511
- index_pt->prepend_transform(vts.chain.back());
512
- vts.chain.pop_back();
702
+ sub_index.release();
703
+ while (vts.size() > 0) {
704
+ if (verbose) {
705
+ printf("prepend trans %d -> %d\n",
706
+ vts.back()->d_in,
707
+ vts.back()->d_out);
708
+ }
709
+ index_pt->prepend_transform(vts.back().release());
710
+ vts.pop_back();
513
711
  }
514
- index = index_pt;
712
+ return ret;
515
713
  }
516
714
 
517
- if (!parenthesis_refine.empty()) {
518
- Index* refine_index =
519
- index_factory(d_in, parenthesis_refine.c_str(), metric);
520
- IndexRefine* index_rf = new IndexRefine(index, refine_index);
521
- index_rf->own_refine_index = true;
522
- index_rf->own_fields = true;
523
- index = index_rf;
715
+ // what we got from the parentheses
716
+ std::vector<std::unique_ptr<Index>> parenthesis_indexes;
717
+
718
+ int begin = 0;
719
+ while (description.find('(', begin) != std::string::npos) {
720
+ // replace indexes in () with Index0, Index1, etc.
721
+ int i0, i1;
722
+ find_matching_parentheses(description, i0, i1, begin);
723
+ std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
724
+ int no = parenthesis_indexes.size();
725
+ parenthesis_indexes.push_back(
726
+ index_factory_sub(d, sub_description, metric));
727
+ description = description.substr(0, i0 + 1) + "Index" +
728
+ std::to_string(no) + description.substr(i1);
729
+ begin = i1 + 1;
524
730
  }
525
731
 
526
- return index;
732
+ if (verbose) {
733
+ printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
734
+ description.c_str(),
735
+ parenthesis_indexes.size(),
736
+ d);
737
+ }
738
+
739
+ // IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
740
+ // support both
741
+ if (re_match(description, "(.+),IDMap2", sm) ||
742
+ re_match(description, "IDMap2,(.+)", sm)) {
743
+ IndexIDMap2* idmap2 = new IndexIDMap2(
744
+ index_factory_sub(d, sm[1].str(), metric).release());
745
+ idmap2->own_fields = true;
746
+ return std::unique_ptr<Index>(idmap2);
747
+ }
748
+
749
+ if (re_match(description, "(.+),IDMap", sm) ||
750
+ re_match(description, "IDMap,(.+)", sm)) {
751
+ IndexIDMap* idmap = new IndexIDMap(
752
+ index_factory_sub(d, sm[1].str(), metric).release());
753
+ idmap->own_fields = true;
754
+ return std::unique_ptr<Index>(idmap);
755
+ }
756
+
757
+ { // handle basic index types
758
+ Index* index = parse_other_indexes(description, d, metric);
759
+ if (index) {
760
+ return std::unique_ptr<Index>(index);
761
+ }
762
+ }
763
+
764
+ // HNSW variants (it was unclear in the old version that the separator was a
765
+ // "," so we support both "_" and ",")
766
+ if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
767
+ int hnsw_M = mres_to_int(sm[1], 32);
768
+ // We also accept empty code string (synonym of Flat)
769
+ std::string code_string =
770
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
771
+ if (verbose) {
772
+ printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
773
+ description.c_str(),
774
+ code_string.c_str(),
775
+ hnsw_M);
776
+ }
777
+
778
+ IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
779
+ FAISS_THROW_IF_NOT_FMT(
780
+ index,
781
+ "could not parse HNSW code description %s in %s",
782
+ code_string.c_str(),
783
+ description.c_str());
784
+ return std::unique_ptr<Index>(index);
785
+ }
786
+
787
+ // NSG variants (it was unclear in the old version that the separator was a
788
+ // "," so we support both "_" and ",")
789
+ if (re_match(description, "NSG([0-9]*)([,_].*)?", sm)) {
790
+ int nsg_R = mres_to_int(sm[1], 32);
791
+ // We also accept empty code string (synonym of Flat)
792
+ std::string code_string =
793
+ sm[2].length() > 0 ? sm[2].str().substr(1) : "";
794
+ if (verbose) {
795
+ printf("parsing NSG string %s code_string=%s nsg_R=%d\n",
796
+ description.c_str(),
797
+ code_string.c_str(),
798
+ nsg_R);
799
+ }
800
+
801
+ IndexNSG* index = parse_IndexNSG(code_string, d, metric, nsg_R);
802
+ FAISS_THROW_IF_NOT_FMT(
803
+ index,
804
+ "could not parse NSG code description %s in %s",
805
+ code_string.c_str(),
806
+ description.c_str());
807
+ return std::unique_ptr<Index>(index);
808
+ }
809
+
810
+ // IndexRowwiseMinMax, fp32 version
811
+ if (description.compare(0, 7, "MinMax,") == 0) {
812
+ size_t comma = description.find(",");
813
+ std::string sub_index_string = description.substr(comma + 1);
814
+ auto sub_index = index_factory_sub(d, sub_index_string, metric);
815
+
816
+ auto index = new IndexRowwiseMinMax(sub_index.release());
817
+ index->own_fields = true;
818
+
819
+ return std::unique_ptr<Index>(index);
820
+ }
821
+
822
+ // IndexRowwiseMinMax, fp16 version
823
+ if (description.compare(0, 11, "MinMaxFP16,") == 0) {
824
+ size_t comma = description.find(",");
825
+ std::string sub_index_string = description.substr(comma + 1);
826
+ auto sub_index = index_factory_sub(d, sub_index_string, metric);
827
+
828
+ auto index = new IndexRowwiseMinMaxFP16(sub_index.release());
829
+ index->own_fields = true;
830
+
831
+ return std::unique_ptr<Index>(index);
832
+ }
833
+
834
+ // IndexIVF
835
+ {
836
+ size_t nlist;
837
+ bool use_2layer;
838
+ size_t comma = description.find(",");
839
+ std::string coarse_string = description.substr(0, comma);
840
+ // Match coarse quantizer part first
841
+ std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
842
+ description.substr(0, comma),
843
+ d,
844
+ metric,
845
+ parenthesis_indexes,
846
+ nlist,
847
+ use_2layer));
848
+
849
+ if (comma != std::string::npos && quantizer.get()) {
850
+ std::string code_description = description.substr(comma + 1);
851
+ if (use_2layer) {
852
+ bool ok =
853
+ re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
854
+ FAISS_THROW_IF_NOT_FMT(
855
+ ok,
856
+ "could not parse 2 layer code description %s in %s",
857
+ code_description.c_str(),
858
+ description.c_str());
859
+ int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
860
+ Index2Layer* index_2l =
861
+ new Index2Layer(quantizer.release(), nlist, M, nbit);
862
+ index_2l->q1.own_fields = true;
863
+ index_2l->q1.quantizer_trains_alone =
864
+ get_trains_alone(index_2l->q1.quantizer);
865
+ return std::unique_ptr<Index>(index_2l);
866
+ }
867
+
868
+ IndexIVF* index_ivf =
869
+ parse_IndexIVF(code_description, quantizer, nlist, metric);
870
+
871
+ FAISS_THROW_IF_NOT_FMT(
872
+ index_ivf,
873
+ "could not parse code description %s in %s",
874
+ code_description.c_str(),
875
+ description.c_str());
876
+ return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
877
+ }
878
+ }
879
+ FAISS_THROW_FMT("could not parse index string %s", description.c_str());
880
+ return nullptr;
881
+ }
882
+
883
+ } // anonymous namespace
884
+
885
+ Index* index_factory(int d, const char* description, MetricType metric) {
886
+ return index_factory_sub(d, description, metric).release();
527
887
  }
528
888
 
529
889
  IndexBinary* index_binary_factory(int d, const char* description) {