faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
@@ -5,92 +5,80 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
/*
|
11
|
-
* implementation of
|
9
|
+
* implementation of the index_factory function. Lots of regex parsing code.
|
12
10
|
*/
|
13
11
|
|
14
|
-
#include <faiss/
|
12
|
+
#include <faiss/index_factory.h>
|
13
|
+
#include "faiss/MetricType.h"
|
14
|
+
#include "faiss/impl/FaissAssert.h"
|
15
15
|
|
16
16
|
#include <cinttypes>
|
17
17
|
#include <cmath>
|
18
18
|
|
19
|
+
#include <map>
|
20
|
+
|
21
|
+
#include <regex>
|
22
|
+
|
19
23
|
#include <faiss/impl/FaissAssert.h>
|
20
|
-
#include <faiss/utils/utils.h>
|
21
24
|
#include <faiss/utils/random.h>
|
25
|
+
#include <faiss/utils/utils.h>
|
22
26
|
|
27
|
+
#include <faiss/Index2Layer.h>
|
28
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
23
29
|
#include <faiss/IndexFlat.h>
|
24
|
-
#include <faiss/
|
25
|
-
#include <faiss/IndexPreTransform.h>
|
26
|
-
#include <faiss/IndexLSH.h>
|
27
|
-
#include <faiss/IndexPQ.h>
|
30
|
+
#include <faiss/IndexHNSW.h>
|
28
31
|
#include <faiss/IndexIVF.h>
|
32
|
+
#include <faiss/IndexIVFAdditiveQuantizer.h>
|
33
|
+
#include <faiss/IndexIVFFlat.h>
|
29
34
|
#include <faiss/IndexIVFPQ.h>
|
35
|
+
#include <faiss/IndexIVFPQFastScan.h>
|
30
36
|
#include <faiss/IndexIVFPQR.h>
|
31
|
-
#include <faiss/
|
32
|
-
#include <faiss/
|
33
|
-
#include <faiss/MetaIndexes.h>
|
34
|
-
#include <faiss/IndexScalarQuantizer.h>
|
35
|
-
#include <faiss/IndexHNSW.h>
|
37
|
+
#include <faiss/IndexIVFSpectralHash.h>
|
38
|
+
#include <faiss/IndexLSH.h>
|
36
39
|
#include <faiss/IndexLattice.h>
|
40
|
+
#include <faiss/IndexNSG.h>
|
41
|
+
#include <faiss/IndexPQ.h>
|
37
42
|
#include <faiss/IndexPQFastScan.h>
|
38
|
-
#include <faiss/
|
43
|
+
#include <faiss/IndexPreTransform.h>
|
39
44
|
#include <faiss/IndexRefine.h>
|
40
|
-
|
45
|
+
#include <faiss/IndexScalarQuantizer.h>
|
46
|
+
#include <faiss/MetaIndexes.h>
|
47
|
+
#include <faiss/VectorTransform.h>
|
41
48
|
|
42
49
|
#include <faiss/IndexBinaryFlat.h>
|
43
50
|
#include <faiss/IndexBinaryHNSW.h>
|
44
|
-
#include <faiss/IndexBinaryIVF.h>
|
45
51
|
#include <faiss/IndexBinaryHash.h>
|
52
|
+
#include <faiss/IndexBinaryIVF.h>
|
53
|
+
#include <string>
|
46
54
|
|
47
55
|
namespace faiss {
|
48
56
|
|
49
|
-
|
50
57
|
/***************************************************************
|
51
58
|
* index_factory
|
52
59
|
***************************************************************/
|
53
60
|
|
54
|
-
|
61
|
+
int index_factory_verbose = 0;
|
55
62
|
|
56
|
-
|
57
|
-
std::vector<VectorTransform *> chain;
|
58
|
-
~VTChain () {
|
59
|
-
for (int i = 0; i < chain.size(); i++) {
|
60
|
-
delete chain[i];
|
61
|
-
}
|
62
|
-
}
|
63
|
-
};
|
64
|
-
|
65
|
-
|
66
|
-
/// what kind of training does this coarse quantizer require?
|
67
|
-
char get_trains_alone(const Index *coarse_quantizer) {
|
68
|
-
return
|
69
|
-
dynamic_cast<const IndexFlat*>(coarse_quantizer) ? 0 :
|
70
|
-
// multi index just needs to be quantized
|
71
|
-
dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ? 1 :
|
72
|
-
dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer) ? 2 :
|
73
|
-
2; // for complicated indexes, we assume they can't be used as a kmeans index
|
74
|
-
}
|
63
|
+
namespace {
|
75
64
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
}
|
65
|
+
/***************************************************************
|
66
|
+
* Small functions
|
67
|
+
*/
|
80
68
|
|
81
|
-
|
82
|
-
|
83
|
-
{
|
84
|
-
int i;
|
85
|
-
for(i = s.length() - 1; i >= 0; i--) {
|
86
|
-
if (!isdigit(s[i])) break;
|
87
|
-
}
|
88
|
-
return str_ends_with(s.substr(0, i + 1), suffix);
|
69
|
+
bool re_match(const std::string& s, const std::string& pat, std::smatch& sm) {
|
70
|
+
return std::regex_match(s, sm, std::regex(pat));
|
89
71
|
}
|
90
72
|
|
91
|
-
|
73
|
+
// find first pair of matching parentheses
|
74
|
+
void find_matching_parentheses(
|
75
|
+
const std::string& s,
|
76
|
+
int& i0,
|
77
|
+
int& i1,
|
78
|
+
int begin = 0) {
|
92
79
|
int st = 0;
|
93
|
-
|
80
|
+
i0 = i1 = 0;
|
81
|
+
for (int i = begin; i < s.length(); i++) {
|
94
82
|
if (s[i] == '(') {
|
95
83
|
if (st == 0) {
|
96
84
|
i0 = i;
|
@@ -105,413 +93,647 @@ void find_matching_parentheses(const std::string &s, int & i0, int & i1) {
|
|
105
93
|
return;
|
106
94
|
}
|
107
95
|
if (st < 0) {
|
108
|
-
FAISS_THROW_FMT(
|
96
|
+
FAISS_THROW_FMT(
|
97
|
+
"factory string %s: unbalanced parentheses", s.c_str());
|
109
98
|
}
|
110
99
|
}
|
100
|
+
}
|
101
|
+
FAISS_THROW_FMT(
|
102
|
+
"factory string %s: unbalanced parentheses st=%d", s.c_str(), st);
|
103
|
+
}
|
111
104
|
|
105
|
+
/// what kind of training does this coarse quantizer require?
|
106
|
+
char get_trains_alone(const Index* coarse_quantizer) {
|
107
|
+
if (dynamic_cast<const IndexFlat*>(coarse_quantizer)) {
|
108
|
+
return 0;
|
109
|
+
}
|
110
|
+
// multi index just needs to be quantized
|
111
|
+
if (dynamic_cast<const MultiIndexQuantizer*>(coarse_quantizer) ||
|
112
|
+
dynamic_cast<const ResidualCoarseQuantizer*>(coarse_quantizer)) {
|
113
|
+
return 1;
|
112
114
|
}
|
113
|
-
|
115
|
+
if (dynamic_cast<const IndexHNSWFlat*>(coarse_quantizer)) {
|
116
|
+
return 2;
|
117
|
+
}
|
118
|
+
return 2; // for complicated indexes, we assume they can't be used as a
|
119
|
+
// kmeans index
|
120
|
+
}
|
114
121
|
|
122
|
+
// set the fields for factory-constructed IVF structures
|
123
|
+
IndexIVF* fix_ivf_fields(IndexIVF* index_ivf) {
|
124
|
+
index_ivf->quantizer_trains_alone = get_trains_alone(index_ivf->quantizer);
|
125
|
+
index_ivf->cp.spherical = index_ivf->metric_type == METRIC_INNER_PRODUCT;
|
126
|
+
index_ivf->own_fields = true;
|
127
|
+
return index_ivf;
|
115
128
|
}
|
116
129
|
|
117
|
-
|
130
|
+
int mres_to_int(const std::ssub_match& mr, int deflt = -1, int begin = 0) {
|
131
|
+
if (mr.length() == 0) {
|
132
|
+
return deflt;
|
133
|
+
}
|
134
|
+
return std::stoi(mr.str().substr(begin));
|
135
|
+
}
|
118
136
|
|
119
|
-
|
120
|
-
{
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
// handle indexes in parentheses
|
136
|
-
while (description.find('(') != std::string::npos) {
|
137
|
-
// then we make a sub-index and remove the () from the description
|
138
|
-
int i0, i1;
|
139
|
-
find_matching_parentheses(description, i0, i1);
|
137
|
+
std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
|
138
|
+
{"SQ8", ScalarQuantizer::QT_8bit},
|
139
|
+
{"SQ4", ScalarQuantizer::QT_4bit},
|
140
|
+
{"SQ6", ScalarQuantizer::QT_6bit},
|
141
|
+
{"SQfp16", ScalarQuantizer::QT_fp16},
|
142
|
+
};
|
143
|
+
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
|
144
|
+
|
145
|
+
std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
|
146
|
+
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
|
147
|
+
{"_Nnone", AdditiveQuantizer::ST_LUT_nonorm},
|
148
|
+
{"_Nqint8", AdditiveQuantizer::ST_norm_qint8},
|
149
|
+
{"_Nqint4", AdditiveQuantizer::ST_norm_qint4},
|
150
|
+
{"_Ncqint8", AdditiveQuantizer::ST_norm_cqint8},
|
151
|
+
{"_Ncqint4", AdditiveQuantizer::ST_norm_cqint4},
|
152
|
+
};
|
140
153
|
|
141
|
-
|
154
|
+
const std::string aq_def_pattern = "[0-9]+x[0-9]+(_[0-9]+x[0-9]+)*";
|
155
|
+
const std::string aq_norm_pattern =
|
156
|
+
"(|_Nnone|_Nfloat|_Nqint8|_Nqint4|_Ncqint8|_Ncqint4)";
|
142
157
|
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
int64_t ncentroids = -1;
|
154
|
-
bool use_2layer = false;
|
155
|
-
int hnsw_M = -1;
|
156
|
-
|
157
|
-
for (char *tok = strtok_r (&description[0], " ,", &ptr);
|
158
|
-
tok;
|
159
|
-
tok = strtok_r (nullptr, " ,", &ptr)) {
|
160
|
-
int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2;
|
161
|
-
std::string stok(tok);
|
162
|
-
nbit = 8;
|
163
|
-
int bbs = -1;
|
164
|
-
char c;
|
165
|
-
|
166
|
-
// to avoid mem leaks with exceptions:
|
167
|
-
// do all tests before any instanciation
|
168
|
-
|
169
|
-
VectorTransform *vt_1 = nullptr;
|
170
|
-
Index *coarse_quantizer_1 = nullptr;
|
171
|
-
Index *index_1 = nullptr;
|
172
|
-
|
173
|
-
// VectorTransforms
|
174
|
-
if (sscanf (tok, "PCA%d", &d_out) == 1) {
|
175
|
-
vt_1 = new PCAMatrix (d, d_out);
|
176
|
-
d = d_out;
|
177
|
-
} else if (sscanf (tok, "PCAR%d", &d_out) == 1) {
|
178
|
-
vt_1 = new PCAMatrix (d, d_out, 0, true);
|
179
|
-
d = d_out;
|
180
|
-
} else if (sscanf (tok, "RR%d", &d_out) == 1) {
|
181
|
-
vt_1 = new RandomRotationMatrix (d, d_out);
|
182
|
-
d = d_out;
|
183
|
-
} else if (sscanf (tok, "PCAW%d", &d_out) == 1) {
|
184
|
-
vt_1 = new PCAMatrix (d, d_out, -0.5, false);
|
185
|
-
d = d_out;
|
186
|
-
} else if (sscanf (tok, "PCAWR%d", &d_out) == 1) {
|
187
|
-
vt_1 = new PCAMatrix (d, d_out, -0.5, true);
|
188
|
-
d = d_out;
|
189
|
-
} else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
|
190
|
-
vt_1 = new OPQMatrix (d, opq_M, d_out);
|
191
|
-
d = d_out;
|
192
|
-
} else if (sscanf (tok, "OPQ%d", &opq_M) == 1) {
|
193
|
-
vt_1 = new OPQMatrix (d, opq_M);
|
194
|
-
} else if (sscanf (tok, "ITQ%d", &d_out) == 1) {
|
195
|
-
vt_1 = new ITQTransform (d, d_out, true);
|
196
|
-
d = d_out;
|
197
|
-
} else if (stok == "ITQ") {
|
198
|
-
vt_1 = new ITQTransform (d, d, false);
|
199
|
-
} else if (sscanf (tok, "Pad%d", &d_out) == 1) {
|
200
|
-
if (d_out > d) {
|
201
|
-
vt_1 = new RemapDimensionsTransform (d, d_out, false);
|
202
|
-
d = d_out;
|
203
|
-
}
|
204
|
-
} else if (stok == "L2norm") {
|
205
|
-
vt_1 = new NormalizationTransform (d, 2.0);
|
206
|
-
|
207
|
-
// coarse quantizers
|
208
|
-
} else if (!coarse_quantizer &&
|
209
|
-
sscanf (tok, "IVF%" PRId64 "_HNSW%d", &ncentroids, &M) == 2) {
|
210
|
-
coarse_quantizer_1 = new IndexHNSWFlat (d, M, metric);
|
211
|
-
|
212
|
-
} else if (!coarse_quantizer &&
|
213
|
-
sscanf (tok, "IVF%" PRId64, &ncentroids) == 1) {
|
214
|
-
if (!parenthesis_ivf.empty()) {
|
215
|
-
coarse_quantizer_1 =
|
216
|
-
index_factory(d, parenthesis_ivf.c_str(), metric);
|
217
|
-
|
218
|
-
} else if (metric == METRIC_L2) {
|
219
|
-
coarse_quantizer_1 = new IndexFlatL2 (d);
|
220
|
-
} else {
|
221
|
-
coarse_quantizer_1 = new IndexFlatIP (d);
|
222
|
-
}
|
158
|
+
AdditiveQuantizer::Search_type_t aq_parse_search_type(
|
159
|
+
std::string stok,
|
160
|
+
MetricType metric) {
|
161
|
+
if (stok == "") {
|
162
|
+
return metric == METRIC_L2 ? AdditiveQuantizer::ST_decompress
|
163
|
+
: AdditiveQuantizer::ST_LUT_nonorm;
|
164
|
+
}
|
165
|
+
int pos = stok.rfind("_");
|
166
|
+
return aq_search_type[stok.substr(pos)];
|
167
|
+
}
|
223
168
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
bool rotate_data = strstr(tok, "r") != nullptr;
|
393
|
-
bool train_thresholds = strstr(tok, "t") != nullptr;
|
394
|
-
index_1 = new IndexLSH (d, d, rotate_data, train_thresholds);
|
395
|
-
} else if (!index &&
|
396
|
-
sscanf (tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) {
|
397
|
-
FAISS_THROW_IF_NOT(!coarse_quantizer);
|
398
|
-
index_1 = new IndexLattice(d, M, nbit, r2);
|
399
|
-
} else if (stok == "RFlat") {
|
400
|
-
parenthesis_refine = "Flat";
|
401
|
-
} else if (stok == "Refine") {
|
402
|
-
FAISS_THROW_IF_NOT_MSG(
|
403
|
-
!parenthesis_refine.empty(),
|
404
|
-
"Refine index should be provided in parentheses"
|
405
|
-
);
|
169
|
+
std::vector<size_t> aq_parse_nbits(std::string stok) {
|
170
|
+
std::vector<size_t> nbits;
|
171
|
+
std::smatch sm;
|
172
|
+
while (std::regex_search(stok, sm, std::regex("([0-9]+)x([0-9]+)"))) {
|
173
|
+
int M = std::stoi(sm[1].str());
|
174
|
+
int nbit = std::stoi(sm[2].str());
|
175
|
+
nbits.resize(nbits.size() + M, nbit);
|
176
|
+
stok = sm.suffix();
|
177
|
+
}
|
178
|
+
return nbits;
|
179
|
+
}
|
180
|
+
|
181
|
+
/***************************************************************
|
182
|
+
* Parse VectorTransform
|
183
|
+
*/
|
184
|
+
|
185
|
+
VectorTransform* parse_VectorTransform(const std::string& description, int d) {
|
186
|
+
std::smatch sm;
|
187
|
+
auto match = [&sm, description](std::string pattern) {
|
188
|
+
return re_match(description, pattern, sm);
|
189
|
+
};
|
190
|
+
if (match("PCA(W?)(R?)([0-9]+)")) {
|
191
|
+
bool white = sm[1].length() > 0;
|
192
|
+
bool rot = sm[2].length() > 0;
|
193
|
+
return new PCAMatrix(d, std::stoi(sm[3].str()), white ? -0.5 : 0, rot);
|
194
|
+
}
|
195
|
+
if (match("L2[nN]orm")) {
|
196
|
+
return new NormalizationTransform(d, 2.0);
|
197
|
+
}
|
198
|
+
if (match("RR([0-9]+)?")) {
|
199
|
+
return new RandomRotationMatrix(d, mres_to_int(sm[1], d));
|
200
|
+
}
|
201
|
+
if (match("ITQ([0-9]+)?")) {
|
202
|
+
return new ITQTransform(d, mres_to_int(sm[1], d), sm[1].length() > 0);
|
203
|
+
}
|
204
|
+
if (match("OPQ([0-9]+)(_[0-9]+)?")) {
|
205
|
+
int M = std::stoi(sm[1].str());
|
206
|
+
int d_out = mres_to_int(sm[2], d, 1);
|
207
|
+
return new OPQMatrix(d, M, d_out);
|
208
|
+
}
|
209
|
+
if (match("Pad([0-9]+)")) {
|
210
|
+
int d_out = std::stoi(sm[1].str());
|
211
|
+
return new RemapDimensionsTransform(d, std::max(d_out, d), false);
|
212
|
+
}
|
213
|
+
return nullptr;
|
214
|
+
};
|
215
|
+
|
216
|
+
/***************************************************************
|
217
|
+
* Parse IndexIVF
|
218
|
+
*/
|
219
|
+
|
220
|
+
// parsing guard + function
|
221
|
+
Index* parse_coarse_quantizer(
|
222
|
+
const std::string& description,
|
223
|
+
int d,
|
224
|
+
MetricType mt,
|
225
|
+
std::vector<std::unique_ptr<Index>>& parenthesis_indexes,
|
226
|
+
size_t& nlist,
|
227
|
+
bool& use_2layer) {
|
228
|
+
std::smatch sm;
|
229
|
+
auto match = [&sm, description](std::string pattern) {
|
230
|
+
return re_match(description, pattern, sm);
|
231
|
+
};
|
232
|
+
use_2layer = false;
|
233
|
+
|
234
|
+
if (match("IVF([0-9]+)")) {
|
235
|
+
nlist = std::stoi(sm[1].str());
|
236
|
+
return new IndexFlat(d, mt);
|
237
|
+
}
|
238
|
+
if (match("IMI2x([0-9]+)")) {
|
239
|
+
int nbit = std::stoi(sm[1].str());
|
240
|
+
FAISS_THROW_IF_NOT_MSG(
|
241
|
+
mt == METRIC_L2,
|
242
|
+
"MultiIndex not implemented for inner prod search");
|
243
|
+
nlist = (size_t)1 << (2 * nbit);
|
244
|
+
return new MultiIndexQuantizer(d, 2, nbit);
|
245
|
+
}
|
246
|
+
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
|
247
|
+
nlist = std::stoi(sm[1].str());
|
248
|
+
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
|
249
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
250
|
+
}
|
251
|
+
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
|
252
|
+
nlist = std::stoi(sm[1].str());
|
253
|
+
int R = std::stoi(sm[2]);
|
254
|
+
return new IndexNSGFlat(d, R, mt);
|
255
|
+
}
|
256
|
+
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
|
257
|
+
nlist = std::stoi(sm[1].str());
|
258
|
+
int no = std::stoi(sm[2].str());
|
259
|
+
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
|
260
|
+
return parenthesis_indexes[no].release();
|
261
|
+
}
|
262
|
+
|
263
|
+
// these two generate Index2Layer's not IndexIVF's
|
264
|
+
if (match("Residual([0-9]+)x([0-9]+)")) {
|
265
|
+
FAISS_THROW_IF_NOT_MSG(
|
266
|
+
mt == METRIC_L2,
|
267
|
+
"MultiIndex not implemented for inner prod search");
|
268
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2]);
|
269
|
+
nlist = (size_t)1 << (M * nbit);
|
270
|
+
use_2layer = true;
|
271
|
+
return new MultiIndexQuantizer(d, M, nbit);
|
272
|
+
}
|
273
|
+
if (match("Residual([0-9]+)")) {
|
274
|
+
FAISS_THROW_IF_NOT_MSG(
|
275
|
+
mt == METRIC_L2,
|
276
|
+
"Residual not implemented for inner prod search");
|
277
|
+
use_2layer = true;
|
278
|
+
nlist = mres_to_int(sm[1]);
|
279
|
+
return new IndexFlatL2(d);
|
280
|
+
}
|
281
|
+
return nullptr;
|
282
|
+
}
|
283
|
+
|
284
|
+
// parse the code description part of the IVF description
|
285
|
+
|
286
|
+
IndexIVF* parse_IndexIVF(
|
287
|
+
const std::string& code_string,
|
288
|
+
std::unique_ptr<Index>& quantizer,
|
289
|
+
size_t nlist,
|
290
|
+
MetricType mt) {
|
291
|
+
std::smatch sm;
|
292
|
+
auto match = [&sm, &code_string](const std::string pattern) {
|
293
|
+
return re_match(code_string, pattern, sm);
|
294
|
+
};
|
295
|
+
auto get_q = [&quantizer] { return quantizer.release(); };
|
296
|
+
int d = quantizer->d;
|
297
|
+
|
298
|
+
if (match("Flat")) {
|
299
|
+
return new IndexIVFFlat(get_q(), d, nlist, mt);
|
300
|
+
}
|
301
|
+
if (match("FlatDedup")) {
|
302
|
+
return new IndexIVFFlatDedup(get_q(), d, nlist, mt);
|
303
|
+
}
|
304
|
+
if (match(sq_pattern)) {
|
305
|
+
return new IndexIVFScalarQuantizer(
|
306
|
+
get_q(), d, nlist, sq_types[sm[1].str()], mt);
|
307
|
+
}
|
308
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
309
|
+
int M = mres_to_int(sm[1]), nbit = mres_to_int(sm[2], 8, 1);
|
310
|
+
IndexIVFPQ* index_ivf = new IndexIVFPQ(get_q(), d, nlist, M, nbit, mt);
|
311
|
+
index_ivf->do_polysemous_training = sm[3].str() != "np";
|
312
|
+
return index_ivf;
|
313
|
+
}
|
314
|
+
if (match("PQ([0-9]+)\\+([0-9]+)")) {
|
315
|
+
FAISS_THROW_IF_NOT_MSG(
|
316
|
+
mt == METRIC_L2,
|
317
|
+
"IVFPQR not implemented for inner product search");
|
318
|
+
int M1 = mres_to_int(sm[1]), M2 = mres_to_int(sm[2]);
|
319
|
+
return new IndexIVFPQR(get_q(), d, nlist, M1, 8, M2, 8);
|
320
|
+
}
|
321
|
+
if (match("PQ([0-9]+)x4fs(r?)(_[0-9]+)?")) {
|
322
|
+
int M = mres_to_int(sm[1]);
|
323
|
+
int bbs = mres_to_int(sm[3], 32, 1);
|
324
|
+
IndexIVFPQFastScan* index_ivf =
|
325
|
+
new IndexIVFPQFastScan(get_q(), d, nlist, M, 4, mt, bbs);
|
326
|
+
index_ivf->by_residual = sm[2].str() == "r";
|
327
|
+
return index_ivf;
|
328
|
+
}
|
329
|
+
if (match("(RQ|LSQ)" + aq_def_pattern + aq_norm_pattern)) {
|
330
|
+
std::vector<size_t> nbits = aq_parse_nbits(sm.str());
|
331
|
+
AdditiveQuantizer::Search_type_t st =
|
332
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), mt);
|
333
|
+
IndexIVF* index_ivf;
|
334
|
+
if (sm[1].str() == "RQ") {
|
335
|
+
index_ivf = new IndexIVFResidualQuantizer(
|
336
|
+
get_q(), d, nlist, nbits, mt, st);
|
406
337
|
} else {
|
407
|
-
|
408
|
-
|
338
|
+
FAISS_THROW_IF_NOT(nbits.size() > 0);
|
339
|
+
index_ivf = new IndexIVFLocalSearchQuantizer(
|
340
|
+
get_q(), d, nlist, nbits.size(), nbits[0], mt, st);
|
341
|
+
}
|
342
|
+
return index_ivf;
|
343
|
+
}
|
344
|
+
if (match("(ITQ|PCA|PCAR)([0-9]+)?,SH([-0-9.e]+)?([gcm])?")) {
|
345
|
+
int outdim = mres_to_int(sm[2], d); // is also the number of bits
|
346
|
+
std::unique_ptr<VectorTransform> vt;
|
347
|
+
if (sm[1] == "ITQ") {
|
348
|
+
vt.reset(new ITQTransform(d, outdim, d != outdim));
|
349
|
+
} else if (sm[1] == "PCA") {
|
350
|
+
vt.reset(new PCAMatrix(d, outdim));
|
351
|
+
} else if (sm[1] == "PCAR") {
|
352
|
+
vt.reset(new PCAMatrix(d, outdim, 0, true));
|
353
|
+
}
|
354
|
+
// the rationale for -1e10 is that this corresponds to simple
|
355
|
+
// thresholding
|
356
|
+
float period = sm[3].length() > 0 ? std::stof(sm[3]) : -1e10;
|
357
|
+
IndexIVFSpectralHash* index_ivf =
|
358
|
+
new IndexIVFSpectralHash(get_q(), d, nlist, outdim, period);
|
359
|
+
index_ivf->replace_vt(vt.release(), true);
|
360
|
+
if (sm[4].length()) {
|
361
|
+
std::string s = sm[4].str();
|
362
|
+
index_ivf->threshold_type = s == "g"
|
363
|
+
? IndexIVFSpectralHash::Thresh_global
|
364
|
+
: s == "c"
|
365
|
+
? IndexIVFSpectralHash::Thresh_centroid
|
366
|
+
:
|
367
|
+
/* s == "m" ? */ IndexIVFSpectralHash::Thresh_median;
|
409
368
|
}
|
369
|
+
return index_ivf;
|
370
|
+
}
|
371
|
+
return nullptr;
|
372
|
+
}
|
410
373
|
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
374
|
+
/***************************************************************
|
375
|
+
* Parse IndexHNSW
|
376
|
+
*/
|
377
|
+
|
378
|
+
IndexHNSW* parse_IndexHNSW(
|
379
|
+
const std::string code_string,
|
380
|
+
int d,
|
381
|
+
MetricType mt,
|
382
|
+
int hnsw_M) {
|
383
|
+
std::smatch sm;
|
384
|
+
auto match = [&sm, &code_string](const std::string& pattern) {
|
385
|
+
return re_match(code_string, pattern, sm);
|
386
|
+
};
|
387
|
+
|
388
|
+
if (match("Flat|")) {
|
389
|
+
return new IndexHNSWFlat(d, hnsw_M, mt);
|
390
|
+
}
|
391
|
+
if (match("PQ([0-9]+)(np)?")) {
|
392
|
+
int M = std::stoi(sm[1].str());
|
393
|
+
IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
|
394
|
+
dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
|
395
|
+
sm[2].str() != "np";
|
396
|
+
return ipq;
|
397
|
+
}
|
398
|
+
if (match(sq_pattern)) {
|
399
|
+
return new IndexHNSWSQ(d, sq_types[sm[1].str()], hnsw_M, mt);
|
400
|
+
}
|
401
|
+
if (match("([0-9]+)\\+PQ([0-9]+)?")) {
|
402
|
+
int ncent = mres_to_int(sm[1]);
|
403
|
+
int pq_m = mres_to_int(sm[2]);
|
404
|
+
IndexHNSW2Level* hidx2l =
|
405
|
+
new IndexHNSW2Level(new IndexFlatL2(d), ncent, pq_m, hnsw_M);
|
406
|
+
dynamic_cast<Index2Layer*>(hidx2l->storage)->q1.own_fields = true;
|
407
|
+
return hidx2l;
|
408
|
+
}
|
409
|
+
if (match("2x([0-9]+)\\+PQ([0-9]+)?")) {
|
410
|
+
int nbit = mres_to_int(sm[1]);
|
411
|
+
int pq_m = mres_to_int(sm[2]);
|
412
|
+
Index* quant = new MultiIndexQuantizer(d, 2, nbit);
|
413
|
+
IndexHNSW2Level* hidx2l = new IndexHNSW2Level(
|
414
|
+
quant, (size_t)1 << (2 * nbit), pq_m, hnsw_M);
|
415
|
+
Index2Layer* idx2l = dynamic_cast<Index2Layer*>(hidx2l->storage);
|
416
|
+
idx2l->q1.own_fields = true;
|
417
|
+
idx2l->q1.quantizer_trains_alone = 1;
|
418
|
+
return hidx2l;
|
419
|
+
}
|
420
|
+
|
421
|
+
return nullptr;
|
422
|
+
}
|
423
|
+
|
424
|
+
/***************************************************************
|
425
|
+
* Parse basic indexes
|
426
|
+
*/
|
427
|
+
|
428
|
+
Index* parse_other_indexes(
|
429
|
+
const std::string& description,
|
430
|
+
int d,
|
431
|
+
MetricType metric) {
|
432
|
+
std::smatch sm;
|
433
|
+
auto match = [&sm, description](const std::string& pattern) {
|
434
|
+
return re_match(description, pattern, sm);
|
435
|
+
};
|
436
|
+
|
437
|
+
// IndexFlat
|
438
|
+
if (description == "Flat") {
|
439
|
+
return new IndexFlat(d, metric);
|
440
|
+
}
|
441
|
+
|
442
|
+
// IndexLSH
|
443
|
+
if (match("LSH(r?)(t?)")) {
|
444
|
+
bool rotate_data = sm[1].length() > 0;
|
445
|
+
bool train_thresholds = sm[2].length() > 0;
|
446
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2);
|
447
|
+
return new IndexLSH(d, d, rotate_data, train_thresholds);
|
448
|
+
}
|
449
|
+
|
450
|
+
// IndexLattice
|
451
|
+
if (match("ZnLattice([0-9]+)x([0-9]+)_([0-9]+)")) {
|
452
|
+
int M = std::stoi(sm[1].str()), r2 = std::stoi(sm[2].str());
|
453
|
+
int nbit = std::stoi(sm[3].str());
|
454
|
+
return new IndexLattice(d, M, nbit, r2);
|
455
|
+
}
|
456
|
+
|
457
|
+
// IndexNSGFlat
|
458
|
+
if (match("NSG([0-9]+)(,Flat)?")) {
|
459
|
+
return new IndexNSGFlat(d, std::stoi(sm[1].str()), metric);
|
460
|
+
}
|
461
|
+
|
462
|
+
// IndexScalarQuantizer
|
463
|
+
if (match(sq_pattern)) {
|
464
|
+
return new IndexScalarQuantizer(d, sq_types[description], metric);
|
465
|
+
}
|
466
|
+
|
467
|
+
// IndexPQ
|
468
|
+
if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
|
469
|
+
int M = std::stoi(sm[1].str());
|
470
|
+
int nbit = mres_to_int(sm[2], 8, 1);
|
471
|
+
IndexPQ* index_pq = new IndexPQ(d, M, nbit, metric);
|
472
|
+
index_pq->do_polysemous_training = sm[3].str() != "np";
|
473
|
+
return index_pq;
|
474
|
+
}
|
475
|
+
|
476
|
+
// IndexPQFastScan
|
477
|
+
if (match("PQ([0-9]+)x4fs(_[0-9]+)?")) {
|
478
|
+
int M = std::stoi(sm[1].str());
|
479
|
+
int bbs = mres_to_int(sm[2], 32, 1);
|
480
|
+
return new IndexPQFastScan(d, M, 4, metric, bbs);
|
481
|
+
}
|
482
|
+
|
483
|
+
// IndexResidualCoarseQuantizer and IndexResidualQuantizer
|
484
|
+
std::string pattern = "(RQ|RCQ)" + aq_def_pattern + aq_norm_pattern;
|
485
|
+
if (match(pattern)) {
|
486
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
487
|
+
if (sm[1].str() == "RCQ") {
|
488
|
+
return new ResidualCoarseQuantizer(d, nbits, metric);
|
417
489
|
}
|
490
|
+
AdditiveQuantizer::Search_type_t st =
|
491
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
492
|
+
return new IndexResidualQuantizer(d, nbits, metric, st);
|
493
|
+
}
|
418
494
|
|
419
|
-
|
420
|
-
|
495
|
+
// LocalSearchCoarseQuantizer and IndexLocalSearchQuantizer
|
496
|
+
if (match("(LSQ|LSCQ)([0-9]+)x([0-9]+)" + aq_norm_pattern)) {
|
497
|
+
std::vector<size_t> nbits = aq_parse_nbits(description);
|
498
|
+
int M = mres_to_int(sm[2]);
|
499
|
+
int nbit = mres_to_int(sm[3]);
|
500
|
+
if (sm[1].str() == "LSCQ") {
|
501
|
+
return new LocalSearchCoarseQuantizer(d, M, nbit, metric);
|
421
502
|
}
|
503
|
+
AdditiveQuantizer::Search_type_t st =
|
504
|
+
aq_parse_search_type(sm[sm.size() - 1].str(), metric);
|
505
|
+
return new IndexLocalSearchQuantizer(d, M, nbit, metric, st);
|
506
|
+
}
|
507
|
+
|
508
|
+
return nullptr;
|
509
|
+
}
|
510
|
+
|
511
|
+
/***************************************************************
|
512
|
+
* Driver function
|
513
|
+
*/
|
514
|
+
std::unique_ptr<Index> index_factory_sub(
|
515
|
+
int d,
|
516
|
+
std::string description,
|
517
|
+
MetricType metric) {
|
518
|
+
// handle composite indexes
|
519
|
+
|
520
|
+
bool verbose = index_factory_verbose;
|
422
521
|
|
423
|
-
|
424
|
-
|
425
|
-
|
522
|
+
if (verbose) {
|
523
|
+
printf("begin parse VectorTransforms: %s \n", description.c_str());
|
524
|
+
}
|
525
|
+
|
526
|
+
// for the current match
|
527
|
+
std::smatch sm;
|
528
|
+
|
529
|
+
// handle refines
|
530
|
+
if (re_match(description, "(.+),RFlat", sm) ||
|
531
|
+
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
|
532
|
+
std::unique_ptr<Index> filter_index =
|
533
|
+
index_factory_sub(d, sm[1].str(), metric);
|
534
|
+
std::unique_ptr<Index> refine_index;
|
535
|
+
|
536
|
+
if (sm.size() == 3) { // Refine
|
537
|
+
refine_index = index_factory_sub(d, sm[2].str(), metric);
|
538
|
+
} else { // RFlat
|
539
|
+
refine_index.reset(new IndexFlat(d, metric));
|
426
540
|
}
|
541
|
+
IndexRefine* index_rf =
|
542
|
+
new IndexRefine(filter_index.get(), refine_index.get());
|
543
|
+
index_rf->own_fields = true;
|
544
|
+
filter_index.release();
|
545
|
+
refine_index.release();
|
546
|
+
index_rf->own_refine_index = true;
|
547
|
+
return std::unique_ptr<Index>(index_rf);
|
548
|
+
}
|
427
549
|
|
428
|
-
|
429
|
-
|
430
|
-
|
550
|
+
// IndexPreTransform
|
551
|
+
// should handle this first (even before parentheses) because it changes d
|
552
|
+
std::vector<std::unique_ptr<VectorTransform>> vts;
|
553
|
+
VectorTransform* vt = nullptr;
|
554
|
+
while (re_match(description, "([^,]+),(.*)", sm) &&
|
555
|
+
(vt = parse_VectorTransform(sm[1], d))) {
|
556
|
+
// reset loop
|
557
|
+
description = sm[sm.size() - 1];
|
558
|
+
vts.emplace_back(vt);
|
559
|
+
d = vts.back()->d_out;
|
560
|
+
}
|
561
|
+
|
562
|
+
if (vts.size() > 0) {
|
563
|
+
std::unique_ptr<Index> sub_index =
|
564
|
+
index_factory_sub(d, description, metric);
|
565
|
+
IndexPreTransform* index_pt = new IndexPreTransform(sub_index.get());
|
566
|
+
std::unique_ptr<Index> ret(index_pt);
|
567
|
+
index_pt->own_fields = true;
|
568
|
+
sub_index.release();
|
569
|
+
while (vts.size() > 0) {
|
570
|
+
if (verbose) {
|
571
|
+
printf("prepend trans %d -> %d\n",
|
572
|
+
vts.back()->d_in,
|
573
|
+
vts.back()->d_out);
|
574
|
+
}
|
575
|
+
index_pt->prepend_transform(vts.back().release());
|
576
|
+
vts.pop_back();
|
431
577
|
}
|
578
|
+
return ret;
|
432
579
|
}
|
433
580
|
|
434
|
-
|
435
|
-
|
436
|
-
|
581
|
+
// what we got from the parentheses
|
582
|
+
std::vector<std::unique_ptr<Index>> parenthesis_indexes;
|
583
|
+
|
584
|
+
int begin = 0;
|
585
|
+
while (description.find('(', begin) != std::string::npos) {
|
586
|
+
// replace indexes in () with Index0, Index1, etc.
|
587
|
+
int i0, i1;
|
588
|
+
find_matching_parentheses(description, i0, i1, begin);
|
589
|
+
std::string sub_description = description.substr(i0 + 1, i1 - i0 - 1);
|
590
|
+
int no = parenthesis_indexes.size();
|
591
|
+
parenthesis_indexes.push_back(
|
592
|
+
index_factory_sub(d, sub_description, metric));
|
593
|
+
description = description.substr(0, i0 + 1) + "Index" +
|
594
|
+
std::to_string(no) + description.substr(i1);
|
595
|
+
begin = i1 + 1;
|
437
596
|
}
|
438
597
|
|
439
|
-
|
440
|
-
|
598
|
+
if (verbose) {
|
599
|
+
printf("after () normalization: %s %ld parenthesis indexes d=%d\n",
|
600
|
+
description.c_str(),
|
601
|
+
parenthesis_indexes.size(),
|
602
|
+
d);
|
603
|
+
}
|
441
604
|
|
442
|
-
//
|
443
|
-
|
444
|
-
|
605
|
+
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
|
606
|
+
// support both
|
607
|
+
if (re_match(description, "(.+),IDMap", sm) ||
|
608
|
+
re_match(description, "IDMap,(.+)", sm)) {
|
609
|
+
IndexIDMap* idmap = new IndexIDMap(
|
610
|
+
index_factory_sub(d, sm[1].str(), metric).release());
|
611
|
+
idmap->own_fields = true;
|
612
|
+
return std::unique_ptr<Index>(idmap);
|
613
|
+
}
|
445
614
|
|
446
|
-
|
447
|
-
|
448
|
-
|
615
|
+
{ // handle basic index types
|
616
|
+
Index* index = parse_other_indexes(description, d, metric);
|
617
|
+
if (index) {
|
618
|
+
return std::unique_ptr<Index>(index);
|
619
|
+
}
|
449
620
|
}
|
450
621
|
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
622
|
+
// HNSW variants (it was unclear in the old version that the separator was a
|
623
|
+
// "," so we support both "_" and ",")
|
624
|
+
if (re_match(description, "HNSW([0-9]*)([,_].*)?", sm)) {
|
625
|
+
int hnsw_M = mres_to_int(sm[1], 32);
|
626
|
+
// We also accept empty code string (synonym of Flat)
|
627
|
+
std::string code_string =
|
628
|
+
sm[2].length() > 0 ? sm[2].str().substr(1) : "";
|
629
|
+
if (verbose) {
|
630
|
+
printf("parsing HNSW string %s code_string=%s hnsw_M=%d\n",
|
631
|
+
description.c_str(),
|
632
|
+
code_string.c_str(),
|
633
|
+
hnsw_M);
|
458
634
|
}
|
459
|
-
|
635
|
+
|
636
|
+
IndexHNSW* index = parse_IndexHNSW(code_string, d, metric, hnsw_M);
|
637
|
+
FAISS_THROW_IF_NOT_FMT(
|
638
|
+
index,
|
639
|
+
"could not parse HNSW code description %s in %s",
|
640
|
+
code_string.c_str(),
|
641
|
+
description.c_str());
|
642
|
+
return std::unique_ptr<Index>(index);
|
460
643
|
}
|
461
644
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
645
|
+
// IndexIVF
|
646
|
+
{
|
647
|
+
size_t nlist;
|
648
|
+
bool use_2layer;
|
649
|
+
size_t comma = description.find(",");
|
650
|
+
std::string coarse_string = description.substr(0, comma);
|
651
|
+
// Match coarse quantizer part first
|
652
|
+
std::unique_ptr<Index> quantizer(parse_coarse_quantizer(
|
653
|
+
description.substr(0, comma),
|
654
|
+
d,
|
655
|
+
metric,
|
656
|
+
parenthesis_indexes,
|
657
|
+
nlist,
|
658
|
+
use_2layer));
|
659
|
+
|
660
|
+
if (comma != std::string::npos && quantizer.get()) {
|
661
|
+
std::string code_description = description.substr(comma + 1);
|
662
|
+
if (use_2layer) {
|
663
|
+
bool ok =
|
664
|
+
re_match(code_description, "PQ([0-9]+)(x[0-9]+)?", sm);
|
665
|
+
FAISS_THROW_IF_NOT_FMT(
|
666
|
+
ok,
|
667
|
+
"could not parse 2 layer code description %s in %s",
|
668
|
+
code_description.c_str(),
|
669
|
+
description.c_str());
|
670
|
+
int M = std::stoi(sm[1].str()), nbit = mres_to_int(sm[2], 8, 1);
|
671
|
+
Index2Layer* index_2l =
|
672
|
+
new Index2Layer(quantizer.release(), nlist, M, nbit);
|
673
|
+
index_2l->q1.own_fields = true;
|
674
|
+
index_2l->q1.quantizer_trains_alone =
|
675
|
+
get_trains_alone(index_2l->q1.quantizer);
|
676
|
+
return std::unique_ptr<Index>(index_2l);
|
677
|
+
}
|
678
|
+
|
679
|
+
IndexIVF* index_ivf =
|
680
|
+
parse_IndexIVF(code_description, quantizer, nlist, metric);
|
681
|
+
|
682
|
+
FAISS_THROW_IF_NOT_FMT(
|
683
|
+
index_ivf,
|
684
|
+
"could not parse code description %s in %s",
|
685
|
+
code_description.c_str(),
|
686
|
+
description.c_str());
|
687
|
+
return std::unique_ptr<Index>(fix_ivf_fields(index_ivf));
|
688
|
+
}
|
468
689
|
}
|
690
|
+
FAISS_THROW_FMT("could not parse index string %s", description.c_str());
|
691
|
+
return nullptr;
|
692
|
+
}
|
469
693
|
|
470
|
-
|
694
|
+
} // anonymous namespace
|
695
|
+
|
696
|
+
Index* index_factory(int d, const char* description, MetricType metric) {
|
697
|
+
return index_factory_sub(d, description, metric).release();
|
471
698
|
}
|
472
699
|
|
473
|
-
IndexBinary
|
474
|
-
|
475
|
-
IndexBinary *index = nullptr;
|
700
|
+
IndexBinary* index_binary_factory(int d, const char* description) {
|
701
|
+
IndexBinary* index = nullptr;
|
476
702
|
|
477
703
|
int ncentroids = -1;
|
478
704
|
int M, nhash, b;
|
479
705
|
|
480
706
|
if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) {
|
481
|
-
IndexBinaryIVF
|
482
|
-
|
483
|
-
);
|
707
|
+
IndexBinaryIVF* index_ivf =
|
708
|
+
new IndexBinaryIVF(new IndexBinaryHNSW(d, M), d, ncentroids);
|
484
709
|
index_ivf->own_fields = true;
|
485
710
|
index = index_ivf;
|
486
711
|
|
487
712
|
} else if (sscanf(description, "BIVF%d", &ncentroids) == 1) {
|
488
|
-
IndexBinaryIVF
|
489
|
-
|
490
|
-
);
|
713
|
+
IndexBinaryIVF* index_ivf =
|
714
|
+
new IndexBinaryIVF(new IndexBinaryFlat(d), d, ncentroids);
|
491
715
|
index_ivf->own_fields = true;
|
492
716
|
index = index_ivf;
|
493
717
|
|
494
718
|
} else if (sscanf(description, "BHNSW%d", &M) == 1) {
|
495
|
-
IndexBinaryHNSW
|
719
|
+
IndexBinaryHNSW* index_hnsw = new IndexBinaryHNSW(d, M);
|
496
720
|
index = index_hnsw;
|
497
721
|
|
498
722
|
} else if (sscanf(description, "BHash%dx%d", &nhash, &b) == 2) {
|
499
|
-
index = new IndexBinaryMultiHash
|
723
|
+
index = new IndexBinaryMultiHash(d, nhash, b);
|
500
724
|
|
501
725
|
} else if (sscanf(description, "BHash%d", &b) == 1) {
|
502
|
-
index = new IndexBinaryHash
|
726
|
+
index = new IndexBinaryHash(d, b);
|
503
727
|
|
504
728
|
} else if (std::string(description) == "BFlat") {
|
505
729
|
index = new IndexBinaryFlat(d);
|
506
730
|
|
507
731
|
} else {
|
508
|
-
FAISS_THROW_IF_NOT_FMT(
|
509
|
-
|
732
|
+
FAISS_THROW_IF_NOT_FMT(
|
733
|
+
index, "description %s did not generate an index", description);
|
510
734
|
}
|
511
735
|
|
512
736
|
return index;
|
513
737
|
}
|
514
738
|
|
515
|
-
|
516
|
-
|
517
739
|
} // namespace faiss
|