faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -0,0 +1,855 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/LocalSearchQuantizer.h>
|
|
9
|
+
|
|
10
|
+
#include <cstddef>
|
|
11
|
+
#include <cstdio>
|
|
12
|
+
#include <cstring>
|
|
13
|
+
#include <memory>
|
|
14
|
+
#include <random>
|
|
15
|
+
|
|
16
|
+
#include <algorithm>
|
|
17
|
+
|
|
18
|
+
#include <faiss/Clustering.h>
|
|
19
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
|
21
|
+
#include <faiss/utils/distances.h>
|
|
22
|
+
#include <faiss/utils/hamming.h> // BitstringWriter
|
|
23
|
+
#include <faiss/utils/utils.h>
|
|
24
|
+
|
|
25
|
+
extern "C" {
|
|
26
|
+
// LU decomoposition of a general matrix
|
|
27
|
+
void sgetrf_(
|
|
28
|
+
FINTEGER* m,
|
|
29
|
+
FINTEGER* n,
|
|
30
|
+
float* a,
|
|
31
|
+
FINTEGER* lda,
|
|
32
|
+
FINTEGER* ipiv,
|
|
33
|
+
FINTEGER* info);
|
|
34
|
+
|
|
35
|
+
// generate inverse of a matrix given its LU decomposition
|
|
36
|
+
void sgetri_(
|
|
37
|
+
FINTEGER* n,
|
|
38
|
+
float* a,
|
|
39
|
+
FINTEGER* lda,
|
|
40
|
+
FINTEGER* ipiv,
|
|
41
|
+
float* work,
|
|
42
|
+
FINTEGER* lwork,
|
|
43
|
+
FINTEGER* info);
|
|
44
|
+
|
|
45
|
+
// general matrix multiplication
|
|
46
|
+
int sgemm_(
|
|
47
|
+
const char* transa,
|
|
48
|
+
const char* transb,
|
|
49
|
+
FINTEGER* m,
|
|
50
|
+
FINTEGER* n,
|
|
51
|
+
FINTEGER* k,
|
|
52
|
+
const float* alpha,
|
|
53
|
+
const float* a,
|
|
54
|
+
FINTEGER* lda,
|
|
55
|
+
const float* b,
|
|
56
|
+
FINTEGER* ldb,
|
|
57
|
+
float* beta,
|
|
58
|
+
float* c,
|
|
59
|
+
FINTEGER* ldc);
|
|
60
|
+
|
|
61
|
+
// LU decomoposition of a general matrix
|
|
62
|
+
void dgetrf_(
|
|
63
|
+
FINTEGER* m,
|
|
64
|
+
FINTEGER* n,
|
|
65
|
+
double* a,
|
|
66
|
+
FINTEGER* lda,
|
|
67
|
+
FINTEGER* ipiv,
|
|
68
|
+
FINTEGER* info);
|
|
69
|
+
|
|
70
|
+
// generate inverse of a matrix given its LU decomposition
|
|
71
|
+
void dgetri_(
|
|
72
|
+
FINTEGER* n,
|
|
73
|
+
double* a,
|
|
74
|
+
FINTEGER* lda,
|
|
75
|
+
FINTEGER* ipiv,
|
|
76
|
+
double* work,
|
|
77
|
+
FINTEGER* lwork,
|
|
78
|
+
FINTEGER* info);
|
|
79
|
+
|
|
80
|
+
// general matrix multiplication
|
|
81
|
+
int dgemm_(
|
|
82
|
+
const char* transa,
|
|
83
|
+
const char* transb,
|
|
84
|
+
FINTEGER* m,
|
|
85
|
+
FINTEGER* n,
|
|
86
|
+
FINTEGER* k,
|
|
87
|
+
const double* alpha,
|
|
88
|
+
const double* a,
|
|
89
|
+
FINTEGER* lda,
|
|
90
|
+
const double* b,
|
|
91
|
+
FINTEGER* ldb,
|
|
92
|
+
double* beta,
|
|
93
|
+
double* c,
|
|
94
|
+
FINTEGER* ldc);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
namespace {
|
|
98
|
+
|
|
99
|
+
void fmat_inverse(float* a, int n) {
|
|
100
|
+
int info;
|
|
101
|
+
int lwork = n * n;
|
|
102
|
+
std::vector<int> ipiv(n);
|
|
103
|
+
std::vector<float> workspace(lwork);
|
|
104
|
+
|
|
105
|
+
sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
106
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
107
|
+
sgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
|
108
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// c and a and b can overlap
|
|
112
|
+
void dfvec_add(size_t d, const double* a, const float* b, double* c) {
|
|
113
|
+
for (size_t i = 0; i < d; i++) {
|
|
114
|
+
c[i] = a[i] + b[i];
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
void dmat_inverse(double* a, int n) {
|
|
119
|
+
int info;
|
|
120
|
+
int lwork = n * n;
|
|
121
|
+
std::vector<int> ipiv(n);
|
|
122
|
+
std::vector<double> workspace(lwork);
|
|
123
|
+
|
|
124
|
+
dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
125
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
126
|
+
dgetri_(&n, a, &n, ipiv.data(), workspace.data(), &lwork, &info);
|
|
127
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
void random_int32(
|
|
131
|
+
std::vector<int32_t>& x,
|
|
132
|
+
int32_t min,
|
|
133
|
+
int32_t max,
|
|
134
|
+
std::mt19937& gen) {
|
|
135
|
+
std::uniform_int_distribution<int32_t> distrib(min, max);
|
|
136
|
+
for (size_t i = 0; i < x.size(); i++) {
|
|
137
|
+
x[i] = distrib(gen);
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
} // anonymous namespace
|
|
142
|
+
|
|
143
|
+
namespace faiss {
|
|
144
|
+
|
|
145
|
+
lsq::LSQTimer lsq_timer;
|
|
146
|
+
using lsq::LSQTimerScope;
|
|
147
|
+
|
|
148
|
+
LocalSearchQuantizer::LocalSearchQuantizer(
|
|
149
|
+
size_t d,
|
|
150
|
+
size_t M,
|
|
151
|
+
size_t nbits,
|
|
152
|
+
Search_type_t search_type)
|
|
153
|
+
: AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
|
|
154
|
+
is_trained = false;
|
|
155
|
+
verbose = false;
|
|
156
|
+
|
|
157
|
+
K = (1 << nbits);
|
|
158
|
+
|
|
159
|
+
train_iters = 25;
|
|
160
|
+
train_ils_iters = 8;
|
|
161
|
+
icm_iters = 4;
|
|
162
|
+
|
|
163
|
+
encode_ils_iters = 16;
|
|
164
|
+
|
|
165
|
+
p = 0.5f;
|
|
166
|
+
lambd = 1e-2f;
|
|
167
|
+
|
|
168
|
+
chunk_size = 10000;
|
|
169
|
+
nperts = 4;
|
|
170
|
+
|
|
171
|
+
random_seed = 0x12345;
|
|
172
|
+
std::srand(random_seed);
|
|
173
|
+
|
|
174
|
+
icm_encoder_factory = nullptr;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
LocalSearchQuantizer::~LocalSearchQuantizer() {
|
|
178
|
+
delete icm_encoder_factory;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
|
|
182
|
+
|
|
183
|
+
void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
184
|
+
FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
|
|
185
|
+
FAISS_THROW_IF_NOT(nperts <= M);
|
|
186
|
+
|
|
187
|
+
lsq_timer.reset();
|
|
188
|
+
LSQTimerScope scope(&lsq_timer, "train");
|
|
189
|
+
if (verbose) {
|
|
190
|
+
printf("Training LSQ, with %zd subcodes on %zd %zdD vectors\n",
|
|
191
|
+
M,
|
|
192
|
+
n,
|
|
193
|
+
d);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// allocate memory for codebooks, size [M, K, d]
|
|
197
|
+
codebooks.resize(M * K * d);
|
|
198
|
+
|
|
199
|
+
// randomly intialize codes
|
|
200
|
+
std::mt19937 gen(random_seed);
|
|
201
|
+
std::vector<int32_t> codes(n * M); // [n, M]
|
|
202
|
+
random_int32(codes, 0, K - 1, gen);
|
|
203
|
+
|
|
204
|
+
// compute standard derivations of each dimension
|
|
205
|
+
std::vector<float> stddev(d, 0);
|
|
206
|
+
|
|
207
|
+
#pragma omp parallel for
|
|
208
|
+
for (int64_t i = 0; i < d; i++) {
|
|
209
|
+
float mean = 0;
|
|
210
|
+
for (size_t j = 0; j < n; j++) {
|
|
211
|
+
mean += x[j * d + i];
|
|
212
|
+
}
|
|
213
|
+
mean = mean / n;
|
|
214
|
+
|
|
215
|
+
float sum = 0;
|
|
216
|
+
for (size_t j = 0; j < n; j++) {
|
|
217
|
+
float xi = x[j * d + i] - mean;
|
|
218
|
+
sum += xi * xi;
|
|
219
|
+
}
|
|
220
|
+
stddev[i] = sqrtf(sum / n);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if (verbose) {
|
|
224
|
+
float obj = evaluate(codes.data(), x, n);
|
|
225
|
+
printf("Before training: obj = %lf\n", obj);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
for (size_t i = 0; i < train_iters; i++) {
|
|
229
|
+
// 1. update codebooks given x and codes
|
|
230
|
+
// 2. add perturbation to codebooks (SR-D)
|
|
231
|
+
// 3. refine codes given x and codebooks using icm
|
|
232
|
+
|
|
233
|
+
// update codebooks
|
|
234
|
+
update_codebooks(x, codes.data(), n);
|
|
235
|
+
|
|
236
|
+
if (verbose) {
|
|
237
|
+
float obj = evaluate(codes.data(), x, n);
|
|
238
|
+
printf("iter %zd:\n", i);
|
|
239
|
+
printf("\tafter updating codebooks: obj = %lf\n", obj);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
// SR-D: perturb codebooks
|
|
243
|
+
float T = pow((1.0f - (i + 1.0f) / train_iters), p);
|
|
244
|
+
perturb_codebooks(T, stddev, gen);
|
|
245
|
+
|
|
246
|
+
if (verbose) {
|
|
247
|
+
float obj = evaluate(codes.data(), x, n);
|
|
248
|
+
printf("\tafter perturbing codebooks: obj = %lf\n", obj);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// refine codes
|
|
252
|
+
icm_encode(codes.data(), x, n, train_ils_iters, gen);
|
|
253
|
+
|
|
254
|
+
if (verbose) {
|
|
255
|
+
float obj = evaluate(codes.data(), x, n);
|
|
256
|
+
printf("\tafter updating codes: obj = %lf\n", obj);
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
is_trained = true;
|
|
261
|
+
{
|
|
262
|
+
std::vector<float> x_recons(n * d);
|
|
263
|
+
std::vector<float> norms(n);
|
|
264
|
+
decode_unpacked(codes.data(), x_recons.data(), n);
|
|
265
|
+
fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
|
|
266
|
+
|
|
267
|
+
norm_min = HUGE_VALF;
|
|
268
|
+
norm_max = -HUGE_VALF;
|
|
269
|
+
for (idx_t i = 0; i < n; i++) {
|
|
270
|
+
if (norms[i] < norm_min) {
|
|
271
|
+
norm_min = norms[i];
|
|
272
|
+
}
|
|
273
|
+
if (norms[i] > norm_max) {
|
|
274
|
+
norm_max = norms[i];
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
|
|
279
|
+
size_t k = (1 << 8);
|
|
280
|
+
if (search_type == ST_norm_cqint4) {
|
|
281
|
+
k = (1 << 4);
|
|
282
|
+
}
|
|
283
|
+
Clustering1D clus(k);
|
|
284
|
+
clus.train_exact(n, norms.data());
|
|
285
|
+
qnorm.add(clus.k, clus.centroids.data());
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if (verbose) {
|
|
290
|
+
float obj = evaluate(codes.data(), x, n);
|
|
291
|
+
scope.finish();
|
|
292
|
+
printf("After training: obj = %lf\n", obj);
|
|
293
|
+
|
|
294
|
+
printf("Time statistic:\n");
|
|
295
|
+
for (const auto& it : lsq_timer.t) {
|
|
296
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
void LocalSearchQuantizer::perturb_codebooks(
|
|
302
|
+
float T,
|
|
303
|
+
const std::vector<float>& stddev,
|
|
304
|
+
std::mt19937& gen) {
|
|
305
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codebooks");
|
|
306
|
+
|
|
307
|
+
std::vector<std::normal_distribution<float>> distribs;
|
|
308
|
+
for (size_t i = 0; i < d; i++) {
|
|
309
|
+
distribs.emplace_back(0.0f, stddev[i]);
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
for (size_t m = 0; m < M; m++) {
|
|
313
|
+
for (size_t k = 0; k < K; k++) {
|
|
314
|
+
for (size_t i = 0; i < d; i++) {
|
|
315
|
+
codebooks[m * K * d + k * d + i] += T * distribs[i](gen) / M;
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
void LocalSearchQuantizer::compute_codes(
|
|
322
|
+
const float* x,
|
|
323
|
+
uint8_t* codes_out,
|
|
324
|
+
size_t n) const {
|
|
325
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
|
|
326
|
+
|
|
327
|
+
lsq_timer.reset();
|
|
328
|
+
LSQTimerScope scope(&lsq_timer, "encode");
|
|
329
|
+
if (verbose) {
|
|
330
|
+
printf("Encoding %zd vectors...\n", n);
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
std::vector<int32_t> codes(n * M);
|
|
334
|
+
std::mt19937 gen(random_seed);
|
|
335
|
+
random_int32(codes, 0, K - 1, gen);
|
|
336
|
+
|
|
337
|
+
icm_encode(codes.data(), x, n, encode_ils_iters, gen);
|
|
338
|
+
pack_codes(n, codes.data(), codes_out);
|
|
339
|
+
|
|
340
|
+
if (verbose) {
|
|
341
|
+
scope.finish();
|
|
342
|
+
printf("Time statistic:\n");
|
|
343
|
+
for (const auto& it : lsq_timer.t) {
|
|
344
|
+
printf("\t%s time: %lf s\n", it.first.data(), it.second / 1000);
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
/** update codebooks given x and codes
|
|
350
|
+
*
|
|
351
|
+
* Let B denote the sparse matrix of codes, size [n, M * K].
|
|
352
|
+
* Let C denote the codebooks, size [M * K, d].
|
|
353
|
+
* Let X denote the training vectors, size [n, d]
|
|
354
|
+
*
|
|
355
|
+
* objective function:
|
|
356
|
+
* L = (X - BC)^2
|
|
357
|
+
*
|
|
358
|
+
* To minimize L, we have:
|
|
359
|
+
* C = (B'B)^(-1)B'X
|
|
360
|
+
* where ' denote transposed
|
|
361
|
+
*
|
|
362
|
+
* Add a regularization term to make B'B inversible:
|
|
363
|
+
* C = (B'B + lambd * I)^(-1)B'X
|
|
364
|
+
*/
|
|
365
|
+
void LocalSearchQuantizer::update_codebooks(
|
|
366
|
+
const float* x,
|
|
367
|
+
const int32_t* codes,
|
|
368
|
+
size_t n) {
|
|
369
|
+
LSQTimerScope scope(&lsq_timer, "update_codebooks");
|
|
370
|
+
|
|
371
|
+
if (!update_codebooks_with_double) {
|
|
372
|
+
// allocate memory
|
|
373
|
+
// bb = B'B, bx = BX
|
|
374
|
+
std::vector<float> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
|
375
|
+
std::vector<float> bx(M * K * d, 0.0f); // [M * K, d]
|
|
376
|
+
|
|
377
|
+
// compute B'B
|
|
378
|
+
for (size_t i = 0; i < n; i++) {
|
|
379
|
+
for (size_t m = 0; m < M; m++) {
|
|
380
|
+
int32_t code1 = codes[i * M + m];
|
|
381
|
+
int32_t idx1 = m * K + code1;
|
|
382
|
+
bb[idx1 * M * K + idx1] += 1;
|
|
383
|
+
|
|
384
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
|
385
|
+
int32_t code2 = codes[i * M + m2];
|
|
386
|
+
int32_t idx2 = m2 * K + code2;
|
|
387
|
+
bb[idx1 * M * K + idx2] += 1;
|
|
388
|
+
bb[idx2 * M * K + idx1] += 1;
|
|
389
|
+
}
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// add a regularization term to B'B
|
|
394
|
+
for (int64_t i = 0; i < M * K; i++) {
|
|
395
|
+
bb[i * (M * K) + i] += lambd;
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
// compute (B'B)^(-1)
|
|
399
|
+
fmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
|
400
|
+
|
|
401
|
+
// compute BX
|
|
402
|
+
for (size_t i = 0; i < n; i++) {
|
|
403
|
+
for (size_t m = 0; m < M; m++) {
|
|
404
|
+
int32_t code = codes[i * M + m];
|
|
405
|
+
float* data = bx.data() + (m * K + code) * d;
|
|
406
|
+
fvec_add(d, data, x + i * d, data);
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
// compute C = (B'B)^(-1) @ BX
|
|
411
|
+
//
|
|
412
|
+
// NOTE: LAPACK use column major order
|
|
413
|
+
// out = alpha * op(A) * op(B) + beta * C
|
|
414
|
+
FINTEGER nrows_A = d;
|
|
415
|
+
FINTEGER ncols_A = M * K;
|
|
416
|
+
|
|
417
|
+
FINTEGER nrows_B = M * K;
|
|
418
|
+
FINTEGER ncols_B = M * K;
|
|
419
|
+
|
|
420
|
+
float alpha = 1.0f;
|
|
421
|
+
float beta = 0.0f;
|
|
422
|
+
sgemm_("Not Transposed",
|
|
423
|
+
"Not Transposed",
|
|
424
|
+
&nrows_A, // nrows of op(A)
|
|
425
|
+
&ncols_B, // ncols of op(B)
|
|
426
|
+
&ncols_A, // ncols of op(A)
|
|
427
|
+
&alpha,
|
|
428
|
+
bx.data(),
|
|
429
|
+
&nrows_A, // nrows of A
|
|
430
|
+
bb.data(),
|
|
431
|
+
&nrows_B, // nrows of B
|
|
432
|
+
&beta,
|
|
433
|
+
codebooks.data(),
|
|
434
|
+
&nrows_A); // nrows of output
|
|
435
|
+
|
|
436
|
+
} else {
|
|
437
|
+
// allocate memory
|
|
438
|
+
// bb = B'B, bx = BX
|
|
439
|
+
std::vector<double> bb(M * K * M * K, 0.0f); // [M * K, M * K]
|
|
440
|
+
std::vector<double> bx(M * K * d, 0.0f); // [M * K, d]
|
|
441
|
+
|
|
442
|
+
// compute B'B
|
|
443
|
+
for (size_t i = 0; i < n; i++) {
|
|
444
|
+
for (size_t m = 0; m < M; m++) {
|
|
445
|
+
int32_t code1 = codes[i * M + m];
|
|
446
|
+
int32_t idx1 = m * K + code1;
|
|
447
|
+
bb[idx1 * M * K + idx1] += 1;
|
|
448
|
+
|
|
449
|
+
for (size_t m2 = m + 1; m2 < M; m2++) {
|
|
450
|
+
int32_t code2 = codes[i * M + m2];
|
|
451
|
+
int32_t idx2 = m2 * K + code2;
|
|
452
|
+
bb[idx1 * M * K + idx2] += 1;
|
|
453
|
+
bb[idx2 * M * K + idx1] += 1;
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
// add a regularization term to B'B
|
|
459
|
+
for (int64_t i = 0; i < M * K; i++) {
|
|
460
|
+
bb[i * (M * K) + i] += lambd;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
// compute (B'B)^(-1)
|
|
464
|
+
dmat_inverse(bb.data(), M * K); // [M*K, M*K]
|
|
465
|
+
|
|
466
|
+
// compute BX
|
|
467
|
+
for (size_t i = 0; i < n; i++) {
|
|
468
|
+
for (size_t m = 0; m < M; m++) {
|
|
469
|
+
int32_t code = codes[i * M + m];
|
|
470
|
+
double* data = bx.data() + (m * K + code) * d;
|
|
471
|
+
dfvec_add(d, data, x + i * d, data);
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
// compute C = (B'B)^(-1) @ BX
|
|
476
|
+
//
|
|
477
|
+
// NOTE: LAPACK use column major order
|
|
478
|
+
// out = alpha * op(A) * op(B) + beta * C
|
|
479
|
+
FINTEGER nrows_A = d;
|
|
480
|
+
FINTEGER ncols_A = M * K;
|
|
481
|
+
|
|
482
|
+
FINTEGER nrows_B = M * K;
|
|
483
|
+
FINTEGER ncols_B = M * K;
|
|
484
|
+
|
|
485
|
+
std::vector<double> d_codebooks(M * K * d);
|
|
486
|
+
|
|
487
|
+
double alpha = 1.0f;
|
|
488
|
+
double beta = 0.0f;
|
|
489
|
+
dgemm_("Not Transposed",
|
|
490
|
+
"Not Transposed",
|
|
491
|
+
&nrows_A, // nrows of op(A)
|
|
492
|
+
&ncols_B, // ncols of op(B)
|
|
493
|
+
&ncols_A, // ncols of op(A)
|
|
494
|
+
&alpha,
|
|
495
|
+
bx.data(),
|
|
496
|
+
&nrows_A, // nrows of A
|
|
497
|
+
bb.data(),
|
|
498
|
+
&nrows_B, // nrows of B
|
|
499
|
+
&beta,
|
|
500
|
+
d_codebooks.data(),
|
|
501
|
+
&nrows_A); // nrows of output
|
|
502
|
+
|
|
503
|
+
for (size_t i = 0; i < M * K * d; i++) {
|
|
504
|
+
codebooks[i] = (float)d_codebooks[i];
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
/** encode using iterative conditional mode
|
|
510
|
+
*
|
|
511
|
+
* iterative conditional mode:
|
|
512
|
+
* For every subcode ci (i = 1, ..., M) of a vector, we fix the other
|
|
513
|
+
* subcodes cj (j != i) and then find the optimal value of ci such
|
|
514
|
+
* that minimizing the objective function.
|
|
515
|
+
|
|
516
|
+
* objective function:
|
|
517
|
+
* L = (X - \sum cj)^2, j = 1, ..., M
|
|
518
|
+
* L = X^2 - 2X * \sum cj + (\sum cj)^2
|
|
519
|
+
*
|
|
520
|
+
* X^2 is negligable since it is the same for all possible value
|
|
521
|
+
* k of the m-th subcode.
|
|
522
|
+
*
|
|
523
|
+
* 2X * \sum cj is the unary term
|
|
524
|
+
* (\sum cj)^2 is the binary term
|
|
525
|
+
* These two terms can be precomputed and store in a look up table.
|
|
526
|
+
*/
|
|
527
|
+
void LocalSearchQuantizer::icm_encode(
|
|
528
|
+
int32_t* codes,
|
|
529
|
+
const float* x,
|
|
530
|
+
size_t n,
|
|
531
|
+
size_t ils_iters,
|
|
532
|
+
std::mt19937& gen) const {
|
|
533
|
+
LSQTimerScope scope(&lsq_timer, "icm_encode");
|
|
534
|
+
|
|
535
|
+
auto factory = icm_encoder_factory;
|
|
536
|
+
std::unique_ptr<lsq::IcmEncoder> icm_encoder;
|
|
537
|
+
if (factory == nullptr) {
|
|
538
|
+
icm_encoder.reset(lsq::IcmEncoderFactory().get(this));
|
|
539
|
+
} else {
|
|
540
|
+
icm_encoder.reset(factory->get(this));
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
// precompute binary terms for all chunks
|
|
544
|
+
icm_encoder->set_binary_term();
|
|
545
|
+
|
|
546
|
+
const size_t n_chunks = (n + chunk_size - 1) / chunk_size;
|
|
547
|
+
for (size_t i = 0; i < n_chunks; i++) {
|
|
548
|
+
size_t ni = std::min(chunk_size, n - i * chunk_size);
|
|
549
|
+
|
|
550
|
+
if (verbose) {
|
|
551
|
+
printf("\r\ticm encoding %zd/%zd ...", i * chunk_size + ni, n);
|
|
552
|
+
fflush(stdout);
|
|
553
|
+
if (i == n_chunks - 1 || i == 0) {
|
|
554
|
+
printf("\n");
|
|
555
|
+
}
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
const float* xi = x + i * chunk_size * d;
|
|
559
|
+
int32_t* codesi = codes + i * chunk_size * M;
|
|
560
|
+
icm_encoder->verbose = (verbose && i == 0);
|
|
561
|
+
icm_encoder->encode(codesi, xi, gen, ni, ils_iters);
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
void LocalSearchQuantizer::icm_encode_impl(
|
|
566
|
+
int32_t* codes,
|
|
567
|
+
const float* x,
|
|
568
|
+
const float* binaries,
|
|
569
|
+
std::mt19937& gen,
|
|
570
|
+
size_t n,
|
|
571
|
+
size_t ils_iters,
|
|
572
|
+
bool verbose) const {
|
|
573
|
+
std::vector<float> unaries(n * M * K); // [M, n, K]
|
|
574
|
+
compute_unary_terms(x, unaries.data(), n);
|
|
575
|
+
|
|
576
|
+
std::vector<int32_t> best_codes;
|
|
577
|
+
best_codes.assign(codes, codes + n * M);
|
|
578
|
+
|
|
579
|
+
std::vector<float> best_objs(n, 0.0f);
|
|
580
|
+
evaluate(codes, x, n, best_objs.data());
|
|
581
|
+
|
|
582
|
+
FAISS_THROW_IF_NOT(nperts <= M);
|
|
583
|
+
for (size_t iter1 = 0; iter1 < ils_iters; iter1++) {
|
|
584
|
+
// add perturbation to codes
|
|
585
|
+
perturb_codes(codes, n, gen);
|
|
586
|
+
|
|
587
|
+
icm_encode_step(codes, unaries.data(), binaries, n, icm_iters);
|
|
588
|
+
|
|
589
|
+
std::vector<float> icm_objs(n, 0.0f);
|
|
590
|
+
evaluate(codes, x, n, icm_objs.data());
|
|
591
|
+
size_t n_betters = 0;
|
|
592
|
+
float mean_obj = 0.0f;
|
|
593
|
+
|
|
594
|
+
// select the best code for every vector xi
|
|
595
|
+
#pragma omp parallel for reduction(+ : n_betters, mean_obj)
|
|
596
|
+
for (int64_t i = 0; i < n; i++) {
|
|
597
|
+
if (icm_objs[i] < best_objs[i]) {
|
|
598
|
+
best_objs[i] = icm_objs[i];
|
|
599
|
+
memcpy(best_codes.data() + i * M,
|
|
600
|
+
codes + i * M,
|
|
601
|
+
sizeof(int32_t) * M);
|
|
602
|
+
n_betters += 1;
|
|
603
|
+
}
|
|
604
|
+
mean_obj += best_objs[i];
|
|
605
|
+
}
|
|
606
|
+
mean_obj /= n;
|
|
607
|
+
|
|
608
|
+
memcpy(codes, best_codes.data(), sizeof(int32_t) * n * M);
|
|
609
|
+
|
|
610
|
+
if (verbose) {
|
|
611
|
+
printf("\tils_iter %zd: obj = %lf, n_betters/n = %zd/%zd\n",
|
|
612
|
+
iter1,
|
|
613
|
+
mean_obj,
|
|
614
|
+
n_betters,
|
|
615
|
+
n);
|
|
616
|
+
}
|
|
617
|
+
} // loop ils_iters
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
void LocalSearchQuantizer::icm_encode_step(
|
|
621
|
+
int32_t* codes,
|
|
622
|
+
const float* unaries,
|
|
623
|
+
const float* binaries,
|
|
624
|
+
size_t n,
|
|
625
|
+
size_t n_iters) const {
|
|
626
|
+
FAISS_THROW_IF_NOT(M != 0 && K != 0);
|
|
627
|
+
FAISS_THROW_IF_NOT(binaries != nullptr);
|
|
628
|
+
|
|
629
|
+
for (size_t iter = 0; iter < n_iters; iter++) {
|
|
630
|
+
// condition on the m-th subcode
|
|
631
|
+
for (size_t m = 0; m < M; m++) {
|
|
632
|
+
std::vector<float> objs(n * K);
|
|
633
|
+
#pragma omp parallel for
|
|
634
|
+
for (int64_t i = 0; i < n; i++) {
|
|
635
|
+
auto u = unaries + m * n * K + i * K;
|
|
636
|
+
memcpy(objs.data() + i * K, u, sizeof(float) * K);
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
// compute objective function by adding unary
|
|
640
|
+
// and binary terms together
|
|
641
|
+
for (size_t other_m = 0; other_m < M; other_m++) {
|
|
642
|
+
if (other_m == m) {
|
|
643
|
+
continue;
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
#pragma omp parallel for
|
|
647
|
+
for (int64_t i = 0; i < n; i++) {
|
|
648
|
+
for (int32_t code = 0; code < K; code++) {
|
|
649
|
+
int32_t code2 = codes[i * M + other_m];
|
|
650
|
+
size_t binary_idx = m * M * K * K + other_m * K * K +
|
|
651
|
+
code * K + code2;
|
|
652
|
+
// binaries[m, other_m, code, code2]
|
|
653
|
+
objs[i * K + code] += binaries[binary_idx];
|
|
654
|
+
}
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
// find the optimal value of the m-th subcode
|
|
659
|
+
#pragma omp parallel for
|
|
660
|
+
for (int64_t i = 0; i < n; i++) {
|
|
661
|
+
float best_obj = HUGE_VALF;
|
|
662
|
+
int32_t best_code = 0;
|
|
663
|
+
for (size_t code = 0; code < K; code++) {
|
|
664
|
+
float obj = objs[i * K + code];
|
|
665
|
+
if (obj < best_obj) {
|
|
666
|
+
best_obj = obj;
|
|
667
|
+
best_code = code;
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
codes[i * M + m] = best_code;
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
} // loop M
|
|
674
|
+
}
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
void LocalSearchQuantizer::perturb_codes(
|
|
678
|
+
int32_t* codes,
|
|
679
|
+
size_t n,
|
|
680
|
+
std::mt19937& gen) const {
|
|
681
|
+
LSQTimerScope scope(&lsq_timer, "perturb_codes");
|
|
682
|
+
|
|
683
|
+
std::uniform_int_distribution<size_t> m_distrib(0, M - 1);
|
|
684
|
+
std::uniform_int_distribution<int32_t> k_distrib(0, K - 1);
|
|
685
|
+
|
|
686
|
+
for (size_t i = 0; i < n; i++) {
|
|
687
|
+
for (size_t j = 0; j < nperts; j++) {
|
|
688
|
+
size_t m = m_distrib(gen);
|
|
689
|
+
codes[i * M + m] = k_distrib(gen);
|
|
690
|
+
}
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
void LocalSearchQuantizer::compute_binary_terms(float* binaries) const {
|
|
695
|
+
LSQTimerScope scope(&lsq_timer, "compute_binary_terms");
|
|
696
|
+
|
|
697
|
+
#pragma omp parallel for
|
|
698
|
+
for (int64_t m12 = 0; m12 < M * M; m12++) {
|
|
699
|
+
size_t m1 = m12 / M;
|
|
700
|
+
size_t m2 = m12 % M;
|
|
701
|
+
|
|
702
|
+
for (size_t code1 = 0; code1 < K; code1++) {
|
|
703
|
+
for (size_t code2 = 0; code2 < K; code2++) {
|
|
704
|
+
const float* c1 = codebooks.data() + m1 * K * d + code1 * d;
|
|
705
|
+
const float* c2 = codebooks.data() + m2 * K * d + code2 * d;
|
|
706
|
+
float ip = fvec_inner_product(c1, c2, d);
|
|
707
|
+
// binaries[m1, m2, code1, code2] = ip * 2
|
|
708
|
+
binaries[m1 * M * K * K + m2 * K * K + code1 * K + code2] =
|
|
709
|
+
ip * 2;
|
|
710
|
+
}
|
|
711
|
+
}
|
|
712
|
+
}
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
void LocalSearchQuantizer::compute_unary_terms(
|
|
716
|
+
const float* x,
|
|
717
|
+
float* unaries, // [M, n, K]
|
|
718
|
+
size_t n) const {
|
|
719
|
+
LSQTimerScope scope(&lsq_timer, "compute_unary_terms");
|
|
720
|
+
|
|
721
|
+
// compute x * codebook^T for each codebook
|
|
722
|
+
//
|
|
723
|
+
// NOTE: LAPACK use column major order
|
|
724
|
+
// out = alpha * op(A) * op(B) + beta * C
|
|
725
|
+
|
|
726
|
+
for (size_t m = 0; m < M; m++) {
|
|
727
|
+
FINTEGER nrows_A = K;
|
|
728
|
+
FINTEGER ncols_A = d;
|
|
729
|
+
|
|
730
|
+
FINTEGER nrows_B = d;
|
|
731
|
+
FINTEGER ncols_B = n;
|
|
732
|
+
|
|
733
|
+
float alpha = -2.0f;
|
|
734
|
+
float beta = 0.0f;
|
|
735
|
+
sgemm_("Transposed",
|
|
736
|
+
"Not Transposed",
|
|
737
|
+
&nrows_A, // nrows of op(A)
|
|
738
|
+
&ncols_B, // ncols of op(B)
|
|
739
|
+
&ncols_A, // ncols of op(A)
|
|
740
|
+
&alpha,
|
|
741
|
+
codebooks.data() + m * K * d,
|
|
742
|
+
&ncols_A, // nrows of A
|
|
743
|
+
x,
|
|
744
|
+
&nrows_B, // nrows of B
|
|
745
|
+
&beta,
|
|
746
|
+
unaries + m * n * K,
|
|
747
|
+
&nrows_A); // nrows of output
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
std::vector<float> norms(M * K);
|
|
751
|
+
fvec_norms_L2sqr(norms.data(), codebooks.data(), d, M * K);
|
|
752
|
+
|
|
753
|
+
#pragma omp parallel for
|
|
754
|
+
for (int64_t i = 0; i < n; i++) {
|
|
755
|
+
for (size_t m = 0; m < M; m++) {
|
|
756
|
+
float* u = unaries + m * n * K + i * K;
|
|
757
|
+
fvec_add(K, u, norms.data() + m * K, u);
|
|
758
|
+
}
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
float LocalSearchQuantizer::evaluate(
|
|
763
|
+
const int32_t* codes,
|
|
764
|
+
const float* x,
|
|
765
|
+
size_t n,
|
|
766
|
+
float* objs) const {
|
|
767
|
+
LSQTimerScope scope(&lsq_timer, "evaluate");
|
|
768
|
+
|
|
769
|
+
// decode
|
|
770
|
+
std::vector<float> decoded_x(n * d, 0.0f);
|
|
771
|
+
float obj = 0.0f;
|
|
772
|
+
|
|
773
|
+
#pragma omp parallel for reduction(+ : obj)
|
|
774
|
+
for (int64_t i = 0; i < n; i++) {
|
|
775
|
+
const auto code = codes + i * M;
|
|
776
|
+
const auto decoded_i = decoded_x.data() + i * d;
|
|
777
|
+
for (size_t m = 0; m < M; m++) {
|
|
778
|
+
// c = codebooks[m, code[m]]
|
|
779
|
+
const auto c = codebooks.data() + m * K * d + code[m] * d;
|
|
780
|
+
fvec_add(d, decoded_i, c, decoded_i);
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
float err = faiss::fvec_L2sqr(x + i * d, decoded_i, d);
|
|
784
|
+
obj += err;
|
|
785
|
+
|
|
786
|
+
if (objs) {
|
|
787
|
+
objs[i] = err;
|
|
788
|
+
}
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
obj = obj / n;
|
|
792
|
+
return obj;
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
namespace lsq {
|
|
796
|
+
|
|
797
|
+
IcmEncoder::IcmEncoder(const LocalSearchQuantizer* lsq)
|
|
798
|
+
: verbose(false), lsq(lsq) {}
|
|
799
|
+
|
|
800
|
+
void IcmEncoder::set_binary_term() {
|
|
801
|
+
auto M = lsq->M;
|
|
802
|
+
auto K = lsq->K;
|
|
803
|
+
binaries.resize(M * M * K * K);
|
|
804
|
+
lsq->compute_binary_terms(binaries.data());
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
void IcmEncoder::encode(
|
|
808
|
+
int32_t* codes,
|
|
809
|
+
const float* x,
|
|
810
|
+
std::mt19937& gen,
|
|
811
|
+
size_t n,
|
|
812
|
+
size_t ils_iters) const {
|
|
813
|
+
lsq->icm_encode_impl(codes, x, binaries.data(), gen, n, ils_iters, verbose);
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
double LSQTimer::get(const std::string& name) {
|
|
817
|
+
if (t.count(name) == 0) {
|
|
818
|
+
return 0.0;
|
|
819
|
+
} else {
|
|
820
|
+
return t[name];
|
|
821
|
+
}
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
void LSQTimer::add(const std::string& name, double delta) {
|
|
825
|
+
if (t.count(name) == 0) {
|
|
826
|
+
t[name] = delta;
|
|
827
|
+
} else {
|
|
828
|
+
t[name] += delta;
|
|
829
|
+
}
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
void LSQTimer::reset() {
|
|
833
|
+
t.clear();
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
LSQTimerScope::LSQTimerScope(LSQTimer* timer, std::string name)
|
|
837
|
+
: timer(timer), name(name), finished(false) {
|
|
838
|
+
t0 = getmillisecs();
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
void LSQTimerScope::finish() {
|
|
842
|
+
if (!finished) {
|
|
843
|
+
auto delta = getmillisecs() - t0;
|
|
844
|
+
timer->add(name, delta);
|
|
845
|
+
finished = true;
|
|
846
|
+
}
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
LSQTimerScope::~LSQTimerScope() {
|
|
850
|
+
finish();
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
} // namespace lsq
|
|
854
|
+
|
|
855
|
+
} // namespace faiss
|