faiss 0.2.3 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -5,19 +5,19 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
/*
|
11
|
-
* implementation of
|
9
|
+
* implementation of the index_factory function. Lots of regex parsing code.
|
12
10
|
*/
|
13
11
|
|
14
12
|
#include <faiss/index_factory.h>
|
15
|
-
|
16
|
-
#include
|
13
|
+
#include "faiss/MetricType.h"
|
14
|
+
#include "faiss/impl/FaissAssert.h"
|
17
15
|
|
18
16
|
#include <cinttypes>
|
19
17
|
#include <cmath>
|
20
18
|
|
19
|
+
#include <map>
|
20
|
+
|
21
21
|
#include <regex>
|
22
22
|
|
23
23
|
#include <faiss/impl/FaissAssert.h>
|
@@ -25,13 +25,18 @@
|
|
25
25
|
#include <faiss/utils/utils.h>
|
26
26
|
|
27
27
|
#include <faiss/Index2Layer.h>
|
28
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
29
|
+
#include <faiss/IndexAdditiveQuantizerFastScan.h>
|
28
30
|
#include <faiss/IndexFlat.h>
|
29
31
|
#include <faiss/IndexHNSW.h>
|
30
32
|
#include <faiss/IndexIVF.h>
|
33
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
34
|
+
#include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
|
31
35
|
#include <faiss/IndexIVFFlat.h>
|
32
36
|
#include <faiss/IndexIVFPQ.h>
|
33
37
|
#include <faiss/IndexIVFPQFastScan.h>
|
34
38
|
#include <faiss/IndexIVFPQR.h>
|
39
|
+
#include <faiss/IndexIVFSpectralHash.h>
|
35
40
|
#include <faiss/IndexLSH.h>
|
36
41
|
#include <faiss/IndexLattice.h>
|
37
42
|
#include <faiss/IndexNSG.h>
|
@@ -39,7 +44,7 @@
|
|
39
44
|
#include <faiss/IndexPQFastScan.h>
|
40
45
|
#include <faiss/IndexPreTransform.h>
|
41
46
|
#include <faiss/IndexRefine.h>
|
42
|
-
#include <faiss/
|
47
|
+
#include <faiss/IndexRowwiseMinMax.h>
|
43
48
|
#include <faiss/IndexScalarQuantizer.h>
|
44
49
|
#include <faiss/MetaIndexes.h>
|
45
50
|
#include <faiss/VectorTransform.h>
|
@@ -48,6 +53,7 @@
|
|
48
53
|
#include <faiss/IndexBinaryHNSW.h>
|
49
54
|
#include <faiss/IndexBinaryHash.h>
|
50
55
|
#include <faiss/IndexBinaryIVF.h>
|
56
|
+
#include <string>
|
51
57
|
|
52
58
|
namespace faiss {
|
53
59
|
|
@@ -55,16 +61,49 @@ namespace faiss {
|
|
55
61
|
* index_factory
|
56
62
|
***************************************************************/
|
57
63
|
|
64
|
+
int index_factory_verbose = 0;
|
65
|
+
|
58
66
|
namespace {
|
59
67
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
68
|
+
/***************************************************************
|
69
|
+
* Small functions
|
70
|
+
*/
|
71
|
+
|
72
|
+
bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
|
73
|
+
return std::regex_match(s, sm, std::regex(pat));
|
74
|
+
}
|
75
|
+
|
76
|
+
// find first pair of matching parentheses
|
77
|
+
void find_matching_parentheses(
|
78
|
+
const std::string& s,
|
79
|
+
int& i0,
|
80
|
+
int& i1,
|
81
|
+
int begin = 0) {
|
82
|
+
int st = 0;
|
83
|
+
i0 = i1 = 0;
|
84
|
+
for (int i = begin; i < s.length(); i++) {
|
85
|
+
if (s[i] == '(') {
|
86
|
+
if (st == 0) {
|
87
|
+
i0 = i;
|
88
|
+
}
|
89
|
+
st++;
|
90
|
+
}
|
91
|
+
|
92
|
+
if (s[i] == ')') {
|
93
|
+
st--;
|
94
|
+
if (st == 0) {
|
95
|
+
i1 = i;
|
96
|
+
return;
|
97
|
+
}
|
98
|
+
if (st < 0) {
|
99
|
+
FAISS_THROW_FMT(
|
100
|
+
"factory string %s: unbalanced parentheses", s.c_str());
|
101
|
+
}
|
65
102
|
}
|
66
103
|
}
|
67
|
-
|
104
|
+
FAISS_THROW_FMT(
|
105
|
+
"factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
|
106
|
+
}
|
68
107
|
|
69
108
|
/// what kind of training does this coarse quantizer require?
|
70
109
|
char get_trains_alone(const Index* coarse_quantizer) {
|
@@ -83,447 +122,768 @@ char get_trains_alone(const Index* coarse_quantizer) {
|
|
83
122
|
// kmeans index
|
84
123
|
}
|
85
124
|
|
86
|
-
|
87
|
-
|
125
|
+
// set the fields for factory-constructed IVF structures
|
126
|
+
IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
|
127
|
+
index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
|
128
|
+
index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
|
129
|
+
index_ivf->own_fields = true;
|
130
|
+
return index_ivf;
|
88
131
|
}
|
89
132
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
for (i = s.length() - 1; i >= 0; i--) {
|
94
|
-
if (!isdigit(s[i]))
|
95
|
-
break;
|
133
|
+
int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
|
134
|
+
if (mr.length() == 0) {
|
135
|
+
return deflt;
|
96
136
|
}
|
97
|
-
return
|
137
|
+
return std::stoi(mr.str().substr(begin));
|
98
138
|
}
|
99
139
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
}
|
140
|
+
std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
|
141
|
+
{"SQ8", ScalarQuantizer::QT_8bit},
|
142
|
+
{"SQ4", ScalarQuantizer::QT_4bit},
|
143
|
+
{"SQ6", ScalarQuantizer::QT_6bit},
|
144
|
+
{"SQfp16", ScalarQuantizer::QT_fp16},
|
145
|
+
};
|
146
|
+
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
|
147
|
+
|
148
|
+
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
|
149
|
+
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
|
150
|
+
{"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
|
151
|
+
{"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
|
152
|
+
{"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
|
153
|
+
{"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
|
154
|
+
{"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
|
155
|
+
{"_Nlsq2x4", AdditiveQuantizer::ST_norm_lsq2x4},
|
156
|
+
{"_Nrq2x4", AdditiveQuantizer::ST_norm_rq2x4},
|
157
|
+
};
|
110
158
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
159
|
+
const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
|
160
|
+
const std::string aq_norm_pattern =
|
161
|
+
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4|_Nlsq2x4|_Nrq2x4)";
|
162
|
+
|
163
|
+
const std::string paq_def_pattern = "([0-9]+)x([0-9]+)x([0-9]+)";
|
164
|
+
|
165
|
+
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
166
|
+
std::string stok,
|
167
|
+
MetricType metric) {
|
168
|
+
if (stok == "") {
|
169
|
+
return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
|
170
|
+
: AdditiveQuantizer::ST_LUT_nonorm;
|
122
171
|
}
|
123
|
-
|
124
|
-
|
172
|
+
int pos = stok.rfind("_");
|
173
|
+
return aq_search_type[stok.substr(pos)];
|
125
174
|
}
|
126
175
|
|
127
|
-
|
176
|
+
std::vector<size_t> aq_parse_nbits(std::string stok) {
|
177
|
+
std::vector<size_t> nbits;
|
178
|
+
std::smatch sm;
|
179
|
+
while (std::regex_search(stok, sm, std::regex("[^q]([0-9]+)x([0-9]+)"))) {
|
180
|
+
int M = std::stoi(sm[1].str());
|
181
|
+
int nbit = std::stoi(sm[2].str());
|
182
|
+
nbits.resize(nbits.size() + M, nbit);
|
183
|
+
stok = sm.suffix();
|
184
|
+
}
|
185
|
+
return nbits;
|
186
|
+
}
|
128
187
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
Index* coarse_quantizer = nullptr;
|
133
|
-
std::string parenthesis_ivf, parenthesis_refine;
|
134
|
-
Index* index = nullptr;
|
135
|
-
bool add_idmap = false;
|
136
|
-
int d_in = d;
|
188
|
+
/***************************************************************
|
189
|
+
* Parse VectorTransform
|
190
|
+
*/
|
137
191
|
|
138
|
-
|
192
|
+
VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
193
|
+
std::smatch sm;
|
194
|
+
auto match = [&sm, description](std::string pattern) {
|
195
|
+
return re_match(description, pattern, sm);
|
196
|
+
};
|
197
|
+
if (match("PCA(W?)(R?)([0-9]+)")) {
|
198
|
+
bool white = sm[1].length() > 0;
|
199
|
+
bool rot = sm[2].length() > 0;
|
200
|
+
return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
|
201
|
+
}
|
202
|
+
if (match("L2[nN]orm")) {
|
203
|
+
return new NormalizationTransform(d, 2.0);
|
204
|
+
}
|
205
|
+
if (match("RR([0-9]+)?")) {
|
206
|
+
return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
|
207
|
+
}
|
208
|
+
if (match("ITQ([0-9]+)?")) {
|
209
|
+
return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
|
210
|
+
}
|
211
|
+
if (match("OPQ([0-9]+)(_[0-9]+)?")) {
|
212
|
+
int M = std::stoi(sm[1].str());
|
213
|
+
int d_out = mres_to_int(sm[2], d, 1);
|
214
|
+
return new OPQMatrix(d, M, d_out);
|
215
|
+
}
|
216
|
+
if (match("Pad([0-9]+)")) {
|
217
|
+
int d_out = std::stoi(sm[1].str());
|
218
|
+
return new RemapDimensionsTransform(d, std::max(d_out, d), false);
|
219
|
+
}
|
220
|
+
return nullptr;
|
221
|
+
};
|
139
222
|
|
140
|
-
|
141
|
-
|
223
|
+
/***************************************************************
|
224
|
+
* Parse IndexIVF
|
225
|
+
*/
|
142
226
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
int
|
147
|
-
|
227
|
+
// parsing guard + function
|
228
|
+
Index* parse_coarse_quantizer(
|
229
|
+
const std::string& description,
|
230
|
+
int d,
|
231
|
+
MetricType mt,
|
232
|
+
std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
|
233
|
+
size_t& nlist,
|
234
|
+
bool& use_2layer) {
|
235
|
+
std::smatch sm;
|
236
|
+
auto match = [&sm, description](std::string pattern) {
|
237
|
+
return re_match(description, pattern, sm);
|
238
|
+
};
|
239
|
+
use_2layer = false;
|
240
|
+
|
241
|
+
if (match("IVF([0-9]+)")) {
|
242
|
+
nlist = std::stoi(sm[1].str());
|
243
|
+
return new IndexFlat(d, mt);
|
244
|
+
}
|
245
|
+
if (match("IMI2x([0-9]+)")) {
|
246
|
+
int nbit = std::stoi(sm[1].str());
|
247
|
+
FAISS_THROW_IF_NOT_MSG(
|
248
|
+
mt == METRIC_L2,
|
249
|
+
"MultiIndex not implemented for inner prod search");
|
250
|
+
nlist = (size_t)1 << (2 * nbit);
|
251
|
+
return new MultiIndexQuantizer(d, 2, nbit);
|
252
|
+
}
|
253
|
+
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
|
254
|
+
nlist = std::stoi(sm[1].str());
|
255
|
+
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
|
256
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
257
|
+
}
|
258
|
+
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
|
259
|
+
nlist = std::stoi(sm[1].str());
|
260
|
+
int R = std::stoi(sm[2]);
|
261
|
+
return new IndexNSGFlat(d, R, mt);
|
262
|
+
}
|
263
|
+
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
|
264
|
+
nlist = std::stoi(sm[1].str());
|
265
|
+
int no = std::stoi(sm[2].str());
|
266
|
+
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
|
267
|
+
return parenthesis_indexes[no].release();
|
268
|
+
}
|
148
269
|
|
149
|
-
|
270
|
+
// these two generate Index2Layer's not IndexIVF's
|
271
|
+
if (match("Residual([0-9]+)x([0-9]+)")) {
|
272
|
+
FAISS_THROW_IF_NOT_MSG(
|
273
|
+
mt == METRIC_L2,
|
274
|
+
"MultiIndex not implemented for inner prod search");
|
275
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
|
276
|
+
nlist = (size_t)1 << (M * nbit);
|
277
|
+
use_2layer = true;
|
278
|
+
return new MultiIndexQuantizer(d, M, nbit);
|
279
|
+
}
|
280
|
+
if (match("Residual([0-9]+)")) {
|
281
|
+
FAISS_THROW_IF_NOT_MSG(
|
282
|
+
mt == METRIC_L2,
|
283
|
+
"Residual not implemented for inner prod search");
|
284
|
+
use_2layer = true;
|
285
|
+
nlist = mres_to_int(sm[1]);
|
286
|
+
return new IndexFlatL2(d);
|
287
|
+
}
|
288
|
+
return nullptr;
|
289
|
+
}
|
150
290
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
291
|
+
// parse the code description part of the IVF description
|
292
|
+
|
293
|
+
IndexIVF* parse_IndexIVF(
|
294
|
+
const std::string& code_string,
|
295
|
+
std::unique_ptr<Index>& quantizer,
|
296
|
+
size_t nlist,
|
297
|
+
MetricType mt) {
|
298
|
+
std::smatch sm;
|
299
|
+
auto match = [&sm, &code_string](const std::string pattern) {
|
300
|
+
return re_match(code_string, pattern, sm);
|
301
|
+
};
|
302
|
+
auto get_q = [&quantizer] { return quantizer.release(); };
|
303
|
+
int d = quantizer->d;
|
304
|
+
|
305
|
+
if (match("Flat")) {
|
306
|
+
return new IndexIVFFlat(get_q(), d, nlist, mt);
|
307
|
+
}
|
308
|
+
if (match("FlatDedup")) {
|
309
|
+
return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
|
310
|
+
}
|
311
|
+
if (match(sq_pattern)) {
|
312
|
+
return new IndexIVFScalarQuantizer(
|
313
|
+
get_q(), d, nlist, sq_types[sm[1].str()], mt);
|
314
|
+
}
|
315
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
316
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
|
317
|
+
IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
|
318
|
+
index_ivf->do_polysemous_training = sm[3].str() != "np";
|
319
|
+
return index_ivf;
|
320
|
+
}
|
321
|
+
if (match("PQ([0-9]+)\\+([0-9]+)")) {
|
322
|
+
FAISS_THROW_IF_NOT_MSG(
|
323
|
+
mt == METRIC_L2,
|
324
|
+
"IVFPQR not implemented for inner product search");
|
325
|
+
int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
|
326
|
+
return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
|
327
|
+
}
|
328
|
+
if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
|
329
|
+
int M = mres_to_int(sm[1]);
|
330
|
+
int bbs = mres_to_int(sm[3], 32, 1);
|
331
|
+
IndexIVFPQFastScan* index_ivf =
|
332
|
+
new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
|
333
|
+
index_ivf->by_residual = sm[2].str() == "r";
|
334
|
+
return index_ivf;
|
335
|
+
}
|
336
|
+
if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
|
337
|
+
std::vector<size_t> nbits = aq_parse_nbits(sm.str());
|
338
|
+
AdditiveQuantizer::Search_type_t st =
|
339
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
340
|
+
IndexIVF* index_ivf;
|
341
|
+
if (sm[1].str() == "RQ") {
|
342
|
+
index_ivf = new IndexIVFResidualQuantizer(
|
343
|
+
get_q(), d, nlist, nbits, mt, st);
|
155
344
|
} else {
|
156
|
-
|
345
|
+
FAISS_THROW_IF_NOT(nbits.size() > 0);
|
346
|
+
index_ivf = new IndexIVFLocalSearchQuantizer(
|
347
|
+
get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
|
157
348
|
}
|
158
|
-
|
159
|
-
}
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
vt_1 = new PCAMatrix(d, d_out);
|
184
|
-
d = d_out;
|
185
|
-
} else if (sscanf(tok, "PCAR%d", &d_out) == 1) {
|
186
|
-
vt_1 = new PCAMatrix(d, d_out, 0, true);
|
187
|
-
d = d_out;
|
188
|
-
} else if (sscanf(tok, "RR%d", &d_out) == 1) {
|
189
|
-
vt_1 = new RandomRotationMatrix(d, d_out);
|
190
|
-
d = d_out;
|
191
|
-
} else if (sscanf(tok, "PCAW%d", &d_out) == 1) {
|
192
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, false);
|
193
|
-
d = d_out;
|
194
|
-
} else if (sscanf(tok, "PCAWR%d", &d_out) == 1) {
|
195
|
-
vt_1 = new PCAMatrix(d, d_out, -0.5, true);
|
196
|
-
d = d_out;
|
197
|
-
} else if (sscanf(tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
|
198
|
-
vt_1 = new OPQMatrix(d, opq_M, d_out);
|
199
|
-
d = d_out;
|
200
|
-
} else if (sscanf(tok, "OPQ%d", &opq_M) == 1) {
|
201
|
-
vt_1 = new OPQMatrix(d, opq_M);
|
202
|
-
} else if (sscanf(tok, "ITQ%d", &d_out) == 1) {
|
203
|
-
vt_1 = new ITQTransform(d, d_out, true);
|
204
|
-
d = d_out;
|
205
|
-
} else if (stok == "ITQ") {
|
206
|
-
vt_1 = new ITQTransform(d, d, false);
|
207
|
-
} else if (sscanf(tok, "Pad%d", &d_out) == 1) {
|
208
|
-
if (d_out > d) {
|
209
|
-
vt_1 = new RemapDimensionsTransform(d, d_out, false);
|
210
|
-
d = d_out;
|
211
|
-
}
|
212
|
-
} else if (stok == "L2norm") {
|
213
|
-
vt_1 = new NormalizationTransform(d, 2.0);
|
214
|
-
|
215
|
-
// coarse quantizers
|
216
|
-
} else if (
|
217
|
-
!coarse_quantizer &&
|
218
|
-
sscanf(tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
|
219
|
-
coarse_quantizer_1 = new IndexHNSWFlat(d, M, metric);
|
220
|
-
|
221
|
-
} else if (
|
222
|
-
!coarse_quantizer &&
|
223
|
-
sscanf(tok, "IVF%" PRId64 "_NSG%d", &ncentroids, &R) == 2) {
|
224
|
-
coarse_quantizer_1 = new IndexNSGFlat(d, R, metric);
|
225
|
-
|
226
|
-
} else if (
|
227
|
-
!coarse_quantizer &&
|
228
|
-
sscanf(tok, "IVF%" PRId64, &ncentroids) == 1) {
|
229
|
-
if (!parenthesis_ivf.empty()) {
|
230
|
-
coarse_quantizer_1 =
|
231
|
-
index_factory(d, parenthesis_ivf.c_str(), metric);
|
232
|
-
} else if (metric == METRIC_L2) {
|
233
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
234
|
-
} else {
|
235
|
-
coarse_quantizer_1 = new IndexFlatIP(d);
|
236
|
-
}
|
237
|
-
|
238
|
-
} else if (!coarse_quantizer && sscanf(tok, "IMI2x%d", &nbit) == 1) {
|
239
|
-
FAISS_THROW_IF_NOT_MSG(
|
240
|
-
metric == METRIC_L2,
|
241
|
-
"MultiIndex not implemented for inner prod search");
|
242
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, 2, nbit);
|
243
|
-
ncentroids = 1 << (2 * nbit);
|
244
|
-
|
245
|
-
} else if (
|
246
|
-
!coarse_quantizer &&
|
247
|
-
sscanf(tok, "Residual%dx%d", &M, &nbit) == 2) {
|
248
|
-
FAISS_THROW_IF_NOT_MSG(
|
249
|
-
metric == METRIC_L2,
|
250
|
-
"MultiIndex not implemented for inner prod search");
|
251
|
-
coarse_quantizer_1 = new MultiIndexQuantizer(d, M, nbit);
|
252
|
-
ncentroids = int64_t(1) << (M * nbit);
|
253
|
-
use_2layer = true;
|
254
|
-
|
255
|
-
} else if (std::regex_match(
|
256
|
-
stok,
|
257
|
-
std::regex(
|
258
|
-
"(RQ|RCQ)[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*"))) {
|
259
|
-
std::vector<size_t> nbits;
|
260
|
-
std::smatch sm;
|
261
|
-
bool is_RCQ = stok.find("RCQ") == 0;
|
262
|
-
while (std::regex_search(
|
263
|
-
stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
264
|
-
int M = std::stoi(sm[1].str());
|
265
|
-
int nbit = std::stoi(sm[2].str());
|
266
|
-
nbits.resize(nbits.size() + M, nbit);
|
267
|
-
stok = sm.suffix();
|
268
|
-
}
|
269
|
-
if (!is_RCQ) {
|
270
|
-
index_1 = new IndexResidual(d, nbits, metric);
|
271
|
-
} else {
|
272
|
-
index_1 = new ResidualCoarseQuantizer(d, nbits, metric);
|
273
|
-
}
|
274
|
-
} else if (
|
275
|
-
!coarse_quantizer &&
|
276
|
-
sscanf(tok, "Residual%" PRId64, &ncentroids) == 1) {
|
277
|
-
coarse_quantizer_1 = new IndexFlatL2(d);
|
278
|
-
use_2layer = true;
|
279
|
-
|
280
|
-
} else if (stok == "IDMap") {
|
281
|
-
add_idmap = true;
|
282
|
-
|
283
|
-
// IVFs
|
284
|
-
} else if (!index && (stok == "Flat" || stok == "FlatDedup")) {
|
285
|
-
if (coarse_quantizer) {
|
286
|
-
// if there was an IVF in front, then it is an IVFFlat
|
287
|
-
IndexIVF* index_ivf = stok == "Flat"
|
288
|
-
? new IndexIVFFlat(
|
289
|
-
coarse_quantizer, d, ncentroids, metric)
|
290
|
-
: new IndexIVFFlatDedup(
|
291
|
-
coarse_quantizer, d, ncentroids, metric);
|
292
|
-
index_ivf->quantizer_trains_alone =
|
293
|
-
get_trains_alone(coarse_quantizer);
|
294
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
295
|
-
del_coarse_quantizer.release();
|
296
|
-
index_ivf->own_fields = true;
|
297
|
-
index_1 = index_ivf;
|
298
|
-
} else if (hnsw_M > 0) {
|
299
|
-
index_1 = new IndexHNSWFlat(d, hnsw_M, metric);
|
300
|
-
} else if (nsg_R > 0) {
|
301
|
-
index_1 = new IndexNSGFlat(d, nsg_R, metric);
|
302
|
-
} else {
|
303
|
-
FAISS_THROW_IF_NOT_MSG(
|
304
|
-
stok != "FlatDedup",
|
305
|
-
"dedup supported only for IVFFlat");
|
306
|
-
index_1 = new IndexFlat(d, metric);
|
307
|
-
}
|
308
|
-
} else if (
|
309
|
-
!index &&
|
310
|
-
(stok == "SQ8" || stok == "SQ4" || stok == "SQ6" ||
|
311
|
-
stok == "SQfp16")) {
|
312
|
-
ScalarQuantizer::QuantizerType qt = stok == "SQ8"
|
313
|
-
? ScalarQuantizer::QT_8bit
|
314
|
-
: stok == "SQ6" ? ScalarQuantizer::QT_6bit
|
315
|
-
: stok == "SQ4" ? ScalarQuantizer::QT_4bit
|
316
|
-
: stok == "SQfp16" ? ScalarQuantizer::QT_fp16
|
317
|
-
: ScalarQuantizer::QT_4bit;
|
318
|
-
if (coarse_quantizer) {
|
319
|
-
FAISS_THROW_IF_NOT(!use_2layer);
|
320
|
-
IndexIVFScalarQuantizer* index_ivf =
|
321
|
-
new IndexIVFScalarQuantizer(
|
322
|
-
coarse_quantizer, d, ncentroids, qt, metric);
|
323
|
-
index_ivf->quantizer_trains_alone =
|
324
|
-
get_trains_alone(coarse_quantizer);
|
325
|
-
del_coarse_quantizer.release();
|
326
|
-
index_ivf->own_fields = true;
|
327
|
-
index_1 = index_ivf;
|
328
|
-
} else if (hnsw_M > 0) {
|
329
|
-
index_1 = new IndexHNSWSQ(d, qt, hnsw_M, metric);
|
330
|
-
} else {
|
331
|
-
index_1 = new IndexScalarQuantizer(d, qt, metric);
|
332
|
-
}
|
333
|
-
} else if (!index && sscanf(tok, "PQ%d+%d", &M, &M2) == 2) {
|
334
|
-
FAISS_THROW_IF_NOT_MSG(
|
335
|
-
coarse_quantizer, "PQ with + works only with an IVF");
|
336
|
-
FAISS_THROW_IF_NOT_MSG(
|
337
|
-
metric == METRIC_L2,
|
338
|
-
"IVFPQR not implemented for inner product search");
|
339
|
-
IndexIVFPQR* index_ivf = new IndexIVFPQR(
|
340
|
-
coarse_quantizer, d, ncentroids, M, 8, M2, 8);
|
341
|
-
index_ivf->quantizer_trains_alone =
|
342
|
-
get_trains_alone(coarse_quantizer);
|
343
|
-
del_coarse_quantizer.release();
|
344
|
-
index_ivf->own_fields = true;
|
345
|
-
index_1 = index_ivf;
|
346
|
-
} else if (
|
347
|
-
!index &&
|
348
|
-
(sscanf(tok, "PQ%dx4fs_%d", &M, &bbs) == 2 ||
|
349
|
-
(sscanf(tok, "PQ%dx4f%c", &M, &c) == 2 && c == 's') ||
|
350
|
-
(sscanf(tok, "PQ%dx4fs%c", &M, &c) == 2 && c == 'r'))) {
|
351
|
-
if (bbs == -1) {
|
352
|
-
bbs = 32;
|
353
|
-
}
|
354
|
-
bool by_residual = str_ends_with(stok, "fsr");
|
355
|
-
if (coarse_quantizer) {
|
356
|
-
IndexIVFPQFastScan* index_ivf = new IndexIVFPQFastScan(
|
357
|
-
coarse_quantizer, d, ncentroids, M, 4, metric, bbs);
|
358
|
-
index_ivf->quantizer_trains_alone =
|
359
|
-
get_trains_alone(coarse_quantizer);
|
360
|
-
index_ivf->metric_type = metric;
|
361
|
-
index_ivf->by_residual = by_residual;
|
362
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
363
|
-
del_coarse_quantizer.release();
|
364
|
-
index_ivf->own_fields = true;
|
365
|
-
index_1 = index_ivf;
|
366
|
-
} else {
|
367
|
-
IndexPQFastScan* index_pq =
|
368
|
-
new IndexPQFastScan(d, M, 4, metric, bbs);
|
369
|
-
index_1 = index_pq;
|
370
|
-
}
|
371
|
-
} else if (
|
372
|
-
!index &&
|
373
|
-
(sscanf(tok, "PQ%dx%d", &M, &nbit) == 2 ||
|
374
|
-
sscanf(tok, "PQ%d", &M) == 1 ||
|
375
|
-
sscanf(tok, "PQ%dnp", &M) == 1)) {
|
376
|
-
bool do_polysemous_training = stok.find("np") == std::string::npos;
|
377
|
-
if (coarse_quantizer) {
|
378
|
-
if (!use_2layer) {
|
379
|
-
IndexIVFPQ* index_ivf = new IndexIVFPQ(
|
380
|
-
coarse_quantizer, d, ncentroids, M, nbit);
|
381
|
-
index_ivf->quantizer_trains_alone =
|
382
|
-
get_trains_alone(coarse_quantizer);
|
383
|
-
index_ivf->metric_type = metric;
|
384
|
-
index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT;
|
385
|
-
del_coarse_quantizer.release();
|
386
|
-
index_ivf->own_fields = true;
|
387
|
-
index_ivf->do_polysemous_training = do_polysemous_training;
|
388
|
-
index_1 = index_ivf;
|
389
|
-
} else {
|
390
|
-
Index2Layer* index_2l = new Index2Layer(
|
391
|
-
coarse_quantizer, ncentroids, M, nbit);
|
392
|
-
index_2l->q1.quantizer_trains_alone =
|
393
|
-
get_trains_alone(coarse_quantizer);
|
394
|
-
index_2l->q1.own_fields = true;
|
395
|
-
index_1 = index_2l;
|
396
|
-
}
|
397
|
-
} else if (hnsw_M > 0) {
|
398
|
-
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
399
|
-
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
400
|
-
do_polysemous_training;
|
401
|
-
index_1 = ipq;
|
402
|
-
} else {
|
403
|
-
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
404
|
-
index_pq->do_polysemous_training = do_polysemous_training;
|
405
|
-
index_1 = index_pq;
|
406
|
-
}
|
407
|
-
} else if (
|
408
|
-
!index &&
|
409
|
-
sscanf(tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) {
|
410
|
-
Index* quant = new IndexFlatL2(d);
|
411
|
-
IndexHNSW2Level* hidx2l =
|
412
|
-
new IndexHNSW2Level(quant, ncent, pq_m, M);
|
413
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
414
|
-
idx2l->q1.own_fields = true;
|
415
|
-
index_1 = hidx2l;
|
416
|
-
} else if (
|
417
|
-
!index &&
|
418
|
-
sscanf(tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) {
|
419
|
-
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
420
|
-
IndexHNSW2Level* hidx2l =
|
421
|
-
new IndexHNSW2Level(quant, 1 << (2 * nbit), pq_m, M);
|
422
|
-
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
423
|
-
idx2l->q1.own_fields = true;
|
424
|
-
idx2l->q1.quantizer_trains_alone = 1;
|
425
|
-
index_1 = hidx2l;
|
426
|
-
} else if (!index && sscanf(tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) {
|
427
|
-
index_1 = new IndexHNSWPQ(d, pq_m, M);
|
428
|
-
} else if (
|
429
|
-
!index && sscanf(tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 &&
|
430
|
-
pq_m == 8) {
|
431
|
-
index_1 = new IndexHNSWSQ(d, ScalarQuantizer::QT_8bit, M);
|
432
|
-
} else if (!index && sscanf(tok, "HNSW%d", &M) == 1) {
|
433
|
-
hnsw_M = M;
|
434
|
-
// here it is unclear what we want: HNSW flat or HNSWx,Y ?
|
435
|
-
} else if (!index && sscanf(tok, "NSG%d", &R) == 1) {
|
436
|
-
nsg_R = R;
|
437
|
-
} else if (
|
438
|
-
!index &&
|
439
|
-
(stok == "LSH" || stok == "LSHr" || stok == "LSHrt" ||
|
440
|
-
stok == "LSHt")) {
|
441
|
-
bool rotate_data = strstr(tok, "r") != nullptr;
|
442
|
-
bool train_thresholds = strstr(tok, "t") != nullptr;
|
443
|
-
index_1 = new IndexLSH(d, d, rotate_data, train_thresholds);
|
444
|
-
} else if (
|
445
|
-
!index &&
|
446
|
-
sscanf(tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
|
447
|
-
FAISS_THROW_IF_NOT(!coarse_quantizer);
|
448
|
-
index_1 = new IndexLattice(d, M, nbit, r2);
|
449
|
-
} else if (stok == "RFlat") {
|
450
|
-
parenthesis_refine = "Flat";
|
451
|
-
} else if (stok == "Refine") {
|
452
|
-
FAISS_THROW_IF_NOT_MSG(
|
453
|
-
!parenthesis_refine.empty(),
|
454
|
-
"Refine index should be provided in parentheses");
|
349
|
+
return index_ivf;
|
350
|
+
}
|
351
|
+
if (match("(PRQ|PLSQ)" + paq_def_pattern + aq_norm_pattern)) {
|
352
|
+
int nsplits = mres_to_int(sm[2]);
|
353
|
+
int Msub = mres_to_int(sm[3]);
|
354
|
+
int nbit = mres_to_int(sm[4]);
|
355
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
356
|
+
IndexIVF* index_ivf;
|
357
|
+
if (sm[1].str() == "PRQ") {
|
358
|
+
index_ivf = new IndexIVFProductResidualQuantizer(
|
359
|
+
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
|
360
|
+
} else {
|
361
|
+
index_ivf = new IndexIVFProductLocalSearchQuantizer(
|
362
|
+
get_q(), d, nlist, nsplits, Msub, nbit, mt, st);
|
363
|
+
}
|
364
|
+
return index_ivf;
|
365
|
+
}
|
366
|
+
if (match("(RQ|LSQ)([0-9]+)x4fs(r?)(_[0-9]+)?" + aq_norm_pattern)) {
|
367
|
+
int M = std::stoi(sm[2].str());
|
368
|
+
int bbs = mres_to_int(sm[4], 32, 1);
|
369
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
370
|
+
IndexIVFAdditiveQuantizerFastScan* index_ivf;
|
371
|
+
if (sm[1].str() == "RQ") {
|
372
|
+
index_ivf = new IndexIVFResidualQuantizerFastScan(
|
373
|
+
get_q(), d, nlist, M, 4, mt, st, bbs);
|
455
374
|
} else {
|
456
|
-
|
457
|
-
|
458
|
-
tok,
|
459
|
-
description_in);
|
375
|
+
index_ivf = new IndexIVFLocalSearchQuantizerFastScan(
|
376
|
+
get_q(), d, nlist, M, 4, mt, st, bbs);
|
460
377
|
}
|
378
|
+
index_ivf->by_residual = (sm[3].str() == "r");
|
379
|
+
return index_ivf;
|
380
|
+
}
|
381
|
+
if (match("(PRQ|PLSQ)([0-9]+)x([0-9]+)x4fs(r?)(_[0-9]+)?" +
|
382
|
+
aq_norm_pattern)) {
|
383
|
+
int nsplits = std::stoi(sm[2].str());
|
384
|
+
int Msub = std::stoi(sm[3].str());
|
385
|
+
int bbs = mres_to_int(sm[5], 32, 1);
|
386
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
387
|
+
IndexIVFAdditiveQuantizerFastScan* index_ivf;
|
388
|
+
if (sm[1].str() == "PRQ") {
|
389
|
+
index_ivf = new IndexIVFProductResidualQuantizerFastScan(
|
390
|
+
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
|
391
|
+
} else {
|
392
|
+
index_ivf = new IndexIVFProductLocalSearchQuantizerFastScan(
|
393
|
+
get_q(), d, nlist, nsplits, Msub, 4, mt, st, bbs);
|
394
|
+
}
|
395
|
+
index_ivf->by_residual = (sm[4].str() == "r");
|
396
|
+
return index_ivf;
|
397
|
+
}
|
398
|
+
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
|
399
|
+
int outdim = mres_to_int(sm[2], d); // is also the number of bits
|
400
|
+
std::unique_ptr<VectorTransform> vt;
|
401
|
+
if (sm[1] == "ITQ") {
|
402
|
+
vt.reset(new ITQTransform(d, outdim, d != outdim));
|
403
|
+
} else if (sm[1] == "PCA") {
|
404
|
+
vt.reset(new PCAMatrix(d, outdim));
|
405
|
+
} else if (sm[1] == "PCAR") {
|
406
|
+
vt.reset(new PCAMatrix(d, outdim, 0, true));
|
407
|
+
}
|
408
|
+
// the rationale for -1e10 is that this corresponds to simple
|
409
|
+
// thresholding
|
410
|
+
float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
|
411
|
+
IndexIVFSpectralHash* index_ivf =
|
412
|
+
new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
|
413
|
+
index_ivf->replace_vt(vt.release(), true);
|
414
|
+
if (sm[4].length()) {
|
415
|
+
std::string s = sm[4].str();
|
416
|
+
index_ivf->threshold_type = s == "g"
|
417
|
+
? IndexIVFSpectralHash::Thresh_global
|
418
|
+
: s == "c"
|
419
|
+
? IndexIVFSpectralHash::Thresh_centroid
|
420
|
+
:
|
421
|
+
/* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
|
422
|
+
}
|
423
|
+
return index_ivf;
|
424
|
+
}
|
425
|
+
return nullptr;
|
426
|
+
}
|
427
|
+
|
428
|
+
/***************************************************************
|
429
|
+
* Parse IndexHNSW
|
430
|
+
*/
|
431
|
+
|
432
|
+
IndexHNSW* parse_IndexHNSW(
|
433
|
+
const std::string code_string,
|
434
|
+
int d,
|
435
|
+
MetricType mt,
|
436
|
+
int hnsw_M) {
|
437
|
+
std::smatch sm;
|
438
|
+
auto match = [&sm, &code_string](const std::string& pattern) {
|
439
|
+
return re_match(code_string, pattern, sm);
|
440
|
+
};
|
441
|
+
|
442
|
+
if (match("Flat|")) {
|
443
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
444
|
+
}
|
445
|
+
if (match("PQ([0-9]+)(np)?")) {
|
446
|
+
int M = std::stoi(sm[1].str());
|
447
|
+
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
448
|
+
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
449
|
+
sm[2].str() != "np";
|
450
|
+
return ipq;
|
451
|
+
}
|
452
|
+
if (match(sq_pattern)) {
|
453
|
+
return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
|
454
|
+
}
|
455
|
+
if (match("([0-9]+)\\+PQ([0-9]+)?")) {
|
456
|
+
int ncent = mres_to_int(sm[1]);
|
457
|
+
int pq_m = mres_to_int(sm[2]);
|
458
|
+
IndexHNSW2Level* hidx2l =
|
459
|
+
new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
|
460
|
+
dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
|
461
|
+
return hidx2l;
|
462
|
+
}
|
463
|
+
if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
|
464
|
+
int nbit = mres_to_int(sm[1]);
|
465
|
+
int pq_m = mres_to_int(sm[2]);
|
466
|
+
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
467
|
+
IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
|
468
|
+
quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
|
469
|
+
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
470
|
+
idx2l->q1.own_fields = true;
|
471
|
+
idx2l->q1.quantizer_trains_alone = 1;
|
472
|
+
return hidx2l;
|
473
|
+
}
|
474
|
+
|
475
|
+
return nullptr;
|
476
|
+
}
|
477
|
+
|
478
|
+
/***************************************************************
|
479
|
+
* Parse IndexNSG
|
480
|
+
*/
|
481
|
+
|
482
|
+
IndexNSG* parse_IndexNSG(
|
483
|
+
const std::string code_string,
|
484
|
+
int d,
|
485
|
+
MetricType mt,
|
486
|
+
int nsg_R) {
|
487
|
+
std::smatch sm;
|
488
|
+
auto match = [&sm, &code_string](const std::string& pattern) {
|
489
|
+
return re_match(code_string, pattern, sm);
|
490
|
+
};
|
491
|
+
|
492
|
+
if (match("Flat|")) {
|
493
|
+
return new IndexNSGFlat(d, nsg_R, mt);
|
494
|
+
}
|
495
|
+
if (match("PQ([0-9]+)(np)?")) {
|
496
|
+
int M = std::stoi(sm[1].str());
|
497
|
+
IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R);
|
498
|
+
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
499
|
+
sm[2].str() != "np";
|
500
|
+
return ipq;
|
501
|
+
}
|
502
|
+
if (match(sq_pattern)) {
|
503
|
+
return new IndexNSGSQ(d, sq_types[sm[1].str()], nsg_R, mt);
|
504
|
+
}
|
505
|
+
|
506
|
+
return nullptr;
|
507
|
+
}
|
461
508
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
509
|
+
/***************************************************************
|
510
|
+
* Parse basic indexes
|
511
|
+
*/
|
512
|
+
|
513
|
+
Index* parse_other_indexes(
|
514
|
+
const std::string& description,
|
515
|
+
int d,
|
516
|
+
MetricType metric) {
|
517
|
+
std::smatch sm;
|
518
|
+
auto match = [&sm, description](const std::string& pattern) {
|
519
|
+
return re_match(description, pattern, sm);
|
520
|
+
};
|
521
|
+
|
522
|
+
// IndexFlat
|
523
|
+
if (description == "Flat") {
|
524
|
+
return new IndexFlat(d, metric);
|
525
|
+
}
|
526
|
+
|
527
|
+
// IndexLSH
|
528
|
+
if (match("LSH(r?)(t?)")) {
|
529
|
+
bool rotate_data = sm[1].length() > 0;
|
530
|
+
bool train_thresholds = sm[2].length() > 0;
|
531
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2);
|
532
|
+
return new IndexLSH(d, d, rotate_data, train_thresholds);
|
533
|
+
}
|
534
|
+
|
535
|
+
// IndexLattice
|
536
|
+
if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
|
537
|
+
int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
|
538
|
+
int nbit = std::stoi(sm[3].str());
|
539
|
+
return new IndexLattice(d, M, nbit, r2);
|
540
|
+
}
|
541
|
+
|
542
|
+
// IndexScalarQuantizer
|
543
|
+
if (match(sq_pattern)) {
|
544
|
+
return new IndexScalarQuantizer(d, sq_types[description], metric);
|
545
|
+
}
|
546
|
+
|
547
|
+
// IndexPQ
|
548
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
549
|
+
int M = std::stoi(sm[1].str());
|
550
|
+
int nbit = mres_to_int(sm[2], 8, 1);
|
551
|
+
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
552
|
+
index_pq->do_polysemous_training = sm[3].str() != "np";
|
553
|
+
return index_pq;
|
554
|
+
}
|
555
|
+
|
556
|
+
// IndexPQFastScan
|
557
|
+
if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
|
558
|
+
int M = std::stoi(sm[1].str());
|
559
|
+
int bbs = mres_to_int(sm[2], 32, 1);
|
560
|
+
return new IndexPQFastScan(d, M, 4, metric, bbs);
|
561
|
+
}
|
562
|
+
|
563
|
+
// IndexResidualCoarseQuantizer and IndexResidualQuantizer
|
564
|
+
std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
|
565
|
+
if (match(pattern)) {
|
566
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
567
|
+
if (sm[1].str() == "RCQ") {
|
568
|
+
return new ResidualCoarseQuantizer(d, nbits, metric);
|
468
569
|
}
|
570
|
+
AdditiveQuantizer::Search_type_t st =
|
571
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
572
|
+
return new IndexResidualQuantizer(d, nbits, metric, st);
|
573
|
+
}
|
469
574
|
|
470
|
-
|
471
|
-
|
575
|
+
// LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
|
576
|
+
if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
|
577
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
578
|
+
int M = mres_to_int(sm[2]);
|
579
|
+
int nbit = mres_to_int(sm[3]);
|
580
|
+
if (sm[1].str() == "LSCQ") {
|
581
|
+
return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
|
472
582
|
}
|
583
|
+
AdditiveQuantizer::Search_type_t st =
|
584
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
585
|
+
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
|
586
|
+
}
|
587
|
+
|
588
|
+
// IndexProductResidualQuantizer
|
589
|
+
if (match("PRQ" + paq_def_pattern + aq_norm_pattern)) {
|
590
|
+
int nsplits = mres_to_int(sm[1]);
|
591
|
+
int Msub = mres_to_int(sm[2]);
|
592
|
+
int nbit = mres_to_int(sm[3]);
|
593
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
594
|
+
return new IndexProductResidualQuantizer(
|
595
|
+
d, nsplits, Msub, nbit, metric, st);
|
596
|
+
}
|
473
597
|
|
474
|
-
|
475
|
-
|
476
|
-
|
598
|
+
// IndexProductLocalSearchQuantizer
|
599
|
+
if (match("PLSQ" + paq_def_pattern + aq_norm_pattern)) {
|
600
|
+
int nsplits = mres_to_int(sm[1]);
|
601
|
+
int Msub = mres_to_int(sm[2]);
|
602
|
+
int nbit = mres_to_int(sm[3]);
|
603
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
604
|
+
return new IndexProductLocalSearchQuantizer(
|
605
|
+
d, nsplits, Msub, nbit, metric, st);
|
606
|
+
}
|
607
|
+
|
608
|
+
// IndexAdditiveQuantizerFastScan
|
609
|
+
// RQ{M}x4fs_{bbs}_{search_type}
|
610
|
+
pattern = "(LSQ|RQ)([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
|
611
|
+
if (match(pattern)) {
|
612
|
+
int M = std::stoi(sm[2].str());
|
613
|
+
int bbs = mres_to_int(sm[3], 32, 1);
|
614
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
615
|
+
|
616
|
+
if (sm[1].str() == "RQ") {
|
617
|
+
return new IndexResidualQuantizerFastScan(d, M, 4, metric, st, bbs);
|
618
|
+
} else if (sm[1].str() == "LSQ") {
|
619
|
+
return new IndexLocalSearchQuantizerFastScan(
|
620
|
+
d, M, 4, metric, st, bbs);
|
477
621
|
}
|
622
|
+
}
|
478
623
|
|
479
|
-
|
480
|
-
|
481
|
-
|
624
|
+
// IndexProductAdditiveQuantizerFastScan
|
625
|
+
// PRQ{nsplits}x{Msub}x4fs_{bbs}_{search_type}
|
626
|
+
pattern = "(PLSQ|PRQ)([0-9]+)x([0-9]+)x4fs(_[0-9]+)?" + aq_norm_pattern;
|
627
|
+
if (match(pattern)) {
|
628
|
+
int nsplits = std::stoi(sm[2].str());
|
629
|
+
int Msub = std::stoi(sm[3].str());
|
630
|
+
int bbs = mres_to_int(sm[4], 32, 1);
|
631
|
+
auto st = aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
632
|
+
|
633
|
+
if (sm[1].str() == "PRQ") {
|
634
|
+
return new IndexProductResidualQuantizerFastScan(
|
635
|
+
d, nsplits, Msub, 4, metric, st, bbs);
|
636
|
+
} else if (sm[1].str() == "PLSQ") {
|
637
|
+
return new IndexProductLocalSearchQuantizerFastScan(
|
638
|
+
d, nsplits, Msub, 4, metric, st, bbs);
|
482
639
|
}
|
483
640
|
}
|
484
641
|
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
642
|
+
return nullptr;
|
643
|
+
}
|
644
|
+
|
645
|
+
/***************************************************************
|
646
|
+
* Driver function
|
647
|
+
*/
|
648
|
+
std::unique_ptr<Index> index_factory_sub(
|
649
|
+
int d,
|
650
|
+
std::string description,
|
651
|
+
MetricType metric) {
|
652
|
+
// handle composite indexes
|
653
|
+
|
654
|
+
bool verbose = index_factory_verbose;
|
655
|
+
|
656
|
+
if (verbose) {
|
657
|
+
printf("begin parse VectorTransforms: %s \n", description.c_str());
|
491
658
|
}
|
492
659
|
|
493
|
-
|
494
|
-
|
660
|
+
// for the current match
|
661
|
+
std::smatch sm;
|
495
662
|
|
496
|
-
//
|
497
|
-
|
498
|
-
|
663
|
+
// handle refines
|
664
|
+
if (re_match(description, "(.+),RFlat", sm) ||
|
665
|
+
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
|
666
|
+
std::unique_ptr<Index> filter_index =
|
667
|
+
index_factory_sub(d, sm[1].str(), metric);
|
668
|
+
std::unique_ptr<Index> refine_index;
|
499
669
|
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
670
|
+
if (sm.size() == 3) { // Refine
|
671
|
+
refine_index = index_factory_sub(d, sm[2].str(), metric);
|
672
|
+
} else { // RFlat
|
673
|
+
refine_index.reset(new IndexFlat(d, metric));
|
674
|
+
}
|
675
|
+
IndexRefine* index_rf =
|
676
|
+
new IndexRefine(filter_index.get(), refine_index.get());
|
677
|
+
index_rf->own_fields = true;
|
678
|
+
filter_index.release();
|
679
|
+
refine_index.release();
|
680
|
+
index_rf->own_refine_index = true;
|
681
|
+
return std::unique_ptr<Index>(index_rf);
|
504
682
|
}
|
505
683
|
|
506
|
-
|
507
|
-
|
684
|
+
// IndexPreTransform
|
685
|
+
// should handle this first (even before parentheses) because it changes d
|
686
|
+
std::vector<std::unique_ptr<VectorTransform>> vts;
|
687
|
+
VectorTransform* vt = nullptr;
|
688
|
+
while (re_match(description, "([^,]+),(.*)", sm) &&
|
689
|
+
(vt = parse_VectorTransform(sm[1], d))) {
|
690
|
+
// reset loop
|
691
|
+
description = sm[sm.size() - 1];
|
692
|
+
vts.emplace_back(vt);
|
693
|
+
d = vts.back()->d_out;
|
694
|
+
}
|
695
|
+
|
696
|
+
if (vts.size() > 0) {
|
697
|
+
std::unique_ptr<Index> sub_index =
|
698
|
+
index_factory_sub(d, description, metric);
|
699
|
+
IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
|
700
|
+
std::unique_ptr<Index> ret(index_pt);
|
508
701
|
index_pt->own_fields = true;
|
509
|
-
|
510
|
-
while (vts.
|
511
|
-
|
512
|
-
|
702
|
+
sub_index.release();
|
703
|
+
while (vts.size() > 0) {
|
704
|
+
if (verbose) {
|
705
|
+
printf("prepend trans %d -> %d\n",
|
706
|
+
vts.back()->d_in,
|
707
|
+
vts.back()->d_out);
|
708
|
+
}
|
709
|
+
index_pt->prepend_transform(vts.back().release());
|
710
|
+
vts.pop_back();
|
513
711
|
}
|
514
|
-
|
712
|
+
return ret;
|
515
713
|
}
|
516
714
|
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
715
|
+
// what we got from the parentheses
|
716
|
+
std::vector<std::unique_ptr<Index>> parenthesis_indexes;
|
717
|
+
|
718
|
+
int begin = 0;
|
719
|
+
while (description.find('(', begin) != std::string::npos) {
|
720
|
+
// replace indexes in () with Index0, Index1, etc.
|
721
|
+
int i0, i1;
|
722
|
+
find_matching_parentheses(description, i0, i1, begin);
|
723
|
+
std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
|
724
|
+
int no = parenthesis_indexes.size();
|
725
|
+
parenthesis_indexes.push_back(
|
726
|
+
index_factory_sub(d, sub_description, metric));
|
727
|
+
description = description.substr(0, i0 + 1) + "Index" +
|
728
|
+
std::to_string(no) + description.substr(i1);
|
729
|
+
begin = i1 + 1;
|
524
730
|
}
|
525
731
|
|
526
|
-
|
732
|
+
if (verbose) {
|
733
|
+
printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
|
734
|
+
description.c_str(),
|
735
|
+
parenthesis_indexes.size(),
|
736
|
+
d);
|
737
|
+
}
|
738
|
+
|
739
|
+
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
740
|
+
// support both
|
741
|
+
if (re_match(description, "(.+),IDMap2", sm) ||
|
742
|
+
re_match(description, "IDMap2,(.+)", sm)) {
|
743
|
+
IndexIDMap2* idmap2 = new IndexIDMap2(
|
744
|
+
index_factory_sub(d, sm[1].str(), metric).release());
|
745
|
+
idmap2->own_fields = true;
|
746
|
+
return std::unique_ptr<Index>(idmap2);
|
747
|
+
}
|
748
|
+
|
749
|
+
if (re_match(description, "(.+),IDMap", sm) ||
|
750
|
+
re_match(description, "IDMap,(.+)", sm)) {
|
751
|
+
IndexIDMap* idmap = new IndexIDMap(
|
752
|
+
index_factory_sub(d, sm[1].str(), metric).release());
|
753
|
+
idmap->own_fields = true;
|
754
|
+
return std::unique_ptr<Index>(idmap);
|
755
|
+
}
|
756
|
+
|
757
|
+
{ // handle basic index types
|
758
|
+
Index* index = parse_other_indexes(description, d, metric);
|
759
|
+
if (index) {
|
760
|
+
return std::unique_ptr<Index>(index);
|
761
|
+
}
|
762
|
+
}
|
763
|
+
|
764
|
+
// HNSW variants (it was unclear in the old version that the separator was a
|
765
|
+
// "," so we support both "_" and ",")
|
766
|
+
if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
|
767
|
+
int hnsw_M = mres_to_int(sm[1], 32);
|
768
|
+
// We also accept empty code string (synonym of Flat)
|
769
|
+
std::string code_string =
|
770
|
+
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
771
|
+
if (verbose) {
|
772
|
+
printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
|
773
|
+
description.c_str(),
|
774
|
+
code_string.c_str(),
|
775
|
+
hnsw_M);
|
776
|
+
}
|
777
|
+
|
778
|
+
IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
|
779
|
+
FAISS_THROW_IF_NOT_FMT(
|
780
|
+
index,
|
781
|
+
"could not parse HNSW code description %s in %s",
|
782
|
+
code_string.c_str(),
|
783
|
+
description.c_str());
|
784
|
+
return std::unique_ptr<Index>(index);
|
785
|
+
}
|
786
|
+
|
787
|
+
// NSG variants (it was unclear in the old version that the separator was a
|
788
|
+
// "," so we support both "_" and ",")
|
789
|
+
if (re_match(description, "NSG([0-9]*)([,_].*)?", sm)) {
|
790
|
+
int nsg_R = mres_to_int(sm[1], 32);
|
791
|
+
// We also accept empty code string (synonym of Flat)
|
792
|
+
std::string code_string =
|
793
|
+
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
794
|
+
if (verbose) {
|
795
|
+
printf("parsing NSG string %s code_string=%s nsg_R=%d\n",
|
796
|
+
description.c_str(),
|
797
|
+
code_string.c_str(),
|
798
|
+
nsg_R);
|
799
|
+
}
|
800
|
+
|
801
|
+
IndexNSG* index = parse_IndexNSG(code_string, d, metric, nsg_R);
|
802
|
+
FAISS_THROW_IF_NOT_FMT(
|
803
|
+
index,
|
804
|
+
"could not parse NSG code description %s in %s",
|
805
|
+
code_string.c_str(),
|
806
|
+
description.c_str());
|
807
|
+
return std::unique_ptr<Index>(index);
|
808
|
+
}
|
809
|
+
|
810
|
+
// IndexRowwiseMinMax, fp32 version
|
811
|
+
if (description.compare(0, 7, "MinMax,") == 0) {
|
812
|
+
size_t comma = description.find(",");
|
813
|
+
std::string sub_index_string = description.substr(comma + 1);
|
814
|
+
auto sub_index = index_factory_sub(d, sub_index_string, metric);
|
815
|
+
|
816
|
+
auto index = new IndexRowwiseMinMax(sub_index.release());
|
817
|
+
index->own_fields = true;
|
818
|
+
|
819
|
+
return std::unique_ptr<Index>(index);
|
820
|
+
}
|
821
|
+
|
822
|
+
// IndexRowwiseMinMax, fp16 version
|
823
|
+
if (description.compare(0, 11, "MinMaxFP16,") == 0) {
|
824
|
+
size_t comma = description.find(",");
|
825
|
+
std::string sub_index_string = description.substr(comma + 1);
|
826
|
+
auto sub_index = index_factory_sub(d, sub_index_string, metric);
|
827
|
+
|
828
|
+
auto index = new IndexRowwiseMinMaxFP16(sub_index.release());
|
829
|
+
index->own_fields = true;
|
830
|
+
|
831
|
+
return std::unique_ptr<Index>(index);
|
832
|
+
}
|
833
|
+
|
834
|
+
// IndexIVF
|
835
|
+
{
|
836
|
+
size_t nlist;
|
837
|
+
bool use_2layer;
|
838
|
+
size_t comma = description.find(",");
|
839
|
+
std::string coarse_string = description.substr(0, comma);
|
840
|
+
// Match coarse quantizer part first
|
841
|
+
std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
|
842
|
+
description.substr(0, comma),
|
843
|
+
d,
|
844
|
+
metric,
|
845
|
+
parenthesis_indexes,
|
846
|
+
nlist,
|
847
|
+
use_2layer));
|
848
|
+
|
849
|
+
if (comma != std::string::npos && quantizer.get()) {
|
850
|
+
std::string code_description = description.substr(comma + 1);
|
851
|
+
if (use_2layer) {
|
852
|
+
bool ok =
|
853
|
+
re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
|
854
|
+
FAISS_THROW_IF_NOT_FMT(
|
855
|
+
ok,
|
856
|
+
"could not parse 2 layer code description %s in %s",
|
857
|
+
code_description.c_str(),
|
858
|
+
description.c_str());
|
859
|
+
int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
|
860
|
+
Index2Layer* index_2l =
|
861
|
+
new Index2Layer(quantizer.release(), nlist, M, nbit);
|
862
|
+
index_2l->q1.own_fields = true;
|
863
|
+
index_2l->q1.quantizer_trains_alone =
|
864
|
+
get_trains_alone(index_2l->q1.quantizer);
|
865
|
+
return std::unique_ptr<Index>(index_2l);
|
866
|
+
}
|
867
|
+
|
868
|
+
IndexIVF* index_ivf =
|
869
|
+
parse_IndexIVF(code_description, quantizer, nlist, metric);
|
870
|
+
|
871
|
+
FAISS_THROW_IF_NOT_FMT(
|
872
|
+
index_ivf,
|
873
|
+
"could not parse code description %s in %s",
|
874
|
+
code_description.c_str(),
|
875
|
+
description.c_str());
|
876
|
+
return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
|
877
|
+
}
|
878
|
+
}
|
879
|
+
FAISS_THROW_FMT("could not parse index string %s", description.c_str());
|
880
|
+
return nullptr;
|
881
|
+
}
|
882
|
+
|
883
|
+
} // anonymous namespace
|
884
|
+
|
885
|
+
Index* index_factory(int d, const char* description, MetricType metric) {
|
886
|
+
return index_factory_sub(d, description, metric).release();
|
527
887
|
}
|
528
888
|
|
529
889
|
IndexBinary* index_binary_factory(int d, const char* description) {
|