faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
@@ -9,318 +9,344 @@
|
|
9
9
|
|
10
10
|
#include <faiss/IndexBinaryIVF.h>
|
11
11
|
|
12
|
+
#include <omp.h>
|
12
13
|
#include <cinttypes>
|
13
14
|
#include <cstdio>
|
14
|
-
#include <omp.h>
|
15
15
|
|
16
|
+
#include <algorithm>
|
16
17
|
#include <memory>
|
17
18
|
|
18
|
-
|
19
|
-
#include <faiss/utils/hamming.h>
|
20
|
-
#include <faiss/utils/utils.h>
|
21
|
-
#include <faiss/impl/AuxIndexStructures.h>
|
22
|
-
#include <faiss/impl/FaissAssert.h>
|
23
19
|
#include <faiss/IndexFlat.h>
|
24
20
|
#include <faiss/IndexLSH.h>
|
25
|
-
|
21
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
22
|
+
#include <faiss/impl/FaissAssert.h>
|
23
|
+
#include <faiss/utils/hamming.h>
|
24
|
+
#include <faiss/utils/utils.h>
|
26
25
|
|
27
26
|
namespace faiss {
|
28
27
|
|
29
|
-
IndexBinaryIVF::IndexBinaryIVF(IndexBinary
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
cp.niter = 10;
|
28
|
+
IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
|
29
|
+
: IndexBinary(d),
|
30
|
+
invlists(new ArrayInvertedLists(nlist, code_size)),
|
31
|
+
own_invlists(true),
|
32
|
+
nprobe(1),
|
33
|
+
max_codes(0),
|
34
|
+
quantizer(quantizer),
|
35
|
+
nlist(nlist),
|
36
|
+
own_fields(false),
|
37
|
+
clustering_index(nullptr) {
|
38
|
+
FAISS_THROW_IF_NOT(d == quantizer->d);
|
39
|
+
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
40
|
+
|
41
|
+
cp.niter = 10;
|
44
42
|
}
|
45
43
|
|
46
44
|
IndexBinaryIVF::IndexBinaryIVF()
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
add_with_ids(n, x, nullptr);
|
45
|
+
: invlists(nullptr),
|
46
|
+
own_invlists(false),
|
47
|
+
nprobe(1),
|
48
|
+
max_codes(0),
|
49
|
+
quantizer(nullptr),
|
50
|
+
nlist(0),
|
51
|
+
own_fields(false),
|
52
|
+
clustering_index(nullptr) {}
|
53
|
+
|
54
|
+
void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
|
55
|
+
add_with_ids(n, x, nullptr);
|
59
56
|
}
|
60
57
|
|
61
|
-
void IndexBinaryIVF::add_with_ids(
|
62
|
-
|
58
|
+
void IndexBinaryIVF::add_with_ids(
|
59
|
+
idx_t n,
|
60
|
+
const uint8_t* x,
|
61
|
+
const idx_t* xids) {
|
62
|
+
add_core(n, x, xids, nullptr);
|
63
63
|
}
|
64
64
|
|
65
|
-
void IndexBinaryIVF::add_core(
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
65
|
+
void IndexBinaryIVF::add_core(
|
66
|
+
idx_t n,
|
67
|
+
const uint8_t* x,
|
68
|
+
const idx_t* xids,
|
69
|
+
const idx_t* precomputed_idx) {
|
70
|
+
FAISS_THROW_IF_NOT(is_trained);
|
71
|
+
assert(invlists);
|
72
|
+
direct_map.check_can_add(xids);
|
70
73
|
|
71
|
-
|
74
|
+
const idx_t* idx;
|
72
75
|
|
73
|
-
|
76
|
+
std::unique_ptr<idx_t[]> scoped_idx;
|
74
77
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
78
|
+
if (precomputed_idx) {
|
79
|
+
idx = precomputed_idx;
|
80
|
+
} else {
|
81
|
+
scoped_idx.reset(new idx_t[n]);
|
82
|
+
quantizer->assign(n, x, scoped_idx.get());
|
83
|
+
idx = scoped_idx.get();
|
84
|
+
}
|
82
85
|
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
86
|
+
idx_t n_add = 0;
|
87
|
+
for (size_t i = 0; i < n; i++) {
|
88
|
+
idx_t id = xids ? xids[i] : ntotal + i;
|
89
|
+
idx_t list_no = idx[i];
|
87
90
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
91
|
+
if (list_no < 0) {
|
92
|
+
direct_map.add_single_id(id, -1, 0);
|
93
|
+
} else {
|
94
|
+
const uint8_t* xi = x + i * code_size;
|
95
|
+
size_t offset = invlists->add_entry(list_no, id, xi);
|
93
96
|
|
94
|
-
|
95
|
-
|
97
|
+
direct_map.add_single_id(id, list_no, offset);
|
98
|
+
}
|
96
99
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
100
|
+
n_add++;
|
101
|
+
}
|
102
|
+
if (verbose) {
|
103
|
+
printf("IndexBinaryIVF::add_with_ids: added "
|
104
|
+
"%" PRId64 " / %" PRId64 " vectors\n",
|
105
|
+
n_add,
|
106
|
+
n);
|
107
|
+
}
|
108
|
+
ntotal += n_add;
|
104
109
|
}
|
105
110
|
|
106
|
-
void IndexBinaryIVF::make_direct_map
|
107
|
-
{
|
111
|
+
void IndexBinaryIVF::make_direct_map(bool b) {
|
108
112
|
if (b) {
|
109
|
-
direct_map.set_type
|
113
|
+
direct_map.set_type(DirectMap::Array, invlists, ntotal);
|
110
114
|
} else {
|
111
|
-
direct_map.set_type
|
115
|
+
direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
|
112
116
|
}
|
113
117
|
}
|
114
118
|
|
115
|
-
void IndexBinaryIVF::set_direct_map_type
|
116
|
-
|
117
|
-
direct_map.set_type (type, invlists, ntotal);
|
119
|
+
void IndexBinaryIVF::set_direct_map_type(DirectMap::Type type) {
|
120
|
+
direct_map.set_type(type, invlists, ntotal);
|
118
121
|
}
|
119
122
|
|
123
|
+
void IndexBinaryIVF::search(
|
124
|
+
idx_t n,
|
125
|
+
const uint8_t* x,
|
126
|
+
idx_t k,
|
127
|
+
int32_t* distances,
|
128
|
+
idx_t* labels) const {
|
129
|
+
FAISS_THROW_IF_NOT(k > 0);
|
130
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
120
131
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
132
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
133
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
134
|
+
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
125
135
|
|
126
|
-
|
127
|
-
|
128
|
-
|
136
|
+
double t0 = getmillisecs();
|
137
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
138
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
129
139
|
|
130
|
-
|
131
|
-
|
140
|
+
t0 = getmillisecs();
|
141
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
132
142
|
|
133
|
-
|
134
|
-
|
135
|
-
|
143
|
+
search_preassigned(
|
144
|
+
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
|
145
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
136
146
|
}
|
137
147
|
|
138
|
-
void IndexBinaryIVF::reconstruct(idx_t key, uint8_t
|
139
|
-
idx_t lo = direct_map.get
|
140
|
-
reconstruct_from_offset
|
148
|
+
void IndexBinaryIVF::reconstruct(idx_t key, uint8_t* recons) const {
|
149
|
+
idx_t lo = direct_map.get(key);
|
150
|
+
reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
|
141
151
|
}
|
142
152
|
|
143
|
-
void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t
|
144
|
-
|
153
|
+
void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
|
154
|
+
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
145
155
|
|
146
|
-
|
147
|
-
|
148
|
-
|
156
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
157
|
+
size_t list_size = invlists->list_size(list_no);
|
158
|
+
const Index::idx_t* idlist = invlists->get_ids(list_no);
|
149
159
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
160
|
+
for (idx_t offset = 0; offset < list_size; offset++) {
|
161
|
+
idx_t id = idlist[offset];
|
162
|
+
if (!(id >= i0 && id < i0 + ni)) {
|
163
|
+
continue;
|
164
|
+
}
|
155
165
|
|
156
|
-
|
157
|
-
|
166
|
+
uint8_t* reconstructed = recons + (id - i0) * d;
|
167
|
+
reconstruct_from_offset(list_no, offset, reconstructed);
|
168
|
+
}
|
158
169
|
}
|
159
|
-
}
|
160
170
|
}
|
161
171
|
|
162
|
-
void IndexBinaryIVF::search_and_reconstruct(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
172
|
+
void IndexBinaryIVF::search_and_reconstruct(
|
173
|
+
idx_t n,
|
174
|
+
const uint8_t* x,
|
175
|
+
idx_t k,
|
176
|
+
int32_t* distances,
|
177
|
+
idx_t* labels,
|
178
|
+
uint8_t* recons) const {
|
179
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
180
|
+
FAISS_THROW_IF_NOT(k > 0);
|
181
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
182
|
+
|
183
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
184
|
+
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
185
|
+
|
186
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
187
|
+
|
188
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
189
|
+
|
190
|
+
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
191
|
+
// and offset into `codes` for reconstruction
|
192
|
+
search_preassigned(
|
193
|
+
n,
|
194
|
+
x,
|
195
|
+
k,
|
196
|
+
idx.get(),
|
197
|
+
coarse_dis.get(),
|
198
|
+
distances,
|
199
|
+
labels,
|
200
|
+
/* store_pairs */ true);
|
201
|
+
for (idx_t i = 0; i < n; ++i) {
|
202
|
+
for (idx_t j = 0; j < k; ++j) {
|
203
|
+
idx_t ij = i * k + j;
|
204
|
+
idx_t key = labels[ij];
|
205
|
+
uint8_t* reconstructed = recons + ij * d;
|
206
|
+
if (key < 0) {
|
207
|
+
// Fill with NaNs
|
208
|
+
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
209
|
+
} else {
|
210
|
+
int list_no = key >> 32;
|
211
|
+
int offset = key & 0xffffffff;
|
212
|
+
|
213
|
+
// Update label to the actual id
|
214
|
+
labels[ij] = invlists->get_single_id(list_no, offset);
|
215
|
+
|
216
|
+
reconstruct_from_offset(list_no, offset, reconstructed);
|
217
|
+
}
|
218
|
+
}
|
193
219
|
}
|
194
|
-
}
|
195
220
|
}
|
196
221
|
|
197
|
-
void IndexBinaryIVF::reconstruct_from_offset(
|
198
|
-
|
199
|
-
|
222
|
+
void IndexBinaryIVF::reconstruct_from_offset(
|
223
|
+
idx_t list_no,
|
224
|
+
idx_t offset,
|
225
|
+
uint8_t* recons) const {
|
226
|
+
memcpy(recons, invlists->get_single_code(list_no, offset), code_size);
|
200
227
|
}
|
201
228
|
|
202
229
|
void IndexBinaryIVF::reset() {
|
203
|
-
|
204
|
-
|
205
|
-
|
230
|
+
direct_map.clear();
|
231
|
+
invlists->reset();
|
232
|
+
ntotal = 0;
|
206
233
|
}
|
207
234
|
|
208
235
|
size_t IndexBinaryIVF::remove_ids(const IDSelector& sel) {
|
209
|
-
size_t nremove = direct_map.remove_ids
|
236
|
+
size_t nremove = direct_map.remove_ids(sel, invlists);
|
210
237
|
ntotal -= nremove;
|
211
238
|
return nremove;
|
212
239
|
}
|
213
240
|
|
214
|
-
void IndexBinaryIVF::train(idx_t n, const uint8_t
|
215
|
-
if (verbose) {
|
216
|
-
printf("Training quantizer\n");
|
217
|
-
}
|
218
|
-
|
219
|
-
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
|
220
|
-
if (verbose) {
|
221
|
-
printf("IVF quantizer does not need training.\n");
|
222
|
-
}
|
223
|
-
} else {
|
241
|
+
void IndexBinaryIVF::train(idx_t n, const uint8_t* x) {
|
224
242
|
if (verbose) {
|
225
|
-
|
243
|
+
printf("Training quantizer\n");
|
226
244
|
}
|
227
245
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
}
|
246
|
+
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
|
247
|
+
if (verbose) {
|
248
|
+
printf("IVF quantizer does not need training.\n");
|
249
|
+
}
|
250
|
+
} else {
|
251
|
+
if (verbose) {
|
252
|
+
printf("Training quantizer on %" PRId64 " vectors in %dD\n", n, d);
|
253
|
+
}
|
237
254
|
|
238
|
-
|
239
|
-
|
255
|
+
Clustering clus(d, nlist, cp);
|
256
|
+
quantizer->reset();
|
240
257
|
|
241
|
-
|
258
|
+
IndexFlatL2 index_tmp(d);
|
242
259
|
|
243
|
-
|
244
|
-
|
245
|
-
|
260
|
+
if (clustering_index && verbose) {
|
261
|
+
printf("using clustering_index of dimension %d to do the clustering\n",
|
262
|
+
clustering_index->d);
|
263
|
+
}
|
246
264
|
|
247
|
-
|
248
|
-
|
249
|
-
}
|
265
|
+
// LSH codec that is able to convert the binary vectors to floats.
|
266
|
+
IndexLSH codec(d, d, false, false);
|
250
267
|
|
251
|
-
|
252
|
-
|
268
|
+
clus.train_encoded(
|
269
|
+
n, x, &codec, clustering_index ? *clustering_index : index_tmp);
|
253
270
|
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
FAISS_THROW_IF_NOT(other.nlist == nlist);
|
258
|
-
FAISS_THROW_IF_NOT(other.code_size == code_size);
|
259
|
-
FAISS_THROW_IF_NOT_MSG(direct_map.no() && other.direct_map.no(),
|
260
|
-
"direct map copy not implemented");
|
261
|
-
FAISS_THROW_IF_NOT_MSG(typeid (*this) == typeid (other),
|
262
|
-
"can only merge indexes of the same type");
|
271
|
+
// convert clusters to binary
|
272
|
+
std::unique_ptr<uint8_t[]> x_b(new uint8_t[clus.k * code_size]);
|
273
|
+
real_to_binary(d * clus.k, clus.centroids.data(), x_b.get());
|
263
274
|
|
264
|
-
|
275
|
+
quantizer->add(clus.k, x_b.get());
|
276
|
+
quantizer->is_trained = true;
|
277
|
+
}
|
265
278
|
|
266
|
-
|
267
|
-
other.ntotal = 0;
|
279
|
+
is_trained = true;
|
268
280
|
}
|
269
281
|
|
270
|
-
void IndexBinaryIVF::
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
282
|
+
void IndexBinaryIVF::merge_from(IndexBinaryIVF& other, idx_t add_id) {
|
283
|
+
// minimal sanity checks
|
284
|
+
FAISS_THROW_IF_NOT(other.d == d);
|
285
|
+
FAISS_THROW_IF_NOT(other.nlist == nlist);
|
286
|
+
FAISS_THROW_IF_NOT(other.code_size == code_size);
|
287
|
+
FAISS_THROW_IF_NOT_MSG(
|
288
|
+
direct_map.no() && other.direct_map.no(),
|
289
|
+
"direct map copy not implemented");
|
290
|
+
FAISS_THROW_IF_NOT_MSG(
|
291
|
+
typeid(*this) == typeid(other),
|
292
|
+
"can only merge indexes of the same type");
|
293
|
+
|
294
|
+
invlists->merge_from(other.invlists, add_id);
|
295
|
+
|
296
|
+
ntotal += other.ntotal;
|
297
|
+
other.ntotal = 0;
|
278
298
|
}
|
279
299
|
|
300
|
+
void IndexBinaryIVF::replace_invlists(InvertedLists* il, bool own) {
|
301
|
+
FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
|
302
|
+
if (own_invlists) {
|
303
|
+
delete invlists;
|
304
|
+
}
|
305
|
+
invlists = il;
|
306
|
+
own_invlists = own;
|
307
|
+
}
|
280
308
|
|
281
309
|
namespace {
|
282
310
|
|
283
311
|
using idx_t = Index::idx_t;
|
284
312
|
|
285
|
-
|
286
|
-
|
287
|
-
struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
288
|
-
|
313
|
+
template <class HammingComputer>
|
314
|
+
struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
|
289
315
|
HammingComputer hc;
|
290
316
|
size_t code_size;
|
291
317
|
bool store_pairs;
|
292
318
|
|
293
|
-
IVFBinaryScannerL2
|
294
|
-
|
295
|
-
{}
|
319
|
+
IVFBinaryScannerL2(size_t code_size, bool store_pairs)
|
320
|
+
: code_size(code_size), store_pairs(store_pairs) {}
|
296
321
|
|
297
|
-
void set_query
|
298
|
-
hc.set
|
322
|
+
void set_query(const uint8_t* query_vector) override {
|
323
|
+
hc.set(query_vector, code_size);
|
299
324
|
}
|
300
325
|
|
301
326
|
idx_t list_no;
|
302
|
-
void set_list
|
327
|
+
void set_list(idx_t list_no, uint8_t /* coarse_dis */) override {
|
303
328
|
this->list_no = list_no;
|
304
329
|
}
|
305
330
|
|
306
|
-
uint32_t distance_to_code
|
307
|
-
return hc.hamming
|
331
|
+
uint32_t distance_to_code(const uint8_t* code) const override {
|
332
|
+
return hc.hamming(code);
|
308
333
|
}
|
309
334
|
|
310
|
-
size_t scan_codes
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
335
|
+
size_t scan_codes(
|
336
|
+
size_t n,
|
337
|
+
const uint8_t* codes,
|
338
|
+
const idx_t* ids,
|
339
|
+
int32_t* simi,
|
340
|
+
idx_t* idxi,
|
341
|
+
size_t k) const override {
|
316
342
|
using C = CMax<int32_t, idx_t>;
|
317
343
|
|
318
344
|
size_t nup = 0;
|
319
345
|
for (size_t j = 0; j < n; j++) {
|
320
|
-
uint32_t dis = hc.hamming
|
346
|
+
uint32_t dis = hc.hamming(codes);
|
321
347
|
if (dis < simi[0]) {
|
322
348
|
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
323
|
-
heap_replace_top<C>
|
349
|
+
heap_replace_top<C>(k, simi, idxi, dis, id);
|
324
350
|
nup++;
|
325
351
|
}
|
326
352
|
codes += code_size;
|
@@ -328,40 +354,38 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
|
|
328
354
|
return nup;
|
329
355
|
}
|
330
356
|
|
331
|
-
void scan_codes_range
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
357
|
+
void scan_codes_range(
|
358
|
+
size_t n,
|
359
|
+
const uint8_t* codes,
|
360
|
+
const idx_t* ids,
|
361
|
+
int radius,
|
362
|
+
RangeQueryResult& result) const override {
|
337
363
|
size_t nup = 0;
|
338
364
|
for (size_t j = 0; j < n; j++) {
|
339
|
-
uint32_t dis = hc.hamming
|
365
|
+
uint32_t dis = hc.hamming(codes);
|
340
366
|
if (dis < radius) {
|
341
|
-
int64_t id = store_pairs ? lo_build
|
342
|
-
result.add
|
367
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
368
|
+
result.add(dis, id);
|
343
369
|
}
|
344
370
|
codes += code_size;
|
345
371
|
}
|
346
|
-
|
347
372
|
}
|
348
|
-
|
349
|
-
|
350
373
|
};
|
351
374
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
{
|
363
|
-
|
364
|
-
|
375
|
+
void search_knn_hamming_heap(
|
376
|
+
const IndexBinaryIVF& ivf,
|
377
|
+
size_t n,
|
378
|
+
const uint8_t* x,
|
379
|
+
idx_t k,
|
380
|
+
const idx_t* keys,
|
381
|
+
const int32_t* coarse_dis,
|
382
|
+
int32_t* distances,
|
383
|
+
idx_t* labels,
|
384
|
+
bool store_pairs,
|
385
|
+
const IVFSearchParameters* params) {
|
386
|
+
idx_t nprobe = params ? params->nprobe : ivf.nprobe;
|
387
|
+
nprobe = std::min((idx_t)ivf.nlist, nprobe);
|
388
|
+
idx_t max_codes = params ? params->max_codes : ivf.max_codes;
|
365
389
|
MetricType metric_type = ivf.metric_type;
|
366
390
|
|
367
391
|
// almost verbatim copy from IndexIVF::search_preassigned
|
@@ -370,57 +394,57 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
|
370
394
|
using HeapForIP = CMin<int32_t, idx_t>;
|
371
395
|
using HeapForL2 = CMax<int32_t, idx_t>;
|
372
396
|
|
373
|
-
#pragma omp parallel if(n > 1) reduction(
|
397
|
+
#pragma omp parallel if (n > 1) reduction(+ : nlistv, ndis, nheap)
|
374
398
|
{
|
375
|
-
std::unique_ptr<BinaryInvertedListScanner> scanner
|
376
|
-
|
399
|
+
std::unique_ptr<BinaryInvertedListScanner> scanner(
|
400
|
+
ivf.get_InvertedListScanner(store_pairs));
|
377
401
|
|
378
402
|
#pragma omp for
|
379
403
|
for (idx_t i = 0; i < n; i++) {
|
380
|
-
const uint8_t
|
404
|
+
const uint8_t* xi = x + i * ivf.code_size;
|
381
405
|
scanner->set_query(xi);
|
382
406
|
|
383
|
-
const idx_t
|
384
|
-
int32_t
|
385
|
-
idx_t
|
407
|
+
const idx_t* keysi = keys + i * nprobe;
|
408
|
+
int32_t* simi = distances + k * i;
|
409
|
+
idx_t* idxi = labels + k * i;
|
386
410
|
|
387
411
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
388
|
-
heap_heapify<HeapForIP>
|
412
|
+
heap_heapify<HeapForIP>(k, simi, idxi);
|
389
413
|
} else {
|
390
|
-
heap_heapify<HeapForL2>
|
414
|
+
heap_heapify<HeapForL2>(k, simi, idxi);
|
391
415
|
}
|
392
416
|
|
393
417
|
size_t nscan = 0;
|
394
418
|
|
395
419
|
for (size_t ik = 0; ik < nprobe; ik++) {
|
396
|
-
idx_t key = keysi[ik];
|
420
|
+
idx_t key = keysi[ik]; /* select the list */
|
397
421
|
if (key < 0) {
|
398
422
|
// not enough centroids for multiprobe
|
399
423
|
continue;
|
400
424
|
}
|
401
|
-
FAISS_THROW_IF_NOT_FMT
|
402
|
-
|
403
|
-
|
404
|
-
|
425
|
+
FAISS_THROW_IF_NOT_FMT(
|
426
|
+
key < (idx_t)ivf.nlist,
|
427
|
+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
428
|
+
key,
|
429
|
+
ik,
|
430
|
+
ivf.nlist);
|
405
431
|
|
406
|
-
scanner->set_list
|
432
|
+
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
|
407
433
|
|
408
434
|
nlistv++;
|
409
435
|
|
410
436
|
size_t list_size = ivf.invlists->list_size(key);
|
411
|
-
InvertedLists::ScopedCodes scodes
|
437
|
+
InvertedLists::ScopedCodes scodes(ivf.invlists, key);
|
412
438
|
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
413
|
-
const Index::idx_t
|
439
|
+
const Index::idx_t* ids = nullptr;
|
414
440
|
|
415
441
|
if (!store_pairs) {
|
416
|
-
sids.reset
|
442
|
+
sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
|
417
443
|
ids = sids->get();
|
418
444
|
}
|
419
445
|
|
420
|
-
nheap += scanner->scan_codes
|
421
|
-
list_size, scodes.get(),
|
422
|
-
ids, simi, idxi, k
|
423
|
-
);
|
446
|
+
nheap += scanner->scan_codes(
|
447
|
+
list_size, scodes.get(), ids, simi, idxi, k);
|
424
448
|
|
425
449
|
nscan += list_size;
|
426
450
|
if (max_codes && nscan >= max_codes)
|
@@ -429,208 +453,205 @@ void search_knn_hamming_heap(const IndexBinaryIVF& ivf,
|
|
429
453
|
|
430
454
|
ndis += nscan;
|
431
455
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
432
|
-
heap_reorder<HeapForIP>
|
456
|
+
heap_reorder<HeapForIP>(k, simi, idxi);
|
433
457
|
} else {
|
434
|
-
heap_reorder<HeapForL2>
|
458
|
+
heap_reorder<HeapForL2>(k, simi, idxi);
|
435
459
|
}
|
436
460
|
|
437
461
|
} // parallel for
|
438
|
-
}
|
462
|
+
} // parallel
|
439
463
|
|
440
464
|
indexIVF_stats.nq += n;
|
441
465
|
indexIVF_stats.nlist += nlistv;
|
442
466
|
indexIVF_stats.ndis += ndis;
|
443
467
|
indexIVF_stats.nheap_updates += nheap;
|
444
|
-
|
445
468
|
}
|
446
469
|
|
447
|
-
template<class HammingComputer, bool store_pairs>
|
448
|
-
void search_knn_hamming_count(
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
size_t nlistv = 0, ndis = 0;
|
475
|
-
|
476
|
-
#pragma omp parallel for reduction(+: nlistv, ndis)
|
477
|
-
for (int64_t i = 0; i < nx; i++) {
|
478
|
-
const idx_t * keysi = keys + i * nprobe;
|
479
|
-
HCounterState<HammingComputer>& csi = cs[i];
|
480
|
-
|
481
|
-
size_t nscan = 0;
|
482
|
-
|
483
|
-
for (size_t ik = 0; ik < nprobe; ik++) {
|
484
|
-
idx_t key = keysi[ik]; /* select the list */
|
485
|
-
if (key < 0) {
|
486
|
-
// not enough centroids for multiprobe
|
487
|
-
continue;
|
488
|
-
}
|
489
|
-
FAISS_THROW_IF_NOT_FMT (
|
490
|
-
key < (idx_t) ivf.nlist,
|
491
|
-
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
492
|
-
key, ik, ivf.nlist);
|
493
|
-
|
494
|
-
nlistv++;
|
495
|
-
size_t list_size = ivf.invlists->list_size(key);
|
496
|
-
InvertedLists::ScopedCodes scodes (ivf.invlists, key);
|
497
|
-
const uint8_t *list_vecs = scodes.get();
|
498
|
-
const Index::idx_t *ids = store_pairs
|
499
|
-
? nullptr
|
500
|
-
: ivf.invlists->get_ids(key);
|
501
|
-
|
502
|
-
for (size_t j = 0; j < list_size; j++) {
|
503
|
-
const uint8_t * yj = list_vecs + ivf.code_size * j;
|
504
|
-
|
505
|
-
idx_t id = store_pairs ? (key << 32 | j) : ids[j];
|
506
|
-
csi.update_counter(yj, id);
|
507
|
-
}
|
508
|
-
if (ids)
|
509
|
-
ivf.invlists->release_ids (key, ids);
|
510
|
-
|
511
|
-
nscan += list_size;
|
512
|
-
if (max_codes && nscan >= max_codes)
|
513
|
-
break;
|
514
|
-
}
|
515
|
-
ndis += nscan;
|
516
|
-
|
517
|
-
int nres = 0;
|
518
|
-
for (int b = 0; b < nBuckets && nres < k; b++) {
|
519
|
-
for (int l = 0; l < csi.counters[b] && nres < k; l++) {
|
520
|
-
labels[i * k + nres] = csi.ids_per_dis[b * k + l];
|
521
|
-
distances[i * k + nres] = b;
|
522
|
-
nres++;
|
523
|
-
}
|
524
|
-
}
|
525
|
-
while (nres < k) {
|
526
|
-
labels[i * k + nres] = -1;
|
527
|
-
distances[i * k + nres] = std::numeric_limits<int32_t>::max();
|
528
|
-
++nres;
|
470
|
+
template <class HammingComputer, bool store_pairs>
|
471
|
+
void search_knn_hamming_count(
|
472
|
+
const IndexBinaryIVF& ivf,
|
473
|
+
size_t nx,
|
474
|
+
const uint8_t* x,
|
475
|
+
const idx_t* keys,
|
476
|
+
int k,
|
477
|
+
int32_t* distances,
|
478
|
+
idx_t* labels,
|
479
|
+
const IVFSearchParameters* params) {
|
480
|
+
const int nBuckets = ivf.d + 1;
|
481
|
+
std::vector<int> all_counters(nx * nBuckets, 0);
|
482
|
+
std::unique_ptr<idx_t[]> all_ids_per_dis(new idx_t[nx * nBuckets * k]);
|
483
|
+
|
484
|
+
idx_t nprobe = params ? params->nprobe : ivf.nprobe;
|
485
|
+
nprobe = std::min((idx_t)ivf.nlist, nprobe);
|
486
|
+
idx_t max_codes = params ? params->max_codes : ivf.max_codes;
|
487
|
+
|
488
|
+
std::vector<HCounterState<HammingComputer>> cs;
|
489
|
+
for (size_t i = 0; i < nx; ++i) {
|
490
|
+
cs.push_back(HCounterState<HammingComputer>(
|
491
|
+
all_counters.data() + i * nBuckets,
|
492
|
+
all_ids_per_dis.get() + i * nBuckets * k,
|
493
|
+
x + i * ivf.code_size,
|
494
|
+
ivf.d,
|
495
|
+
k));
|
529
496
|
}
|
530
|
-
}
|
531
497
|
|
532
|
-
|
533
|
-
indexIVF_stats.nlist += nlistv;
|
534
|
-
indexIVF_stats.ndis += ndis;
|
535
|
-
}
|
498
|
+
size_t nlistv = 0, ndis = 0;
|
536
499
|
|
500
|
+
#pragma omp parallel for reduction(+ : nlistv, ndis)
|
501
|
+
for (int64_t i = 0; i < nx; i++) {
|
502
|
+
const idx_t* keysi = keys + i * nprobe;
|
503
|
+
HCounterState<HammingComputer>& csi = cs[i];
|
537
504
|
|
505
|
+
size_t nscan = 0;
|
538
506
|
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
507
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
508
|
+
idx_t key = keysi[ik]; /* select the list */
|
509
|
+
if (key < 0) {
|
510
|
+
// not enough centroids for multiprobe
|
511
|
+
continue;
|
512
|
+
}
|
513
|
+
FAISS_THROW_IF_NOT_FMT(
|
514
|
+
key < (idx_t)ivf.nlist,
|
515
|
+
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
516
|
+
key,
|
517
|
+
ik,
|
518
|
+
ivf.nlist);
|
519
|
+
|
520
|
+
nlistv++;
|
521
|
+
size_t list_size = ivf.invlists->list_size(key);
|
522
|
+
InvertedLists::ScopedCodes scodes(ivf.invlists, key);
|
523
|
+
const uint8_t* list_vecs = scodes.get();
|
524
|
+
const Index::idx_t* ids =
|
525
|
+
store_pairs ? nullptr : ivf.invlists->get_ids(key);
|
526
|
+
|
527
|
+
for (size_t j = 0; j < list_size; j++) {
|
528
|
+
const uint8_t* yj = list_vecs + ivf.code_size * j;
|
529
|
+
|
530
|
+
idx_t id = store_pairs ? (key << 32 | j) : ids[j];
|
531
|
+
csi.update_counter(yj, id);
|
532
|
+
}
|
533
|
+
if (ids)
|
534
|
+
ivf.invlists->release_ids(key, ids);
|
535
|
+
|
536
|
+
nscan += list_size;
|
537
|
+
if (max_codes && nscan >= max_codes)
|
538
|
+
break;
|
539
|
+
}
|
540
|
+
ndis += nscan;
|
541
|
+
|
542
|
+
int nres = 0;
|
543
|
+
for (int b = 0; b < nBuckets && nres < k; b++) {
|
544
|
+
for (int l = 0; l < csi.counters[b] && nres < k; l++) {
|
545
|
+
labels[i * k + nres] = csi.ids_per_dis[b * k + l];
|
546
|
+
distances[i * k + nres] = b;
|
547
|
+
nres++;
|
548
|
+
}
|
549
|
+
}
|
550
|
+
while (nres < k) {
|
551
|
+
labels[i * k + nres] = -1;
|
552
|
+
distances[i * k + nres] = std::numeric_limits<int32_t>::max();
|
553
|
+
++nres;
|
572
554
|
}
|
573
|
-
break;
|
574
555
|
}
|
575
556
|
|
557
|
+
indexIVF_stats.nq += nx;
|
558
|
+
indexIVF_stats.nlist += nlistv;
|
559
|
+
indexIVF_stats.ndis += ndis;
|
576
560
|
}
|
577
561
|
|
578
|
-
|
562
|
+
template <bool store_pairs>
|
563
|
+
void search_knn_hamming_count_1(
|
564
|
+
const IndexBinaryIVF& ivf,
|
565
|
+
size_t nx,
|
566
|
+
const uint8_t* x,
|
567
|
+
const idx_t* keys,
|
568
|
+
int k,
|
569
|
+
int32_t* distances,
|
570
|
+
idx_t* labels,
|
571
|
+
const IVFSearchParameters* params) {
|
572
|
+
switch (ivf.code_size) {
|
573
|
+
#define HANDLE_CS(cs) \
|
574
|
+
case cs: \
|
575
|
+
search_knn_hamming_count<HammingComputer##cs, store_pairs>( \
|
576
|
+
ivf, nx, x, keys, k, distances, labels, params); \
|
577
|
+
break;
|
578
|
+
HANDLE_CS(4);
|
579
|
+
HANDLE_CS(8);
|
580
|
+
HANDLE_CS(16);
|
581
|
+
HANDLE_CS(20);
|
582
|
+
HANDLE_CS(32);
|
583
|
+
HANDLE_CS(64);
|
584
|
+
#undef HANDLE_CS
|
585
|
+
default:
|
586
|
+
search_knn_hamming_count<HammingComputerDefault, store_pairs>(
|
587
|
+
ivf, nx, x, keys, k, distances, labels, params);
|
588
|
+
break;
|
589
|
+
}
|
590
|
+
}
|
579
591
|
|
580
|
-
|
581
|
-
(bool store_pairs) const
|
582
|
-
{
|
592
|
+
} // namespace
|
583
593
|
|
584
|
-
|
594
|
+
BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
|
595
|
+
bool store_pairs) const {
|
596
|
+
#define HC(name) return new IVFBinaryScannerL2<name>(code_size, store_pairs)
|
585
597
|
switch (code_size) {
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
+
case 4:
|
599
|
+
HC(HammingComputer4);
|
600
|
+
case 8:
|
601
|
+
HC(HammingComputer8);
|
602
|
+
case 16:
|
603
|
+
HC(HammingComputer16);
|
604
|
+
case 20:
|
605
|
+
HC(HammingComputer20);
|
606
|
+
case 32:
|
607
|
+
HC(HammingComputer32);
|
608
|
+
case 64:
|
609
|
+
HC(HammingComputer64);
|
610
|
+
default:
|
598
611
|
HC(HammingComputerDefault);
|
599
|
-
}
|
600
612
|
}
|
601
613
|
#undef HC
|
602
|
-
|
603
614
|
}
|
604
615
|
|
605
|
-
void IndexBinaryIVF::search_preassigned(
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
616
|
+
void IndexBinaryIVF::search_preassigned(
|
617
|
+
idx_t n,
|
618
|
+
const uint8_t* x,
|
619
|
+
idx_t k,
|
620
|
+
const idx_t* idx,
|
621
|
+
const int32_t* coarse_dis,
|
622
|
+
int32_t* distances,
|
623
|
+
idx_t* labels,
|
624
|
+
bool store_pairs,
|
625
|
+
const IVFSearchParameters* params) const {
|
613
626
|
if (use_heap) {
|
614
|
-
search_knn_hamming_heap
|
615
|
-
|
616
|
-
|
627
|
+
search_knn_hamming_heap(
|
628
|
+
*this,
|
629
|
+
n,
|
630
|
+
x,
|
631
|
+
k,
|
632
|
+
idx,
|
633
|
+
coarse_dis,
|
634
|
+
distances,
|
635
|
+
labels,
|
636
|
+
store_pairs,
|
637
|
+
params);
|
617
638
|
} else {
|
618
639
|
if (store_pairs) {
|
619
|
-
search_knn_hamming_count_1<true>
|
620
|
-
|
640
|
+
search_knn_hamming_count_1<true>(
|
641
|
+
*this, n, x, idx, k, distances, labels, params);
|
621
642
|
} else {
|
622
|
-
search_knn_hamming_count_1<false>
|
623
|
-
|
643
|
+
search_knn_hamming_count_1<false>(
|
644
|
+
*this, n, x, idx, k, distances, labels, params);
|
624
645
|
}
|
625
646
|
}
|
626
647
|
}
|
627
648
|
|
628
|
-
|
629
649
|
void IndexBinaryIVF::range_search(
|
630
|
-
idx_t n,
|
631
|
-
|
632
|
-
|
633
|
-
|
650
|
+
idx_t n,
|
651
|
+
const uint8_t* x,
|
652
|
+
int radius,
|
653
|
+
RangeSearchResult* res) const {
|
654
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
634
655
|
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
635
656
|
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
636
657
|
|
@@ -641,77 +662,84 @@ void IndexBinaryIVF::range_search(
|
|
641
662
|
t0 = getmillisecs();
|
642
663
|
invlists->prefetch_lists(idx.get(), n * nprobe);
|
643
664
|
|
665
|
+
range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
|
666
|
+
|
667
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
668
|
+
}
|
669
|
+
|
670
|
+
void IndexBinaryIVF::range_search_preassigned(
|
671
|
+
idx_t n,
|
672
|
+
const uint8_t* x,
|
673
|
+
int radius,
|
674
|
+
const idx_t* assign,
|
675
|
+
const int32_t* centroid_dis,
|
676
|
+
RangeSearchResult* res) const {
|
677
|
+
const size_t nprobe = std::min(nlist, this->nprobe);
|
644
678
|
bool store_pairs = false;
|
645
679
|
size_t nlistv = 0, ndis = 0;
|
646
680
|
|
647
|
-
std::vector<RangeSearchPartialResult
|
681
|
+
std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
|
648
682
|
|
649
|
-
#pragma omp parallel reduction(
|
683
|
+
#pragma omp parallel reduction(+ : nlistv, ndis)
|
650
684
|
{
|
651
685
|
RangeSearchPartialResult pres(res);
|
652
|
-
std::unique_ptr<BinaryInvertedListScanner> scanner
|
653
|
-
|
654
|
-
FAISS_THROW_IF_NOT
|
686
|
+
std::unique_ptr<BinaryInvertedListScanner> scanner(
|
687
|
+
get_InvertedListScanner(store_pairs));
|
688
|
+
FAISS_THROW_IF_NOT(scanner.get());
|
655
689
|
|
656
690
|
all_pres[omp_get_thread_num()] = &pres;
|
657
691
|
|
658
|
-
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
key < (idx_t) nlist,
|
692
|
+
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
|
693
|
+
idx_t key = assign[i * nprobe + ik]; /* select the list */
|
694
|
+
if (key < 0)
|
695
|
+
return;
|
696
|
+
FAISS_THROW_IF_NOT_FMT(
|
697
|
+
key < (idx_t)nlist,
|
665
698
|
"Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
|
666
|
-
key,
|
699
|
+
key,
|
700
|
+
ik,
|
701
|
+
nlist);
|
667
702
|
const size_t list_size = invlists->list_size(key);
|
668
703
|
|
669
|
-
if (list_size == 0)
|
704
|
+
if (list_size == 0)
|
705
|
+
return;
|
670
706
|
|
671
|
-
InvertedLists::ScopedCodes scodes
|
672
|
-
InvertedLists::ScopedIds ids
|
707
|
+
InvertedLists::ScopedCodes scodes(invlists, key);
|
708
|
+
InvertedLists::ScopedIds ids(invlists, key);
|
673
709
|
|
674
|
-
scanner->set_list
|
710
|
+
scanner->set_list(key, assign[i * nprobe + ik]);
|
675
711
|
nlistv++;
|
676
712
|
ndis += list_size;
|
677
|
-
scanner->scan_codes_range
|
678
|
-
|
713
|
+
scanner->scan_codes_range(
|
714
|
+
list_size, scodes.get(), ids.get(), radius, qres);
|
679
715
|
};
|
680
716
|
|
681
717
|
#pragma omp for
|
682
718
|
for (idx_t i = 0; i < n; i++) {
|
683
|
-
scanner->set_query
|
719
|
+
scanner->set_query(x + i * code_size);
|
684
720
|
|
685
|
-
RangeQueryResult
|
721
|
+
RangeQueryResult& qres = pres.new_result(i);
|
686
722
|
|
687
723
|
for (size_t ik = 0; ik < nprobe; ik++) {
|
688
|
-
scan_list_func
|
724
|
+
scan_list_func(i, ik, qres);
|
689
725
|
}
|
690
|
-
|
691
726
|
}
|
692
727
|
|
693
728
|
pres.finalize();
|
694
|
-
|
695
729
|
}
|
696
730
|
indexIVF_stats.nq += n;
|
697
731
|
indexIVF_stats.nlist += nlistv;
|
698
732
|
indexIVF_stats.ndis += ndis;
|
699
|
-
indexIVF_stats.search_time += getmillisecs() - t0;
|
700
|
-
|
701
733
|
}
|
702
734
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
735
|
IndexBinaryIVF::~IndexBinaryIVF() {
|
707
|
-
|
708
|
-
|
709
|
-
|
736
|
+
if (own_invlists) {
|
737
|
+
delete invlists;
|
738
|
+
}
|
710
739
|
|
711
|
-
|
712
|
-
|
713
|
-
|
740
|
+
if (own_fields) {
|
741
|
+
delete quantizer;
|
742
|
+
}
|
714
743
|
}
|
715
744
|
|
716
|
-
|
717
|
-
} // namespace faiss
|
745
|
+
} // namespace faiss
|