faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,113 +9,118 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/ProductQuantizer.h>
|
|
11
11
|
|
|
12
|
-
|
|
13
12
|
#include <cstddef>
|
|
14
|
-
#include <cstring>
|
|
15
13
|
#include <cstdio>
|
|
14
|
+
#include <cstring>
|
|
16
15
|
#include <memory>
|
|
17
16
|
|
|
18
17
|
#include <algorithm>
|
|
19
18
|
|
|
20
|
-
#include <faiss/impl/FaissAssert.h>
|
|
21
|
-
#include <faiss/VectorTransform.h>
|
|
22
19
|
#include <faiss/IndexFlat.h>
|
|
20
|
+
#include <faiss/VectorTransform.h>
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
23
22
|
#include <faiss/utils/distances.h>
|
|
24
23
|
|
|
25
|
-
|
|
26
24
|
extern "C" {
|
|
27
25
|
|
|
28
26
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
29
27
|
|
|
30
|
-
int sgemm_
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
28
|
+
int sgemm_(
|
|
29
|
+
const char* transa,
|
|
30
|
+
const char* transb,
|
|
31
|
+
FINTEGER* m,
|
|
32
|
+
FINTEGER* n,
|
|
33
|
+
FINTEGER* k,
|
|
34
|
+
const float* alpha,
|
|
35
|
+
const float* a,
|
|
36
|
+
FINTEGER* lda,
|
|
37
|
+
const float* b,
|
|
38
|
+
FINTEGER* ldb,
|
|
39
|
+
float* beta,
|
|
40
|
+
float* c,
|
|
41
|
+
FINTEGER* ldc);
|
|
35
42
|
}
|
|
36
43
|
|
|
37
|
-
|
|
38
44
|
namespace faiss {
|
|
39
45
|
|
|
40
|
-
|
|
41
46
|
/* compute an estimator using look-up tables for typical values of M */
|
|
42
47
|
template <typename CT, class C>
|
|
43
|
-
void pq_estimators_from_tables_Mmul4
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
48
|
+
void pq_estimators_from_tables_Mmul4(
|
|
49
|
+
int M,
|
|
50
|
+
const CT* codes,
|
|
51
|
+
size_t ncodes,
|
|
52
|
+
const float* __restrict dis_table,
|
|
53
|
+
size_t ksub,
|
|
54
|
+
size_t k,
|
|
55
|
+
float* heap_dis,
|
|
56
|
+
int64_t* heap_ids) {
|
|
52
57
|
for (size_t j = 0; j < ncodes; j++) {
|
|
53
58
|
float dis = 0;
|
|
54
|
-
const float
|
|
59
|
+
const float* dt = dis_table;
|
|
55
60
|
|
|
56
|
-
for (size_t m = 0; m < M; m+=4) {
|
|
61
|
+
for (size_t m = 0; m < M; m += 4) {
|
|
57
62
|
float dism = 0;
|
|
58
|
-
dism
|
|
59
|
-
|
|
60
|
-
dism += dt[*codes++];
|
|
61
|
-
|
|
63
|
+
dism = dt[*codes++];
|
|
64
|
+
dt += ksub;
|
|
65
|
+
dism += dt[*codes++];
|
|
66
|
+
dt += ksub;
|
|
67
|
+
dism += dt[*codes++];
|
|
68
|
+
dt += ksub;
|
|
69
|
+
dism += dt[*codes++];
|
|
70
|
+
dt += ksub;
|
|
62
71
|
dis += dism;
|
|
63
72
|
}
|
|
64
73
|
|
|
65
|
-
if (C::cmp
|
|
66
|
-
heap_replace_top<C>
|
|
74
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
75
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
|
67
76
|
}
|
|
68
77
|
}
|
|
69
78
|
}
|
|
70
79
|
|
|
71
|
-
|
|
72
80
|
template <typename CT, class C>
|
|
73
|
-
void pq_estimators_from_tables_M4
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
{
|
|
81
|
-
|
|
81
|
+
void pq_estimators_from_tables_M4(
|
|
82
|
+
const CT* codes,
|
|
83
|
+
size_t ncodes,
|
|
84
|
+
const float* __restrict dis_table,
|
|
85
|
+
size_t ksub,
|
|
86
|
+
size_t k,
|
|
87
|
+
float* heap_dis,
|
|
88
|
+
int64_t* heap_ids) {
|
|
82
89
|
for (size_t j = 0; j < ncodes; j++) {
|
|
83
90
|
float dis = 0;
|
|
84
|
-
const float
|
|
85
|
-
dis
|
|
86
|
-
|
|
87
|
-
dis += dt[*codes++];
|
|
91
|
+
const float* dt = dis_table;
|
|
92
|
+
dis = dt[*codes++];
|
|
93
|
+
dt += ksub;
|
|
94
|
+
dis += dt[*codes++];
|
|
95
|
+
dt += ksub;
|
|
96
|
+
dis += dt[*codes++];
|
|
97
|
+
dt += ksub;
|
|
88
98
|
dis += dt[*codes++];
|
|
89
99
|
|
|
90
|
-
if (C::cmp
|
|
91
|
-
heap_replace_top<C>
|
|
100
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
101
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
|
92
102
|
}
|
|
93
103
|
}
|
|
94
104
|
}
|
|
95
105
|
|
|
96
|
-
|
|
97
106
|
template <typename CT, class C>
|
|
98
|
-
static inline void pq_estimators_from_tables
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
{
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
|
|
110
|
-
dis_table, pq.ksub, k,
|
|
111
|
-
heap_dis, heap_ids);
|
|
107
|
+
static inline void pq_estimators_from_tables(
|
|
108
|
+
const ProductQuantizer& pq,
|
|
109
|
+
const CT* codes,
|
|
110
|
+
size_t ncodes,
|
|
111
|
+
const float* dis_table,
|
|
112
|
+
size_t k,
|
|
113
|
+
float* heap_dis,
|
|
114
|
+
int64_t* heap_ids) {
|
|
115
|
+
if (pq.M == 4) {
|
|
116
|
+
pq_estimators_from_tables_M4<CT, C>(
|
|
117
|
+
codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
|
|
112
118
|
return;
|
|
113
119
|
}
|
|
114
120
|
|
|
115
121
|
if (pq.M % 4 == 0) {
|
|
116
|
-
pq_estimators_from_tables_Mmul4<CT, C>
|
|
117
|
-
|
|
118
|
-
heap_dis, heap_ids);
|
|
122
|
+
pq_estimators_from_tables_Mmul4<CT, C>(
|
|
123
|
+
pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
|
|
119
124
|
return;
|
|
120
125
|
}
|
|
121
126
|
|
|
@@ -124,132 +129,124 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
|
|
|
124
129
|
const size_t ksub = pq.ksub;
|
|
125
130
|
for (size_t j = 0; j < ncodes; j++) {
|
|
126
131
|
float dis = 0;
|
|
127
|
-
const float
|
|
132
|
+
const float* __restrict dt = dis_table;
|
|
128
133
|
for (int m = 0; m < M; m++) {
|
|
129
134
|
dis += dt[*codes++];
|
|
130
135
|
dt += ksub;
|
|
131
136
|
}
|
|
132
|
-
if (C::cmp
|
|
133
|
-
heap_replace_top<C>
|
|
137
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
138
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
|
134
139
|
}
|
|
135
140
|
}
|
|
136
141
|
}
|
|
137
142
|
|
|
138
143
|
template <class C>
|
|
139
|
-
static inline void pq_estimators_from_tables_generic(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
{
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
dt += ksub;
|
|
160
|
-
}
|
|
144
|
+
static inline void pq_estimators_from_tables_generic(
|
|
145
|
+
const ProductQuantizer& pq,
|
|
146
|
+
size_t nbits,
|
|
147
|
+
const uint8_t* codes,
|
|
148
|
+
size_t ncodes,
|
|
149
|
+
const float* dis_table,
|
|
150
|
+
size_t k,
|
|
151
|
+
float* heap_dis,
|
|
152
|
+
int64_t* heap_ids) {
|
|
153
|
+
const size_t M = pq.M;
|
|
154
|
+
const size_t ksub = pq.ksub;
|
|
155
|
+
for (size_t j = 0; j < ncodes; ++j) {
|
|
156
|
+
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
|
|
157
|
+
float dis = 0;
|
|
158
|
+
const float* __restrict dt = dis_table;
|
|
159
|
+
for (size_t m = 0; m < M; m++) {
|
|
160
|
+
uint64_t c = decoder.decode();
|
|
161
|
+
dis += dt[c];
|
|
162
|
+
dt += ksub;
|
|
163
|
+
}
|
|
161
164
|
|
|
162
|
-
|
|
163
|
-
|
|
165
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
166
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
|
167
|
+
}
|
|
164
168
|
}
|
|
165
|
-
}
|
|
166
169
|
}
|
|
167
170
|
|
|
168
171
|
/*********************************************
|
|
169
172
|
* PQ implementation
|
|
170
173
|
*********************************************/
|
|
171
174
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
d(d), M(M), nbits(nbits), assign_index(nullptr)
|
|
176
|
-
{
|
|
177
|
-
set_derived_values ();
|
|
175
|
+
ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
|
|
176
|
+
: d(d), M(M), nbits(nbits), assign_index(nullptr) {
|
|
177
|
+
set_derived_values();
|
|
178
178
|
}
|
|
179
179
|
|
|
180
|
-
ProductQuantizer::ProductQuantizer ()
|
|
181
|
-
: ProductQuantizer(0, 1, 0) {}
|
|
180
|
+
ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
|
|
182
181
|
|
|
183
|
-
void ProductQuantizer::set_derived_values
|
|
182
|
+
void ProductQuantizer::set_derived_values() {
|
|
184
183
|
// quite a few derived values
|
|
185
|
-
FAISS_THROW_IF_NOT_MSG
|
|
184
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
185
|
+
d % M == 0,
|
|
186
|
+
"The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
|
|
186
187
|
dsub = d / M;
|
|
187
188
|
code_size = (nbits * M + 7) / 8;
|
|
188
189
|
ksub = 1 << nbits;
|
|
189
|
-
centroids.resize
|
|
190
|
+
centroids.resize(d * ksub);
|
|
190
191
|
verbose = false;
|
|
191
192
|
train_type = Train_default;
|
|
192
193
|
}
|
|
193
194
|
|
|
194
|
-
void ProductQuantizer::set_params
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
195
|
+
void ProductQuantizer::set_params(const float* centroids_, int m) {
|
|
196
|
+
memcpy(get_centroids(m, 0),
|
|
197
|
+
centroids_,
|
|
198
|
+
ksub * dsub * sizeof(centroids_[0]));
|
|
198
199
|
}
|
|
199
200
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
std::vector<float> mean
|
|
201
|
+
static void init_hypercube(
|
|
202
|
+
int d,
|
|
203
|
+
int nbits,
|
|
204
|
+
int n,
|
|
205
|
+
const float* x,
|
|
206
|
+
float* centroids) {
|
|
207
|
+
std::vector<float> mean(d);
|
|
207
208
|
for (int i = 0; i < n; i++)
|
|
208
209
|
for (int j = 0; j < d; j++)
|
|
209
|
-
mean
|
|
210
|
+
mean[j] += x[i * d + j];
|
|
210
211
|
|
|
211
212
|
float maxm = 0;
|
|
212
213
|
for (int j = 0; j < d; j++) {
|
|
213
|
-
mean
|
|
214
|
-
if (fabs(mean[j]) > maxm)
|
|
214
|
+
mean[j] /= n;
|
|
215
|
+
if (fabs(mean[j]) > maxm)
|
|
216
|
+
maxm = fabs(mean[j]);
|
|
215
217
|
}
|
|
216
218
|
|
|
217
219
|
for (int i = 0; i < (1 << nbits); i++) {
|
|
218
|
-
float
|
|
220
|
+
float* cent = centroids + i * d;
|
|
219
221
|
for (int j = 0; j < nbits; j++)
|
|
220
|
-
cent[j] = mean
|
|
222
|
+
cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm;
|
|
221
223
|
for (int j = nbits; j < d; j++)
|
|
222
|
-
cent[j] = mean
|
|
224
|
+
cent[j] = mean[j];
|
|
223
225
|
}
|
|
224
|
-
|
|
225
|
-
|
|
226
226
|
}
|
|
227
227
|
|
|
228
|
-
static void init_hypercube_pca
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
228
|
+
static void init_hypercube_pca(
|
|
229
|
+
int d,
|
|
230
|
+
int nbits,
|
|
231
|
+
int n,
|
|
232
|
+
const float* x,
|
|
233
|
+
float* centroids) {
|
|
234
|
+
PCAMatrix pca(d, nbits);
|
|
235
|
+
pca.train(n, x);
|
|
235
236
|
|
|
236
237
|
for (int i = 0; i < (1 << nbits); i++) {
|
|
237
|
-
float
|
|
238
|
+
float* cent = centroids + i * d;
|
|
238
239
|
for (int j = 0; j < d; j++) {
|
|
239
240
|
cent[j] = pca.mean[j];
|
|
240
241
|
float f = 1.0;
|
|
241
242
|
for (int k = 0; k < nbits; k++)
|
|
242
|
-
cent[j] += f *
|
|
243
|
-
|
|
244
|
-
(((i >> k) & 1) ? 1 : -1) *
|
|
245
|
-
pca.PCAMat [j + k * d];
|
|
243
|
+
cent[j] += f * sqrt(pca.eigenvalues[k]) *
|
|
244
|
+
(((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d];
|
|
246
245
|
}
|
|
247
246
|
}
|
|
248
|
-
|
|
249
247
|
}
|
|
250
248
|
|
|
251
|
-
void ProductQuantizer::train
|
|
252
|
-
{
|
|
249
|
+
void ProductQuantizer::train(int n, const float* x) {
|
|
253
250
|
if (train_type != Train_shared) {
|
|
254
251
|
train_type_t final_train_type;
|
|
255
252
|
final_train_type = train_type;
|
|
@@ -257,234 +254,229 @@ void ProductQuantizer::train (int n, const float * x)
|
|
|
257
254
|
train_type == Train_hypercube_pca) {
|
|
258
255
|
if (dsub < nbits) {
|
|
259
256
|
final_train_type = Train_default;
|
|
260
|
-
printf
|
|
261
|
-
|
|
257
|
+
printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
|
|
258
|
+
nbits,
|
|
259
|
+
dsub);
|
|
262
260
|
}
|
|
263
261
|
}
|
|
264
262
|
|
|
265
|
-
float
|
|
266
|
-
ScopeDeleter<float> del
|
|
263
|
+
float* xslice = new float[n * dsub];
|
|
264
|
+
ScopeDeleter<float> del(xslice);
|
|
267
265
|
for (int m = 0; m < M; m++) {
|
|
268
266
|
for (int j = 0; j < n; j++)
|
|
269
|
-
memcpy
|
|
270
|
-
|
|
271
|
-
|
|
267
|
+
memcpy(xslice + j * dsub,
|
|
268
|
+
x + j * d + m * dsub,
|
|
269
|
+
dsub * sizeof(float));
|
|
272
270
|
|
|
273
|
-
Clustering clus
|
|
271
|
+
Clustering clus(dsub, ksub, cp);
|
|
274
272
|
|
|
275
273
|
// we have some initialization for the centroids
|
|
276
274
|
if (final_train_type != Train_default) {
|
|
277
|
-
clus.centroids.resize
|
|
275
|
+
clus.centroids.resize(dsub * ksub);
|
|
278
276
|
}
|
|
279
277
|
|
|
280
278
|
switch (final_train_type) {
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
279
|
+
case Train_hypercube:
|
|
280
|
+
init_hypercube(
|
|
281
|
+
dsub, nbits, n, xslice, clus.centroids.data());
|
|
282
|
+
break;
|
|
283
|
+
case Train_hypercube_pca:
|
|
284
|
+
init_hypercube_pca(
|
|
285
|
+
dsub, nbits, n, xslice, clus.centroids.data());
|
|
286
|
+
break;
|
|
287
|
+
case Train_hot_start:
|
|
288
|
+
memcpy(clus.centroids.data(),
|
|
289
|
+
get_centroids(m, 0),
|
|
290
|
+
dsub * ksub * sizeof(float));
|
|
291
|
+
break;
|
|
292
|
+
default:;
|
|
295
293
|
}
|
|
296
294
|
|
|
297
|
-
if(verbose) {
|
|
295
|
+
if (verbose) {
|
|
298
296
|
clus.verbose = true;
|
|
299
|
-
printf
|
|
297
|
+
printf("Training PQ slice %d/%zd\n", m, M);
|
|
300
298
|
}
|
|
301
|
-
IndexFlatL2 index
|
|
302
|
-
clus.train
|
|
303
|
-
set_params
|
|
299
|
+
IndexFlatL2 index(dsub);
|
|
300
|
+
clus.train(n, xslice, assign_index ? *assign_index : index);
|
|
301
|
+
set_params(clus.centroids.data(), m);
|
|
304
302
|
}
|
|
305
303
|
|
|
306
|
-
|
|
307
304
|
} else {
|
|
305
|
+
Clustering clus(dsub, ksub, cp);
|
|
308
306
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
if(verbose) {
|
|
307
|
+
if (verbose) {
|
|
312
308
|
clus.verbose = true;
|
|
313
|
-
printf
|
|
309
|
+
printf("Training all PQ slices at once\n");
|
|
314
310
|
}
|
|
315
311
|
|
|
316
|
-
IndexFlatL2 index
|
|
312
|
+
IndexFlatL2 index(dsub);
|
|
317
313
|
|
|
318
|
-
clus.train
|
|
314
|
+
clus.train(n * M, x, assign_index ? *assign_index : index);
|
|
319
315
|
for (int m = 0; m < M; m++) {
|
|
320
|
-
set_params
|
|
316
|
+
set_params(clus.centroids.data(), m);
|
|
321
317
|
}
|
|
322
|
-
|
|
323
318
|
}
|
|
324
319
|
}
|
|
325
320
|
|
|
326
|
-
template<class PQEncoder>
|
|
327
|
-
void compute_code(const ProductQuantizer& pq, const float
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
321
|
+
template <class PQEncoder>
|
|
322
|
+
void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
|
|
323
|
+
std::vector<float> distances(pq.ksub);
|
|
324
|
+
PQEncoder encoder(code, pq.nbits);
|
|
325
|
+
for (size_t m = 0; m < pq.M; m++) {
|
|
326
|
+
float mindis = 1e20;
|
|
327
|
+
uint64_t idxm = 0;
|
|
328
|
+
const float* xsub = x + m * pq.dsub;
|
|
329
|
+
|
|
330
|
+
fvec_L2sqr_ny(
|
|
331
|
+
distances.data(),
|
|
332
|
+
xsub,
|
|
333
|
+
pq.get_centroids(m, 0),
|
|
334
|
+
pq.dsub,
|
|
335
|
+
pq.ksub);
|
|
336
|
+
|
|
337
|
+
/* Find best centroid */
|
|
338
|
+
for (size_t i = 0; i < pq.ksub; i++) {
|
|
339
|
+
float dis = distances[i];
|
|
340
|
+
if (dis < mindis) {
|
|
341
|
+
mindis = dis;
|
|
342
|
+
idxm = i;
|
|
343
|
+
}
|
|
344
|
+
}
|
|
345
345
|
|
|
346
|
-
|
|
347
|
-
|
|
346
|
+
encoder.encode(idxm);
|
|
347
|
+
}
|
|
348
348
|
}
|
|
349
349
|
|
|
350
|
-
void ProductQuantizer::compute_code(const float
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
350
|
+
void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
|
|
351
|
+
switch (nbits) {
|
|
352
|
+
case 8:
|
|
353
|
+
faiss::compute_code<PQEncoder8>(*this, x, code);
|
|
354
|
+
break;
|
|
355
355
|
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
356
|
+
case 16:
|
|
357
|
+
faiss::compute_code<PQEncoder16>(*this, x, code);
|
|
358
|
+
break;
|
|
359
359
|
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
360
|
+
default:
|
|
361
|
+
faiss::compute_code<PQEncoderGeneric>(*this, x, code);
|
|
362
|
+
break;
|
|
363
|
+
}
|
|
364
364
|
}
|
|
365
365
|
|
|
366
|
-
template<class PQDecoder>
|
|
367
|
-
void decode(const ProductQuantizer& pq, const uint8_t
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
366
|
+
template <class PQDecoder>
|
|
367
|
+
void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) {
|
|
368
|
+
PQDecoder decoder(code, pq.nbits);
|
|
369
|
+
for (size_t m = 0; m < pq.M; m++) {
|
|
370
|
+
uint64_t c = decoder.decode();
|
|
371
|
+
memcpy(x + m * pq.dsub,
|
|
372
|
+
pq.get_centroids(m, c),
|
|
373
|
+
sizeof(float) * pq.dsub);
|
|
374
|
+
}
|
|
374
375
|
}
|
|
375
376
|
|
|
376
|
-
void ProductQuantizer::decode
|
|
377
|
-
{
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
break;
|
|
386
|
-
|
|
387
|
-
default:
|
|
388
|
-
faiss::decode<PQDecoderGeneric>(*this, code, x);
|
|
389
|
-
break;
|
|
390
|
-
}
|
|
391
|
-
}
|
|
377
|
+
void ProductQuantizer::decode(const uint8_t* code, float* x) const {
|
|
378
|
+
switch (nbits) {
|
|
379
|
+
case 8:
|
|
380
|
+
faiss::decode<PQDecoder8>(*this, code, x);
|
|
381
|
+
break;
|
|
382
|
+
|
|
383
|
+
case 16:
|
|
384
|
+
faiss::decode<PQDecoder16>(*this, code, x);
|
|
385
|
+
break;
|
|
392
386
|
|
|
387
|
+
default:
|
|
388
|
+
faiss::decode<PQDecoderGeneric>(*this, code, x);
|
|
389
|
+
break;
|
|
390
|
+
}
|
|
391
|
+
}
|
|
393
392
|
|
|
394
|
-
void ProductQuantizer::decode
|
|
395
|
-
{
|
|
393
|
+
void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
|
396
394
|
for (size_t i = 0; i < n; i++) {
|
|
397
|
-
this->decode
|
|
395
|
+
this->decode(code + code_size * i, x + d * i);
|
|
398
396
|
}
|
|
399
397
|
}
|
|
400
398
|
|
|
399
|
+
void ProductQuantizer::compute_code_from_distance_table(
|
|
400
|
+
const float* tab,
|
|
401
|
+
uint8_t* code) const {
|
|
402
|
+
PQEncoderGeneric encoder(code, nbits);
|
|
403
|
+
for (size_t m = 0; m < M; m++) {
|
|
404
|
+
float mindis = 1e20;
|
|
405
|
+
uint64_t idxm = 0;
|
|
406
|
+
|
|
407
|
+
/* Find best centroid */
|
|
408
|
+
for (size_t j = 0; j < ksub; j++) {
|
|
409
|
+
float dis = *tab++;
|
|
410
|
+
if (dis < mindis) {
|
|
411
|
+
mindis = dis;
|
|
412
|
+
idxm = j;
|
|
413
|
+
}
|
|
414
|
+
}
|
|
401
415
|
|
|
402
|
-
|
|
403
|
-
uint8_t *code) const
|
|
404
|
-
{
|
|
405
|
-
PQEncoderGeneric encoder(code, nbits);
|
|
406
|
-
for (size_t m = 0; m < M; m++) {
|
|
407
|
-
float mindis = 1e20;
|
|
408
|
-
uint64_t idxm = 0;
|
|
409
|
-
|
|
410
|
-
/* Find best centroid */
|
|
411
|
-
for (size_t j = 0; j < ksub; j++) {
|
|
412
|
-
float dis = *tab++;
|
|
413
|
-
if (dis < mindis) {
|
|
414
|
-
mindis = dis;
|
|
415
|
-
idxm = j;
|
|
416
|
-
}
|
|
416
|
+
encoder.encode(idxm);
|
|
417
417
|
}
|
|
418
|
-
|
|
419
|
-
encoder.encode(idxm);
|
|
420
|
-
}
|
|
421
418
|
}
|
|
422
419
|
|
|
423
|
-
void ProductQuantizer::compute_codes_with_assign_index
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
|
|
420
|
+
void ProductQuantizer::compute_codes_with_assign_index(
|
|
421
|
+
const float* x,
|
|
422
|
+
uint8_t* codes,
|
|
423
|
+
size_t n) {
|
|
424
|
+
FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub);
|
|
429
425
|
|
|
430
426
|
for (size_t m = 0; m < M; m++) {
|
|
431
|
-
assign_index->reset
|
|
432
|
-
assign_index->add
|
|
427
|
+
assign_index->reset();
|
|
428
|
+
assign_index->add(ksub, get_centroids(m, 0));
|
|
433
429
|
size_t bs = 65536;
|
|
434
|
-
float
|
|
435
|
-
ScopeDeleter<float> del
|
|
436
|
-
idx_t
|
|
437
|
-
ScopeDeleter<idx_t> del2
|
|
430
|
+
float* xslice = new float[bs * dsub];
|
|
431
|
+
ScopeDeleter<float> del(xslice);
|
|
432
|
+
idx_t* assign = new idx_t[bs];
|
|
433
|
+
ScopeDeleter<idx_t> del2(assign);
|
|
438
434
|
|
|
439
435
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
440
436
|
size_t i1 = std::min(i0 + bs, n);
|
|
441
437
|
|
|
442
438
|
for (size_t i = i0; i < i1; i++) {
|
|
443
|
-
memcpy
|
|
444
|
-
|
|
445
|
-
|
|
439
|
+
memcpy(xslice + (i - i0) * dsub,
|
|
440
|
+
x + i * d + m * dsub,
|
|
441
|
+
dsub * sizeof(float));
|
|
446
442
|
}
|
|
447
443
|
|
|
448
|
-
assign_index->assign
|
|
444
|
+
assign_index->assign(i1 - i0, xslice, assign);
|
|
449
445
|
|
|
450
446
|
if (nbits == 8) {
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
447
|
+
uint8_t* c = codes + code_size * i0 + m;
|
|
448
|
+
for (size_t i = i0; i < i1; i++) {
|
|
449
|
+
*c = assign[i - i0];
|
|
450
|
+
c += M;
|
|
451
|
+
}
|
|
456
452
|
} else if (nbits == 16) {
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
453
|
+
uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2);
|
|
454
|
+
for (size_t i = i0; i < i1; i++) {
|
|
455
|
+
*c = assign[i - i0];
|
|
456
|
+
c += M;
|
|
457
|
+
}
|
|
462
458
|
} else {
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
459
|
+
for (size_t i = i0; i < i1; ++i) {
|
|
460
|
+
uint8_t* c = codes + code_size * i + ((m * nbits) / 8);
|
|
461
|
+
uint8_t offset = (m * nbits) % 8;
|
|
462
|
+
uint64_t ass = assign[i - i0];
|
|
463
|
+
|
|
464
|
+
PQEncoderGeneric encoder(c, nbits, offset);
|
|
465
|
+
encoder.encode(ass);
|
|
466
|
+
}
|
|
471
467
|
}
|
|
472
|
-
|
|
473
468
|
}
|
|
474
469
|
}
|
|
475
|
-
|
|
476
470
|
}
|
|
477
471
|
|
|
478
|
-
void ProductQuantizer::compute_codes
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
{
|
|
482
|
-
// process by blocks to avoid using too much RAM
|
|
472
|
+
void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
473
|
+
const {
|
|
474
|
+
// process by blocks to avoid using too much RAM
|
|
483
475
|
size_t bs = 256 * 1024;
|
|
484
476
|
if (n > bs) {
|
|
485
477
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
486
478
|
size_t i1 = std::min(i0 + bs, n);
|
|
487
|
-
compute_codes
|
|
479
|
+
compute_codes(x + d * i0, codes + code_size * i0, i1 - i0);
|
|
488
480
|
}
|
|
489
481
|
return;
|
|
490
482
|
}
|
|
@@ -493,282 +485,300 @@ void ProductQuantizer::compute_codes (const float * x,
|
|
|
493
485
|
|
|
494
486
|
#pragma omp parallel for
|
|
495
487
|
for (int64_t i = 0; i < n; i++)
|
|
496
|
-
compute_code
|
|
488
|
+
compute_code(x + i * d, codes + i * code_size);
|
|
497
489
|
|
|
498
490
|
} else { // worthwile to use BLAS
|
|
499
|
-
float
|
|
500
|
-
ScopeDeleter<float> del
|
|
501
|
-
compute_distance_tables
|
|
491
|
+
float* dis_tables = new float[n * ksub * M];
|
|
492
|
+
ScopeDeleter<float> del(dis_tables);
|
|
493
|
+
compute_distance_tables(n, x, dis_tables);
|
|
502
494
|
|
|
503
495
|
#pragma omp parallel for
|
|
504
496
|
for (int64_t i = 0; i < n; i++) {
|
|
505
|
-
uint8_t
|
|
506
|
-
const float
|
|
507
|
-
compute_code_from_distance_table
|
|
497
|
+
uint8_t* code = codes + i * code_size;
|
|
498
|
+
const float* tab = dis_tables + i * ksub * M;
|
|
499
|
+
compute_code_from_distance_table(tab, code);
|
|
508
500
|
}
|
|
509
501
|
}
|
|
510
502
|
}
|
|
511
503
|
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
float * dis_table) const
|
|
515
|
-
{
|
|
504
|
+
void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
|
|
505
|
+
const {
|
|
516
506
|
size_t m;
|
|
517
507
|
|
|
518
508
|
for (m = 0; m < M; m++) {
|
|
519
|
-
fvec_L2sqr_ny
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
509
|
+
fvec_L2sqr_ny(
|
|
510
|
+
dis_table + m * ksub,
|
|
511
|
+
x + m * dsub,
|
|
512
|
+
get_centroids(m, 0),
|
|
513
|
+
dsub,
|
|
514
|
+
ksub);
|
|
524
515
|
}
|
|
525
516
|
}
|
|
526
517
|
|
|
527
|
-
void ProductQuantizer::compute_inner_prod_table
|
|
528
|
-
|
|
529
|
-
{
|
|
518
|
+
void ProductQuantizer::compute_inner_prod_table(
|
|
519
|
+
const float* x,
|
|
520
|
+
float* dis_table) const {
|
|
530
521
|
size_t m;
|
|
531
522
|
|
|
532
523
|
for (m = 0; m < M; m++) {
|
|
533
|
-
fvec_inner_products_ny
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
524
|
+
fvec_inner_products_ny(
|
|
525
|
+
dis_table + m * ksub,
|
|
526
|
+
x + m * dsub,
|
|
527
|
+
get_centroids(m, 0),
|
|
528
|
+
dsub,
|
|
529
|
+
ksub);
|
|
538
530
|
}
|
|
539
531
|
}
|
|
540
532
|
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
{
|
|
547
|
-
|
|
548
|
-
#ifdef __AVX2__
|
|
533
|
+
void ProductQuantizer::compute_distance_tables(
|
|
534
|
+
size_t nx,
|
|
535
|
+
const float* x,
|
|
536
|
+
float* dis_tables) const {
|
|
537
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
|
549
538
|
if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
|
|
550
539
|
compute_PQ_dis_tables_dsub2(
|
|
551
|
-
|
|
552
|
-
nx, x, false, dis_tables
|
|
553
|
-
);
|
|
540
|
+
d, ksub, centroids.data(), nx, x, false, dis_tables);
|
|
554
541
|
} else
|
|
555
542
|
#endif
|
|
556
|
-
|
|
543
|
+
if (dsub < 16) {
|
|
557
544
|
|
|
558
545
|
#pragma omp parallel for
|
|
559
546
|
for (int64_t i = 0; i < nx; i++) {
|
|
560
|
-
compute_distance_table
|
|
547
|
+
compute_distance_table(x + i * d, dis_tables + i * ksub * M);
|
|
561
548
|
}
|
|
562
549
|
|
|
563
550
|
} else { // use BLAS
|
|
564
551
|
|
|
565
552
|
for (int m = 0; m < M; m++) {
|
|
566
|
-
pairwise_L2sqr
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
553
|
+
pairwise_L2sqr(
|
|
554
|
+
dsub,
|
|
555
|
+
nx,
|
|
556
|
+
x + dsub * m,
|
|
557
|
+
ksub,
|
|
558
|
+
centroids.data() + m * dsub * ksub,
|
|
559
|
+
dis_tables + ksub * m,
|
|
560
|
+
d,
|
|
561
|
+
dsub,
|
|
562
|
+
ksub * M);
|
|
571
563
|
}
|
|
572
564
|
}
|
|
573
565
|
}
|
|
574
566
|
|
|
575
|
-
void ProductQuantizer::compute_inner_prod_tables
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
#ifdef __AVX2__
|
|
567
|
+
void ProductQuantizer::compute_inner_prod_tables(
|
|
568
|
+
size_t nx,
|
|
569
|
+
const float* x,
|
|
570
|
+
float* dis_tables) const {
|
|
571
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
|
581
572
|
if (dsub == 2 && nbits < 8) {
|
|
582
573
|
compute_PQ_dis_tables_dsub2(
|
|
583
|
-
|
|
584
|
-
nx, x, true, dis_tables
|
|
585
|
-
);
|
|
574
|
+
d, ksub, centroids.data(), nx, x, true, dis_tables);
|
|
586
575
|
} else
|
|
587
576
|
#endif
|
|
588
|
-
|
|
577
|
+
if (dsub < 16) {
|
|
589
578
|
|
|
590
579
|
#pragma omp parallel for
|
|
591
580
|
for (int64_t i = 0; i < nx; i++) {
|
|
592
|
-
compute_inner_prod_table
|
|
581
|
+
compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M);
|
|
593
582
|
}
|
|
594
583
|
|
|
595
584
|
} else { // use BLAS
|
|
596
585
|
|
|
597
586
|
// compute distance tables
|
|
598
587
|
for (int m = 0; m < M; m++) {
|
|
599
|
-
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
|
|
600
|
-
|
|
588
|
+
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub,
|
|
589
|
+
di = d;
|
|
601
590
|
float one = 1.0, zero = 0;
|
|
602
591
|
|
|
603
|
-
sgemm_
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
592
|
+
sgemm_("Transposed",
|
|
593
|
+
"Not transposed",
|
|
594
|
+
&ksubi,
|
|
595
|
+
&nxi,
|
|
596
|
+
&dsubi,
|
|
597
|
+
&one,
|
|
598
|
+
¢roids[m * dsub * ksub],
|
|
599
|
+
&dsubi,
|
|
600
|
+
x + dsub * m,
|
|
601
|
+
&di,
|
|
602
|
+
&zero,
|
|
603
|
+
dis_tables + ksub * m,
|
|
604
|
+
&ldc);
|
|
608
605
|
}
|
|
609
|
-
|
|
610
606
|
}
|
|
611
607
|
}
|
|
612
608
|
|
|
613
609
|
template <class C>
|
|
614
|
-
static void pq_knn_search_with_tables
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
{
|
|
610
|
+
static void pq_knn_search_with_tables(
|
|
611
|
+
const ProductQuantizer& pq,
|
|
612
|
+
size_t nbits,
|
|
613
|
+
const float* dis_tables,
|
|
614
|
+
const uint8_t* codes,
|
|
615
|
+
const size_t ncodes,
|
|
616
|
+
HeapArray<C>* res,
|
|
617
|
+
bool init_finalize_heap) {
|
|
623
618
|
size_t k = res->k, nx = res->nh;
|
|
624
619
|
size_t ksub = pq.ksub, M = pq.M;
|
|
625
620
|
|
|
626
|
-
|
|
627
621
|
#pragma omp parallel for
|
|
628
622
|
for (int64_t i = 0; i < nx; i++) {
|
|
629
623
|
/* query preparation for asymmetric search: compute look-up tables */
|
|
630
624
|
const float* dis_table = dis_tables + i * ksub * M;
|
|
631
625
|
|
|
632
626
|
/* Compute distances and keep smallest values */
|
|
633
|
-
int64_t
|
|
634
|
-
float
|
|
627
|
+
int64_t* __restrict heap_ids = res->ids + i * k;
|
|
628
|
+
float* __restrict heap_dis = res->val + i * k;
|
|
635
629
|
|
|
636
630
|
if (init_finalize_heap) {
|
|
637
|
-
heap_heapify<C>
|
|
631
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
638
632
|
}
|
|
639
633
|
|
|
640
634
|
switch (nbits) {
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
635
|
+
case 8:
|
|
636
|
+
pq_estimators_from_tables<uint8_t, C>(
|
|
637
|
+
pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
|
|
638
|
+
break;
|
|
639
|
+
|
|
640
|
+
case 16:
|
|
641
|
+
pq_estimators_from_tables<uint16_t, C>(
|
|
642
|
+
pq,
|
|
643
|
+
(uint16_t*)codes,
|
|
644
|
+
ncodes,
|
|
645
|
+
dis_table,
|
|
646
|
+
k,
|
|
647
|
+
heap_dis,
|
|
648
|
+
heap_ids);
|
|
649
|
+
break;
|
|
650
|
+
|
|
651
|
+
default:
|
|
652
|
+
pq_estimators_from_tables_generic<C>(
|
|
653
|
+
pq,
|
|
654
|
+
nbits,
|
|
655
|
+
codes,
|
|
656
|
+
ncodes,
|
|
657
|
+
dis_table,
|
|
658
|
+
k,
|
|
659
|
+
heap_dis,
|
|
660
|
+
heap_ids);
|
|
661
|
+
break;
|
|
662
662
|
}
|
|
663
663
|
|
|
664
664
|
if (init_finalize_heap) {
|
|
665
|
-
heap_reorder<C>
|
|
665
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
666
666
|
}
|
|
667
667
|
}
|
|
668
668
|
}
|
|
669
669
|
|
|
670
|
-
void ProductQuantizer::search
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
{
|
|
677
|
-
FAISS_THROW_IF_NOT
|
|
678
|
-
std::unique_ptr<float[]> dis_tables(new float
|
|
679
|
-
compute_distance_tables
|
|
680
|
-
|
|
681
|
-
pq_knn_search_with_tables<CMax<float, int64_t>>
|
|
682
|
-
|
|
670
|
+
void ProductQuantizer::search(
|
|
671
|
+
const float* __restrict x,
|
|
672
|
+
size_t nx,
|
|
673
|
+
const uint8_t* codes,
|
|
674
|
+
const size_t ncodes,
|
|
675
|
+
float_maxheap_array_t* res,
|
|
676
|
+
bool init_finalize_heap) const {
|
|
677
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
|
678
|
+
std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
|
|
679
|
+
compute_distance_tables(nx, x, dis_tables.get());
|
|
680
|
+
|
|
681
|
+
pq_knn_search_with_tables<CMax<float, int64_t>>(
|
|
682
|
+
*this,
|
|
683
|
+
nbits,
|
|
684
|
+
dis_tables.get(),
|
|
685
|
+
codes,
|
|
686
|
+
ncodes,
|
|
687
|
+
res,
|
|
688
|
+
init_finalize_heap);
|
|
683
689
|
}
|
|
684
690
|
|
|
685
|
-
void ProductQuantizer::search_ip
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
{
|
|
692
|
-
FAISS_THROW_IF_NOT
|
|
693
|
-
std::unique_ptr<float[]> dis_tables(new float
|
|
694
|
-
compute_inner_prod_tables
|
|
695
|
-
|
|
696
|
-
pq_knn_search_with_tables<CMin<float, int64_t
|
|
697
|
-
|
|
691
|
+
void ProductQuantizer::search_ip(
|
|
692
|
+
const float* __restrict x,
|
|
693
|
+
size_t nx,
|
|
694
|
+
const uint8_t* codes,
|
|
695
|
+
const size_t ncodes,
|
|
696
|
+
float_minheap_array_t* res,
|
|
697
|
+
bool init_finalize_heap) const {
|
|
698
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
|
699
|
+
std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
|
|
700
|
+
compute_inner_prod_tables(nx, x, dis_tables.get());
|
|
701
|
+
|
|
702
|
+
pq_knn_search_with_tables<CMin<float, int64_t>>(
|
|
703
|
+
*this,
|
|
704
|
+
nbits,
|
|
705
|
+
dis_tables.get(),
|
|
706
|
+
codes,
|
|
707
|
+
ncodes,
|
|
708
|
+
res,
|
|
709
|
+
init_finalize_heap);
|
|
698
710
|
}
|
|
699
711
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
static float sqr (float x) {
|
|
712
|
+
static float sqr(float x) {
|
|
703
713
|
return x * x;
|
|
704
714
|
}
|
|
705
715
|
|
|
706
|
-
void ProductQuantizer::compute_sdc_table
|
|
707
|
-
|
|
708
|
-
sdc_table.resize (M * ksub * ksub);
|
|
709
|
-
|
|
710
|
-
for (int m = 0; m < M; m++) {
|
|
716
|
+
void ProductQuantizer::compute_sdc_table() {
|
|
717
|
+
sdc_table.resize(M * ksub * ksub);
|
|
711
718
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
719
|
+
if (dsub < 4) {
|
|
720
|
+
#pragma omp parallel for
|
|
721
|
+
for (int mk = 0; mk < M * ksub; mk++) {
|
|
722
|
+
// allow omp to schedule in a more fine-grained way
|
|
723
|
+
// `collapse` is not supported in OpenMP 2.x
|
|
724
|
+
int m = mk / ksub;
|
|
725
|
+
int k = mk % ksub;
|
|
726
|
+
const float* cents = centroids.data() + m * ksub * dsub;
|
|
727
|
+
const float* centi = cents + k * dsub;
|
|
728
|
+
float* dis_tab = sdc_table.data() + m * ksub * ksub;
|
|
729
|
+
fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
|
|
730
|
+
}
|
|
731
|
+
} else {
|
|
732
|
+
// NOTE: it would disable the omp loop in pairwise_L2sqr
|
|
733
|
+
// but still accelerate especially when M >= 4
|
|
734
|
+
#pragma omp parallel for
|
|
735
|
+
for (int m = 0; m < M; m++) {
|
|
736
|
+
const float* cents = centroids.data() + m * ksub * dsub;
|
|
737
|
+
float* dis_tab = sdc_table.data() + m * ksub * ksub;
|
|
738
|
+
pairwise_L2sqr(
|
|
739
|
+
dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub);
|
|
725
740
|
}
|
|
726
741
|
}
|
|
727
742
|
}
|
|
728
743
|
|
|
729
|
-
void ProductQuantizer::search_sdc
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
{
|
|
736
|
-
FAISS_THROW_IF_NOT
|
|
737
|
-
FAISS_THROW_IF_NOT
|
|
744
|
+
void ProductQuantizer::search_sdc(
|
|
745
|
+
const uint8_t* qcodes,
|
|
746
|
+
size_t nq,
|
|
747
|
+
const uint8_t* bcodes,
|
|
748
|
+
const size_t nb,
|
|
749
|
+
float_maxheap_array_t* res,
|
|
750
|
+
bool init_finalize_heap) const {
|
|
751
|
+
FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub);
|
|
752
|
+
FAISS_THROW_IF_NOT(nbits == 8);
|
|
738
753
|
size_t k = res->k;
|
|
739
754
|
|
|
740
|
-
|
|
741
755
|
#pragma omp parallel for
|
|
742
756
|
for (int64_t i = 0; i < nq; i++) {
|
|
743
|
-
|
|
744
757
|
/* Compute distances and keep smallest values */
|
|
745
|
-
idx_t
|
|
746
|
-
float
|
|
747
|
-
const uint8_t
|
|
758
|
+
idx_t* heap_ids = res->ids + i * k;
|
|
759
|
+
float* heap_dis = res->val + i * k;
|
|
760
|
+
const uint8_t* qcode = qcodes + i * code_size;
|
|
748
761
|
|
|
749
762
|
if (init_finalize_heap)
|
|
750
|
-
maxheap_heapify
|
|
763
|
+
maxheap_heapify(k, heap_dis, heap_ids);
|
|
751
764
|
|
|
752
|
-
const uint8_t
|
|
765
|
+
const uint8_t* bcode = bcodes;
|
|
753
766
|
for (size_t j = 0; j < nb; j++) {
|
|
754
767
|
float dis = 0;
|
|
755
|
-
const float
|
|
768
|
+
const float* tab = sdc_table.data();
|
|
756
769
|
for (int m = 0; m < M; m++) {
|
|
757
770
|
dis += tab[bcode[m] + qcode[m] * ksub];
|
|
758
771
|
tab += ksub * ksub;
|
|
759
772
|
}
|
|
760
773
|
if (dis < heap_dis[0]) {
|
|
761
|
-
maxheap_replace_top
|
|
774
|
+
maxheap_replace_top(k, heap_dis, heap_ids, dis, j);
|
|
762
775
|
}
|
|
763
776
|
bcode += code_size;
|
|
764
777
|
}
|
|
765
778
|
|
|
766
779
|
if (init_finalize_heap)
|
|
767
|
-
maxheap_reorder
|
|
780
|
+
maxheap_reorder(k, heap_dis, heap_ids);
|
|
768
781
|
}
|
|
769
|
-
|
|
770
782
|
}
|
|
771
783
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
} // namespace faiss
|
|
784
|
+
} // namespace faiss
|