faiss 0.1.7 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +7 -7
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +8 -2
- data/ext/faiss/index.cpp +102 -69
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +0 -5
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +26 -12
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,15 +10,15 @@
|
|
10
10
|
#include <faiss/IndexPQ.h>
|
11
11
|
|
12
12
|
#include <cinttypes>
|
13
|
+
#include <cmath>
|
13
14
|
#include <cstddef>
|
14
|
-
#include <cstring>
|
15
15
|
#include <cstdio>
|
16
|
-
#include <
|
16
|
+
#include <cstring>
|
17
17
|
|
18
18
|
#include <algorithm>
|
19
19
|
|
20
|
-
#include <faiss/impl/FaissAssert.h>
|
21
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
22
22
|
#include <faiss/utils/hamming.h>
|
23
23
|
|
24
24
|
namespace faiss {
|
@@ -27,10 +27,8 @@ namespace faiss {
|
|
27
27
|
* IndexPQ implementation
|
28
28
|
********************************************************/
|
29
29
|
|
30
|
-
|
31
|
-
|
32
|
-
Index(d, metric), pq(d, M, nbits)
|
33
|
-
{
|
30
|
+
IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
|
31
|
+
: Index(d, metric), pq(d, M, nbits) {
|
34
32
|
is_trained = false;
|
35
33
|
do_polysemous_training = false;
|
36
34
|
polysemous_ht = nbits * M + 1;
|
@@ -38,8 +36,7 @@ IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
|
|
38
36
|
encode_signs = false;
|
39
37
|
}
|
40
38
|
|
41
|
-
IndexPQ::IndexPQ
|
42
|
-
{
|
39
|
+
IndexPQ::IndexPQ() {
|
43
40
|
metric_type = METRIC_L2;
|
44
41
|
is_trained = false;
|
45
42
|
do_polysemous_training = false;
|
@@ -48,10 +45,8 @@ IndexPQ::IndexPQ ()
|
|
48
45
|
encode_signs = false;
|
49
46
|
}
|
50
47
|
|
51
|
-
|
52
|
-
|
53
|
-
{
|
54
|
-
if (!do_polysemous_training) { // standard training
|
48
|
+
void IndexPQ::train(idx_t n, const float* x) {
|
49
|
+
if (!do_polysemous_training) { // standard training
|
55
50
|
pq.train(n, x);
|
56
51
|
} else {
|
57
52
|
idx_t ntrain_perm = polysemous_training.ntrain_permutation;
|
@@ -59,38 +54,38 @@ void IndexPQ::train (idx_t n, const float *x)
|
|
59
54
|
if (ntrain_perm > n / 4)
|
60
55
|
ntrain_perm = n / 4;
|
61
56
|
if (verbose) {
|
62
|
-
printf
|
63
|
-
|
64
|
-
|
65
|
-
|
57
|
+
printf("PQ training on %" PRId64 " points, remains %" PRId64
|
58
|
+
" points: "
|
59
|
+
"training polysemous on %s\n",
|
60
|
+
n - ntrain_perm,
|
61
|
+
ntrain_perm,
|
62
|
+
ntrain_perm == 0 ? "centroids" : "these");
|
66
63
|
}
|
67
64
|
pq.train(n - ntrain_perm, x);
|
68
65
|
|
69
|
-
polysemous_training.optimize_pq_for_hamming
|
70
|
-
|
66
|
+
polysemous_training.optimize_pq_for_hamming(
|
67
|
+
pq, ntrain_perm, x + (n - ntrain_perm) * d);
|
71
68
|
}
|
72
69
|
is_trained = true;
|
73
70
|
}
|
74
71
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
codes.resize ((n + ntotal) * pq.code_size);
|
80
|
-
pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
|
72
|
+
void IndexPQ::add(idx_t n, const float* x) {
|
73
|
+
FAISS_THROW_IF_NOT(is_trained);
|
74
|
+
codes.resize((n + ntotal) * pq.code_size);
|
75
|
+
pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
|
81
76
|
ntotal += n;
|
82
77
|
}
|
83
78
|
|
84
|
-
|
85
|
-
size_t IndexPQ::remove_ids (const IDSelector & sel)
|
86
|
-
{
|
79
|
+
size_t IndexPQ::remove_ids(const IDSelector& sel) {
|
87
80
|
idx_t j = 0;
|
88
81
|
for (idx_t i = 0; i < ntotal; i++) {
|
89
|
-
if (sel.is_member
|
82
|
+
if (sel.is_member(i)) {
|
90
83
|
// should be removed
|
91
84
|
} else {
|
92
85
|
if (i > j) {
|
93
|
-
memmove
|
86
|
+
memmove(&codes[pq.code_size * j],
|
87
|
+
&codes[pq.code_size * i],
|
88
|
+
pq.code_size);
|
94
89
|
}
|
95
90
|
j++;
|
96
91
|
}
|
@@ -98,53 +93,46 @@ size_t IndexPQ::remove_ids (const IDSelector & sel)
|
|
98
93
|
size_t nremove = ntotal - j;
|
99
94
|
if (nremove > 0) {
|
100
95
|
ntotal = j;
|
101
|
-
codes.resize
|
96
|
+
codes.resize(ntotal * pq.code_size);
|
102
97
|
}
|
103
98
|
return nremove;
|
104
99
|
}
|
105
100
|
|
106
|
-
|
107
|
-
void IndexPQ::reset()
|
108
|
-
{
|
101
|
+
void IndexPQ::reset() {
|
109
102
|
codes.clear();
|
110
103
|
ntotal = 0;
|
111
104
|
}
|
112
105
|
|
113
|
-
void IndexPQ::reconstruct_n
|
114
|
-
|
115
|
-
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
106
|
+
void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
107
|
+
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
116
108
|
for (idx_t i = 0; i < ni; i++) {
|
117
|
-
const uint8_t
|
118
|
-
pq.decode
|
109
|
+
const uint8_t* code = &codes[(i0 + i) * pq.code_size];
|
110
|
+
pq.decode(code, recons + i * d);
|
119
111
|
}
|
120
112
|
}
|
121
113
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
|
126
|
-
pq.decode (&codes[key * pq.code_size], recons);
|
114
|
+
void IndexPQ::reconstruct(idx_t key, float* recons) const {
|
115
|
+
FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
|
116
|
+
pq.decode(&codes[key * pq.code_size], recons);
|
127
117
|
}
|
128
118
|
|
129
|
-
|
130
119
|
namespace {
|
131
120
|
|
132
|
-
template<class PQDecoder>
|
133
|
-
struct PQDistanceComputer: DistanceComputer {
|
121
|
+
template <class PQDecoder>
|
122
|
+
struct PQDistanceComputer : DistanceComputer {
|
134
123
|
size_t d;
|
135
124
|
MetricType metric;
|
136
125
|
Index::idx_t nb;
|
137
|
-
const uint8_t
|
126
|
+
const uint8_t* codes;
|
138
127
|
size_t code_size;
|
139
|
-
const ProductQuantizer
|
140
|
-
const float
|
128
|
+
const ProductQuantizer& pq;
|
129
|
+
const float* sdc;
|
141
130
|
std::vector<float> precomputed_table;
|
142
131
|
size_t ndis;
|
143
132
|
|
144
|
-
float operator
|
145
|
-
|
146
|
-
const
|
147
|
-
const float *dt = precomputed_table.data();
|
133
|
+
float operator()(idx_t i) override {
|
134
|
+
const uint8_t* code = codes + i * code_size;
|
135
|
+
const float* dt = precomputed_table.data();
|
148
136
|
PQDecoder decoder(code, pq.nbits);
|
149
137
|
float accu = 0;
|
150
138
|
for (int j = 0; j < pq.M; j++) {
|
@@ -155,13 +143,12 @@ struct PQDistanceComputer: DistanceComputer {
|
|
155
143
|
return accu;
|
156
144
|
}
|
157
145
|
|
158
|
-
float symmetric_dis(idx_t i, idx_t j) override
|
159
|
-
{
|
146
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
160
147
|
FAISS_THROW_IF_NOT(sdc);
|
161
|
-
const float
|
148
|
+
const float* sdci = sdc;
|
162
149
|
float accu = 0;
|
163
|
-
PQDecoder codei
|
164
|
-
PQDecoder codej
|
150
|
+
PQDecoder codei(codes + i * code_size, pq.nbits);
|
151
|
+
PQDecoder codej(codes + j * code_size, pq.nbits);
|
165
152
|
|
166
153
|
for (int l = 0; l < pq.M; l++) {
|
167
154
|
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
|
@@ -171,8 +158,7 @@ struct PQDistanceComputer: DistanceComputer {
|
|
171
158
|
return accu;
|
172
159
|
}
|
173
160
|
|
174
|
-
explicit PQDistanceComputer(const IndexPQ& storage)
|
175
|
-
: pq(storage.pq) {
|
161
|
+
explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
|
176
162
|
precomputed_table.resize(pq.M * pq.ksub);
|
177
163
|
nb = storage.ntotal;
|
178
164
|
d = storage.d;
|
@@ -187,21 +173,18 @@ struct PQDistanceComputer: DistanceComputer {
|
|
187
173
|
ndis = 0;
|
188
174
|
}
|
189
175
|
|
190
|
-
void set_query(const float
|
176
|
+
void set_query(const float* x) override {
|
191
177
|
if (metric == METRIC_L2) {
|
192
178
|
pq.compute_distance_table(x, precomputed_table.data());
|
193
179
|
} else {
|
194
180
|
pq.compute_inner_prod_table(x, precomputed_table.data());
|
195
181
|
}
|
196
|
-
|
197
182
|
}
|
198
183
|
};
|
199
184
|
|
185
|
+
} // namespace
|
200
186
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
DistanceComputer * IndexPQ::get_distance_computer() const {
|
187
|
+
DistanceComputer* IndexPQ::get_distance_computer() const {
|
205
188
|
if (pq.nbits == 8) {
|
206
189
|
return new PQDistanceComputer<PQDecoder8>(*this);
|
207
190
|
} else if (pq.nbits == 16) {
|
@@ -211,142 +194,142 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
|
|
211
194
|
}
|
212
195
|
}
|
213
196
|
|
214
|
-
|
215
197
|
/*****************************************
|
216
198
|
* IndexPQ polysemous search routines
|
217
199
|
******************************************/
|
218
200
|
|
201
|
+
void IndexPQ::search(
|
202
|
+
idx_t n,
|
203
|
+
const float* x,
|
204
|
+
idx_t k,
|
205
|
+
float* distances,
|
206
|
+
idx_t* labels) const {
|
207
|
+
FAISS_THROW_IF_NOT(k > 0);
|
219
208
|
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
void IndexPQ::search (idx_t n, const float *x, idx_t k,
|
224
|
-
float *distances, idx_t *labels) const
|
225
|
-
{
|
226
|
-
FAISS_THROW_IF_NOT (is_trained);
|
227
|
-
if (search_type == ST_PQ) { // Simple PQ search
|
209
|
+
FAISS_THROW_IF_NOT(is_trained);
|
210
|
+
if (search_type == ST_PQ) { // Simple PQ search
|
228
211
|
|
229
212
|
if (metric_type == METRIC_L2) {
|
230
213
|
float_maxheap_array_t res = {
|
231
|
-
|
232
|
-
pq.search
|
214
|
+
size_t(n), size_t(k), labels, distances};
|
215
|
+
pq.search(x, n, codes.data(), ntotal, &res, true);
|
233
216
|
} else {
|
234
217
|
float_minheap_array_t res = {
|
235
|
-
|
236
|
-
pq.search_ip
|
218
|
+
size_t(n), size_t(k), labels, distances};
|
219
|
+
pq.search_ip(x, n, codes.data(), ntotal, &res, true);
|
237
220
|
}
|
238
221
|
indexPQ_stats.nq += n;
|
239
222
|
indexPQ_stats.ncode += n * ntotal;
|
240
223
|
|
241
|
-
} else if (
|
242
|
-
|
243
|
-
|
244
|
-
FAISS_THROW_IF_NOT
|
224
|
+
} else if (
|
225
|
+
search_type == ST_polysemous ||
|
226
|
+
search_type == ST_polysemous_generalize) {
|
227
|
+
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
|
245
228
|
|
246
|
-
search_core_polysemous
|
229
|
+
search_core_polysemous(n, x, k, distances, labels);
|
247
230
|
|
248
231
|
} else { // code-to-code distances
|
249
232
|
|
250
|
-
uint8_t
|
251
|
-
ScopeDeleter<uint8_t> del
|
252
|
-
|
233
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
234
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
253
235
|
|
254
236
|
if (!encode_signs) {
|
255
|
-
pq.compute_codes
|
237
|
+
pq.compute_codes(x, q_codes, n);
|
256
238
|
} else {
|
257
|
-
FAISS_THROW_IF_NOT
|
258
|
-
memset
|
239
|
+
FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
|
240
|
+
memset(q_codes, 0, n * pq.code_size);
|
259
241
|
for (size_t i = 0; i < n; i++) {
|
260
|
-
const float
|
261
|
-
uint8_t
|
242
|
+
const float* xi = x + i * d;
|
243
|
+
uint8_t* code = q_codes + i * pq.code_size;
|
262
244
|
for (int j = 0; j < d; j++)
|
263
|
-
if (xi[j] > 0)
|
245
|
+
if (xi[j] > 0)
|
246
|
+
code[j >> 3] |= 1 << (j & 7);
|
264
247
|
}
|
265
248
|
}
|
266
249
|
|
267
|
-
if (search_type == ST_SDC)
|
268
|
-
|
250
|
+
if (search_type == ST_SDC) {
|
269
251
|
float_maxheap_array_t res = {
|
270
|
-
|
252
|
+
size_t(n), size_t(k), labels, distances};
|
271
253
|
|
272
|
-
pq.search_sdc
|
254
|
+
pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
|
273
255
|
|
274
256
|
} else {
|
275
|
-
int
|
276
|
-
ScopeDeleter<int> del
|
257
|
+
int* idistances = new int[n * k];
|
258
|
+
ScopeDeleter<int> del(idistances);
|
277
259
|
|
278
260
|
int_maxheap_array_t res = {
|
279
|
-
|
261
|
+
size_t(n), size_t(k), labels, idistances};
|
280
262
|
|
281
263
|
if (search_type == ST_HE) {
|
282
|
-
|
283
|
-
|
284
|
-
|
264
|
+
hammings_knn_hc(
|
265
|
+
&res,
|
266
|
+
q_codes,
|
267
|
+
codes.data(),
|
268
|
+
ntotal,
|
269
|
+
pq.code_size,
|
270
|
+
true);
|
285
271
|
|
286
272
|
} else if (search_type == ST_generalized_HE) {
|
287
|
-
|
288
|
-
|
289
|
-
|
273
|
+
generalized_hammings_knn_hc(
|
274
|
+
&res,
|
275
|
+
q_codes,
|
276
|
+
codes.data(),
|
277
|
+
ntotal,
|
278
|
+
pq.code_size,
|
279
|
+
true);
|
290
280
|
}
|
291
281
|
|
292
282
|
// convert distances to floats
|
293
283
|
for (int i = 0; i < k * n; i++)
|
294
284
|
distances[i] = idistances[i];
|
295
|
-
|
296
285
|
}
|
297
286
|
|
298
|
-
|
299
287
|
indexPQ_stats.nq += n;
|
300
288
|
indexPQ_stats.ncode += n * ntotal;
|
301
289
|
}
|
302
290
|
}
|
303
291
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
void IndexPQStats::reset()
|
309
|
-
{
|
292
|
+
void IndexPQStats::reset() {
|
310
293
|
nq = ncode = n_hamming_pass = 0;
|
311
294
|
}
|
312
295
|
|
313
296
|
IndexPQStats indexPQ_stats;
|
314
297
|
|
315
|
-
|
316
298
|
template <class HammingComputer>
|
317
|
-
static size_t polysemous_inner_loop
|
318
|
-
const IndexPQ
|
319
|
-
const float
|
320
|
-
|
321
|
-
|
322
|
-
|
299
|
+
static size_t polysemous_inner_loop(
|
300
|
+
const IndexPQ& index,
|
301
|
+
const float* dis_table_qi,
|
302
|
+
const uint8_t* q_code,
|
303
|
+
size_t k,
|
304
|
+
float* heap_dis,
|
305
|
+
int64_t* heap_ids) {
|
323
306
|
int M = index.pq.M;
|
324
307
|
int code_size = index.pq.code_size;
|
325
308
|
int ksub = index.pq.ksub;
|
326
309
|
size_t ntotal = index.ntotal;
|
327
310
|
int ht = index.polysemous_ht;
|
328
311
|
|
329
|
-
const uint8_t
|
312
|
+
const uint8_t* b_code = index.codes.data();
|
330
313
|
|
331
314
|
size_t n_pass_i = 0;
|
332
315
|
|
333
|
-
HammingComputer hc
|
316
|
+
HammingComputer hc(q_code, code_size);
|
334
317
|
|
335
318
|
for (int64_t bi = 0; bi < ntotal; bi++) {
|
336
|
-
int hd = hc.hamming
|
319
|
+
int hd = hc.hamming(b_code);
|
337
320
|
|
338
321
|
if (hd < ht) {
|
339
|
-
n_pass_i
|
322
|
+
n_pass_i++;
|
340
323
|
|
341
324
|
float dis = 0;
|
342
|
-
const float
|
325
|
+
const float* dis_table = dis_table_qi;
|
343
326
|
for (int m = 0; m < M; m++) {
|
344
|
-
dis += dis_table
|
327
|
+
dis += dis_table[b_code[m]];
|
345
328
|
dis_table += ksub;
|
346
329
|
}
|
347
330
|
|
348
331
|
if (dis < heap_dis[0]) {
|
349
|
-
maxheap_replace_top
|
332
|
+
maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
|
350
333
|
}
|
351
334
|
}
|
352
335
|
b_code += code_size;
|
@@ -354,201 +337,204 @@ static size_t polysemous_inner_loop (
|
|
354
337
|
return n_pass_i;
|
355
338
|
}
|
356
339
|
|
340
|
+
void IndexPQ::search_core_polysemous(
|
341
|
+
idx_t n,
|
342
|
+
const float* x,
|
343
|
+
idx_t k,
|
344
|
+
float* distances,
|
345
|
+
idx_t* labels) const {
|
346
|
+
FAISS_THROW_IF_NOT(k > 0);
|
357
347
|
|
358
|
-
|
359
|
-
float *distances, idx_t *labels) const
|
360
|
-
{
|
361
|
-
FAISS_THROW_IF_NOT (pq.nbits == 8);
|
348
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
362
349
|
|
363
350
|
// PQ distance tables
|
364
|
-
float
|
365
|
-
ScopeDeleter<float> del
|
366
|
-
pq.compute_distance_tables
|
351
|
+
float* dis_tables = new float[n * pq.ksub * pq.M];
|
352
|
+
ScopeDeleter<float> del(dis_tables);
|
353
|
+
pq.compute_distance_tables(n, x, dis_tables);
|
367
354
|
|
368
355
|
// Hamming embedding queries
|
369
|
-
uint8_t
|
370
|
-
ScopeDeleter<uint8_t> del2
|
356
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
357
|
+
ScopeDeleter<uint8_t> del2(q_codes);
|
371
358
|
|
372
359
|
if (false) {
|
373
|
-
pq.compute_codes
|
360
|
+
pq.compute_codes(x, q_codes, n);
|
374
361
|
} else {
|
375
362
|
#pragma omp parallel for
|
376
363
|
for (idx_t qi = 0; qi < n; qi++) {
|
377
|
-
pq.compute_code_from_distance_table
|
378
|
-
|
379
|
-
|
364
|
+
pq.compute_code_from_distance_table(
|
365
|
+
dis_tables + qi * pq.M * pq.ksub,
|
366
|
+
q_codes + qi * pq.code_size);
|
380
367
|
}
|
381
368
|
}
|
382
369
|
|
383
370
|
size_t n_pass = 0;
|
384
371
|
|
385
|
-
#pragma omp parallel for reduction
|
372
|
+
#pragma omp parallel for reduction(+ : n_pass)
|
386
373
|
for (idx_t qi = 0; qi < n; qi++) {
|
387
|
-
const uint8_t
|
374
|
+
const uint8_t* q_code = q_codes + qi * pq.code_size;
|
388
375
|
|
389
|
-
const float
|
376
|
+
const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
|
390
377
|
|
391
|
-
int64_t
|
392
|
-
float
|
393
|
-
maxheap_heapify
|
378
|
+
int64_t* heap_ids = labels + qi * k;
|
379
|
+
float* heap_dis = distances + qi * k;
|
380
|
+
maxheap_heapify(k, heap_dis, heap_ids);
|
394
381
|
|
395
382
|
if (search_type == ST_polysemous) {
|
396
|
-
|
397
383
|
switch (pq.code_size) {
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
384
|
+
case 4:
|
385
|
+
n_pass += polysemous_inner_loop<HammingComputer4>(
|
386
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
387
|
+
break;
|
388
|
+
case 8:
|
389
|
+
n_pass += polysemous_inner_loop<HammingComputer8>(
|
390
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
391
|
+
break;
|
392
|
+
case 16:
|
393
|
+
n_pass += polysemous_inner_loop<HammingComputer16>(
|
394
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
395
|
+
break;
|
396
|
+
case 32:
|
397
|
+
n_pass += polysemous_inner_loop<HammingComputer32>(
|
398
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
399
|
+
break;
|
400
|
+
case 20:
|
401
|
+
n_pass += polysemous_inner_loop<HammingComputer20>(
|
402
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
403
|
+
break;
|
404
|
+
default:
|
405
|
+
if (pq.code_size % 4 == 0) {
|
406
|
+
n_pass += polysemous_inner_loop<HammingComputerDefault>(
|
407
|
+
*this,
|
408
|
+
dis_table_qi,
|
409
|
+
q_code,
|
410
|
+
k,
|
411
|
+
heap_dis,
|
412
|
+
heap_ids);
|
413
|
+
} else {
|
414
|
+
FAISS_THROW_FMT(
|
415
|
+
"code size %zd not supported for polysemous",
|
416
|
+
pq.code_size);
|
417
|
+
}
|
418
|
+
break;
|
431
419
|
}
|
432
420
|
} else {
|
433
421
|
switch (pq.code_size) {
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
422
|
+
case 8:
|
423
|
+
n_pass += polysemous_inner_loop<GenHammingComputer8>(
|
424
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
425
|
+
break;
|
426
|
+
case 16:
|
427
|
+
n_pass += polysemous_inner_loop<GenHammingComputer16>(
|
428
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
429
|
+
break;
|
430
|
+
case 32:
|
431
|
+
n_pass += polysemous_inner_loop<GenHammingComputer32>(
|
432
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
433
|
+
break;
|
434
|
+
default:
|
435
|
+
if (pq.code_size % 8 == 0) {
|
436
|
+
n_pass += polysemous_inner_loop<GenHammingComputerM8>(
|
437
|
+
*this,
|
438
|
+
dis_table_qi,
|
439
|
+
q_code,
|
440
|
+
k,
|
441
|
+
heap_dis,
|
442
|
+
heap_ids);
|
443
|
+
} else {
|
444
|
+
FAISS_THROW_FMT(
|
445
|
+
"code size %zd not supported for polysemous",
|
446
|
+
pq.code_size);
|
447
|
+
}
|
448
|
+
break;
|
456
449
|
}
|
457
450
|
}
|
458
|
-
maxheap_reorder
|
451
|
+
maxheap_reorder(k, heap_dis, heap_ids);
|
459
452
|
}
|
460
453
|
|
461
454
|
indexPQ_stats.nq += n;
|
462
455
|
indexPQ_stats.ncode += n * ntotal;
|
463
456
|
indexPQ_stats.n_hamming_pass += n_pass;
|
464
|
-
|
465
|
-
|
466
457
|
}
|
467
458
|
|
468
|
-
|
469
459
|
/* The standalone codec interface (just remaps to the PQ functions) */
|
470
|
-
size_t IndexPQ::sa_code_size
|
471
|
-
{
|
460
|
+
size_t IndexPQ::sa_code_size() const {
|
472
461
|
return pq.code_size;
|
473
462
|
}
|
474
463
|
|
475
|
-
void IndexPQ::sa_encode
|
476
|
-
|
477
|
-
pq.compute_codes (x, bytes, n);
|
464
|
+
void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
465
|
+
pq.compute_codes(x, bytes, n);
|
478
466
|
}
|
479
467
|
|
480
|
-
void IndexPQ::sa_decode
|
481
|
-
|
482
|
-
pq.decode (bytes, x, n);
|
468
|
+
void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
469
|
+
pq.decode(bytes, x, n);
|
483
470
|
}
|
484
471
|
|
485
|
-
|
486
|
-
|
487
|
-
|
488
472
|
/*****************************************
|
489
473
|
* Stats of IndexPQ codes
|
490
474
|
******************************************/
|
491
475
|
|
476
|
+
void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
|
477
|
+
const {
|
478
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
479
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
492
480
|
|
481
|
+
pq.compute_codes(x, q_codes, n);
|
493
482
|
|
494
|
-
|
495
|
-
void IndexPQ::hamming_distance_table (idx_t n, const float *x,
|
496
|
-
int32_t *dis) const
|
497
|
-
{
|
498
|
-
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
499
|
-
ScopeDeleter<uint8_t> del (q_codes);
|
500
|
-
|
501
|
-
pq.compute_codes (x, q_codes, n);
|
502
|
-
|
503
|
-
hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
483
|
+
hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
504
484
|
}
|
505
485
|
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
FAISS_THROW_IF_NOT
|
513
|
-
FAISS_THROW_IF_NOT
|
486
|
+
void IndexPQ::hamming_distance_histogram(
|
487
|
+
idx_t n,
|
488
|
+
const float* x,
|
489
|
+
idx_t nb,
|
490
|
+
const float* xb,
|
491
|
+
int64_t* hist) {
|
492
|
+
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
|
493
|
+
FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
|
494
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
514
495
|
|
515
496
|
// Hamming embedding queries
|
516
|
-
uint8_t
|
517
|
-
ScopeDeleter
|
518
|
-
pq.compute_codes
|
497
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
498
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
499
|
+
pq.compute_codes(x, q_codes, n);
|
519
500
|
|
520
|
-
uint8_t
|
521
|
-
ScopeDeleter
|
501
|
+
uint8_t* b_codes;
|
502
|
+
ScopeDeleter<uint8_t> del_b_codes;
|
522
503
|
|
523
504
|
if (xb) {
|
524
|
-
b_codes = new uint8_t
|
525
|
-
del_b_codes.set
|
526
|
-
pq.compute_codes
|
505
|
+
b_codes = new uint8_t[nb * pq.code_size];
|
506
|
+
del_b_codes.set(b_codes);
|
507
|
+
pq.compute_codes(xb, b_codes, nb);
|
527
508
|
} else {
|
528
509
|
nb = ntotal;
|
529
510
|
b_codes = codes.data();
|
530
511
|
}
|
531
512
|
int nbits = pq.M * pq.nbits;
|
532
|
-
memset
|
513
|
+
memset(hist, 0, sizeof(*hist) * (nbits + 1));
|
533
514
|
size_t bs = 256;
|
534
515
|
|
535
516
|
#pragma omp parallel
|
536
517
|
{
|
537
|
-
std::vector<int64_t> histi
|
538
|
-
hamdis_t
|
539
|
-
ScopeDeleter<hamdis_t> del
|
518
|
+
std::vector<int64_t> histi(nbits + 1);
|
519
|
+
hamdis_t* distances = new hamdis_t[nb * bs];
|
520
|
+
ScopeDeleter<hamdis_t> del(distances);
|
540
521
|
#pragma omp for
|
541
522
|
for (idx_t q0 = 0; q0 < n; q0 += bs) {
|
542
523
|
// printf ("dis stats: %zd/%zd\n", q0, n);
|
543
524
|
size_t q1 = q0 + bs;
|
544
|
-
if (q1 > n)
|
525
|
+
if (q1 > n)
|
526
|
+
q1 = n;
|
545
527
|
|
546
|
-
hammings
|
547
|
-
|
548
|
-
|
528
|
+
hammings(
|
529
|
+
q_codes + q0 * pq.code_size,
|
530
|
+
b_codes,
|
531
|
+
q1 - q0,
|
532
|
+
nb,
|
533
|
+
pq.code_size,
|
534
|
+
distances);
|
549
535
|
|
550
536
|
for (size_t i = 0; i < nb * (q1 - q0); i++)
|
551
|
-
histi
|
537
|
+
histi[distances[i]]++;
|
552
538
|
}
|
553
539
|
#pragma omp critical
|
554
540
|
{
|
@@ -556,28 +542,8 @@ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
|
|
556
542
|
hist[i] += histi[i];
|
557
543
|
}
|
558
544
|
}
|
559
|
-
|
560
545
|
}
|
561
546
|
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
547
|
/*****************************************
|
582
548
|
* MultiIndexQuantizer
|
583
549
|
******************************************/
|
@@ -586,90 +552,87 @@ namespace {
|
|
586
552
|
|
587
553
|
template <typename T>
|
588
554
|
struct PreSortedArray {
|
589
|
-
|
590
|
-
const T * x;
|
555
|
+
const T* x;
|
591
556
|
int N;
|
592
557
|
|
593
|
-
explicit PreSortedArray
|
594
|
-
|
595
|
-
void init (const T*x) {
|
558
|
+
explicit PreSortedArray(int N) : N(N) {}
|
559
|
+
void init(const T* x) {
|
596
560
|
this->x = x;
|
597
561
|
}
|
598
562
|
// get smallest value
|
599
|
-
T get_0
|
563
|
+
T get_0() {
|
600
564
|
return x[0];
|
601
565
|
}
|
602
566
|
|
603
567
|
// get delta between n-smallest and n-1 -smallest
|
604
|
-
T get_diff
|
568
|
+
T get_diff(int n) {
|
605
569
|
return x[n] - x[n - 1];
|
606
570
|
}
|
607
571
|
|
608
572
|
// remap orders counted from smallest to indices in array
|
609
|
-
int get_ord
|
573
|
+
int get_ord(int n) {
|
610
574
|
return n;
|
611
575
|
}
|
612
|
-
|
613
576
|
};
|
614
577
|
|
615
578
|
template <typename T>
|
616
579
|
struct ArgSort {
|
617
|
-
const T
|
618
|
-
bool operator()
|
580
|
+
const T* x;
|
581
|
+
bool operator()(size_t i, size_t j) {
|
619
582
|
return x[i] < x[j];
|
620
583
|
}
|
621
584
|
};
|
622
585
|
|
623
|
-
|
624
586
|
/** Array that maintains a permutation of its elements so that the
|
625
587
|
* array's elements are sorted
|
626
588
|
*/
|
627
589
|
template <typename T>
|
628
590
|
struct SortedArray {
|
629
|
-
const T
|
591
|
+
const T* x;
|
630
592
|
int N;
|
631
593
|
std::vector<int> perm;
|
632
594
|
|
633
|
-
explicit SortedArray
|
595
|
+
explicit SortedArray(int N) {
|
634
596
|
this->N = N;
|
635
|
-
perm.resize
|
597
|
+
perm.resize(N);
|
636
598
|
}
|
637
599
|
|
638
|
-
void init
|
600
|
+
void init(const T* x) {
|
639
601
|
this->x = x;
|
640
602
|
for (int n = 0; n < N; n++)
|
641
603
|
perm[n] = n;
|
642
|
-
ArgSort<T> cmp = {x
|
643
|
-
std::sort
|
604
|
+
ArgSort<T> cmp = {x};
|
605
|
+
std::sort(perm.begin(), perm.end(), cmp);
|
644
606
|
}
|
645
607
|
|
646
608
|
// get smallest value
|
647
|
-
T get_0
|
609
|
+
T get_0() {
|
648
610
|
return x[perm[0]];
|
649
611
|
}
|
650
612
|
|
651
613
|
// get delta between n-smallest and n-1 -smallest
|
652
|
-
T get_diff
|
614
|
+
T get_diff(int n) {
|
653
615
|
return x[perm[n]] - x[perm[n - 1]];
|
654
616
|
}
|
655
617
|
|
656
618
|
// remap orders counted from smallest to indices in array
|
657
|
-
int get_ord
|
619
|
+
int get_ord(int n) {
|
658
620
|
return perm[n];
|
659
621
|
}
|
660
622
|
};
|
661
623
|
|
662
|
-
|
663
|
-
|
664
624
|
/** Array has n values. Sort the k first ones and copy the other ones
|
665
625
|
* into elements k..n-1
|
666
626
|
*/
|
667
627
|
template <class C>
|
668
|
-
void partial_sort
|
669
|
-
|
628
|
+
void partial_sort(
|
629
|
+
int k,
|
630
|
+
int n,
|
631
|
+
const typename C::T* vals,
|
632
|
+
typename C::TI* perm) {
|
670
633
|
// insert first k elts in heap
|
671
634
|
for (int i = 1; i < k; i++) {
|
672
|
-
indirect_heap_push<C>
|
635
|
+
indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
|
673
636
|
}
|
674
637
|
|
675
638
|
// insert next n - k elts in heap
|
@@ -678,8 +641,8 @@ void partial_sort (int k, int n,
|
|
678
641
|
typename C::TI top = perm[0];
|
679
642
|
|
680
643
|
if (C::cmp(vals[top], vals[id])) {
|
681
|
-
indirect_heap_pop<C>
|
682
|
-
indirect_heap_push<C>
|
644
|
+
indirect_heap_pop<C>(k, vals, perm);
|
645
|
+
indirect_heap_push<C>(k, vals, perm, id);
|
683
646
|
perm[i] = top;
|
684
647
|
} else {
|
685
648
|
// nothing, elt at i is good where it is.
|
@@ -689,7 +652,7 @@ void partial_sort (int k, int n,
|
|
689
652
|
// order the k first elements in heap
|
690
653
|
for (int i = k - 1; i > 0; i--) {
|
691
654
|
typename C::TI top = perm[0];
|
692
|
-
indirect_heap_pop<C>
|
655
|
+
indirect_heap_pop<C>(i + 1, vals, perm);
|
693
656
|
perm[i] = top;
|
694
657
|
}
|
695
658
|
}
|
@@ -697,69 +660,67 @@ void partial_sort (int k, int n,
|
|
697
660
|
/** same as SortedArray, but only the k first elements are sorted */
|
698
661
|
template <typename T>
|
699
662
|
struct SemiSortedArray {
|
700
|
-
const T
|
663
|
+
const T* x;
|
701
664
|
int N;
|
702
665
|
|
703
666
|
// type of the heap: CMax = sort ascending
|
704
667
|
typedef CMax<T, int> HC;
|
705
668
|
std::vector<int> perm;
|
706
669
|
|
707
|
-
int k;
|
670
|
+
int k; // k elements are sorted
|
708
671
|
|
709
672
|
int initial_k, k_factor;
|
710
673
|
|
711
|
-
explicit SemiSortedArray
|
674
|
+
explicit SemiSortedArray(int N) {
|
712
675
|
this->N = N;
|
713
|
-
perm.resize
|
714
|
-
perm.resize
|
676
|
+
perm.resize(N);
|
677
|
+
perm.resize(N);
|
715
678
|
initial_k = 3;
|
716
679
|
k_factor = 4;
|
717
680
|
}
|
718
681
|
|
719
|
-
void init
|
682
|
+
void init(const T* x) {
|
720
683
|
this->x = x;
|
721
684
|
for (int n = 0; n < N; n++)
|
722
685
|
perm[n] = n;
|
723
686
|
k = 0;
|
724
|
-
grow
|
687
|
+
grow(initial_k);
|
725
688
|
}
|
726
689
|
|
727
690
|
/// grow the sorted part of the array to size next_k
|
728
|
-
void grow
|
691
|
+
void grow(int next_k) {
|
729
692
|
if (next_k < N) {
|
730
|
-
partial_sort<HC>
|
693
|
+
partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
|
731
694
|
k = next_k;
|
732
695
|
} else { // full sort of remainder of array
|
733
|
-
ArgSort<T> cmp = {x
|
734
|
-
std::sort
|
696
|
+
ArgSort<T> cmp = {x};
|
697
|
+
std::sort(perm.begin() + k, perm.end(), cmp);
|
735
698
|
k = N;
|
736
699
|
}
|
737
700
|
}
|
738
701
|
|
739
702
|
// get smallest value
|
740
|
-
T get_0
|
703
|
+
T get_0() {
|
741
704
|
return x[perm[0]];
|
742
705
|
}
|
743
706
|
|
744
707
|
// get delta between n-smallest and n-1 -smallest
|
745
|
-
T get_diff
|
708
|
+
T get_diff(int n) {
|
746
709
|
if (n >= k) {
|
747
710
|
// want to keep powers of 2 - 1
|
748
711
|
int next_k = (k + 1) * k_factor - 1;
|
749
|
-
grow
|
712
|
+
grow(next_k);
|
750
713
|
}
|
751
714
|
return x[perm[n]] - x[perm[n - 1]];
|
752
715
|
}
|
753
716
|
|
754
717
|
// remap orders counted from smallest to indices in array
|
755
|
-
int get_ord
|
756
|
-
assert
|
718
|
+
int get_ord(int n) {
|
719
|
+
assert(n < k);
|
757
720
|
return perm[n];
|
758
721
|
}
|
759
722
|
};
|
760
723
|
|
761
|
-
|
762
|
-
|
763
724
|
/*****************************************
|
764
725
|
* Find the k smallest sums of M terms, where each term is taken in a
|
765
726
|
* table x of n values.
|
@@ -779,19 +740,19 @@ struct SemiSortedArray {
|
|
779
740
|
* occasionally several t's are returned.
|
780
741
|
*
|
781
742
|
* @param x size M * n, values to add up
|
782
|
-
* @
|
743
|
+
* @param k nb of results to retrieve
|
783
744
|
* @param M nb of terms
|
784
745
|
* @param n nb of distinct values
|
785
746
|
* @param sums output, size k, sorted
|
786
|
-
* @
|
747
|
+
* @param terms output, size k, with encoding as above
|
787
748
|
*
|
788
749
|
******************************************/
|
789
750
|
template <typename T, class SSA, bool use_seen>
|
790
751
|
struct MinSumK {
|
791
|
-
int K;
|
792
|
-
int M;
|
752
|
+
int K; ///< nb of sums to return
|
753
|
+
int M; ///< nb of elements to sum up
|
793
754
|
int nbit; ///< nb of bits to encode one entry
|
794
|
-
int N;
|
755
|
+
int N; ///< nb of possible elements for each of the M terms
|
795
756
|
|
796
757
|
/** the heap.
|
797
758
|
* We use a heap to maintain a queue of sums, with the associated
|
@@ -799,21 +760,20 @@ struct MinSumK {
|
|
799
760
|
*/
|
800
761
|
typedef CMin<T, int64_t> HC;
|
801
762
|
size_t heap_capacity, heap_size;
|
802
|
-
T
|
803
|
-
int64_t
|
763
|
+
T* bh_val;
|
764
|
+
int64_t* bh_ids;
|
804
765
|
|
805
|
-
std::vector
|
766
|
+
std::vector<SSA> ssx;
|
806
767
|
|
807
768
|
// all results get pushed several times. When there are ties, they
|
808
769
|
// are popped interleaved with others, so it is not easy to
|
809
770
|
// identify them. Therefore, this bit array just marks elements
|
810
771
|
// that were seen before.
|
811
|
-
std::vector
|
772
|
+
std::vector<uint8_t> seen;
|
812
773
|
|
813
|
-
MinSumK
|
814
|
-
K(K), M(M), nbit(nbit), N(N) {
|
774
|
+
MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
|
815
775
|
heap_capacity = K * M;
|
816
|
-
assert
|
776
|
+
assert(N <= (1 << nbit));
|
817
777
|
|
818
778
|
// we'll do k steps, each step pushes at most M vals
|
819
779
|
bh_val = new T[heap_capacity];
|
@@ -821,29 +781,27 @@ struct MinSumK {
|
|
821
781
|
|
822
782
|
if (use_seen) {
|
823
783
|
int64_t n_ids = weight(M);
|
824
|
-
seen.resize
|
784
|
+
seen.resize((n_ids + 7) / 8);
|
825
785
|
}
|
826
786
|
|
827
787
|
for (int m = 0; m < M; m++)
|
828
|
-
ssx.push_back
|
829
|
-
|
788
|
+
ssx.push_back(SSA(N));
|
830
789
|
}
|
831
790
|
|
832
|
-
int64_t weight
|
791
|
+
int64_t weight(int i) {
|
833
792
|
return 1 << (i * nbit);
|
834
793
|
}
|
835
794
|
|
836
|
-
bool is_seen
|
795
|
+
bool is_seen(int64_t i) {
|
837
796
|
return (seen[i >> 3] >> (i & 7)) & 1;
|
838
797
|
}
|
839
798
|
|
840
|
-
void mark_seen
|
799
|
+
void mark_seen(int64_t i) {
|
841
800
|
if (use_seen)
|
842
|
-
seen
|
801
|
+
seen[i >> 3] |= 1 << (i & 7);
|
843
802
|
}
|
844
803
|
|
845
|
-
void run
|
846
|
-
T * sums, int64_t * terms) {
|
804
|
+
void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
|
847
805
|
heap_size = 0;
|
848
806
|
|
849
807
|
for (int m = 0; m < M; m++) {
|
@@ -854,38 +812,41 @@ struct MinSumK {
|
|
854
812
|
{ // initial result: take min for all elements
|
855
813
|
T sum = 0;
|
856
814
|
terms[0] = 0;
|
857
|
-
mark_seen
|
815
|
+
mark_seen(0);
|
858
816
|
for (int m = 0; m < M; m++) {
|
859
817
|
sum += ssx[m].get_0();
|
860
818
|
}
|
861
819
|
sums[0] = sum;
|
862
820
|
for (int m = 0; m < M; m++) {
|
863
|
-
heap_push<HC>
|
864
|
-
|
865
|
-
|
821
|
+
heap_push<HC>(
|
822
|
+
++heap_size,
|
823
|
+
bh_val,
|
824
|
+
bh_ids,
|
825
|
+
sum + ssx[m].get_diff(1),
|
826
|
+
weight(m));
|
866
827
|
}
|
867
828
|
}
|
868
829
|
|
869
830
|
for (int k = 1; k < K; k++) {
|
870
831
|
// pop smallest value from heap
|
871
|
-
if (use_seen) {// skip already seen elements
|
872
|
-
while (is_seen
|
873
|
-
assert
|
874
|
-
heap_pop<HC>
|
832
|
+
if (use_seen) { // skip already seen elements
|
833
|
+
while (is_seen(bh_ids[0])) {
|
834
|
+
assert(heap_size > 0);
|
835
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
875
836
|
}
|
876
837
|
}
|
877
|
-
assert
|
838
|
+
assert(heap_size > 0);
|
878
839
|
|
879
840
|
T sum = sums[k] = bh_val[0];
|
880
841
|
int64_t ti = terms[k] = bh_ids[0];
|
881
842
|
|
882
843
|
if (use_seen) {
|
883
|
-
mark_seen
|
884
|
-
heap_pop<HC>
|
844
|
+
mark_seen(ti);
|
845
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
885
846
|
} else {
|
886
847
|
do {
|
887
|
-
heap_pop<HC>
|
888
|
-
}
|
848
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
849
|
+
} while (heap_size > 0 && bh_ids[0] == ti);
|
889
850
|
}
|
890
851
|
|
891
852
|
// enqueue followers
|
@@ -893,9 +854,10 @@ struct MinSumK {
|
|
893
854
|
for (int m = 0; m < M; m++) {
|
894
855
|
int64_t n = ii & ((1L << nbit) - 1);
|
895
856
|
ii >>= nbit;
|
896
|
-
if (n + 1 >= N)
|
857
|
+
if (n + 1 >= N)
|
858
|
+
continue;
|
897
859
|
|
898
|
-
enqueue_follower
|
860
|
+
enqueue_follower(ti, m, n, sum);
|
899
861
|
}
|
900
862
|
}
|
901
863
|
|
@@ -922,37 +884,29 @@ struct MinSumK {
|
|
922
884
|
}
|
923
885
|
}
|
924
886
|
|
925
|
-
|
926
|
-
void enqueue_follower (int64_t ti, int m, int n, T sum) {
|
887
|
+
void enqueue_follower(int64_t ti, int m, int n, T sum) {
|
927
888
|
T next_sum = sum + ssx[m].get_diff(n + 1);
|
928
889
|
int64_t next_ti = ti + weight(m);
|
929
|
-
heap_push<HC>
|
890
|
+
heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
|
930
891
|
}
|
931
892
|
|
932
|
-
~MinSumK
|
933
|
-
delete
|
934
|
-
delete
|
893
|
+
~MinSumK() {
|
894
|
+
delete[] bh_ids;
|
895
|
+
delete[] bh_val;
|
935
896
|
}
|
936
897
|
};
|
937
898
|
|
938
899
|
} // anonymous namespace
|
939
900
|
|
940
|
-
|
941
|
-
|
942
|
-
size_t M,
|
943
|
-
size_t nbits):
|
944
|
-
Index(d, METRIC_L2), pq(d, M, nbits)
|
945
|
-
{
|
901
|
+
MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
|
902
|
+
: Index(d, METRIC_L2), pq(d, M, nbits) {
|
946
903
|
is_trained = false;
|
947
904
|
pq.verbose = verbose;
|
948
905
|
}
|
949
906
|
|
950
|
-
|
951
|
-
|
952
|
-
void MultiIndexQuantizer::train(idx_t n, const float *x)
|
953
|
-
{
|
907
|
+
void MultiIndexQuantizer::train(idx_t n, const float* x) {
|
954
908
|
pq.verbose = verbose;
|
955
|
-
pq.train
|
909
|
+
pq.train(n, x);
|
956
910
|
is_trained = true;
|
957
911
|
// count virtual elements in index
|
958
912
|
ntotal = 1;
|
@@ -960,10 +914,16 @@ void MultiIndexQuantizer::train(idx_t n, const float *x)
|
|
960
914
|
ntotal *= pq.ksub;
|
961
915
|
}
|
962
916
|
|
917
|
+
void MultiIndexQuantizer::search(
|
918
|
+
idx_t n,
|
919
|
+
const float* x,
|
920
|
+
idx_t k,
|
921
|
+
float* distances,
|
922
|
+
idx_t* labels) const {
|
923
|
+
if (n == 0)
|
924
|
+
return;
|
963
925
|
|
964
|
-
|
965
|
-
float *distances, idx_t *labels) const {
|
966
|
-
if (n == 0) return;
|
926
|
+
FAISS_THROW_IF_NOT(k > 0);
|
967
927
|
|
968
928
|
// the allocation just below can be severe...
|
969
929
|
idx_t bs = 32768;
|
@@ -971,27 +931,28 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
|
|
971
931
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
972
932
|
idx_t i1 = std::min(i0 + bs, n);
|
973
933
|
if (verbose) {
|
974
|
-
printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
|
975
|
-
|
934
|
+
printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
|
935
|
+
" / %" PRId64 "\n",
|
936
|
+
i0,
|
937
|
+
i1,
|
938
|
+
n);
|
976
939
|
}
|
977
|
-
search
|
978
|
-
distances + i0 * k,
|
979
|
-
labels + i0 * k);
|
940
|
+
search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
|
980
941
|
}
|
981
942
|
return;
|
982
943
|
}
|
983
944
|
|
984
|
-
float
|
985
|
-
ScopeDeleter<float> del
|
945
|
+
float* dis_tables = new float[n * pq.ksub * pq.M];
|
946
|
+
ScopeDeleter<float> del(dis_tables);
|
986
947
|
|
987
|
-
pq.compute_distance_tables
|
948
|
+
pq.compute_distance_tables(n, x, dis_tables);
|
988
949
|
|
989
950
|
if (k == 1) {
|
990
951
|
// simple version that just finds the min in each table
|
991
952
|
|
992
953
|
#pragma omp parallel for
|
993
954
|
for (int i = 0; i < n; i++) {
|
994
|
-
const float
|
955
|
+
const float* dis_table = dis_tables + i * pq.ksub * pq.M;
|
995
956
|
float dis = 0;
|
996
957
|
idx_t label = 0;
|
997
958
|
|
@@ -1010,32 +971,27 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
|
|
1010
971
|
dis_table += pq.ksub;
|
1011
972
|
}
|
1012
973
|
|
1013
|
-
distances
|
1014
|
-
labels
|
974
|
+
distances[i] = dis;
|
975
|
+
labels[i] = label;
|
1015
976
|
}
|
1016
977
|
|
1017
|
-
|
1018
978
|
} else {
|
1019
|
-
|
1020
|
-
#pragma omp parallel if(n > 1)
|
979
|
+
#pragma omp parallel if (n > 1)
|
1021
980
|
{
|
1022
|
-
MinSumK
|
1023
|
-
|
981
|
+
MinSumK<float, SemiSortedArray<float>, false> msk(
|
982
|
+
k, pq.M, pq.nbits, pq.ksub);
|
1024
983
|
#pragma omp for
|
1025
984
|
for (int i = 0; i < n; i++) {
|
1026
|
-
msk.run
|
1027
|
-
|
1028
|
-
|
985
|
+
msk.run(dis_tables + i * pq.ksub * pq.M,
|
986
|
+
pq.ksub,
|
987
|
+
distances + i * k,
|
988
|
+
labels + i * k);
|
1029
989
|
}
|
1030
990
|
}
|
1031
991
|
}
|
1032
|
-
|
1033
992
|
}
|
1034
993
|
|
1035
|
-
|
1036
|
-
void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
|
1037
|
-
{
|
1038
|
-
|
994
|
+
void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
|
1039
995
|
int64_t jj = key;
|
1040
996
|
for (int m = 0; m < pq.M; m++) {
|
1041
997
|
int64_t n = jj & ((1L << pq.nbits) - 1);
|
@@ -1046,65 +1002,53 @@ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
|
|
1046
1002
|
}
|
1047
1003
|
|
1048
1004
|
void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1005
|
+
FAISS_THROW_MSG(
|
1006
|
+
"This index has virtual elements, "
|
1007
|
+
"it does not support add");
|
1052
1008
|
}
|
1053
1009
|
|
1054
|
-
void MultiIndexQuantizer::reset
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1010
|
+
void MultiIndexQuantizer::reset() {
|
1011
|
+
FAISS_THROW_MSG(
|
1012
|
+
"This index has virtual elements, "
|
1013
|
+
"it does not support reset");
|
1058
1014
|
}
|
1059
1015
|
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
1016
|
/*****************************************
|
1070
1017
|
* MultiIndexQuantizer2
|
1071
1018
|
******************************************/
|
1072
1019
|
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
Index
|
1078
|
-
|
1079
|
-
|
1080
|
-
assign_indexes.resize (M);
|
1020
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2(
|
1021
|
+
int d,
|
1022
|
+
size_t M,
|
1023
|
+
size_t nbits,
|
1024
|
+
Index** indexes)
|
1025
|
+
: MultiIndexQuantizer(d, M, nbits) {
|
1026
|
+
assign_indexes.resize(M);
|
1081
1027
|
for (int i = 0; i < M; i++) {
|
1082
1028
|
FAISS_THROW_IF_NOT_MSG(
|
1083
|
-
|
1084
|
-
|
1029
|
+
indexes[i]->d == pq.dsub,
|
1030
|
+
"Provided sub-index has incorrect size");
|
1085
1031
|
assign_indexes[i] = indexes[i];
|
1086
1032
|
}
|
1087
1033
|
own_fields = false;
|
1088
1034
|
}
|
1089
1035
|
|
1090
|
-
MultiIndexQuantizer2::MultiIndexQuantizer2
|
1091
|
-
int d,
|
1092
|
-
|
1093
|
-
Index
|
1094
|
-
|
1095
|
-
{
|
1036
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2(
|
1037
|
+
int d,
|
1038
|
+
size_t nbits,
|
1039
|
+
Index* assign_index_0,
|
1040
|
+
Index* assign_index_1)
|
1041
|
+
: MultiIndexQuantizer(d, 2, nbits) {
|
1096
1042
|
FAISS_THROW_IF_NOT_MSG(
|
1097
|
-
assign_index_0->d == pq.dsub &&
|
1098
|
-
assign_index_1->d == pq.dsub,
|
1043
|
+
assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
|
1099
1044
|
"Provided sub-index has incorrect size");
|
1100
|
-
assign_indexes.resize
|
1101
|
-
assign_indexes
|
1102
|
-
assign_indexes
|
1045
|
+
assign_indexes.resize(2);
|
1046
|
+
assign_indexes[0] = assign_index_0;
|
1047
|
+
assign_indexes[1] = assign_index_1;
|
1103
1048
|
own_fields = false;
|
1104
1049
|
}
|
1105
1050
|
|
1106
|
-
void MultiIndexQuantizer2::train(idx_t n, const float* x)
|
1107
|
-
{
|
1051
|
+
void MultiIndexQuantizer2::train(idx_t n, const float* x) {
|
1108
1052
|
MultiIndexQuantizer::train(n, x);
|
1109
1053
|
// add centroids to sub-indexes
|
1110
1054
|
for (int i = 0; i < pq.M; i++) {
|
@@ -1112,15 +1056,17 @@ void MultiIndexQuantizer2::train(idx_t n, const float* x)
|
|
1112
1056
|
}
|
1113
1057
|
}
|
1114
1058
|
|
1115
|
-
|
1116
1059
|
void MultiIndexQuantizer2::search(
|
1117
|
-
idx_t n,
|
1118
|
-
float*
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1060
|
+
idx_t n,
|
1061
|
+
const float* x,
|
1062
|
+
idx_t K,
|
1063
|
+
float* distances,
|
1064
|
+
idx_t* labels) const {
|
1065
|
+
if (n == 0)
|
1066
|
+
return;
|
1122
1067
|
|
1123
1068
|
int k2 = std::min(K, int64_t(pq.ksub));
|
1069
|
+
FAISS_THROW_IF_NOT(k2);
|
1124
1070
|
|
1125
1071
|
int64_t M = pq.M;
|
1126
1072
|
int64_t dsub = pq.dsub, ksub = pq.ksub;
|
@@ -1131,8 +1077,8 @@ void MultiIndexQuantizer2::search(
|
|
1131
1077
|
std::vector<float> xsub(n * dsub);
|
1132
1078
|
|
1133
1079
|
for (int m = 0; m < M; m++) {
|
1134
|
-
float
|
1135
|
-
const float
|
1080
|
+
float* xdest = xsub.data();
|
1081
|
+
const float* xsrc = x + m * dsub;
|
1136
1082
|
for (int j = 0; j < n; j++) {
|
1137
1083
|
memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
|
1138
1084
|
xsrc += d;
|
@@ -1140,14 +1086,12 @@ void MultiIndexQuantizer2::search(
|
|
1140
1086
|
}
|
1141
1087
|
|
1142
1088
|
assign_indexes[m]->search(
|
1143
|
-
|
1144
|
-
&sub_dis[k2 * n * m],
|
1145
|
-
&sub_ids[k2 * n * m]);
|
1089
|
+
n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
|
1146
1090
|
}
|
1147
1091
|
|
1148
1092
|
if (K == 1) {
|
1149
1093
|
// simple version that just finds the min in each table
|
1150
|
-
assert
|
1094
|
+
assert(k2 == 1);
|
1151
1095
|
|
1152
1096
|
for (int i = 0; i < n; i++) {
|
1153
1097
|
float dis = 0;
|
@@ -1159,30 +1103,28 @@ void MultiIndexQuantizer2::search(
|
|
1159
1103
|
dis += vmin;
|
1160
1104
|
label |= lmin << (m * pq.nbits);
|
1161
1105
|
}
|
1162
|
-
distances
|
1163
|
-
labels
|
1106
|
+
distances[i] = dis;
|
1107
|
+
labels[i] = label;
|
1164
1108
|
}
|
1165
1109
|
|
1166
1110
|
} else {
|
1167
|
-
|
1168
|
-
#pragma omp parallel if(n > 1)
|
1111
|
+
#pragma omp parallel if (n > 1)
|
1169
1112
|
{
|
1170
|
-
MinSumK
|
1171
|
-
|
1113
|
+
MinSumK<float, PreSortedArray<float>, false> msk(
|
1114
|
+
K, pq.M, pq.nbits, k2);
|
1172
1115
|
#pragma omp for
|
1173
1116
|
for (int i = 0; i < n; i++) {
|
1174
|
-
idx_t
|
1175
|
-
msk.run
|
1176
|
-
distances + i * K, li);
|
1117
|
+
idx_t* li = labels + i * K;
|
1118
|
+
msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
|
1177
1119
|
|
1178
1120
|
// remap ids
|
1179
1121
|
|
1180
|
-
const idx_t
|
1122
|
+
const idx_t* idmap0 = sub_ids.data() + i * k2;
|
1181
1123
|
int64_t ld_idmap = k2 * n;
|
1182
1124
|
int64_t mask1 = ksub - 1L;
|
1183
1125
|
|
1184
1126
|
for (int k = 0; k < K; k++) {
|
1185
|
-
const idx_t
|
1127
|
+
const idx_t* idmap = idmap0;
|
1186
1128
|
int64_t vin = li[k];
|
1187
1129
|
int64_t vout = 0;
|
1188
1130
|
int bs = 0;
|
@@ -1200,5 +1142,4 @@ void MultiIndexQuantizer2::search(
|
|
1200
1142
|
}
|
1201
1143
|
}
|
1202
1144
|
|
1203
|
-
|
1204
1145
|
} // namespace faiss
|