faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -10,15 +10,15 @@
|
|
|
10
10
|
#include <faiss/IndexPQ.h>
|
|
11
11
|
|
|
12
12
|
#include <cinttypes>
|
|
13
|
+
#include <cmath>
|
|
13
14
|
#include <cstddef>
|
|
14
|
-
#include <cstring>
|
|
15
15
|
#include <cstdio>
|
|
16
|
-
#include <
|
|
16
|
+
#include <cstring>
|
|
17
17
|
|
|
18
18
|
#include <algorithm>
|
|
19
19
|
|
|
20
|
-
#include <faiss/impl/FaissAssert.h>
|
|
21
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
22
|
#include <faiss/utils/hamming.h>
|
|
23
23
|
|
|
24
24
|
namespace faiss {
|
|
@@ -27,19 +27,17 @@ namespace faiss {
|
|
|
27
27
|
* IndexPQ implementation
|
|
28
28
|
********************************************************/
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
Index(d, metric), pq(d, M, nbits)
|
|
33
|
-
{
|
|
30
|
+
IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
|
|
31
|
+
: IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
|
|
34
32
|
is_trained = false;
|
|
35
33
|
do_polysemous_training = false;
|
|
36
34
|
polysemous_ht = nbits * M + 1;
|
|
37
35
|
search_type = ST_PQ;
|
|
38
36
|
encode_signs = false;
|
|
37
|
+
code_size = pq.code_size;
|
|
39
38
|
}
|
|
40
39
|
|
|
41
|
-
IndexPQ::IndexPQ
|
|
42
|
-
{
|
|
40
|
+
IndexPQ::IndexPQ() {
|
|
43
41
|
metric_type = METRIC_L2;
|
|
44
42
|
is_trained = false;
|
|
45
43
|
do_polysemous_training = false;
|
|
@@ -48,10 +46,8 @@ IndexPQ::IndexPQ ()
|
|
|
48
46
|
encode_signs = false;
|
|
49
47
|
}
|
|
50
48
|
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
{
|
|
54
|
-
if (!do_polysemous_training) { // standard training
|
|
49
|
+
void IndexPQ::train(idx_t n, const float* x) {
|
|
50
|
+
if (!do_polysemous_training) { // standard training
|
|
55
51
|
pq.train(n, x);
|
|
56
52
|
} else {
|
|
57
53
|
idx_t ntrain_perm = polysemous_training.ntrain_permutation;
|
|
@@ -59,92 +55,38 @@ void IndexPQ::train (idx_t n, const float *x)
|
|
|
59
55
|
if (ntrain_perm > n / 4)
|
|
60
56
|
ntrain_perm = n / 4;
|
|
61
57
|
if (verbose) {
|
|
62
|
-
printf
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
58
|
+
printf("PQ training on %" PRId64 " points, remains %" PRId64
|
|
59
|
+
" points: "
|
|
60
|
+
"training polysemous on %s\n",
|
|
61
|
+
n - ntrain_perm,
|
|
62
|
+
ntrain_perm,
|
|
63
|
+
ntrain_perm == 0 ? "centroids" : "these");
|
|
66
64
|
}
|
|
67
65
|
pq.train(n - ntrain_perm, x);
|
|
68
66
|
|
|
69
|
-
polysemous_training.optimize_pq_for_hamming
|
|
70
|
-
|
|
67
|
+
polysemous_training.optimize_pq_for_hamming(
|
|
68
|
+
pq, ntrain_perm, x + (n - ntrain_perm) * d);
|
|
71
69
|
}
|
|
72
70
|
is_trained = true;
|
|
73
71
|
}
|
|
74
72
|
|
|
75
|
-
|
|
76
|
-
void IndexPQ::add (idx_t n, const float *x)
|
|
77
|
-
{
|
|
78
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
79
|
-
codes.resize ((n + ntotal) * pq.code_size);
|
|
80
|
-
pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
|
|
81
|
-
ntotal += n;
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
size_t IndexPQ::remove_ids (const IDSelector & sel)
|
|
86
|
-
{
|
|
87
|
-
idx_t j = 0;
|
|
88
|
-
for (idx_t i = 0; i < ntotal; i++) {
|
|
89
|
-
if (sel.is_member (i)) {
|
|
90
|
-
// should be removed
|
|
91
|
-
} else {
|
|
92
|
-
if (i > j) {
|
|
93
|
-
memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
|
|
94
|
-
}
|
|
95
|
-
j++;
|
|
96
|
-
}
|
|
97
|
-
}
|
|
98
|
-
size_t nremove = ntotal - j;
|
|
99
|
-
if (nremove > 0) {
|
|
100
|
-
ntotal = j;
|
|
101
|
-
codes.resize (ntotal * pq.code_size);
|
|
102
|
-
}
|
|
103
|
-
return nremove;
|
|
104
|
-
}
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
void IndexPQ::reset()
|
|
108
|
-
{
|
|
109
|
-
codes.clear();
|
|
110
|
-
ntotal = 0;
|
|
111
|
-
}
|
|
112
|
-
|
|
113
|
-
void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
|
|
114
|
-
{
|
|
115
|
-
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
116
|
-
for (idx_t i = 0; i < ni; i++) {
|
|
117
|
-
const uint8_t * code = &codes[(i0 + i) * pq.code_size];
|
|
118
|
-
pq.decode (code, recons + i * d);
|
|
119
|
-
}
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
void IndexPQ::reconstruct (idx_t key, float * recons) const
|
|
124
|
-
{
|
|
125
|
-
FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
|
|
126
|
-
pq.decode (&codes[key * pq.code_size], recons);
|
|
127
|
-
}
|
|
128
|
-
|
|
129
|
-
|
|
130
73
|
namespace {
|
|
131
74
|
|
|
132
|
-
template<class PQDecoder>
|
|
133
|
-
struct PQDistanceComputer: DistanceComputer {
|
|
75
|
+
template <class PQDecoder>
|
|
76
|
+
struct PQDistanceComputer : DistanceComputer {
|
|
134
77
|
size_t d;
|
|
135
78
|
MetricType metric;
|
|
136
79
|
Index::idx_t nb;
|
|
137
|
-
const uint8_t
|
|
80
|
+
const uint8_t* codes;
|
|
138
81
|
size_t code_size;
|
|
139
|
-
const ProductQuantizer
|
|
140
|
-
const float
|
|
82
|
+
const ProductQuantizer& pq;
|
|
83
|
+
const float* sdc;
|
|
141
84
|
std::vector<float> precomputed_table;
|
|
142
85
|
size_t ndis;
|
|
143
86
|
|
|
144
|
-
float operator
|
|
145
|
-
|
|
146
|
-
const
|
|
147
|
-
const float *dt = precomputed_table.data();
|
|
87
|
+
float operator()(idx_t i) override {
|
|
88
|
+
const uint8_t* code = codes + i * code_size;
|
|
89
|
+
const float* dt = precomputed_table.data();
|
|
148
90
|
PQDecoder decoder(code, pq.nbits);
|
|
149
91
|
float accu = 0;
|
|
150
92
|
for (int j = 0; j < pq.M; j++) {
|
|
@@ -155,13 +97,12 @@ struct PQDistanceComputer: DistanceComputer {
|
|
|
155
97
|
return accu;
|
|
156
98
|
}
|
|
157
99
|
|
|
158
|
-
float symmetric_dis(idx_t i, idx_t j) override
|
|
159
|
-
{
|
|
100
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
160
101
|
FAISS_THROW_IF_NOT(sdc);
|
|
161
|
-
const float
|
|
102
|
+
const float* sdci = sdc;
|
|
162
103
|
float accu = 0;
|
|
163
|
-
PQDecoder codei
|
|
164
|
-
PQDecoder codej
|
|
104
|
+
PQDecoder codei(codes + i * code_size, pq.nbits);
|
|
105
|
+
PQDecoder codej(codes + j * code_size, pq.nbits);
|
|
165
106
|
|
|
166
107
|
for (int l = 0; l < pq.M; l++) {
|
|
167
108
|
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
|
|
@@ -171,8 +112,7 @@ struct PQDistanceComputer: DistanceComputer {
|
|
|
171
112
|
return accu;
|
|
172
113
|
}
|
|
173
114
|
|
|
174
|
-
explicit PQDistanceComputer(const IndexPQ& storage)
|
|
175
|
-
: pq(storage.pq) {
|
|
115
|
+
explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
|
|
176
116
|
precomputed_table.resize(pq.M * pq.ksub);
|
|
177
117
|
nb = storage.ntotal;
|
|
178
118
|
d = storage.d;
|
|
@@ -187,21 +127,18 @@ struct PQDistanceComputer: DistanceComputer {
|
|
|
187
127
|
ndis = 0;
|
|
188
128
|
}
|
|
189
129
|
|
|
190
|
-
void set_query(const float
|
|
130
|
+
void set_query(const float* x) override {
|
|
191
131
|
if (metric == METRIC_L2) {
|
|
192
132
|
pq.compute_distance_table(x, precomputed_table.data());
|
|
193
133
|
} else {
|
|
194
134
|
pq.compute_inner_prod_table(x, precomputed_table.data());
|
|
195
135
|
}
|
|
196
|
-
|
|
197
136
|
}
|
|
198
137
|
};
|
|
199
138
|
|
|
139
|
+
} // namespace
|
|
200
140
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
DistanceComputer * IndexPQ::get_distance_computer() const {
|
|
141
|
+
DistanceComputer* IndexPQ::get_distance_computer() const {
|
|
205
142
|
if (pq.nbits == 8) {
|
|
206
143
|
return new PQDistanceComputer<PQDecoder8>(*this);
|
|
207
144
|
} else if (pq.nbits == 16) {
|
|
@@ -211,142 +148,142 @@ DistanceComputer * IndexPQ::get_distance_computer() const {
|
|
|
211
148
|
}
|
|
212
149
|
}
|
|
213
150
|
|
|
214
|
-
|
|
215
151
|
/*****************************************
|
|
216
152
|
* IndexPQ polysemous search routines
|
|
217
153
|
******************************************/
|
|
218
154
|
|
|
155
|
+
void IndexPQ::search(
|
|
156
|
+
idx_t n,
|
|
157
|
+
const float* x,
|
|
158
|
+
idx_t k,
|
|
159
|
+
float* distances,
|
|
160
|
+
idx_t* labels) const {
|
|
161
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
219
162
|
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
void IndexPQ::search (idx_t n, const float *x, idx_t k,
|
|
224
|
-
float *distances, idx_t *labels) const
|
|
225
|
-
{
|
|
226
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
227
|
-
if (search_type == ST_PQ) { // Simple PQ search
|
|
163
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
164
|
+
if (search_type == ST_PQ) { // Simple PQ search
|
|
228
165
|
|
|
229
166
|
if (metric_type == METRIC_L2) {
|
|
230
167
|
float_maxheap_array_t res = {
|
|
231
|
-
|
|
232
|
-
pq.search
|
|
168
|
+
size_t(n), size_t(k), labels, distances};
|
|
169
|
+
pq.search(x, n, codes.data(), ntotal, &res, true);
|
|
233
170
|
} else {
|
|
234
171
|
float_minheap_array_t res = {
|
|
235
|
-
|
|
236
|
-
pq.search_ip
|
|
172
|
+
size_t(n), size_t(k), labels, distances};
|
|
173
|
+
pq.search_ip(x, n, codes.data(), ntotal, &res, true);
|
|
237
174
|
}
|
|
238
175
|
indexPQ_stats.nq += n;
|
|
239
176
|
indexPQ_stats.ncode += n * ntotal;
|
|
240
177
|
|
|
241
|
-
} else if (
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
FAISS_THROW_IF_NOT
|
|
178
|
+
} else if (
|
|
179
|
+
search_type == ST_polysemous ||
|
|
180
|
+
search_type == ST_polysemous_generalize) {
|
|
181
|
+
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
|
|
245
182
|
|
|
246
|
-
search_core_polysemous
|
|
183
|
+
search_core_polysemous(n, x, k, distances, labels);
|
|
247
184
|
|
|
248
185
|
} else { // code-to-code distances
|
|
249
186
|
|
|
250
|
-
uint8_t
|
|
251
|
-
ScopeDeleter<uint8_t> del
|
|
252
|
-
|
|
187
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
|
188
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
|
253
189
|
|
|
254
190
|
if (!encode_signs) {
|
|
255
|
-
pq.compute_codes
|
|
191
|
+
pq.compute_codes(x, q_codes, n);
|
|
256
192
|
} else {
|
|
257
|
-
FAISS_THROW_IF_NOT
|
|
258
|
-
memset
|
|
193
|
+
FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
|
|
194
|
+
memset(q_codes, 0, n * pq.code_size);
|
|
259
195
|
for (size_t i = 0; i < n; i++) {
|
|
260
|
-
const float
|
|
261
|
-
uint8_t
|
|
196
|
+
const float* xi = x + i * d;
|
|
197
|
+
uint8_t* code = q_codes + i * pq.code_size;
|
|
262
198
|
for (int j = 0; j < d; j++)
|
|
263
|
-
if (xi[j] > 0)
|
|
199
|
+
if (xi[j] > 0)
|
|
200
|
+
code[j >> 3] |= 1 << (j & 7);
|
|
264
201
|
}
|
|
265
202
|
}
|
|
266
203
|
|
|
267
|
-
if (search_type == ST_SDC)
|
|
268
|
-
|
|
204
|
+
if (search_type == ST_SDC) {
|
|
269
205
|
float_maxheap_array_t res = {
|
|
270
|
-
|
|
206
|
+
size_t(n), size_t(k), labels, distances};
|
|
271
207
|
|
|
272
|
-
pq.search_sdc
|
|
208
|
+
pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
|
|
273
209
|
|
|
274
210
|
} else {
|
|
275
|
-
int
|
|
276
|
-
ScopeDeleter<int> del
|
|
211
|
+
int* idistances = new int[n * k];
|
|
212
|
+
ScopeDeleter<int> del(idistances);
|
|
277
213
|
|
|
278
214
|
int_maxheap_array_t res = {
|
|
279
|
-
|
|
215
|
+
size_t(n), size_t(k), labels, idistances};
|
|
280
216
|
|
|
281
217
|
if (search_type == ST_HE) {
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
218
|
+
hammings_knn_hc(
|
|
219
|
+
&res,
|
|
220
|
+
q_codes,
|
|
221
|
+
codes.data(),
|
|
222
|
+
ntotal,
|
|
223
|
+
pq.code_size,
|
|
224
|
+
true);
|
|
285
225
|
|
|
286
226
|
} else if (search_type == ST_generalized_HE) {
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
227
|
+
generalized_hammings_knn_hc(
|
|
228
|
+
&res,
|
|
229
|
+
q_codes,
|
|
230
|
+
codes.data(),
|
|
231
|
+
ntotal,
|
|
232
|
+
pq.code_size,
|
|
233
|
+
true);
|
|
290
234
|
}
|
|
291
235
|
|
|
292
236
|
// convert distances to floats
|
|
293
237
|
for (int i = 0; i < k * n; i++)
|
|
294
238
|
distances[i] = idistances[i];
|
|
295
|
-
|
|
296
239
|
}
|
|
297
240
|
|
|
298
|
-
|
|
299
241
|
indexPQ_stats.nq += n;
|
|
300
242
|
indexPQ_stats.ncode += n * ntotal;
|
|
301
243
|
}
|
|
302
244
|
}
|
|
303
245
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
void IndexPQStats::reset()
|
|
309
|
-
{
|
|
246
|
+
void IndexPQStats::reset() {
|
|
310
247
|
nq = ncode = n_hamming_pass = 0;
|
|
311
248
|
}
|
|
312
249
|
|
|
313
250
|
IndexPQStats indexPQ_stats;
|
|
314
251
|
|
|
315
|
-
|
|
316
252
|
template <class HammingComputer>
|
|
317
|
-
static size_t polysemous_inner_loop
|
|
318
|
-
const IndexPQ
|
|
319
|
-
const float
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
253
|
+
static size_t polysemous_inner_loop(
|
|
254
|
+
const IndexPQ& index,
|
|
255
|
+
const float* dis_table_qi,
|
|
256
|
+
const uint8_t* q_code,
|
|
257
|
+
size_t k,
|
|
258
|
+
float* heap_dis,
|
|
259
|
+
int64_t* heap_ids) {
|
|
323
260
|
int M = index.pq.M;
|
|
324
261
|
int code_size = index.pq.code_size;
|
|
325
262
|
int ksub = index.pq.ksub;
|
|
326
263
|
size_t ntotal = index.ntotal;
|
|
327
264
|
int ht = index.polysemous_ht;
|
|
328
265
|
|
|
329
|
-
const uint8_t
|
|
266
|
+
const uint8_t* b_code = index.codes.data();
|
|
330
267
|
|
|
331
268
|
size_t n_pass_i = 0;
|
|
332
269
|
|
|
333
|
-
HammingComputer hc
|
|
270
|
+
HammingComputer hc(q_code, code_size);
|
|
334
271
|
|
|
335
272
|
for (int64_t bi = 0; bi < ntotal; bi++) {
|
|
336
|
-
int hd = hc.hamming
|
|
273
|
+
int hd = hc.hamming(b_code);
|
|
337
274
|
|
|
338
275
|
if (hd < ht) {
|
|
339
|
-
n_pass_i
|
|
276
|
+
n_pass_i++;
|
|
340
277
|
|
|
341
278
|
float dis = 0;
|
|
342
|
-
const float
|
|
279
|
+
const float* dis_table = dis_table_qi;
|
|
343
280
|
for (int m = 0; m < M; m++) {
|
|
344
|
-
dis += dis_table
|
|
281
|
+
dis += dis_table[b_code[m]];
|
|
345
282
|
dis_table += ksub;
|
|
346
283
|
}
|
|
347
284
|
|
|
348
285
|
if (dis < heap_dis[0]) {
|
|
349
|
-
maxheap_replace_top
|
|
286
|
+
maxheap_replace_top(k, heap_dis, heap_ids, dis, bi);
|
|
350
287
|
}
|
|
351
288
|
}
|
|
352
289
|
b_code += code_size;
|
|
@@ -354,201 +291,201 @@ static size_t polysemous_inner_loop (
|
|
|
354
291
|
return n_pass_i;
|
|
355
292
|
}
|
|
356
293
|
|
|
294
|
+
void IndexPQ::search_core_polysemous(
|
|
295
|
+
idx_t n,
|
|
296
|
+
const float* x,
|
|
297
|
+
idx_t k,
|
|
298
|
+
float* distances,
|
|
299
|
+
idx_t* labels) const {
|
|
300
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
357
301
|
|
|
358
|
-
|
|
359
|
-
float *distances, idx_t *labels) const
|
|
360
|
-
{
|
|
361
|
-
FAISS_THROW_IF_NOT (pq.nbits == 8);
|
|
302
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
|
362
303
|
|
|
363
304
|
// PQ distance tables
|
|
364
|
-
float
|
|
365
|
-
ScopeDeleter<float> del
|
|
366
|
-
pq.compute_distance_tables
|
|
305
|
+
float* dis_tables = new float[n * pq.ksub * pq.M];
|
|
306
|
+
ScopeDeleter<float> del(dis_tables);
|
|
307
|
+
pq.compute_distance_tables(n, x, dis_tables);
|
|
367
308
|
|
|
368
309
|
// Hamming embedding queries
|
|
369
|
-
uint8_t
|
|
370
|
-
ScopeDeleter<uint8_t> del2
|
|
310
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
|
311
|
+
ScopeDeleter<uint8_t> del2(q_codes);
|
|
371
312
|
|
|
372
313
|
if (false) {
|
|
373
|
-
pq.compute_codes
|
|
314
|
+
pq.compute_codes(x, q_codes, n);
|
|
374
315
|
} else {
|
|
375
316
|
#pragma omp parallel for
|
|
376
317
|
for (idx_t qi = 0; qi < n; qi++) {
|
|
377
|
-
pq.compute_code_from_distance_table
|
|
378
|
-
|
|
379
|
-
|
|
318
|
+
pq.compute_code_from_distance_table(
|
|
319
|
+
dis_tables + qi * pq.M * pq.ksub,
|
|
320
|
+
q_codes + qi * pq.code_size);
|
|
380
321
|
}
|
|
381
322
|
}
|
|
382
323
|
|
|
383
324
|
size_t n_pass = 0;
|
|
384
325
|
|
|
385
|
-
#pragma omp parallel for reduction
|
|
326
|
+
#pragma omp parallel for reduction(+ : n_pass)
|
|
386
327
|
for (idx_t qi = 0; qi < n; qi++) {
|
|
387
|
-
const uint8_t
|
|
328
|
+
const uint8_t* q_code = q_codes + qi * pq.code_size;
|
|
388
329
|
|
|
389
|
-
const float
|
|
330
|
+
const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
|
|
390
331
|
|
|
391
|
-
int64_t
|
|
392
|
-
float
|
|
393
|
-
maxheap_heapify
|
|
332
|
+
int64_t* heap_ids = labels + qi * k;
|
|
333
|
+
float* heap_dis = distances + qi * k;
|
|
334
|
+
maxheap_heapify(k, heap_dis, heap_ids);
|
|
394
335
|
|
|
395
336
|
if (search_type == ST_polysemous) {
|
|
396
|
-
|
|
397
337
|
switch (pq.code_size) {
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
338
|
+
case 4:
|
|
339
|
+
n_pass += polysemous_inner_loop<HammingComputer4>(
|
|
340
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
341
|
+
break;
|
|
342
|
+
case 8:
|
|
343
|
+
n_pass += polysemous_inner_loop<HammingComputer8>(
|
|
344
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
345
|
+
break;
|
|
346
|
+
case 16:
|
|
347
|
+
n_pass += polysemous_inner_loop<HammingComputer16>(
|
|
348
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
349
|
+
break;
|
|
350
|
+
case 32:
|
|
351
|
+
n_pass += polysemous_inner_loop<HammingComputer32>(
|
|
352
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
353
|
+
break;
|
|
354
|
+
case 20:
|
|
355
|
+
n_pass += polysemous_inner_loop<HammingComputer20>(
|
|
356
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
357
|
+
break;
|
|
358
|
+
default:
|
|
359
|
+
if (pq.code_size % 4 == 0) {
|
|
360
|
+
n_pass += polysemous_inner_loop<HammingComputerDefault>(
|
|
361
|
+
*this,
|
|
362
|
+
dis_table_qi,
|
|
363
|
+
q_code,
|
|
364
|
+
k,
|
|
365
|
+
heap_dis,
|
|
366
|
+
heap_ids);
|
|
367
|
+
} else {
|
|
368
|
+
FAISS_THROW_FMT(
|
|
369
|
+
"code size %zd not supported for polysemous",
|
|
370
|
+
pq.code_size);
|
|
371
|
+
}
|
|
372
|
+
break;
|
|
431
373
|
}
|
|
432
374
|
} else {
|
|
433
375
|
switch (pq.code_size) {
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
376
|
+
case 8:
|
|
377
|
+
n_pass += polysemous_inner_loop<GenHammingComputer8>(
|
|
378
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
379
|
+
break;
|
|
380
|
+
case 16:
|
|
381
|
+
n_pass += polysemous_inner_loop<GenHammingComputer16>(
|
|
382
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
383
|
+
break;
|
|
384
|
+
case 32:
|
|
385
|
+
n_pass += polysemous_inner_loop<GenHammingComputer32>(
|
|
386
|
+
*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
387
|
+
break;
|
|
388
|
+
default:
|
|
389
|
+
if (pq.code_size % 8 == 0) {
|
|
390
|
+
n_pass += polysemous_inner_loop<GenHammingComputerM8>(
|
|
391
|
+
*this,
|
|
392
|
+
dis_table_qi,
|
|
393
|
+
q_code,
|
|
394
|
+
k,
|
|
395
|
+
heap_dis,
|
|
396
|
+
heap_ids);
|
|
397
|
+
} else {
|
|
398
|
+
FAISS_THROW_FMT(
|
|
399
|
+
"code size %zd not supported for polysemous",
|
|
400
|
+
pq.code_size);
|
|
401
|
+
}
|
|
402
|
+
break;
|
|
456
403
|
}
|
|
457
404
|
}
|
|
458
|
-
maxheap_reorder
|
|
405
|
+
maxheap_reorder(k, heap_dis, heap_ids);
|
|
459
406
|
}
|
|
460
407
|
|
|
461
408
|
indexPQ_stats.nq += n;
|
|
462
409
|
indexPQ_stats.ncode += n * ntotal;
|
|
463
410
|
indexPQ_stats.n_hamming_pass += n_pass;
|
|
464
|
-
|
|
465
|
-
|
|
466
411
|
}
|
|
467
412
|
|
|
468
|
-
|
|
469
413
|
/* The standalone codec interface (just remaps to the PQ functions) */
|
|
470
|
-
size_t IndexPQ::sa_code_size () const
|
|
471
|
-
{
|
|
472
|
-
return pq.code_size;
|
|
473
|
-
}
|
|
474
414
|
|
|
475
|
-
void IndexPQ::sa_encode
|
|
476
|
-
|
|
477
|
-
pq.compute_codes (x, bytes, n);
|
|
415
|
+
void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
416
|
+
pq.compute_codes(x, bytes, n);
|
|
478
417
|
}
|
|
479
418
|
|
|
480
|
-
void IndexPQ::sa_decode
|
|
481
|
-
|
|
482
|
-
pq.decode (bytes, x, n);
|
|
419
|
+
void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
420
|
+
pq.decode(bytes, x, n);
|
|
483
421
|
}
|
|
484
422
|
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
423
|
/*****************************************
|
|
489
424
|
* Stats of IndexPQ codes
|
|
490
425
|
******************************************/
|
|
491
426
|
|
|
427
|
+
void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
|
|
428
|
+
const {
|
|
429
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
|
430
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
|
492
431
|
|
|
432
|
+
pq.compute_codes(x, q_codes, n);
|
|
493
433
|
|
|
494
|
-
|
|
495
|
-
void IndexPQ::hamming_distance_table (idx_t n, const float *x,
|
|
496
|
-
int32_t *dis) const
|
|
497
|
-
{
|
|
498
|
-
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
|
499
|
-
ScopeDeleter<uint8_t> del (q_codes);
|
|
500
|
-
|
|
501
|
-
pq.compute_codes (x, q_codes, n);
|
|
502
|
-
|
|
503
|
-
hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
|
434
|
+
hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
|
504
435
|
}
|
|
505
436
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
FAISS_THROW_IF_NOT
|
|
513
|
-
FAISS_THROW_IF_NOT
|
|
437
|
+
void IndexPQ::hamming_distance_histogram(
|
|
438
|
+
idx_t n,
|
|
439
|
+
const float* x,
|
|
440
|
+
idx_t nb,
|
|
441
|
+
const float* xb,
|
|
442
|
+
int64_t* hist) {
|
|
443
|
+
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
|
|
444
|
+
FAISS_THROW_IF_NOT(pq.code_size % 8 == 0);
|
|
445
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
|
514
446
|
|
|
515
447
|
// Hamming embedding queries
|
|
516
|
-
uint8_t
|
|
517
|
-
ScopeDeleter
|
|
518
|
-
pq.compute_codes
|
|
448
|
+
uint8_t* q_codes = new uint8_t[n * pq.code_size];
|
|
449
|
+
ScopeDeleter<uint8_t> del(q_codes);
|
|
450
|
+
pq.compute_codes(x, q_codes, n);
|
|
519
451
|
|
|
520
|
-
uint8_t
|
|
521
|
-
ScopeDeleter
|
|
452
|
+
uint8_t* b_codes;
|
|
453
|
+
ScopeDeleter<uint8_t> del_b_codes;
|
|
522
454
|
|
|
523
455
|
if (xb) {
|
|
524
|
-
b_codes = new uint8_t
|
|
525
|
-
del_b_codes.set
|
|
526
|
-
pq.compute_codes
|
|
456
|
+
b_codes = new uint8_t[nb * pq.code_size];
|
|
457
|
+
del_b_codes.set(b_codes);
|
|
458
|
+
pq.compute_codes(xb, b_codes, nb);
|
|
527
459
|
} else {
|
|
528
460
|
nb = ntotal;
|
|
529
461
|
b_codes = codes.data();
|
|
530
462
|
}
|
|
531
463
|
int nbits = pq.M * pq.nbits;
|
|
532
|
-
memset
|
|
464
|
+
memset(hist, 0, sizeof(*hist) * (nbits + 1));
|
|
533
465
|
size_t bs = 256;
|
|
534
466
|
|
|
535
467
|
#pragma omp parallel
|
|
536
468
|
{
|
|
537
|
-
std::vector<int64_t> histi
|
|
538
|
-
hamdis_t
|
|
539
|
-
ScopeDeleter<hamdis_t> del
|
|
469
|
+
std::vector<int64_t> histi(nbits + 1);
|
|
470
|
+
hamdis_t* distances = new hamdis_t[nb * bs];
|
|
471
|
+
ScopeDeleter<hamdis_t> del(distances);
|
|
540
472
|
#pragma omp for
|
|
541
473
|
for (idx_t q0 = 0; q0 < n; q0 += bs) {
|
|
542
474
|
// printf ("dis stats: %zd/%zd\n", q0, n);
|
|
543
475
|
size_t q1 = q0 + bs;
|
|
544
|
-
if (q1 > n)
|
|
476
|
+
if (q1 > n)
|
|
477
|
+
q1 = n;
|
|
545
478
|
|
|
546
|
-
hammings
|
|
547
|
-
|
|
548
|
-
|
|
479
|
+
hammings(
|
|
480
|
+
q_codes + q0 * pq.code_size,
|
|
481
|
+
b_codes,
|
|
482
|
+
q1 - q0,
|
|
483
|
+
nb,
|
|
484
|
+
pq.code_size,
|
|
485
|
+
distances);
|
|
549
486
|
|
|
550
487
|
for (size_t i = 0; i < nb * (q1 - q0); i++)
|
|
551
|
-
histi
|
|
488
|
+
histi[distances[i]]++;
|
|
552
489
|
}
|
|
553
490
|
#pragma omp critical
|
|
554
491
|
{
|
|
@@ -556,28 +493,8 @@ void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
|
|
|
556
493
|
hist[i] += histi[i];
|
|
557
494
|
}
|
|
558
495
|
}
|
|
559
|
-
|
|
560
496
|
}
|
|
561
497
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
498
|
/*****************************************
|
|
582
499
|
* MultiIndexQuantizer
|
|
583
500
|
******************************************/
|
|
@@ -586,90 +503,87 @@ namespace {
|
|
|
586
503
|
|
|
587
504
|
template <typename T>
|
|
588
505
|
struct PreSortedArray {
|
|
589
|
-
|
|
590
|
-
const T * x;
|
|
506
|
+
const T* x;
|
|
591
507
|
int N;
|
|
592
508
|
|
|
593
|
-
explicit PreSortedArray
|
|
594
|
-
|
|
595
|
-
void init (const T*x) {
|
|
509
|
+
explicit PreSortedArray(int N) : N(N) {}
|
|
510
|
+
void init(const T* x) {
|
|
596
511
|
this->x = x;
|
|
597
512
|
}
|
|
598
513
|
// get smallest value
|
|
599
|
-
T get_0
|
|
514
|
+
T get_0() {
|
|
600
515
|
return x[0];
|
|
601
516
|
}
|
|
602
517
|
|
|
603
518
|
// get delta between n-smallest and n-1 -smallest
|
|
604
|
-
T get_diff
|
|
519
|
+
T get_diff(int n) {
|
|
605
520
|
return x[n] - x[n - 1];
|
|
606
521
|
}
|
|
607
522
|
|
|
608
523
|
// remap orders counted from smallest to indices in array
|
|
609
|
-
int get_ord
|
|
524
|
+
int get_ord(int n) {
|
|
610
525
|
return n;
|
|
611
526
|
}
|
|
612
|
-
|
|
613
527
|
};
|
|
614
528
|
|
|
615
529
|
template <typename T>
|
|
616
530
|
struct ArgSort {
|
|
617
|
-
const T
|
|
618
|
-
bool operator()
|
|
531
|
+
const T* x;
|
|
532
|
+
bool operator()(size_t i, size_t j) {
|
|
619
533
|
return x[i] < x[j];
|
|
620
534
|
}
|
|
621
535
|
};
|
|
622
536
|
|
|
623
|
-
|
|
624
537
|
/** Array that maintains a permutation of its elements so that the
|
|
625
538
|
* array's elements are sorted
|
|
626
539
|
*/
|
|
627
540
|
template <typename T>
|
|
628
541
|
struct SortedArray {
|
|
629
|
-
const T
|
|
542
|
+
const T* x;
|
|
630
543
|
int N;
|
|
631
544
|
std::vector<int> perm;
|
|
632
545
|
|
|
633
|
-
explicit SortedArray
|
|
546
|
+
explicit SortedArray(int N) {
|
|
634
547
|
this->N = N;
|
|
635
|
-
perm.resize
|
|
548
|
+
perm.resize(N);
|
|
636
549
|
}
|
|
637
550
|
|
|
638
|
-
void init
|
|
551
|
+
void init(const T* x) {
|
|
639
552
|
this->x = x;
|
|
640
553
|
for (int n = 0; n < N; n++)
|
|
641
554
|
perm[n] = n;
|
|
642
|
-
ArgSort<T> cmp = {x
|
|
643
|
-
std::sort
|
|
555
|
+
ArgSort<T> cmp = {x};
|
|
556
|
+
std::sort(perm.begin(), perm.end(), cmp);
|
|
644
557
|
}
|
|
645
558
|
|
|
646
559
|
// get smallest value
|
|
647
|
-
T get_0
|
|
560
|
+
T get_0() {
|
|
648
561
|
return x[perm[0]];
|
|
649
562
|
}
|
|
650
563
|
|
|
651
564
|
// get delta between n-smallest and n-1 -smallest
|
|
652
|
-
T get_diff
|
|
565
|
+
T get_diff(int n) {
|
|
653
566
|
return x[perm[n]] - x[perm[n - 1]];
|
|
654
567
|
}
|
|
655
568
|
|
|
656
569
|
// remap orders counted from smallest to indices in array
|
|
657
|
-
int get_ord
|
|
570
|
+
int get_ord(int n) {
|
|
658
571
|
return perm[n];
|
|
659
572
|
}
|
|
660
573
|
};
|
|
661
574
|
|
|
662
|
-
|
|
663
|
-
|
|
664
575
|
/** Array has n values. Sort the k first ones and copy the other ones
|
|
665
576
|
* into elements k..n-1
|
|
666
577
|
*/
|
|
667
578
|
template <class C>
|
|
668
|
-
void partial_sort
|
|
669
|
-
|
|
579
|
+
void partial_sort(
|
|
580
|
+
int k,
|
|
581
|
+
int n,
|
|
582
|
+
const typename C::T* vals,
|
|
583
|
+
typename C::TI* perm) {
|
|
670
584
|
// insert first k elts in heap
|
|
671
585
|
for (int i = 1; i < k; i++) {
|
|
672
|
-
indirect_heap_push<C>
|
|
586
|
+
indirect_heap_push<C>(i + 1, vals, perm, perm[i]);
|
|
673
587
|
}
|
|
674
588
|
|
|
675
589
|
// insert next n - k elts in heap
|
|
@@ -678,8 +592,8 @@ void partial_sort (int k, int n,
|
|
|
678
592
|
typename C::TI top = perm[0];
|
|
679
593
|
|
|
680
594
|
if (C::cmp(vals[top], vals[id])) {
|
|
681
|
-
indirect_heap_pop<C>
|
|
682
|
-
indirect_heap_push<C>
|
|
595
|
+
indirect_heap_pop<C>(k, vals, perm);
|
|
596
|
+
indirect_heap_push<C>(k, vals, perm, id);
|
|
683
597
|
perm[i] = top;
|
|
684
598
|
} else {
|
|
685
599
|
// nothing, elt at i is good where it is.
|
|
@@ -689,7 +603,7 @@ void partial_sort (int k, int n,
|
|
|
689
603
|
// order the k first elements in heap
|
|
690
604
|
for (int i = k - 1; i > 0; i--) {
|
|
691
605
|
typename C::TI top = perm[0];
|
|
692
|
-
indirect_heap_pop<C>
|
|
606
|
+
indirect_heap_pop<C>(i + 1, vals, perm);
|
|
693
607
|
perm[i] = top;
|
|
694
608
|
}
|
|
695
609
|
}
|
|
@@ -697,69 +611,67 @@ void partial_sort (int k, int n,
|
|
|
697
611
|
/** same as SortedArray, but only the k first elements are sorted */
|
|
698
612
|
template <typename T>
|
|
699
613
|
struct SemiSortedArray {
|
|
700
|
-
const T
|
|
614
|
+
const T* x;
|
|
701
615
|
int N;
|
|
702
616
|
|
|
703
617
|
// type of the heap: CMax = sort ascending
|
|
704
618
|
typedef CMax<T, int> HC;
|
|
705
619
|
std::vector<int> perm;
|
|
706
620
|
|
|
707
|
-
int k;
|
|
621
|
+
int k; // k elements are sorted
|
|
708
622
|
|
|
709
623
|
int initial_k, k_factor;
|
|
710
624
|
|
|
711
|
-
explicit SemiSortedArray
|
|
625
|
+
explicit SemiSortedArray(int N) {
|
|
712
626
|
this->N = N;
|
|
713
|
-
perm.resize
|
|
714
|
-
perm.resize
|
|
627
|
+
perm.resize(N);
|
|
628
|
+
perm.resize(N);
|
|
715
629
|
initial_k = 3;
|
|
716
630
|
k_factor = 4;
|
|
717
631
|
}
|
|
718
632
|
|
|
719
|
-
void init
|
|
633
|
+
void init(const T* x) {
|
|
720
634
|
this->x = x;
|
|
721
635
|
for (int n = 0; n < N; n++)
|
|
722
636
|
perm[n] = n;
|
|
723
637
|
k = 0;
|
|
724
|
-
grow
|
|
638
|
+
grow(initial_k);
|
|
725
639
|
}
|
|
726
640
|
|
|
727
641
|
/// grow the sorted part of the array to size next_k
|
|
728
|
-
void grow
|
|
642
|
+
void grow(int next_k) {
|
|
729
643
|
if (next_k < N) {
|
|
730
|
-
partial_sort<HC>
|
|
644
|
+
partial_sort<HC>(next_k - k, N - k, x, &perm[k]);
|
|
731
645
|
k = next_k;
|
|
732
646
|
} else { // full sort of remainder of array
|
|
733
|
-
ArgSort<T> cmp = {x
|
|
734
|
-
std::sort
|
|
647
|
+
ArgSort<T> cmp = {x};
|
|
648
|
+
std::sort(perm.begin() + k, perm.end(), cmp);
|
|
735
649
|
k = N;
|
|
736
650
|
}
|
|
737
651
|
}
|
|
738
652
|
|
|
739
653
|
// get smallest value
|
|
740
|
-
T get_0
|
|
654
|
+
T get_0() {
|
|
741
655
|
return x[perm[0]];
|
|
742
656
|
}
|
|
743
657
|
|
|
744
658
|
// get delta between n-smallest and n-1 -smallest
|
|
745
|
-
T get_diff
|
|
659
|
+
T get_diff(int n) {
|
|
746
660
|
if (n >= k) {
|
|
747
661
|
// want to keep powers of 2 - 1
|
|
748
662
|
int next_k = (k + 1) * k_factor - 1;
|
|
749
|
-
grow
|
|
663
|
+
grow(next_k);
|
|
750
664
|
}
|
|
751
665
|
return x[perm[n]] - x[perm[n - 1]];
|
|
752
666
|
}
|
|
753
667
|
|
|
754
668
|
// remap orders counted from smallest to indices in array
|
|
755
|
-
int get_ord
|
|
756
|
-
assert
|
|
669
|
+
int get_ord(int n) {
|
|
670
|
+
assert(n < k);
|
|
757
671
|
return perm[n];
|
|
758
672
|
}
|
|
759
673
|
};
|
|
760
674
|
|
|
761
|
-
|
|
762
|
-
|
|
763
675
|
/*****************************************
|
|
764
676
|
* Find the k smallest sums of M terms, where each term is taken in a
|
|
765
677
|
* table x of n values.
|
|
@@ -779,19 +691,19 @@ struct SemiSortedArray {
|
|
|
779
691
|
* occasionally several t's are returned.
|
|
780
692
|
*
|
|
781
693
|
* @param x size M * n, values to add up
|
|
782
|
-
* @
|
|
694
|
+
* @param k nb of results to retrieve
|
|
783
695
|
* @param M nb of terms
|
|
784
696
|
* @param n nb of distinct values
|
|
785
697
|
* @param sums output, size k, sorted
|
|
786
|
-
* @
|
|
698
|
+
* @param terms output, size k, with encoding as above
|
|
787
699
|
*
|
|
788
700
|
******************************************/
|
|
789
701
|
template <typename T, class SSA, bool use_seen>
|
|
790
702
|
struct MinSumK {
|
|
791
|
-
int K;
|
|
792
|
-
int M;
|
|
703
|
+
int K; ///< nb of sums to return
|
|
704
|
+
int M; ///< nb of elements to sum up
|
|
793
705
|
int nbit; ///< nb of bits to encode one entry
|
|
794
|
-
int N;
|
|
706
|
+
int N; ///< nb of possible elements for each of the M terms
|
|
795
707
|
|
|
796
708
|
/** the heap.
|
|
797
709
|
* We use a heap to maintain a queue of sums, with the associated
|
|
@@ -799,21 +711,20 @@ struct MinSumK {
|
|
|
799
711
|
*/
|
|
800
712
|
typedef CMin<T, int64_t> HC;
|
|
801
713
|
size_t heap_capacity, heap_size;
|
|
802
|
-
T
|
|
803
|
-
int64_t
|
|
714
|
+
T* bh_val;
|
|
715
|
+
int64_t* bh_ids;
|
|
804
716
|
|
|
805
|
-
std::vector
|
|
717
|
+
std::vector<SSA> ssx;
|
|
806
718
|
|
|
807
719
|
// all results get pushed several times. When there are ties, they
|
|
808
720
|
// are popped interleaved with others, so it is not easy to
|
|
809
721
|
// identify them. Therefore, this bit array just marks elements
|
|
810
722
|
// that were seen before.
|
|
811
|
-
std::vector
|
|
723
|
+
std::vector<uint8_t> seen;
|
|
812
724
|
|
|
813
|
-
MinSumK
|
|
814
|
-
K(K), M(M), nbit(nbit), N(N) {
|
|
725
|
+
MinSumK(int K, int M, int nbit, int N) : K(K), M(M), nbit(nbit), N(N) {
|
|
815
726
|
heap_capacity = K * M;
|
|
816
|
-
assert
|
|
727
|
+
assert(N <= (1 << nbit));
|
|
817
728
|
|
|
818
729
|
// we'll do k steps, each step pushes at most M vals
|
|
819
730
|
bh_val = new T[heap_capacity];
|
|
@@ -821,29 +732,27 @@ struct MinSumK {
|
|
|
821
732
|
|
|
822
733
|
if (use_seen) {
|
|
823
734
|
int64_t n_ids = weight(M);
|
|
824
|
-
seen.resize
|
|
735
|
+
seen.resize((n_ids + 7) / 8);
|
|
825
736
|
}
|
|
826
737
|
|
|
827
738
|
for (int m = 0; m < M; m++)
|
|
828
|
-
ssx.push_back
|
|
829
|
-
|
|
739
|
+
ssx.push_back(SSA(N));
|
|
830
740
|
}
|
|
831
741
|
|
|
832
|
-
int64_t weight
|
|
742
|
+
int64_t weight(int i) {
|
|
833
743
|
return 1 << (i * nbit);
|
|
834
744
|
}
|
|
835
745
|
|
|
836
|
-
bool is_seen
|
|
746
|
+
bool is_seen(int64_t i) {
|
|
837
747
|
return (seen[i >> 3] >> (i & 7)) & 1;
|
|
838
748
|
}
|
|
839
749
|
|
|
840
|
-
void mark_seen
|
|
750
|
+
void mark_seen(int64_t i) {
|
|
841
751
|
if (use_seen)
|
|
842
|
-
seen
|
|
752
|
+
seen[i >> 3] |= 1 << (i & 7);
|
|
843
753
|
}
|
|
844
754
|
|
|
845
|
-
void run
|
|
846
|
-
T * sums, int64_t * terms) {
|
|
755
|
+
void run(const T* x, int64_t ldx, T* sums, int64_t* terms) {
|
|
847
756
|
heap_size = 0;
|
|
848
757
|
|
|
849
758
|
for (int m = 0; m < M; m++) {
|
|
@@ -854,38 +763,41 @@ struct MinSumK {
|
|
|
854
763
|
{ // initial result: take min for all elements
|
|
855
764
|
T sum = 0;
|
|
856
765
|
terms[0] = 0;
|
|
857
|
-
mark_seen
|
|
766
|
+
mark_seen(0);
|
|
858
767
|
for (int m = 0; m < M; m++) {
|
|
859
768
|
sum += ssx[m].get_0();
|
|
860
769
|
}
|
|
861
770
|
sums[0] = sum;
|
|
862
771
|
for (int m = 0; m < M; m++) {
|
|
863
|
-
heap_push<HC>
|
|
864
|
-
|
|
865
|
-
|
|
772
|
+
heap_push<HC>(
|
|
773
|
+
++heap_size,
|
|
774
|
+
bh_val,
|
|
775
|
+
bh_ids,
|
|
776
|
+
sum + ssx[m].get_diff(1),
|
|
777
|
+
weight(m));
|
|
866
778
|
}
|
|
867
779
|
}
|
|
868
780
|
|
|
869
781
|
for (int k = 1; k < K; k++) {
|
|
870
782
|
// pop smallest value from heap
|
|
871
|
-
if (use_seen) {// skip already seen elements
|
|
872
|
-
while (is_seen
|
|
873
|
-
assert
|
|
874
|
-
heap_pop<HC>
|
|
783
|
+
if (use_seen) { // skip already seen elements
|
|
784
|
+
while (is_seen(bh_ids[0])) {
|
|
785
|
+
assert(heap_size > 0);
|
|
786
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
|
875
787
|
}
|
|
876
788
|
}
|
|
877
|
-
assert
|
|
789
|
+
assert(heap_size > 0);
|
|
878
790
|
|
|
879
791
|
T sum = sums[k] = bh_val[0];
|
|
880
792
|
int64_t ti = terms[k] = bh_ids[0];
|
|
881
793
|
|
|
882
794
|
if (use_seen) {
|
|
883
|
-
mark_seen
|
|
884
|
-
heap_pop<HC>
|
|
795
|
+
mark_seen(ti);
|
|
796
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
|
885
797
|
} else {
|
|
886
798
|
do {
|
|
887
|
-
heap_pop<HC>
|
|
888
|
-
}
|
|
799
|
+
heap_pop<HC>(heap_size--, bh_val, bh_ids);
|
|
800
|
+
} while (heap_size > 0 && bh_ids[0] == ti);
|
|
889
801
|
}
|
|
890
802
|
|
|
891
803
|
// enqueue followers
|
|
@@ -893,9 +805,10 @@ struct MinSumK {
|
|
|
893
805
|
for (int m = 0; m < M; m++) {
|
|
894
806
|
int64_t n = ii & ((1L << nbit) - 1);
|
|
895
807
|
ii >>= nbit;
|
|
896
|
-
if (n + 1 >= N)
|
|
808
|
+
if (n + 1 >= N)
|
|
809
|
+
continue;
|
|
897
810
|
|
|
898
|
-
enqueue_follower
|
|
811
|
+
enqueue_follower(ti, m, n, sum);
|
|
899
812
|
}
|
|
900
813
|
}
|
|
901
814
|
|
|
@@ -922,37 +835,29 @@ struct MinSumK {
|
|
|
922
835
|
}
|
|
923
836
|
}
|
|
924
837
|
|
|
925
|
-
|
|
926
|
-
void enqueue_follower (int64_t ti, int m, int n, T sum) {
|
|
838
|
+
void enqueue_follower(int64_t ti, int m, int n, T sum) {
|
|
927
839
|
T next_sum = sum + ssx[m].get_diff(n + 1);
|
|
928
840
|
int64_t next_ti = ti + weight(m);
|
|
929
|
-
heap_push<HC>
|
|
841
|
+
heap_push<HC>(++heap_size, bh_val, bh_ids, next_sum, next_ti);
|
|
930
842
|
}
|
|
931
843
|
|
|
932
|
-
~MinSumK
|
|
933
|
-
delete
|
|
934
|
-
delete
|
|
844
|
+
~MinSumK() {
|
|
845
|
+
delete[] bh_ids;
|
|
846
|
+
delete[] bh_val;
|
|
935
847
|
}
|
|
936
848
|
};
|
|
937
849
|
|
|
938
850
|
} // anonymous namespace
|
|
939
851
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
size_t M,
|
|
943
|
-
size_t nbits):
|
|
944
|
-
Index(d, METRIC_L2), pq(d, M, nbits)
|
|
945
|
-
{
|
|
852
|
+
MultiIndexQuantizer::MultiIndexQuantizer(int d, size_t M, size_t nbits)
|
|
853
|
+
: Index(d, METRIC_L2), pq(d, M, nbits) {
|
|
946
854
|
is_trained = false;
|
|
947
855
|
pq.verbose = verbose;
|
|
948
856
|
}
|
|
949
857
|
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
void MultiIndexQuantizer::train(idx_t n, const float *x)
|
|
953
|
-
{
|
|
858
|
+
void MultiIndexQuantizer::train(idx_t n, const float* x) {
|
|
954
859
|
pq.verbose = verbose;
|
|
955
|
-
pq.train
|
|
860
|
+
pq.train(n, x);
|
|
956
861
|
is_trained = true;
|
|
957
862
|
// count virtual elements in index
|
|
958
863
|
ntotal = 1;
|
|
@@ -960,10 +865,16 @@ void MultiIndexQuantizer::train(idx_t n, const float *x)
|
|
|
960
865
|
ntotal *= pq.ksub;
|
|
961
866
|
}
|
|
962
867
|
|
|
868
|
+
void MultiIndexQuantizer::search(
|
|
869
|
+
idx_t n,
|
|
870
|
+
const float* x,
|
|
871
|
+
idx_t k,
|
|
872
|
+
float* distances,
|
|
873
|
+
idx_t* labels) const {
|
|
874
|
+
if (n == 0)
|
|
875
|
+
return;
|
|
963
876
|
|
|
964
|
-
|
|
965
|
-
float *distances, idx_t *labels) const {
|
|
966
|
-
if (n == 0) return;
|
|
877
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
967
878
|
|
|
968
879
|
// the allocation just below can be severe...
|
|
969
880
|
idx_t bs = 32768;
|
|
@@ -971,27 +882,28 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
|
|
|
971
882
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
972
883
|
idx_t i1 = std::min(i0 + bs, n);
|
|
973
884
|
if (verbose) {
|
|
974
|
-
printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
|
|
975
|
-
|
|
885
|
+
printf("MultiIndexQuantizer::search: %" PRId64 ":%" PRId64
|
|
886
|
+
" / %" PRId64 "\n",
|
|
887
|
+
i0,
|
|
888
|
+
i1,
|
|
889
|
+
n);
|
|
976
890
|
}
|
|
977
|
-
search
|
|
978
|
-
distances + i0 * k,
|
|
979
|
-
labels + i0 * k);
|
|
891
|
+
search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
|
|
980
892
|
}
|
|
981
893
|
return;
|
|
982
894
|
}
|
|
983
895
|
|
|
984
|
-
float
|
|
985
|
-
ScopeDeleter<float> del
|
|
896
|
+
float* dis_tables = new float[n * pq.ksub * pq.M];
|
|
897
|
+
ScopeDeleter<float> del(dis_tables);
|
|
986
898
|
|
|
987
|
-
pq.compute_distance_tables
|
|
899
|
+
pq.compute_distance_tables(n, x, dis_tables);
|
|
988
900
|
|
|
989
901
|
if (k == 1) {
|
|
990
902
|
// simple version that just finds the min in each table
|
|
991
903
|
|
|
992
904
|
#pragma omp parallel for
|
|
993
905
|
for (int i = 0; i < n; i++) {
|
|
994
|
-
const float
|
|
906
|
+
const float* dis_table = dis_tables + i * pq.ksub * pq.M;
|
|
995
907
|
float dis = 0;
|
|
996
908
|
idx_t label = 0;
|
|
997
909
|
|
|
@@ -1010,32 +922,27 @@ void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
|
|
|
1010
922
|
dis_table += pq.ksub;
|
|
1011
923
|
}
|
|
1012
924
|
|
|
1013
|
-
distances
|
|
1014
|
-
labels
|
|
925
|
+
distances[i] = dis;
|
|
926
|
+
labels[i] = label;
|
|
1015
927
|
}
|
|
1016
928
|
|
|
1017
|
-
|
|
1018
929
|
} else {
|
|
1019
|
-
|
|
1020
|
-
#pragma omp parallel if(n > 1)
|
|
930
|
+
#pragma omp parallel if (n > 1)
|
|
1021
931
|
{
|
|
1022
|
-
MinSumK
|
|
1023
|
-
|
|
932
|
+
MinSumK<float, SemiSortedArray<float>, false> msk(
|
|
933
|
+
k, pq.M, pq.nbits, pq.ksub);
|
|
1024
934
|
#pragma omp for
|
|
1025
935
|
for (int i = 0; i < n; i++) {
|
|
1026
|
-
msk.run
|
|
1027
|
-
|
|
1028
|
-
|
|
936
|
+
msk.run(dis_tables + i * pq.ksub * pq.M,
|
|
937
|
+
pq.ksub,
|
|
938
|
+
distances + i * k,
|
|
939
|
+
labels + i * k);
|
|
1029
940
|
}
|
|
1030
941
|
}
|
|
1031
942
|
}
|
|
1032
|
-
|
|
1033
943
|
}
|
|
1034
944
|
|
|
1035
|
-
|
|
1036
|
-
void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
|
|
1037
|
-
{
|
|
1038
|
-
|
|
945
|
+
void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
|
|
1039
946
|
int64_t jj = key;
|
|
1040
947
|
for (int m = 0; m < pq.M; m++) {
|
|
1041
948
|
int64_t n = jj & ((1L << pq.nbits) - 1);
|
|
@@ -1046,65 +953,53 @@ void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
|
|
|
1046
953
|
}
|
|
1047
954
|
|
|
1048
955
|
void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
956
|
+
FAISS_THROW_MSG(
|
|
957
|
+
"This index has virtual elements, "
|
|
958
|
+
"it does not support add");
|
|
1052
959
|
}
|
|
1053
960
|
|
|
1054
|
-
void MultiIndexQuantizer::reset
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
961
|
+
void MultiIndexQuantizer::reset() {
|
|
962
|
+
FAISS_THROW_MSG(
|
|
963
|
+
"This index has virtual elements, "
|
|
964
|
+
"it does not support reset");
|
|
1058
965
|
}
|
|
1059
966
|
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
967
|
/*****************************************
|
|
1070
968
|
* MultiIndexQuantizer2
|
|
1071
969
|
******************************************/
|
|
1072
970
|
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
Index
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
assign_indexes.resize (M);
|
|
971
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2(
|
|
972
|
+
int d,
|
|
973
|
+
size_t M,
|
|
974
|
+
size_t nbits,
|
|
975
|
+
Index** indexes)
|
|
976
|
+
: MultiIndexQuantizer(d, M, nbits) {
|
|
977
|
+
assign_indexes.resize(M);
|
|
1081
978
|
for (int i = 0; i < M; i++) {
|
|
1082
979
|
FAISS_THROW_IF_NOT_MSG(
|
|
1083
|
-
|
|
1084
|
-
|
|
980
|
+
indexes[i]->d == pq.dsub,
|
|
981
|
+
"Provided sub-index has incorrect size");
|
|
1085
982
|
assign_indexes[i] = indexes[i];
|
|
1086
983
|
}
|
|
1087
984
|
own_fields = false;
|
|
1088
985
|
}
|
|
1089
986
|
|
|
1090
|
-
MultiIndexQuantizer2::MultiIndexQuantizer2
|
|
1091
|
-
int d,
|
|
1092
|
-
|
|
1093
|
-
Index
|
|
1094
|
-
|
|
1095
|
-
{
|
|
987
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2(
|
|
988
|
+
int d,
|
|
989
|
+
size_t nbits,
|
|
990
|
+
Index* assign_index_0,
|
|
991
|
+
Index* assign_index_1)
|
|
992
|
+
: MultiIndexQuantizer(d, 2, nbits) {
|
|
1096
993
|
FAISS_THROW_IF_NOT_MSG(
|
|
1097
|
-
assign_index_0->d == pq.dsub &&
|
|
1098
|
-
assign_index_1->d == pq.dsub,
|
|
994
|
+
assign_index_0->d == pq.dsub && assign_index_1->d == pq.dsub,
|
|
1099
995
|
"Provided sub-index has incorrect size");
|
|
1100
|
-
assign_indexes.resize
|
|
1101
|
-
assign_indexes
|
|
1102
|
-
assign_indexes
|
|
996
|
+
assign_indexes.resize(2);
|
|
997
|
+
assign_indexes[0] = assign_index_0;
|
|
998
|
+
assign_indexes[1] = assign_index_1;
|
|
1103
999
|
own_fields = false;
|
|
1104
1000
|
}
|
|
1105
1001
|
|
|
1106
|
-
void MultiIndexQuantizer2::train(idx_t n, const float* x)
|
|
1107
|
-
{
|
|
1002
|
+
void MultiIndexQuantizer2::train(idx_t n, const float* x) {
|
|
1108
1003
|
MultiIndexQuantizer::train(n, x);
|
|
1109
1004
|
// add centroids to sub-indexes
|
|
1110
1005
|
for (int i = 0; i < pq.M; i++) {
|
|
@@ -1112,15 +1007,17 @@ void MultiIndexQuantizer2::train(idx_t n, const float* x)
|
|
|
1112
1007
|
}
|
|
1113
1008
|
}
|
|
1114
1009
|
|
|
1115
|
-
|
|
1116
1010
|
void MultiIndexQuantizer2::search(
|
|
1117
|
-
idx_t n,
|
|
1118
|
-
float*
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1011
|
+
idx_t n,
|
|
1012
|
+
const float* x,
|
|
1013
|
+
idx_t K,
|
|
1014
|
+
float* distances,
|
|
1015
|
+
idx_t* labels) const {
|
|
1016
|
+
if (n == 0)
|
|
1017
|
+
return;
|
|
1122
1018
|
|
|
1123
1019
|
int k2 = std::min(K, int64_t(pq.ksub));
|
|
1020
|
+
FAISS_THROW_IF_NOT(k2);
|
|
1124
1021
|
|
|
1125
1022
|
int64_t M = pq.M;
|
|
1126
1023
|
int64_t dsub = pq.dsub, ksub = pq.ksub;
|
|
@@ -1131,8 +1028,8 @@ void MultiIndexQuantizer2::search(
|
|
|
1131
1028
|
std::vector<float> xsub(n * dsub);
|
|
1132
1029
|
|
|
1133
1030
|
for (int m = 0; m < M; m++) {
|
|
1134
|
-
float
|
|
1135
|
-
const float
|
|
1031
|
+
float* xdest = xsub.data();
|
|
1032
|
+
const float* xsrc = x + m * dsub;
|
|
1136
1033
|
for (int j = 0; j < n; j++) {
|
|
1137
1034
|
memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
|
|
1138
1035
|
xsrc += d;
|
|
@@ -1140,14 +1037,12 @@ void MultiIndexQuantizer2::search(
|
|
|
1140
1037
|
}
|
|
1141
1038
|
|
|
1142
1039
|
assign_indexes[m]->search(
|
|
1143
|
-
|
|
1144
|
-
&sub_dis[k2 * n * m],
|
|
1145
|
-
&sub_ids[k2 * n * m]);
|
|
1040
|
+
n, xsub.data(), k2, &sub_dis[k2 * n * m], &sub_ids[k2 * n * m]);
|
|
1146
1041
|
}
|
|
1147
1042
|
|
|
1148
1043
|
if (K == 1) {
|
|
1149
1044
|
// simple version that just finds the min in each table
|
|
1150
|
-
assert
|
|
1045
|
+
assert(k2 == 1);
|
|
1151
1046
|
|
|
1152
1047
|
for (int i = 0; i < n; i++) {
|
|
1153
1048
|
float dis = 0;
|
|
@@ -1159,30 +1054,28 @@ void MultiIndexQuantizer2::search(
|
|
|
1159
1054
|
dis += vmin;
|
|
1160
1055
|
label |= lmin << (m * pq.nbits);
|
|
1161
1056
|
}
|
|
1162
|
-
distances
|
|
1163
|
-
labels
|
|
1057
|
+
distances[i] = dis;
|
|
1058
|
+
labels[i] = label;
|
|
1164
1059
|
}
|
|
1165
1060
|
|
|
1166
1061
|
} else {
|
|
1167
|
-
|
|
1168
|
-
#pragma omp parallel if(n > 1)
|
|
1062
|
+
#pragma omp parallel if (n > 1)
|
|
1169
1063
|
{
|
|
1170
|
-
MinSumK
|
|
1171
|
-
|
|
1064
|
+
MinSumK<float, PreSortedArray<float>, false> msk(
|
|
1065
|
+
K, pq.M, pq.nbits, k2);
|
|
1172
1066
|
#pragma omp for
|
|
1173
1067
|
for (int i = 0; i < n; i++) {
|
|
1174
|
-
idx_t
|
|
1175
|
-
msk.run
|
|
1176
|
-
distances + i * K, li);
|
|
1068
|
+
idx_t* li = labels + i * K;
|
|
1069
|
+
msk.run(&sub_dis[i * k2], k2 * n, distances + i * K, li);
|
|
1177
1070
|
|
|
1178
1071
|
// remap ids
|
|
1179
1072
|
|
|
1180
|
-
const idx_t
|
|
1073
|
+
const idx_t* idmap0 = sub_ids.data() + i * k2;
|
|
1181
1074
|
int64_t ld_idmap = k2 * n;
|
|
1182
1075
|
int64_t mask1 = ksub - 1L;
|
|
1183
1076
|
|
|
1184
1077
|
for (int k = 0; k < K; k++) {
|
|
1185
|
-
const idx_t
|
|
1078
|
+
const idx_t* idmap = idmap0;
|
|
1186
1079
|
int64_t vin = li[k];
|
|
1187
1080
|
int64_t vout = 0;
|
|
1188
1081
|
int bs = 0;
|
|
@@ -1200,5 +1093,4 @@ void MultiIndexQuantizer2::search(
|
|
|
1200
1093
|
}
|
|
1201
1094
|
}
|
|
1202
1095
|
|
|
1203
|
-
|
|
1204
1096
|
} // namespace faiss
|