faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,7 +9,6 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/IndexIVF.h>
|
|
11
11
|
|
|
12
|
-
|
|
13
12
|
#include <omp.h>
|
|
14
13
|
#include <mutex>
|
|
15
14
|
|
|
@@ -18,12 +17,12 @@
|
|
|
18
17
|
#include <cstdio>
|
|
19
18
|
#include <memory>
|
|
20
19
|
|
|
21
|
-
#include <faiss/utils/utils.h>
|
|
22
20
|
#include <faiss/utils/hamming.h>
|
|
21
|
+
#include <faiss/utils/utils.h>
|
|
23
22
|
|
|
24
|
-
#include <faiss/impl/FaissAssert.h>
|
|
25
23
|
#include <faiss/IndexFlat.h>
|
|
26
24
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
25
|
+
#include <faiss/impl/FaissAssert.h>
|
|
27
26
|
|
|
28
27
|
namespace faiss {
|
|
29
28
|
|
|
@@ -34,99 +33,104 @@ using ScopedCodes = InvertedLists::ScopedCodes;
|
|
|
34
33
|
* Level1Quantizer implementation
|
|
35
34
|
******************************************/
|
|
36
35
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
clustering_index (nullptr)
|
|
44
|
-
{
|
|
36
|
+
Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
|
|
37
|
+
: quantizer(quantizer),
|
|
38
|
+
nlist(nlist),
|
|
39
|
+
quantizer_trains_alone(0),
|
|
40
|
+
own_fields(false),
|
|
41
|
+
clustering_index(nullptr) {
|
|
45
42
|
// here we set a low # iterations because this is typically used
|
|
46
43
|
// for large clusterings (nb this is not used for the MultiIndex,
|
|
47
44
|
// for which quantizer_trains_alone = true)
|
|
48
45
|
cp.niter = 10;
|
|
49
46
|
}
|
|
50
47
|
|
|
51
|
-
Level1Quantizer::Level1Quantizer
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
{}
|
|
48
|
+
Level1Quantizer::Level1Quantizer()
|
|
49
|
+
: quantizer(nullptr),
|
|
50
|
+
nlist(0),
|
|
51
|
+
quantizer_trains_alone(0),
|
|
52
|
+
own_fields(false),
|
|
53
|
+
clustering_index(nullptr) {}
|
|
57
54
|
|
|
58
|
-
Level1Quantizer::~Level1Quantizer
|
|
59
|
-
|
|
60
|
-
|
|
55
|
+
Level1Quantizer::~Level1Quantizer() {
|
|
56
|
+
if (own_fields)
|
|
57
|
+
delete quantizer;
|
|
61
58
|
}
|
|
62
59
|
|
|
63
|
-
void Level1Quantizer::train_q1
|
|
64
|
-
|
|
60
|
+
void Level1Quantizer::train_q1(
|
|
61
|
+
size_t n,
|
|
62
|
+
const float* x,
|
|
63
|
+
bool verbose,
|
|
64
|
+
MetricType metric_type) {
|
|
65
65
|
size_t d = quantizer->d;
|
|
66
66
|
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
|
|
67
67
|
if (verbose)
|
|
68
|
-
printf
|
|
68
|
+
printf("IVF quantizer does not need training.\n");
|
|
69
69
|
} else if (quantizer_trains_alone == 1) {
|
|
70
70
|
if (verbose)
|
|
71
|
-
printf
|
|
72
|
-
quantizer->train
|
|
71
|
+
printf("IVF quantizer trains alone...\n");
|
|
72
|
+
quantizer->train(n, x);
|
|
73
73
|
quantizer->verbose = verbose;
|
|
74
|
-
FAISS_THROW_IF_NOT_MSG
|
|
75
|
-
|
|
74
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
75
|
+
quantizer->ntotal == nlist,
|
|
76
|
+
"nlist not consistent with quantizer size");
|
|
76
77
|
} else if (quantizer_trains_alone == 0) {
|
|
77
78
|
if (verbose)
|
|
78
|
-
printf
|
|
79
|
-
n, d);
|
|
79
|
+
printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
|
|
80
80
|
|
|
81
|
-
Clustering clus
|
|
81
|
+
Clustering clus(d, nlist, cp);
|
|
82
82
|
quantizer->reset();
|
|
83
83
|
if (clustering_index) {
|
|
84
|
-
clus.train
|
|
85
|
-
quantizer->add
|
|
84
|
+
clus.train(n, x, *clustering_index);
|
|
85
|
+
quantizer->add(nlist, clus.centroids.data());
|
|
86
86
|
} else {
|
|
87
|
-
clus.train
|
|
87
|
+
clus.train(n, x, *quantizer);
|
|
88
88
|
}
|
|
89
89
|
quantizer->is_trained = true;
|
|
90
90
|
} else if (quantizer_trains_alone == 2) {
|
|
91
91
|
if (verbose) {
|
|
92
|
-
printf
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
92
|
+
printf("Training L2 quantizer on %zd vectors in %zdD%s\n",
|
|
93
|
+
n,
|
|
94
|
+
d,
|
|
95
|
+
clustering_index ? "(user provided index)" : "");
|
|
96
96
|
}
|
|
97
97
|
// also accept spherical centroids because in that case
|
|
98
98
|
// L2 and IP are equivalent
|
|
99
|
-
FAISS_THROW_IF_NOT
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
);
|
|
99
|
+
FAISS_THROW_IF_NOT(
|
|
100
|
+
metric_type == METRIC_L2 ||
|
|
101
|
+
(metric_type == METRIC_INNER_PRODUCT && cp.spherical));
|
|
103
102
|
|
|
104
|
-
Clustering clus
|
|
103
|
+
Clustering clus(d, nlist, cp);
|
|
105
104
|
if (!clustering_index) {
|
|
106
|
-
IndexFlatL2 assigner
|
|
105
|
+
IndexFlatL2 assigner(d);
|
|
107
106
|
clus.train(n, x, assigner);
|
|
108
107
|
} else {
|
|
109
108
|
clus.train(n, x, *clustering_index);
|
|
110
109
|
}
|
|
111
|
-
if (verbose)
|
|
112
|
-
printf
|
|
113
|
-
|
|
110
|
+
if (verbose) {
|
|
111
|
+
printf("Adding centroids to quantizer\n");
|
|
112
|
+
}
|
|
113
|
+
if (!quantizer->is_trained) {
|
|
114
|
+
if (verbose) {
|
|
115
|
+
printf("But training it first on centroids table...\n");
|
|
116
|
+
}
|
|
117
|
+
quantizer->train(nlist, clus.centroids.data());
|
|
118
|
+
}
|
|
119
|
+
quantizer->add(nlist, clus.centroids.data());
|
|
114
120
|
}
|
|
115
121
|
}
|
|
116
122
|
|
|
117
|
-
size_t Level1Quantizer::coarse_code_size
|
|
118
|
-
{
|
|
123
|
+
size_t Level1Quantizer::coarse_code_size() const {
|
|
119
124
|
size_t nl = nlist - 1;
|
|
120
125
|
size_t nbyte = 0;
|
|
121
126
|
while (nl > 0) {
|
|
122
|
-
nbyte
|
|
127
|
+
nbyte++;
|
|
123
128
|
nl >>= 8;
|
|
124
129
|
}
|
|
125
130
|
return nbyte;
|
|
126
131
|
}
|
|
127
132
|
|
|
128
|
-
void Level1Quantizer::encode_listno
|
|
129
|
-
{
|
|
133
|
+
void Level1Quantizer::encode_listno(Index::idx_t list_no, uint8_t* code) const {
|
|
130
134
|
// little endian
|
|
131
135
|
size_t nl = nlist - 1;
|
|
132
136
|
while (nl > 0) {
|
|
@@ -136,8 +140,7 @@ void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
|
|
|
136
140
|
}
|
|
137
141
|
}
|
|
138
142
|
|
|
139
|
-
Index::idx_t Level1Quantizer::decode_listno
|
|
140
|
-
{
|
|
143
|
+
Index::idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
|
|
141
144
|
size_t nl = nlist - 1;
|
|
142
145
|
int64_t list_no = 0;
|
|
143
146
|
int nbit = 0;
|
|
@@ -146,161 +149,198 @@ Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
|
|
|
146
149
|
nbit += 8;
|
|
147
150
|
nl >>= 8;
|
|
148
151
|
}
|
|
149
|
-
FAISS_THROW_IF_NOT
|
|
152
|
+
FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
|
|
150
153
|
return list_no;
|
|
151
154
|
}
|
|
152
155
|
|
|
153
|
-
|
|
154
|
-
|
|
155
156
|
/*****************************************
|
|
156
157
|
* IndexIVF implementation
|
|
157
158
|
******************************************/
|
|
158
159
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
160
|
+
IndexIVF::IndexIVF(
|
|
161
|
+
Index* quantizer,
|
|
162
|
+
size_t d,
|
|
163
|
+
size_t nlist,
|
|
164
|
+
size_t code_size,
|
|
165
|
+
MetricType metric)
|
|
166
|
+
: Index(d, metric),
|
|
167
|
+
Level1Quantizer(quantizer, nlist),
|
|
168
|
+
invlists(new ArrayInvertedLists(nlist, code_size)),
|
|
169
|
+
own_invlists(true),
|
|
170
|
+
code_size(code_size),
|
|
171
|
+
nprobe(1),
|
|
172
|
+
max_codes(0),
|
|
173
|
+
parallel_mode(0) {
|
|
174
|
+
FAISS_THROW_IF_NOT(d == quantizer->d);
|
|
173
175
|
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
|
174
176
|
// Spherical by default if the metric is inner_product
|
|
175
177
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
176
178
|
cp.spherical = true;
|
|
177
179
|
}
|
|
178
|
-
|
|
179
180
|
}
|
|
180
181
|
|
|
181
|
-
IndexIVF::IndexIVF
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
182
|
+
IndexIVF::IndexIVF()
|
|
183
|
+
: invlists(nullptr),
|
|
184
|
+
own_invlists(false),
|
|
185
|
+
code_size(0),
|
|
186
|
+
nprobe(1),
|
|
187
|
+
max_codes(0),
|
|
188
|
+
parallel_mode(0) {}
|
|
189
|
+
|
|
190
|
+
void IndexIVF::add(idx_t n, const float* x) {
|
|
191
|
+
add_with_ids(n, x, nullptr);
|
|
192
|
+
}
|
|
186
193
|
|
|
187
|
-
void IndexIVF::
|
|
188
|
-
|
|
189
|
-
|
|
194
|
+
void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
|
|
195
|
+
std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
|
|
196
|
+
quantizer->assign(n, x, coarse_idx.get());
|
|
197
|
+
add_core(n, x, xids, coarse_idx.get());
|
|
190
198
|
}
|
|
191
199
|
|
|
200
|
+
void IndexIVF::add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids) {
|
|
201
|
+
size_t coarse_size = coarse_code_size();
|
|
202
|
+
DirectMapAdd dm_adder(direct_map, n, xids);
|
|
203
|
+
|
|
204
|
+
for (idx_t i = 0; i < n; i++) {
|
|
205
|
+
const uint8_t* code = codes + (code_size + coarse_size) * i;
|
|
206
|
+
idx_t list_no = decode_listno(code);
|
|
207
|
+
idx_t id = xids ? xids[i] : ntotal + i;
|
|
208
|
+
size_t ofs = invlists->add_entry(list_no, id, code + coarse_size);
|
|
209
|
+
dm_adder.add(i, list_no, ofs);
|
|
210
|
+
}
|
|
211
|
+
ntotal += n;
|
|
212
|
+
}
|
|
192
213
|
|
|
193
|
-
void IndexIVF::
|
|
194
|
-
|
|
214
|
+
void IndexIVF::add_core(
|
|
215
|
+
idx_t n,
|
|
216
|
+
const float* x,
|
|
217
|
+
const idx_t* xids,
|
|
218
|
+
const idx_t* coarse_idx) {
|
|
195
219
|
// do some blocking to avoid excessive allocs
|
|
196
220
|
idx_t bs = 65536;
|
|
197
221
|
if (n > bs) {
|
|
198
222
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
199
|
-
idx_t i1 = std::min
|
|
223
|
+
idx_t i1 = std::min(n, i0 + bs);
|
|
200
224
|
if (verbose) {
|
|
201
|
-
printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
|
|
225
|
+
printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
|
|
226
|
+
i0,
|
|
227
|
+
i1);
|
|
202
228
|
}
|
|
203
|
-
|
|
204
|
-
|
|
229
|
+
add_core(
|
|
230
|
+
i1 - i0,
|
|
231
|
+
x + i0 * d,
|
|
232
|
+
xids ? xids + i0 : nullptr,
|
|
233
|
+
coarse_idx + i0);
|
|
205
234
|
}
|
|
206
235
|
return;
|
|
207
236
|
}
|
|
237
|
+
FAISS_THROW_IF_NOT(coarse_idx);
|
|
238
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
239
|
+
direct_map.check_can_add(xids);
|
|
208
240
|
|
|
209
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
210
|
-
direct_map.check_can_add (xids);
|
|
211
|
-
|
|
212
|
-
std::unique_ptr<idx_t []> idx(new idx_t[n]);
|
|
213
|
-
quantizer->assign (n, x, idx.get());
|
|
214
241
|
size_t nadd = 0, nminus1 = 0;
|
|
215
242
|
|
|
216
243
|
for (size_t i = 0; i < n; i++) {
|
|
217
|
-
if (
|
|
244
|
+
if (coarse_idx[i] < 0)
|
|
245
|
+
nminus1++;
|
|
218
246
|
}
|
|
219
247
|
|
|
220
|
-
std::unique_ptr<uint8_t
|
|
221
|
-
encode_vectors
|
|
248
|
+
std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]);
|
|
249
|
+
encode_vectors(n, x, coarse_idx, flat_codes.get());
|
|
222
250
|
|
|
223
251
|
DirectMapAdd dm_adder(direct_map, n, xids);
|
|
224
252
|
|
|
225
|
-
#pragma omp parallel reduction(
|
|
253
|
+
#pragma omp parallel reduction(+ : nadd)
|
|
226
254
|
{
|
|
227
255
|
int nt = omp_get_num_threads();
|
|
228
256
|
int rank = omp_get_thread_num();
|
|
229
257
|
|
|
230
258
|
// each thread takes care of a subset of lists
|
|
231
259
|
for (size_t i = 0; i < n; i++) {
|
|
232
|
-
idx_t list_no =
|
|
260
|
+
idx_t list_no = coarse_idx[i];
|
|
233
261
|
if (list_no >= 0 && list_no % nt == rank) {
|
|
234
262
|
idx_t id = xids ? xids[i] : ntotal + i;
|
|
235
|
-
size_t ofs = invlists->add_entry
|
|
236
|
-
|
|
237
|
-
flat_codes.get() + i * code_size
|
|
238
|
-
);
|
|
263
|
+
size_t ofs = invlists->add_entry(
|
|
264
|
+
list_no, id, flat_codes.get() + i * code_size);
|
|
239
265
|
|
|
240
|
-
dm_adder.add
|
|
266
|
+
dm_adder.add(i, list_no, ofs);
|
|
241
267
|
|
|
242
268
|
nadd++;
|
|
243
269
|
} else if (rank == 0 && list_no == -1) {
|
|
244
|
-
dm_adder.add
|
|
270
|
+
dm_adder.add(i, -1, 0);
|
|
245
271
|
}
|
|
246
272
|
}
|
|
247
273
|
}
|
|
248
274
|
|
|
249
|
-
|
|
250
275
|
if (verbose) {
|
|
251
|
-
printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n",
|
|
276
|
+
printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n",
|
|
277
|
+
nadd,
|
|
278
|
+
n,
|
|
279
|
+
nminus1);
|
|
252
280
|
}
|
|
253
281
|
|
|
254
282
|
ntotal += n;
|
|
255
283
|
}
|
|
256
284
|
|
|
257
|
-
void IndexIVF::make_direct_map
|
|
258
|
-
{
|
|
285
|
+
void IndexIVF::make_direct_map(bool b) {
|
|
259
286
|
if (b) {
|
|
260
|
-
direct_map.set_type
|
|
287
|
+
direct_map.set_type(DirectMap::Array, invlists, ntotal);
|
|
261
288
|
} else {
|
|
262
|
-
direct_map.set_type
|
|
289
|
+
direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
|
|
263
290
|
}
|
|
264
291
|
}
|
|
265
292
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
void IndexIVF::set_direct_map_type (DirectMap::Type type)
|
|
269
|
-
{
|
|
270
|
-
direct_map.set_type (type, invlists, ntotal);
|
|
293
|
+
void IndexIVF::set_direct_map_type(DirectMap::Type type) {
|
|
294
|
+
direct_map.set_type(type, invlists, ntotal);
|
|
271
295
|
}
|
|
272
296
|
|
|
273
297
|
/** It is a sad fact of software that a conceptually simple function like this
|
|
274
298
|
* becomes very complex when you factor in several ways of parallelizing +
|
|
275
299
|
* interrupt/error handling + collecting stats + min/max collection. The
|
|
276
300
|
* codepath that is used 95% of time is the one for parallel_mode = 0 */
|
|
277
|
-
void IndexIVF::search
|
|
278
|
-
|
|
279
|
-
|
|
301
|
+
void IndexIVF::search(
|
|
302
|
+
idx_t n,
|
|
303
|
+
const float* x,
|
|
304
|
+
idx_t k,
|
|
305
|
+
float* distances,
|
|
306
|
+
idx_t* labels) const {
|
|
307
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
280
308
|
|
|
309
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
|
310
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
|
281
311
|
|
|
282
312
|
// search function for a subset of queries
|
|
283
|
-
auto sub_search_func = [this, k]
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
313
|
+
auto sub_search_func = [this, k, nprobe](
|
|
314
|
+
idx_t n,
|
|
315
|
+
const float* x,
|
|
316
|
+
float* distances,
|
|
317
|
+
idx_t* labels,
|
|
318
|
+
IndexIVFStats* ivf_stats) {
|
|
287
319
|
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
|
288
320
|
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
289
321
|
|
|
290
322
|
double t0 = getmillisecs();
|
|
291
|
-
quantizer->search
|
|
323
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
|
292
324
|
|
|
293
325
|
double t1 = getmillisecs();
|
|
294
|
-
invlists->prefetch_lists
|
|
295
|
-
|
|
296
|
-
search_preassigned
|
|
297
|
-
|
|
326
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
|
327
|
+
|
|
328
|
+
search_preassigned(
|
|
329
|
+
n,
|
|
330
|
+
x,
|
|
331
|
+
k,
|
|
332
|
+
idx.get(),
|
|
333
|
+
coarse_dis.get(),
|
|
334
|
+
distances,
|
|
335
|
+
labels,
|
|
336
|
+
false,
|
|
337
|
+
nullptr,
|
|
338
|
+
ivf_stats);
|
|
298
339
|
double t2 = getmillisecs();
|
|
299
340
|
ivf_stats->quantization_time += t1 - t0;
|
|
300
341
|
ivf_stats->search_time += t2 - t0;
|
|
301
342
|
};
|
|
302
343
|
|
|
303
|
-
|
|
304
344
|
if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
|
|
305
345
|
int nt = std::min(omp_get_max_threads(), int(n));
|
|
306
346
|
std::vector<IndexIVFStats> stats(nt);
|
|
@@ -308,18 +348,19 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
|
|
308
348
|
std::string exception_string;
|
|
309
349
|
|
|
310
350
|
#pragma omp parallel for if (nt > 1)
|
|
311
|
-
for(idx_t slice = 0; slice < nt; slice++) {
|
|
351
|
+
for (idx_t slice = 0; slice < nt; slice++) {
|
|
312
352
|
IndexIVFStats local_stats;
|
|
313
353
|
idx_t i0 = n * slice / nt;
|
|
314
354
|
idx_t i1 = n * (slice + 1) / nt;
|
|
315
355
|
if (i1 > i0) {
|
|
316
356
|
try {
|
|
317
357
|
sub_search_func(
|
|
318
|
-
i1 - i0,
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
358
|
+
i1 - i0,
|
|
359
|
+
x + i0 * d,
|
|
360
|
+
distances + i0 * k,
|
|
361
|
+
labels + i0 * k,
|
|
362
|
+
&stats[slice]);
|
|
363
|
+
} catch (const std::exception& e) {
|
|
323
364
|
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
324
365
|
exception_string = e.what();
|
|
325
366
|
}
|
|
@@ -327,32 +368,38 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
|
|
327
368
|
}
|
|
328
369
|
|
|
329
370
|
if (!exception_string.empty()) {
|
|
330
|
-
FAISS_THROW_MSG
|
|
371
|
+
FAISS_THROW_MSG(exception_string.c_str());
|
|
331
372
|
}
|
|
332
373
|
|
|
333
374
|
// collect stats
|
|
334
|
-
for(idx_t slice = 0; slice < nt; slice++) {
|
|
375
|
+
for (idx_t slice = 0; slice < nt; slice++) {
|
|
335
376
|
indexIVF_stats.add(stats[slice]);
|
|
336
377
|
}
|
|
337
378
|
} else {
|
|
338
|
-
// handle paralellization at level below (or don't run in parallel at
|
|
379
|
+
// handle paralellization at level below (or don't run in parallel at
|
|
380
|
+
// all)
|
|
339
381
|
sub_search_func(n, x, distances, labels, &indexIVF_stats);
|
|
340
382
|
}
|
|
341
|
-
|
|
342
|
-
|
|
343
383
|
}
|
|
344
384
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
385
|
+
void IndexIVF::search_preassigned(
|
|
386
|
+
idx_t n,
|
|
387
|
+
const float* x,
|
|
388
|
+
idx_t k,
|
|
389
|
+
const idx_t* keys,
|
|
390
|
+
const float* coarse_dis,
|
|
391
|
+
float* distances,
|
|
392
|
+
idx_t* labels,
|
|
393
|
+
bool store_pairs,
|
|
394
|
+
const IVFSearchParameters* params,
|
|
395
|
+
IndexIVFStats* ivf_stats) const {
|
|
396
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
397
|
+
|
|
398
|
+
idx_t nprobe = params ? params->nprobe : this->nprobe;
|
|
399
|
+
nprobe = std::min((idx_t)nlist, nprobe);
|
|
400
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
|
401
|
+
|
|
402
|
+
idx_t max_codes = params ? params->max_codes : this->max_codes;
|
|
356
403
|
|
|
357
404
|
size_t nlistv = 0, ndis = 0, nheap = 0;
|
|
358
405
|
|
|
@@ -366,15 +413,15 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
366
413
|
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
|
367
414
|
bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
|
|
368
415
|
|
|
369
|
-
bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
370
|
-
pmode == 0
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
416
|
+
bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
417
|
+
(pmode == 0 ? false
|
|
418
|
+
: pmode == 3 ? n > 1
|
|
419
|
+
: pmode == 1 ? nprobe > 1
|
|
420
|
+
: nprobe * n > 1);
|
|
374
421
|
|
|
375
|
-
#pragma omp parallel if(do_parallel) reduction(
|
|
422
|
+
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
|
|
376
423
|
{
|
|
377
|
-
InvertedListScanner
|
|
424
|
+
InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
|
|
378
425
|
ScopeDeleter1<InvertedListScanner> del(scanner);
|
|
379
426
|
|
|
380
427
|
/*****************************************************
|
|
@@ -385,49 +432,52 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
385
432
|
|
|
386
433
|
// intialize + reorder a result heap
|
|
387
434
|
|
|
388
|
-
auto init_result = [&](float
|
|
389
|
-
if (!do_heap_init)
|
|
435
|
+
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
436
|
+
if (!do_heap_init)
|
|
437
|
+
return;
|
|
390
438
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
391
|
-
heap_heapify<HeapForIP>
|
|
439
|
+
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
392
440
|
} else {
|
|
393
|
-
heap_heapify<HeapForL2>
|
|
441
|
+
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
394
442
|
}
|
|
395
443
|
};
|
|
396
444
|
|
|
397
|
-
auto add_local_results = [&](
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
445
|
+
auto add_local_results = [&](const float* local_dis,
|
|
446
|
+
const idx_t* local_idx,
|
|
447
|
+
float* simi,
|
|
448
|
+
idx_t* idxi) {
|
|
401
449
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
402
|
-
heap_addn<HeapForIP>
|
|
403
|
-
(k, simi, idxi, local_dis, local_idx, k);
|
|
450
|
+
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
|
|
404
451
|
} else {
|
|
405
|
-
heap_addn<HeapForL2>
|
|
406
|
-
(k, simi, idxi, local_dis, local_idx, k);
|
|
452
|
+
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
|
|
407
453
|
}
|
|
408
454
|
};
|
|
409
455
|
|
|
410
|
-
auto reorder_result = [&]
|
|
411
|
-
if (!do_heap_init)
|
|
456
|
+
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
457
|
+
if (!do_heap_init)
|
|
458
|
+
return;
|
|
412
459
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
413
|
-
heap_reorder<HeapForIP>
|
|
460
|
+
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
414
461
|
} else {
|
|
415
|
-
heap_reorder<HeapForL2>
|
|
462
|
+
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
416
463
|
}
|
|
417
464
|
};
|
|
418
465
|
|
|
419
466
|
// single list scan using the current scanner (with query
|
|
420
467
|
// set porperly) and storing results in simi and idxi
|
|
421
|
-
auto scan_one_list = [&]
|
|
422
|
-
|
|
423
|
-
|
|
468
|
+
auto scan_one_list = [&](idx_t key,
|
|
469
|
+
float coarse_dis_i,
|
|
470
|
+
float* simi,
|
|
471
|
+
idx_t* idxi) {
|
|
424
472
|
if (key < 0) {
|
|
425
473
|
// not enough centroids for multiprobe
|
|
426
474
|
return (size_t)0;
|
|
427
475
|
}
|
|
428
|
-
FAISS_THROW_IF_NOT_FMT
|
|
429
|
-
|
|
430
|
-
|
|
476
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
477
|
+
key < (idx_t)nlist,
|
|
478
|
+
"Invalid key=%" PRId64 " nlist=%zd\n",
|
|
479
|
+
key,
|
|
480
|
+
nlist);
|
|
431
481
|
|
|
432
482
|
size_t list_size = invlists->list_size(key);
|
|
433
483
|
|
|
@@ -436,28 +486,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
436
486
|
return (size_t)0;
|
|
437
487
|
}
|
|
438
488
|
|
|
439
|
-
scanner->set_list
|
|
489
|
+
scanner->set_list(key, coarse_dis_i);
|
|
440
490
|
|
|
441
491
|
nlistv++;
|
|
442
492
|
|
|
443
493
|
try {
|
|
444
|
-
InvertedLists::ScopedCodes scodes
|
|
494
|
+
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
445
495
|
|
|
446
496
|
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
|
447
|
-
const Index::idx_t
|
|
497
|
+
const Index::idx_t* ids = nullptr;
|
|
448
498
|
|
|
449
|
-
if (!store_pairs)
|
|
450
|
-
sids.reset
|
|
499
|
+
if (!store_pairs) {
|
|
500
|
+
sids.reset(new InvertedLists::ScopedIds(invlists, key));
|
|
451
501
|
ids = sids->get();
|
|
452
502
|
}
|
|
453
503
|
|
|
454
|
-
nheap += scanner->scan_codes
|
|
455
|
-
|
|
504
|
+
nheap += scanner->scan_codes(
|
|
505
|
+
list_size, scodes.get(), ids, simi, idxi, k);
|
|
456
506
|
|
|
457
|
-
} catch(const std::exception
|
|
507
|
+
} catch (const std::exception& e) {
|
|
458
508
|
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
459
509
|
exception_string =
|
|
460
|
-
|
|
510
|
+
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
461
511
|
interrupt = true;
|
|
462
512
|
return size_t(0);
|
|
463
513
|
}
|
|
@@ -470,31 +520,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
470
520
|
****************************************************/
|
|
471
521
|
|
|
472
522
|
if (pmode == 0 || pmode == 3) {
|
|
473
|
-
|
|
474
523
|
#pragma omp for
|
|
475
524
|
for (idx_t i = 0; i < n; i++) {
|
|
476
|
-
|
|
477
525
|
if (interrupt) {
|
|
478
526
|
continue;
|
|
479
527
|
}
|
|
480
528
|
|
|
481
529
|
// loop over queries
|
|
482
|
-
scanner->set_query
|
|
483
|
-
float
|
|
484
|
-
idx_t
|
|
530
|
+
scanner->set_query(x + i * d);
|
|
531
|
+
float* simi = distances + i * k;
|
|
532
|
+
idx_t* idxi = labels + i * k;
|
|
485
533
|
|
|
486
|
-
init_result
|
|
534
|
+
init_result(simi, idxi);
|
|
487
535
|
|
|
488
|
-
|
|
536
|
+
idx_t nscan = 0;
|
|
489
537
|
|
|
490
538
|
// loop over probes
|
|
491
539
|
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
);
|
|
540
|
+
nscan += scan_one_list(
|
|
541
|
+
keys[i * nprobe + ik],
|
|
542
|
+
coarse_dis[i * nprobe + ik],
|
|
543
|
+
simi,
|
|
544
|
+
idxi);
|
|
498
545
|
|
|
499
546
|
if (max_codes && nscan >= max_codes) {
|
|
500
547
|
break;
|
|
@@ -502,54 +549,55 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
502
549
|
}
|
|
503
550
|
|
|
504
551
|
ndis += nscan;
|
|
505
|
-
reorder_result
|
|
552
|
+
reorder_result(simi, idxi);
|
|
506
553
|
|
|
507
|
-
if (InterruptCallback::is_interrupted
|
|
554
|
+
if (InterruptCallback::is_interrupted()) {
|
|
508
555
|
interrupt = true;
|
|
509
556
|
}
|
|
510
557
|
|
|
511
558
|
} // parallel for
|
|
512
559
|
} else if (pmode == 1) {
|
|
513
|
-
std::vector
|
|
514
|
-
std::vector
|
|
560
|
+
std::vector<idx_t> local_idx(k);
|
|
561
|
+
std::vector<float> local_dis(k);
|
|
515
562
|
|
|
516
563
|
for (size_t i = 0; i < n; i++) {
|
|
517
|
-
scanner->set_query
|
|
518
|
-
init_result
|
|
564
|
+
scanner->set_query(x + i * d);
|
|
565
|
+
init_result(local_dis.data(), local_idx.data());
|
|
519
566
|
|
|
520
567
|
#pragma omp for schedule(dynamic)
|
|
521
|
-
for (
|
|
522
|
-
ndis += scan_one_list
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
568
|
+
for (idx_t ik = 0; ik < nprobe; ik++) {
|
|
569
|
+
ndis += scan_one_list(
|
|
570
|
+
keys[i * nprobe + ik],
|
|
571
|
+
coarse_dis[i * nprobe + ik],
|
|
572
|
+
local_dis.data(),
|
|
573
|
+
local_idx.data());
|
|
526
574
|
|
|
527
575
|
// can't do the test on max_codes
|
|
528
576
|
}
|
|
529
577
|
// merge thread-local results
|
|
530
578
|
|
|
531
|
-
float
|
|
532
|
-
idx_t
|
|
579
|
+
float* simi = distances + i * k;
|
|
580
|
+
idx_t* idxi = labels + i * k;
|
|
533
581
|
#pragma omp single
|
|
534
|
-
init_result
|
|
582
|
+
init_result(simi, idxi);
|
|
535
583
|
|
|
536
584
|
#pragma omp barrier
|
|
537
585
|
#pragma omp critical
|
|
538
586
|
{
|
|
539
|
-
add_local_results
|
|
540
|
-
|
|
587
|
+
add_local_results(
|
|
588
|
+
local_dis.data(), local_idx.data(), simi, idxi);
|
|
541
589
|
}
|
|
542
590
|
#pragma omp barrier
|
|
543
591
|
#pragma omp single
|
|
544
|
-
reorder_result
|
|
592
|
+
reorder_result(simi, idxi);
|
|
545
593
|
}
|
|
546
594
|
} else if (pmode == 2) {
|
|
547
|
-
std::vector
|
|
548
|
-
std::vector
|
|
595
|
+
std::vector<idx_t> local_idx(k);
|
|
596
|
+
std::vector<float> local_dis(k);
|
|
549
597
|
|
|
550
598
|
#pragma omp single
|
|
551
599
|
for (int64_t i = 0; i < n; i++) {
|
|
552
|
-
init_result
|
|
600
|
+
init_result(distances + i * k, labels + i * k);
|
|
553
601
|
}
|
|
554
602
|
|
|
555
603
|
#pragma omp for schedule(dynamic)
|
|
@@ -557,33 +605,37 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
557
605
|
size_t i = ij / nprobe;
|
|
558
606
|
size_t j = ij % nprobe;
|
|
559
607
|
|
|
560
|
-
scanner->set_query
|
|
561
|
-
init_result
|
|
562
|
-
ndis += scan_one_list
|
|
563
|
-
keys
|
|
564
|
-
|
|
608
|
+
scanner->set_query(x + i * d);
|
|
609
|
+
init_result(local_dis.data(), local_idx.data());
|
|
610
|
+
ndis += scan_one_list(
|
|
611
|
+
keys[ij],
|
|
612
|
+
coarse_dis[ij],
|
|
613
|
+
local_dis.data(),
|
|
614
|
+
local_idx.data());
|
|
565
615
|
#pragma omp critical
|
|
566
616
|
{
|
|
567
|
-
add_local_results
|
|
568
|
-
|
|
617
|
+
add_local_results(
|
|
618
|
+
local_dis.data(),
|
|
619
|
+
local_idx.data(),
|
|
620
|
+
distances + i * k,
|
|
621
|
+
labels + i * k);
|
|
569
622
|
}
|
|
570
623
|
}
|
|
571
624
|
#pragma omp single
|
|
572
625
|
for (int64_t i = 0; i < n; i++) {
|
|
573
|
-
reorder_result
|
|
626
|
+
reorder_result(distances + i * k, labels + i * k);
|
|
574
627
|
}
|
|
575
628
|
} else {
|
|
576
|
-
FAISS_THROW_FMT
|
|
577
|
-
pmode);
|
|
629
|
+
FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
|
|
578
630
|
}
|
|
579
631
|
} // parallel section
|
|
580
632
|
|
|
581
633
|
if (interrupt) {
|
|
582
634
|
if (!exception_string.empty()) {
|
|
583
|
-
FAISS_THROW_FMT
|
|
584
|
-
|
|
635
|
+
FAISS_THROW_FMT(
|
|
636
|
+
"search interrupted with: %s", exception_string.c_str());
|
|
585
637
|
} else {
|
|
586
|
-
FAISS_THROW_MSG
|
|
638
|
+
FAISS_THROW_MSG("computation interrupted");
|
|
587
639
|
}
|
|
588
640
|
}
|
|
589
641
|
|
|
@@ -595,38 +647,49 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
|
595
647
|
}
|
|
596
648
|
}
|
|
597
649
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
std::unique_ptr<idx_t[]> keys
|
|
605
|
-
std::unique_ptr<float
|
|
650
|
+
void IndexIVF::range_search(
|
|
651
|
+
idx_t nx,
|
|
652
|
+
const float* x,
|
|
653
|
+
float radius,
|
|
654
|
+
RangeSearchResult* result) const {
|
|
655
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
|
656
|
+
std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
|
|
657
|
+
std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
|
|
606
658
|
|
|
607
659
|
double t0 = getmillisecs();
|
|
608
|
-
quantizer->search
|
|
660
|
+
quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
|
|
609
661
|
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
610
662
|
|
|
611
663
|
t0 = getmillisecs();
|
|
612
|
-
invlists->prefetch_lists
|
|
613
|
-
|
|
614
|
-
range_search_preassigned
|
|
615
|
-
|
|
664
|
+
invlists->prefetch_lists(keys.get(), nx * nprobe);
|
|
665
|
+
|
|
666
|
+
range_search_preassigned(
|
|
667
|
+
nx,
|
|
668
|
+
x,
|
|
669
|
+
radius,
|
|
670
|
+
keys.get(),
|
|
671
|
+
coarse_dis.get(),
|
|
672
|
+
result,
|
|
673
|
+
false,
|
|
674
|
+
nullptr,
|
|
675
|
+
&indexIVF_stats);
|
|
616
676
|
|
|
617
677
|
indexIVF_stats.search_time += getmillisecs() - t0;
|
|
618
678
|
}
|
|
619
679
|
|
|
620
|
-
void IndexIVF::range_search_preassigned
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
680
|
+
void IndexIVF::range_search_preassigned(
|
|
681
|
+
idx_t nx,
|
|
682
|
+
const float* x,
|
|
683
|
+
float radius,
|
|
684
|
+
const idx_t* keys,
|
|
685
|
+
const float* coarse_dis,
|
|
686
|
+
RangeSearchResult* result,
|
|
687
|
+
bool store_pairs,
|
|
688
|
+
const IVFSearchParameters* params,
|
|
689
|
+
IndexIVFStats* stats) const {
|
|
690
|
+
idx_t nprobe = params ? params->nprobe : this->nprobe;
|
|
691
|
+
nprobe = std::min((idx_t)nlist, nprobe);
|
|
692
|
+
idx_t max_codes = params ? params->max_codes : this->max_codes;
|
|
630
693
|
|
|
631
694
|
size_t nlistv = 0, ndis = 0;
|
|
632
695
|
|
|
@@ -634,119 +697,116 @@ void IndexIVF::range_search_preassigned (
|
|
|
634
697
|
std::mutex exception_mutex;
|
|
635
698
|
std::string exception_string;
|
|
636
699
|
|
|
637
|
-
std::vector<RangeSearchPartialResult
|
|
700
|
+
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
|
|
638
701
|
|
|
639
702
|
int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
|
|
640
703
|
// don't start parallel section if single query
|
|
641
|
-
bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
642
|
-
pmode == 3
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
704
|
+
bool do_parallel = omp_get_max_threads() >= 2 &&
|
|
705
|
+
(pmode == 3 ? false
|
|
706
|
+
: pmode == 0 ? nx > 1
|
|
707
|
+
: pmode == 1 ? nprobe > 1
|
|
708
|
+
: nprobe * nx > 1);
|
|
646
709
|
|
|
647
|
-
#pragma omp parallel if(do_parallel) reduction(
|
|
710
|
+
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
|
|
648
711
|
{
|
|
649
712
|
RangeSearchPartialResult pres(result);
|
|
650
|
-
std::unique_ptr<InvertedListScanner> scanner
|
|
651
|
-
|
|
652
|
-
FAISS_THROW_IF_NOT
|
|
713
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
|
714
|
+
get_InvertedListScanner(store_pairs));
|
|
715
|
+
FAISS_THROW_IF_NOT(scanner.get());
|
|
653
716
|
all_pres[omp_get_thread_num()] = &pres;
|
|
654
717
|
|
|
655
718
|
// prepare the list scanning function
|
|
656
719
|
|
|
657
|
-
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
FAISS_THROW_IF_NOT_FMT
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
720
|
+
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
|
|
721
|
+
idx_t key = keys[i * nprobe + ik]; /* select the list */
|
|
722
|
+
if (key < 0)
|
|
723
|
+
return;
|
|
724
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
725
|
+
key < (idx_t)nlist,
|
|
726
|
+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
|
727
|
+
key,
|
|
728
|
+
ik,
|
|
729
|
+
nlist);
|
|
665
730
|
const size_t list_size = invlists->list_size(key);
|
|
666
731
|
|
|
667
|
-
if (list_size == 0)
|
|
732
|
+
if (list_size == 0)
|
|
733
|
+
return;
|
|
668
734
|
|
|
669
735
|
try {
|
|
736
|
+
InvertedLists::ScopedCodes scodes(invlists, key);
|
|
737
|
+
InvertedLists::ScopedIds ids(invlists, key);
|
|
670
738
|
|
|
671
|
-
|
|
672
|
-
InvertedLists::ScopedIds ids (invlists, key);
|
|
673
|
-
|
|
674
|
-
scanner->set_list (key, coarse_dis[i * nprobe + ik]);
|
|
739
|
+
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
|
|
675
740
|
nlistv++;
|
|
676
741
|
ndis += list_size;
|
|
677
|
-
scanner->scan_codes_range
|
|
678
|
-
|
|
742
|
+
scanner->scan_codes_range(
|
|
743
|
+
list_size, scodes.get(), ids.get(), radius, qres);
|
|
679
744
|
|
|
680
|
-
} catch(const std::exception
|
|
745
|
+
} catch (const std::exception& e) {
|
|
681
746
|
std::lock_guard<std::mutex> lock(exception_mutex);
|
|
682
747
|
exception_string =
|
|
683
|
-
|
|
748
|
+
demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
|
|
684
749
|
interrupt = true;
|
|
685
750
|
}
|
|
686
|
-
|
|
687
751
|
};
|
|
688
752
|
|
|
689
753
|
if (parallel_mode == 0) {
|
|
690
|
-
|
|
691
754
|
#pragma omp for
|
|
692
755
|
for (idx_t i = 0; i < nx; i++) {
|
|
693
|
-
scanner->set_query
|
|
756
|
+
scanner->set_query(x + i * d);
|
|
694
757
|
|
|
695
|
-
RangeQueryResult
|
|
758
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
696
759
|
|
|
697
760
|
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
698
|
-
scan_list_func
|
|
761
|
+
scan_list_func(i, ik, qres);
|
|
699
762
|
}
|
|
700
|
-
|
|
701
763
|
}
|
|
702
764
|
|
|
703
765
|
} else if (parallel_mode == 1) {
|
|
704
|
-
|
|
705
766
|
for (size_t i = 0; i < nx; i++) {
|
|
706
|
-
scanner->set_query
|
|
767
|
+
scanner->set_query(x + i * d);
|
|
707
768
|
|
|
708
|
-
RangeQueryResult
|
|
769
|
+
RangeQueryResult& qres = pres.new_result(i);
|
|
709
770
|
|
|
710
771
|
#pragma omp for schedule(dynamic)
|
|
711
772
|
for (int64_t ik = 0; ik < nprobe; ik++) {
|
|
712
|
-
scan_list_func
|
|
773
|
+
scan_list_func(i, ik, qres);
|
|
713
774
|
}
|
|
714
775
|
}
|
|
715
776
|
} else if (parallel_mode == 2) {
|
|
716
|
-
std::vector<RangeQueryResult
|
|
717
|
-
RangeQueryResult
|
|
777
|
+
std::vector<RangeQueryResult*> all_qres(nx);
|
|
778
|
+
RangeQueryResult* qres = nullptr;
|
|
718
779
|
|
|
719
780
|
#pragma omp for schedule(dynamic)
|
|
720
781
|
for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
|
|
721
782
|
idx_t i = iik / (idx_t)nprobe;
|
|
722
783
|
idx_t ik = iik % (idx_t)nprobe;
|
|
723
784
|
if (qres == nullptr || qres->qno != i) {
|
|
724
|
-
FAISS_ASSERT
|
|
725
|
-
qres = &pres.new_result
|
|
726
|
-
scanner->set_query
|
|
785
|
+
FAISS_ASSERT(!qres || i > qres->qno);
|
|
786
|
+
qres = &pres.new_result(i);
|
|
787
|
+
scanner->set_query(x + i * d);
|
|
727
788
|
}
|
|
728
|
-
scan_list_func
|
|
789
|
+
scan_list_func(i, ik, *qres);
|
|
729
790
|
}
|
|
730
791
|
} else {
|
|
731
|
-
FAISS_THROW_FMT
|
|
792
|
+
FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
|
|
732
793
|
}
|
|
733
794
|
if (parallel_mode == 0) {
|
|
734
|
-
pres.finalize
|
|
795
|
+
pres.finalize();
|
|
735
796
|
} else {
|
|
736
797
|
#pragma omp barrier
|
|
737
798
|
#pragma omp single
|
|
738
|
-
RangeSearchPartialResult::merge
|
|
799
|
+
RangeSearchPartialResult::merge(all_pres, false);
|
|
739
800
|
#pragma omp barrier
|
|
740
|
-
|
|
741
801
|
}
|
|
742
802
|
}
|
|
743
803
|
|
|
744
804
|
if (interrupt) {
|
|
745
805
|
if (!exception_string.empty()) {
|
|
746
|
-
FAISS_THROW_FMT
|
|
747
|
-
|
|
806
|
+
FAISS_THROW_FMT(
|
|
807
|
+
"search interrupted with: %s", exception_string.c_str());
|
|
748
808
|
} else {
|
|
749
|
-
FAISS_THROW_MSG
|
|
809
|
+
FAISS_THROW_MSG("computation interrupted");
|
|
750
810
|
}
|
|
751
811
|
}
|
|
752
812
|
|
|
@@ -757,27 +817,22 @@ void IndexIVF::range_search_preassigned (
|
|
|
757
817
|
}
|
|
758
818
|
}
|
|
759
819
|
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
bool /*store_pairs*/) const
|
|
763
|
-
{
|
|
820
|
+
InvertedListScanner* IndexIVF::get_InvertedListScanner(
|
|
821
|
+
bool /*store_pairs*/) const {
|
|
764
822
|
return nullptr;
|
|
765
823
|
}
|
|
766
824
|
|
|
767
|
-
void IndexIVF::reconstruct
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
|
|
825
|
+
void IndexIVF::reconstruct(idx_t key, float* recons) const {
|
|
826
|
+
idx_t lo = direct_map.get(key);
|
|
827
|
+
reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
|
|
771
828
|
}
|
|
772
829
|
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
{
|
|
776
|
-
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
830
|
+
void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
|
831
|
+
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
777
832
|
|
|
778
833
|
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
|
779
|
-
size_t list_size = invlists->list_size
|
|
780
|
-
ScopedIds idlist
|
|
834
|
+
size_t list_size = invlists->list_size(list_no);
|
|
835
|
+
ScopedIds idlist(invlists, list_no);
|
|
781
836
|
|
|
782
837
|
for (idx_t offset = 0; offset < list_size; offset++) {
|
|
783
838
|
idx_t id = idlist[offset];
|
|
@@ -786,46 +841,56 @@ void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
|
|
|
786
841
|
}
|
|
787
842
|
|
|
788
843
|
float* reconstructed = recons + (id - i0) * d;
|
|
789
|
-
reconstruct_from_offset
|
|
844
|
+
reconstruct_from_offset(list_no, offset, reconstructed);
|
|
790
845
|
}
|
|
791
846
|
}
|
|
792
847
|
}
|
|
793
848
|
|
|
794
|
-
|
|
795
849
|
/* standalone codec interface */
|
|
796
|
-
size_t IndexIVF::sa_code_size
|
|
797
|
-
{
|
|
850
|
+
size_t IndexIVF::sa_code_size() const {
|
|
798
851
|
size_t coarse_size = coarse_code_size();
|
|
799
852
|
return code_size + coarse_size;
|
|
800
853
|
}
|
|
801
854
|
|
|
802
|
-
void IndexIVF::sa_encode
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
quantizer->assign (n, x, idx.get());
|
|
808
|
-
encode_vectors (n, x, idx.get(), bytes, true);
|
|
855
|
+
void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
856
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
857
|
+
std::unique_ptr<int64_t[]> idx(new int64_t[n]);
|
|
858
|
+
quantizer->assign(n, x, idx.get());
|
|
859
|
+
encode_vectors(n, x, idx.get(), bytes, true);
|
|
809
860
|
}
|
|
810
861
|
|
|
862
|
+
void IndexIVF::search_and_reconstruct(
|
|
863
|
+
idx_t n,
|
|
864
|
+
const float* x,
|
|
865
|
+
idx_t k,
|
|
866
|
+
float* distances,
|
|
867
|
+
idx_t* labels,
|
|
868
|
+
float* recons) const {
|
|
869
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
811
870
|
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
float *recons) const
|
|
815
|
-
{
|
|
816
|
-
idx_t * idx = new idx_t [n * nprobe];
|
|
817
|
-
ScopeDeleter<idx_t> del (idx);
|
|
818
|
-
float * coarse_dis = new float [n * nprobe];
|
|
819
|
-
ScopeDeleter<float> del2 (coarse_dis);
|
|
871
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
|
872
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
|
820
873
|
|
|
821
|
-
|
|
874
|
+
idx_t* idx = new idx_t[n * nprobe];
|
|
875
|
+
ScopeDeleter<idx_t> del(idx);
|
|
876
|
+
float* coarse_dis = new float[n * nprobe];
|
|
877
|
+
ScopeDeleter<float> del2(coarse_dis);
|
|
822
878
|
|
|
823
|
-
|
|
879
|
+
quantizer->search(n, x, nprobe, coarse_dis, idx);
|
|
880
|
+
|
|
881
|
+
invlists->prefetch_lists(idx, n * nprobe);
|
|
824
882
|
|
|
825
883
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
826
884
|
// and offset into `codes` for reconstruction
|
|
827
|
-
search_preassigned
|
|
828
|
-
|
|
885
|
+
search_preassigned(
|
|
886
|
+
n,
|
|
887
|
+
x,
|
|
888
|
+
k,
|
|
889
|
+
idx,
|
|
890
|
+
coarse_dis,
|
|
891
|
+
distances,
|
|
892
|
+
labels,
|
|
893
|
+
true /* store_pairs */);
|
|
829
894
|
for (idx_t i = 0; i < n; ++i) {
|
|
830
895
|
for (idx_t j = 0; j < k; ++j) {
|
|
831
896
|
idx_t ij = i * k + j;
|
|
@@ -835,165 +900,151 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
|
|
835
900
|
// Fill with NaNs
|
|
836
901
|
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
|
837
902
|
} else {
|
|
838
|
-
int list_no = lo_listno
|
|
839
|
-
int offset = lo_offset
|
|
903
|
+
int list_no = lo_listno(key);
|
|
904
|
+
int offset = lo_offset(key);
|
|
840
905
|
|
|
841
906
|
// Update label to the actual id
|
|
842
|
-
labels[ij] = invlists->get_single_id
|
|
907
|
+
labels[ij] = invlists->get_single_id(list_no, offset);
|
|
843
908
|
|
|
844
|
-
reconstruct_from_offset
|
|
909
|
+
reconstruct_from_offset(list_no, offset, reconstructed);
|
|
845
910
|
}
|
|
846
911
|
}
|
|
847
912
|
}
|
|
848
913
|
}
|
|
849
914
|
|
|
850
915
|
void IndexIVF::reconstruct_from_offset(
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
916
|
+
int64_t /*list_no*/,
|
|
917
|
+
int64_t /*offset*/,
|
|
918
|
+
float* /*recons*/) const {
|
|
919
|
+
FAISS_THROW_MSG("reconstruct_from_offset not implemented");
|
|
855
920
|
}
|
|
856
921
|
|
|
857
|
-
void IndexIVF::reset
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
invlists->reset ();
|
|
922
|
+
void IndexIVF::reset() {
|
|
923
|
+
direct_map.clear();
|
|
924
|
+
invlists->reset();
|
|
861
925
|
ntotal = 0;
|
|
862
926
|
}
|
|
863
927
|
|
|
864
|
-
|
|
865
|
-
size_t
|
|
866
|
-
{
|
|
867
|
-
size_t nremove = direct_map.remove_ids (sel, invlists);
|
|
928
|
+
size_t IndexIVF::remove_ids(const IDSelector& sel) {
|
|
929
|
+
size_t nremove = direct_map.remove_ids(sel, invlists);
|
|
868
930
|
ntotal -= nremove;
|
|
869
931
|
return nremove;
|
|
870
932
|
}
|
|
871
933
|
|
|
872
|
-
|
|
873
|
-
void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
|
|
874
|
-
{
|
|
875
|
-
|
|
934
|
+
void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
|
|
876
935
|
if (direct_map.type == DirectMap::Hashtable) {
|
|
877
936
|
// just remove then add
|
|
878
937
|
IDSelectorArray sel(n, new_ids);
|
|
879
|
-
size_t nremove = remove_ids
|
|
880
|
-
FAISS_THROW_IF_NOT_MSG
|
|
881
|
-
|
|
882
|
-
add_with_ids
|
|
938
|
+
size_t nremove = remove_ids(sel);
|
|
939
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
940
|
+
nremove == n, "did not find all entries to remove");
|
|
941
|
+
add_with_ids(n, x, new_ids);
|
|
883
942
|
return;
|
|
884
943
|
}
|
|
885
944
|
|
|
886
|
-
FAISS_THROW_IF_NOT
|
|
945
|
+
FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array);
|
|
887
946
|
// here it is more tricky because we don't want to introduce holes
|
|
888
947
|
// in continuous range of ids
|
|
889
948
|
|
|
890
|
-
FAISS_THROW_IF_NOT
|
|
891
|
-
std::vector<idx_t> assign
|
|
892
|
-
quantizer->assign
|
|
893
|
-
|
|
894
|
-
std::vector<uint8_t> flat_codes (n * code_size);
|
|
895
|
-
encode_vectors (n, x, assign.data(), flat_codes.data());
|
|
949
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
950
|
+
std::vector<idx_t> assign(n);
|
|
951
|
+
quantizer->assign(n, x, assign.data());
|
|
896
952
|
|
|
897
|
-
|
|
953
|
+
std::vector<uint8_t> flat_codes(n * code_size);
|
|
954
|
+
encode_vectors(n, x, assign.data(), flat_codes.data());
|
|
898
955
|
|
|
956
|
+
direct_map.update_codes(
|
|
957
|
+
invlists, n, new_ids, assign.data(), flat_codes.data());
|
|
899
958
|
}
|
|
900
959
|
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
void IndexIVF::train (idx_t n, const float *x)
|
|
905
|
-
{
|
|
960
|
+
void IndexIVF::train(idx_t n, const float* x) {
|
|
906
961
|
if (verbose)
|
|
907
|
-
printf
|
|
962
|
+
printf("Training level-1 quantizer\n");
|
|
908
963
|
|
|
909
|
-
train_q1
|
|
964
|
+
train_q1(n, x, verbose, metric_type);
|
|
910
965
|
|
|
911
966
|
if (verbose)
|
|
912
|
-
printf
|
|
967
|
+
printf("Training IVF residual\n");
|
|
913
968
|
|
|
914
|
-
train_residual
|
|
969
|
+
train_residual(n, x);
|
|
915
970
|
is_trained = true;
|
|
916
|
-
|
|
917
971
|
}
|
|
918
972
|
|
|
919
973
|
void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
974
|
+
if (verbose)
|
|
975
|
+
printf("IndexIVF: no residual training\n");
|
|
976
|
+
// does nothing by default
|
|
923
977
|
}
|
|
924
978
|
|
|
925
|
-
|
|
926
|
-
void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
|
|
927
|
-
{
|
|
979
|
+
void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
|
|
928
980
|
// minimal sanity checks
|
|
929
|
-
FAISS_THROW_IF_NOT
|
|
930
|
-
FAISS_THROW_IF_NOT
|
|
931
|
-
FAISS_THROW_IF_NOT
|
|
932
|
-
FAISS_THROW_IF_NOT_MSG
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
981
|
+
FAISS_THROW_IF_NOT(other.d == d);
|
|
982
|
+
FAISS_THROW_IF_NOT(other.nlist == nlist);
|
|
983
|
+
FAISS_THROW_IF_NOT(other.code_size == code_size);
|
|
984
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
985
|
+
typeid(*this) == typeid(other),
|
|
986
|
+
"can only merge indexes of the same type");
|
|
987
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
988
|
+
this->direct_map.no() && other.direct_map.no(),
|
|
989
|
+
"merge direct_map not implemented");
|
|
936
990
|
}
|
|
937
991
|
|
|
992
|
+
void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
|
|
993
|
+
check_compatible_for_merge(other);
|
|
938
994
|
|
|
939
|
-
|
|
940
|
-
{
|
|
941
|
-
check_compatible_for_merge (other);
|
|
942
|
-
|
|
943
|
-
invlists->merge_from (other.invlists, add_id);
|
|
995
|
+
invlists->merge_from(other.invlists, add_id);
|
|
944
996
|
|
|
945
997
|
ntotal += other.ntotal;
|
|
946
998
|
other.ntotal = 0;
|
|
947
999
|
}
|
|
948
1000
|
|
|
949
|
-
|
|
950
|
-
void IndexIVF::replace_invlists (InvertedLists *il, bool own)
|
|
951
|
-
{
|
|
1001
|
+
void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
|
|
952
1002
|
if (own_invlists) {
|
|
953
1003
|
delete invlists;
|
|
954
1004
|
invlists = nullptr;
|
|
955
1005
|
}
|
|
956
1006
|
// FAISS_THROW_IF_NOT (ntotal == 0);
|
|
957
1007
|
if (il) {
|
|
958
|
-
FAISS_THROW_IF_NOT
|
|
959
|
-
FAISS_THROW_IF_NOT
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
);
|
|
1008
|
+
FAISS_THROW_IF_NOT(il->nlist == nlist);
|
|
1009
|
+
FAISS_THROW_IF_NOT(
|
|
1010
|
+
il->code_size == code_size ||
|
|
1011
|
+
il->code_size == InvertedLists::INVALID_CODE_SIZE);
|
|
963
1012
|
}
|
|
964
1013
|
invlists = il;
|
|
965
1014
|
own_invlists = own;
|
|
966
1015
|
}
|
|
967
1016
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
FAISS_THROW_IF_NOT
|
|
974
|
-
FAISS_THROW_IF_NOT
|
|
975
|
-
FAISS_THROW_IF_NOT
|
|
976
|
-
FAISS_THROW_IF_NOT_FMT
|
|
977
|
-
|
|
978
|
-
|
|
1017
|
+
void IndexIVF::copy_subset_to(
|
|
1018
|
+
IndexIVF& other,
|
|
1019
|
+
int subset_type,
|
|
1020
|
+
idx_t a1,
|
|
1021
|
+
idx_t a2) const {
|
|
1022
|
+
FAISS_THROW_IF_NOT(nlist == other.nlist);
|
|
1023
|
+
FAISS_THROW_IF_NOT(code_size == other.code_size);
|
|
1024
|
+
FAISS_THROW_IF_NOT(other.direct_map.no());
|
|
1025
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
1026
|
+
subset_type == 0 || subset_type == 1 || subset_type == 2,
|
|
1027
|
+
"subset type %d not implemented",
|
|
1028
|
+
subset_type);
|
|
979
1029
|
|
|
980
1030
|
size_t accu_n = 0;
|
|
981
1031
|
size_t accu_a1 = 0;
|
|
982
1032
|
size_t accu_a2 = 0;
|
|
983
1033
|
|
|
984
|
-
InvertedLists
|
|
1034
|
+
InvertedLists* oivf = other.invlists;
|
|
985
1035
|
|
|
986
1036
|
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
|
987
|
-
size_t n = invlists->list_size
|
|
988
|
-
ScopedIds ids_in
|
|
1037
|
+
size_t n = invlists->list_size(list_no);
|
|
1038
|
+
ScopedIds ids_in(invlists, list_no);
|
|
989
1039
|
|
|
990
1040
|
if (subset_type == 0) {
|
|
991
1041
|
for (idx_t i = 0; i < n; i++) {
|
|
992
1042
|
idx_t id = ids_in[i];
|
|
993
1043
|
if (a1 <= id && id < a2) {
|
|
994
|
-
oivf->add_entry
|
|
995
|
-
|
|
996
|
-
|
|
1044
|
+
oivf->add_entry(
|
|
1045
|
+
list_no,
|
|
1046
|
+
invlists->get_single_id(list_no, i),
|
|
1047
|
+
ScopedCodes(invlists, list_no, i).get());
|
|
997
1048
|
other.ntotal++;
|
|
998
1049
|
}
|
|
999
1050
|
}
|
|
@@ -1001,9 +1052,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
|
1001
1052
|
for (idx_t i = 0; i < n; i++) {
|
|
1002
1053
|
idx_t id = ids_in[i];
|
|
1003
1054
|
if (id % a1 == a2) {
|
|
1004
|
-
oivf->add_entry
|
|
1005
|
-
|
|
1006
|
-
|
|
1055
|
+
oivf->add_entry(
|
|
1056
|
+
list_no,
|
|
1057
|
+
invlists->get_single_id(list_no, i),
|
|
1058
|
+
ScopedCodes(invlists, list_no, i).get());
|
|
1007
1059
|
other.ntotal++;
|
|
1008
1060
|
}
|
|
1009
1061
|
}
|
|
@@ -1016,9 +1068,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
|
1016
1068
|
size_t i2 = next_accu_a2 - accu_a2;
|
|
1017
1069
|
|
|
1018
1070
|
for (idx_t i = i1; i < i2; i++) {
|
|
1019
|
-
oivf->add_entry
|
|
1020
|
-
|
|
1021
|
-
|
|
1071
|
+
oivf->add_entry(
|
|
1072
|
+
list_no,
|
|
1073
|
+
invlists->get_single_id(list_no, i),
|
|
1074
|
+
ScopedCodes(invlists, list_no, i).get());
|
|
1022
1075
|
}
|
|
1023
1076
|
|
|
1024
1077
|
other.ntotal += i2 - i1;
|
|
@@ -1028,48 +1081,87 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
|
1028
1081
|
accu_n += n;
|
|
1029
1082
|
}
|
|
1030
1083
|
FAISS_ASSERT(accu_n == ntotal);
|
|
1031
|
-
|
|
1032
1084
|
}
|
|
1033
1085
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
IndexIVF::~IndexIVF()
|
|
1038
|
-
{
|
|
1086
|
+
IndexIVF::~IndexIVF() {
|
|
1039
1087
|
if (own_invlists) {
|
|
1040
1088
|
delete invlists;
|
|
1041
1089
|
}
|
|
1042
1090
|
}
|
|
1043
1091
|
|
|
1092
|
+
/*************************************************************************
|
|
1093
|
+
* IndexIVFStats
|
|
1094
|
+
*************************************************************************/
|
|
1044
1095
|
|
|
1045
|
-
void IndexIVFStats::reset()
|
|
1046
|
-
|
|
1047
|
-
memset ((void*)this, 0, sizeof (*this));
|
|
1096
|
+
void IndexIVFStats::reset() {
|
|
1097
|
+
memset((void*)this, 0, sizeof(*this));
|
|
1048
1098
|
}
|
|
1049
1099
|
|
|
1050
|
-
void IndexIVFStats::add
|
|
1051
|
-
{
|
|
1100
|
+
void IndexIVFStats::add(const IndexIVFStats& other) {
|
|
1052
1101
|
nq += other.nq;
|
|
1053
1102
|
nlist += other.nlist;
|
|
1054
1103
|
ndis += other.ndis;
|
|
1055
1104
|
nheap_updates += other.nheap_updates;
|
|
1056
1105
|
quantization_time += other.quantization_time;
|
|
1057
1106
|
search_time += other.search_time;
|
|
1058
|
-
|
|
1059
1107
|
}
|
|
1060
1108
|
|
|
1061
|
-
|
|
1062
1109
|
IndexIVFStats indexIVF_stats;
|
|
1063
1110
|
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1111
|
+
/*************************************************************************
|
|
1112
|
+
* InvertedListScanner
|
|
1113
|
+
*************************************************************************/
|
|
1114
|
+
|
|
1115
|
+
size_t InvertedListScanner::scan_codes(
|
|
1116
|
+
size_t list_size,
|
|
1117
|
+
const uint8_t* codes,
|
|
1118
|
+
const idx_t* ids,
|
|
1119
|
+
float* simi,
|
|
1120
|
+
idx_t* idxi,
|
|
1121
|
+
size_t k) const {
|
|
1122
|
+
size_t nup = 0;
|
|
1123
|
+
|
|
1124
|
+
if (!keep_max) {
|
|
1125
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1126
|
+
float dis = distance_to_code(codes);
|
|
1127
|
+
if (dis < simi[0]) {
|
|
1128
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
1129
|
+
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
1130
|
+
nup++;
|
|
1131
|
+
}
|
|
1132
|
+
codes += code_size;
|
|
1133
|
+
}
|
|
1134
|
+
} else {
|
|
1135
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1136
|
+
float dis = distance_to_code(codes);
|
|
1137
|
+
if (dis > simi[0]) {
|
|
1138
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
1139
|
+
minheap_replace_top(k, simi, idxi, dis, id);
|
|
1140
|
+
nup++;
|
|
1141
|
+
}
|
|
1142
|
+
codes += code_size;
|
|
1143
|
+
}
|
|
1144
|
+
}
|
|
1145
|
+
return nup;
|
|
1071
1146
|
}
|
|
1072
1147
|
|
|
1073
|
-
|
|
1148
|
+
void InvertedListScanner::scan_codes_range(
|
|
1149
|
+
size_t list_size,
|
|
1150
|
+
const uint8_t* codes,
|
|
1151
|
+
const idx_t* ids,
|
|
1152
|
+
float radius,
|
|
1153
|
+
RangeQueryResult& res) const {
|
|
1154
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1155
|
+
float dis = distance_to_code(codes);
|
|
1156
|
+
bool keep = !keep_max
|
|
1157
|
+
? dis < radius
|
|
1158
|
+
: dis > radius; // TODO templatize to remove this test
|
|
1159
|
+
if (keep) {
|
|
1160
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
1161
|
+
res.add(dis, id);
|
|
1162
|
+
}
|
|
1163
|
+
codes += code_size;
|
|
1164
|
+
}
|
|
1165
|
+
}
|
|
1074
1166
|
|
|
1075
1167
|
} // namespace faiss
|