faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/IndexLSH.h>
|
|
11
9
|
|
|
12
10
|
#include <cstdio>
|
|
@@ -14,10 +12,9 @@
|
|
|
14
12
|
|
|
15
13
|
#include <algorithm>
|
|
16
14
|
|
|
17
|
-
#include <faiss/utils/utils.h>
|
|
18
|
-
#include <faiss/utils/hamming.h>
|
|
19
15
|
#include <faiss/impl/FaissAssert.h>
|
|
20
|
-
|
|
16
|
+
#include <faiss/utils/hamming.h>
|
|
17
|
+
#include <faiss/utils/utils.h>
|
|
21
18
|
|
|
22
19
|
namespace faiss {
|
|
23
20
|
|
|
@@ -25,143 +22,117 @@ namespace faiss {
|
|
|
25
22
|
* IndexLSH
|
|
26
23
|
***************************************************************/
|
|
27
24
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
25
|
+
IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
|
|
26
|
+
: IndexFlatCodes((nbits + 7) / 8, d),
|
|
27
|
+
nbits(nbits),
|
|
28
|
+
rotate_data(rotate_data),
|
|
29
|
+
train_thresholds(train_thresholds),
|
|
30
|
+
rrot(d, nbits) {
|
|
33
31
|
is_trained = !train_thresholds;
|
|
34
32
|
|
|
35
|
-
bytes_per_vec = (nbits + 7) / 8;
|
|
36
|
-
|
|
37
33
|
if (rotate_data) {
|
|
38
34
|
rrot.init(5);
|
|
39
35
|
} else {
|
|
40
|
-
FAISS_THROW_IF_NOT
|
|
36
|
+
FAISS_THROW_IF_NOT(d >= nbits);
|
|
41
37
|
}
|
|
42
38
|
}
|
|
43
39
|
|
|
44
|
-
IndexLSH::IndexLSH ()
|
|
45
|
-
nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
|
|
46
|
-
{
|
|
47
|
-
}
|
|
48
|
-
|
|
40
|
+
IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
|
|
49
41
|
|
|
50
|
-
const float
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
float *xt = nullptr;
|
|
42
|
+
const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
|
|
43
|
+
float* xt = nullptr;
|
|
54
44
|
if (rotate_data) {
|
|
55
45
|
// also applies bias if exists
|
|
56
|
-
xt = rrot.apply
|
|
46
|
+
xt = rrot.apply(n, x);
|
|
57
47
|
} else if (d != nbits) {
|
|
58
|
-
assert
|
|
59
|
-
xt = new float
|
|
60
|
-
float
|
|
48
|
+
assert(nbits < d);
|
|
49
|
+
xt = new float[nbits * n];
|
|
50
|
+
float* xp = xt;
|
|
61
51
|
for (idx_t i = 0; i < n; i++) {
|
|
62
|
-
const float
|
|
52
|
+
const float* xl = x + i * d;
|
|
63
53
|
for (int j = 0; j < nbits; j++)
|
|
64
|
-
*xp++ = xl
|
|
54
|
+
*xp++ = xl[j];
|
|
65
55
|
}
|
|
66
56
|
}
|
|
67
57
|
|
|
68
58
|
if (train_thresholds) {
|
|
69
|
-
|
|
70
59
|
if (xt == NULL) {
|
|
71
|
-
xt = new float
|
|
72
|
-
memcpy
|
|
60
|
+
xt = new float[nbits * n];
|
|
61
|
+
memcpy(xt, x, sizeof(*x) * n * nbits);
|
|
73
62
|
}
|
|
74
63
|
|
|
75
|
-
float
|
|
64
|
+
float* xp = xt;
|
|
76
65
|
for (idx_t i = 0; i < n; i++)
|
|
77
66
|
for (int j = 0; j < nbits; j++)
|
|
78
|
-
*xp++ -= thresholds
|
|
67
|
+
*xp++ -= thresholds[j];
|
|
79
68
|
}
|
|
80
69
|
|
|
81
70
|
return xt ? xt : x;
|
|
82
71
|
}
|
|
83
72
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
void IndexLSH::train (idx_t n, const float *x)
|
|
87
|
-
{
|
|
73
|
+
void IndexLSH::train(idx_t n, const float* x) {
|
|
88
74
|
if (train_thresholds) {
|
|
89
|
-
thresholds.resize
|
|
75
|
+
thresholds.resize(nbits);
|
|
90
76
|
train_thresholds = false;
|
|
91
|
-
const float
|
|
92
|
-
ScopeDeleter<float> del
|
|
77
|
+
const float* xt = apply_preprocess(n, x);
|
|
78
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
93
79
|
train_thresholds = true;
|
|
94
80
|
|
|
95
|
-
float
|
|
96
|
-
ScopeDeleter<float> del2
|
|
81
|
+
float* transposed_x = new float[n * nbits];
|
|
82
|
+
ScopeDeleter<float> del2(transposed_x);
|
|
97
83
|
|
|
98
84
|
for (idx_t i = 0; i < n; i++)
|
|
99
85
|
for (idx_t j = 0; j < nbits; j++)
|
|
100
|
-
transposed_x
|
|
86
|
+
transposed_x[j * n + i] = xt[i * nbits + j];
|
|
101
87
|
|
|
102
88
|
for (idx_t i = 0; i < nbits; i++) {
|
|
103
|
-
float
|
|
89
|
+
float* xi = transposed_x + i * n;
|
|
104
90
|
// std::nth_element
|
|
105
|
-
std::sort
|
|
91
|
+
std::sort(xi, xi + n);
|
|
106
92
|
if (n % 2 == 1)
|
|
107
|
-
thresholds
|
|
93
|
+
thresholds[i] = xi[n / 2];
|
|
108
94
|
else
|
|
109
|
-
thresholds
|
|
110
|
-
|
|
95
|
+
thresholds[i] = (xi[n / 2 - 1] + xi[n / 2]) / 2;
|
|
111
96
|
}
|
|
112
97
|
}
|
|
113
98
|
is_trained = true;
|
|
114
99
|
}
|
|
115
100
|
|
|
116
|
-
|
|
117
|
-
void IndexLSH::add (idx_t n, const float *x)
|
|
118
|
-
{
|
|
119
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
120
|
-
codes.resize ((ntotal + n) * bytes_per_vec);
|
|
121
|
-
|
|
122
|
-
sa_encode (n, x, &codes[ntotal * bytes_per_vec]);
|
|
123
|
-
|
|
124
|
-
ntotal += n;
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
void IndexLSH::search (
|
|
101
|
+
void IndexLSH::search(
|
|
129
102
|
idx_t n,
|
|
130
|
-
const float
|
|
103
|
+
const float* x,
|
|
131
104
|
idx_t k,
|
|
132
|
-
float
|
|
133
|
-
idx_t
|
|
134
|
-
|
|
135
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
136
|
-
const float *xt = apply_preprocess (n, x);
|
|
137
|
-
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
|
105
|
+
float* distances,
|
|
106
|
+
idx_t* labels) const {
|
|
107
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
138
108
|
|
|
139
|
-
|
|
140
|
-
|
|
109
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
110
|
+
const float* xt = apply_preprocess(n, x);
|
|
111
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
141
112
|
|
|
142
|
-
|
|
113
|
+
uint8_t* qcodes = new uint8_t[n * code_size];
|
|
114
|
+
ScopeDeleter<uint8_t> del2(qcodes);
|
|
143
115
|
|
|
144
|
-
|
|
145
|
-
ScopeDeleter<int> del3 (idistances);
|
|
116
|
+
fvecs2bitvecs(xt, qcodes, nbits, n);
|
|
146
117
|
|
|
147
|
-
|
|
118
|
+
int* idistances = new int[n * k];
|
|
119
|
+
ScopeDeleter<int> del3(idistances);
|
|
148
120
|
|
|
149
|
-
|
|
150
|
-
ntotal, bytes_per_vec, true);
|
|
121
|
+
int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
|
|
151
122
|
|
|
123
|
+
hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
|
|
152
124
|
|
|
153
125
|
// convert distances to floats
|
|
154
126
|
for (int i = 0; i < k * n; i++)
|
|
155
127
|
distances[i] = idistances[i];
|
|
156
|
-
|
|
157
128
|
}
|
|
158
129
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
FAISS_THROW_IF_NOT
|
|
130
|
+
void IndexLSH::transfer_thresholds(LinearTransform* vt) {
|
|
131
|
+
if (!train_thresholds)
|
|
132
|
+
return;
|
|
133
|
+
FAISS_THROW_IF_NOT(nbits == vt->d_out);
|
|
163
134
|
if (!vt->have_bias) {
|
|
164
|
-
vt->b.resize
|
|
135
|
+
vt->b.resize(nbits, 0);
|
|
165
136
|
vt->have_bias = true;
|
|
166
137
|
}
|
|
167
138
|
for (int i = 0; i < nbits; i++)
|
|
@@ -170,56 +141,38 @@ void IndexLSH::transfer_thresholds (LinearTransform *vt) {
|
|
|
170
141
|
thresholds.clear();
|
|
171
142
|
}
|
|
172
143
|
|
|
173
|
-
void IndexLSH::
|
|
174
|
-
|
|
175
|
-
|
|
144
|
+
void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
|
145
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
146
|
+
const float* xt = apply_preprocess(n, x);
|
|
147
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
148
|
+
fvecs2bitvecs(xt, bytes, nbits, n);
|
|
176
149
|
}
|
|
177
150
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
{
|
|
181
|
-
return bytes_per_vec;
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
void IndexLSH::sa_encode (idx_t n, const float *x,
|
|
185
|
-
uint8_t *bytes) const
|
|
186
|
-
{
|
|
187
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
188
|
-
const float *xt = apply_preprocess (n, x);
|
|
189
|
-
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
|
190
|
-
fvecs2bitvecs (xt, bytes, nbits, n);
|
|
191
|
-
}
|
|
192
|
-
|
|
193
|
-
void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes,
|
|
194
|
-
float *x) const
|
|
195
|
-
{
|
|
196
|
-
float *xt = x;
|
|
151
|
+
void IndexLSH::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
152
|
+
float* xt = x;
|
|
197
153
|
ScopeDeleter<float> del;
|
|
198
154
|
if (rotate_data || nbits != d) {
|
|
199
|
-
xt = new float
|
|
155
|
+
xt = new float[n * nbits];
|
|
200
156
|
del.set(xt);
|
|
201
157
|
}
|
|
202
|
-
bitvecs2fvecs
|
|
158
|
+
bitvecs2fvecs(bytes, xt, nbits, n);
|
|
203
159
|
|
|
204
160
|
if (train_thresholds) {
|
|
205
|
-
float
|
|
161
|
+
float* xp = xt;
|
|
206
162
|
for (idx_t i = 0; i < n; i++) {
|
|
207
163
|
for (int j = 0; j < nbits; j++) {
|
|
208
|
-
*xp++ += thresholds
|
|
164
|
+
*xp++ += thresholds[j];
|
|
209
165
|
}
|
|
210
166
|
}
|
|
211
167
|
}
|
|
212
168
|
|
|
213
169
|
if (rotate_data) {
|
|
214
|
-
rrot.reverse_transform
|
|
170
|
+
rrot.reverse_transform(n, xt, x);
|
|
215
171
|
} else if (nbits != d) {
|
|
216
172
|
for (idx_t i = 0; i < n; i++) {
|
|
217
|
-
memcpy
|
|
218
|
-
nbits * sizeof(xt[0]));
|
|
173
|
+
memcpy(x + i * d, xt + i * nbits, nbits * sizeof(xt[0]));
|
|
219
174
|
}
|
|
220
175
|
}
|
|
221
176
|
}
|
|
222
177
|
|
|
223
|
-
|
|
224
|
-
|
|
225
178
|
} // namespace faiss
|
|
@@ -12,30 +12,24 @@
|
|
|
12
12
|
|
|
13
13
|
#include <vector>
|
|
14
14
|
|
|
15
|
-
#include <faiss/
|
|
15
|
+
#include <faiss/IndexFlatCodes.h>
|
|
16
16
|
#include <faiss/VectorTransform.h>
|
|
17
17
|
|
|
18
18
|
namespace faiss {
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
/** The sign of each vector component is put in a binary signature */
|
|
22
|
-
struct IndexLSH:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
int bytes_per_vec; ///< nb of 8-bits per encoded vector
|
|
27
|
-
bool rotate_data; ///< whether to apply a random rotation to input
|
|
28
|
-
bool train_thresholds; ///< whether we train thresholds or use 0
|
|
21
|
+
struct IndexLSH : IndexFlatCodes {
|
|
22
|
+
int nbits; ///< nb of bits per vector
|
|
23
|
+
bool rotate_data; ///< whether to apply a random rotation to input
|
|
24
|
+
bool train_thresholds; ///< whether we train thresholds or use 0
|
|
29
25
|
|
|
30
26
|
RandomRotationMatrix rrot; ///< optional random rotation
|
|
31
27
|
|
|
32
|
-
std::vector
|
|
33
|
-
|
|
34
|
-
/// encoded dataset
|
|
35
|
-
std::vector<uint8_t> codes;
|
|
28
|
+
std::vector<float> thresholds; ///< thresholds to compare with
|
|
36
29
|
|
|
37
|
-
IndexLSH
|
|
38
|
-
idx_t d,
|
|
30
|
+
IndexLSH(
|
|
31
|
+
idx_t d,
|
|
32
|
+
int nbits,
|
|
39
33
|
bool rotate_data = true,
|
|
40
34
|
bool train_thresholds = false);
|
|
41
35
|
|
|
@@ -46,45 +40,33 @@ struct IndexLSH:Index {
|
|
|
46
40
|
* @return output vectors, size n * bits. May be the same pointer
|
|
47
41
|
* as x, otherwise it should be deleted by the caller
|
|
48
42
|
*/
|
|
49
|
-
const float
|
|
43
|
+
const float* apply_preprocess(idx_t n, const float* x) const;
|
|
50
44
|
|
|
51
45
|
void train(idx_t n, const float* x) override;
|
|
52
46
|
|
|
53
|
-
void add(idx_t n, const float* x) override;
|
|
54
|
-
|
|
55
47
|
void search(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
void reset() override;
|
|
48
|
+
idx_t n,
|
|
49
|
+
const float* x,
|
|
50
|
+
idx_t k,
|
|
51
|
+
float* distances,
|
|
52
|
+
idx_t* labels) const override;
|
|
63
53
|
|
|
64
54
|
/// transfer the thresholds to a pre-processing stage (and unset
|
|
65
55
|
/// train_thresholds)
|
|
66
|
-
void transfer_thresholds
|
|
56
|
+
void transfer_thresholds(LinearTransform* vt);
|
|
67
57
|
|
|
68
58
|
~IndexLSH() override {}
|
|
69
59
|
|
|
70
|
-
IndexLSH
|
|
60
|
+
IndexLSH();
|
|
71
61
|
|
|
72
62
|
/* standalone codec interface.
|
|
73
63
|
*
|
|
74
64
|
* The vectors are decoded to +/- 1 (not 0, 1) */
|
|
65
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
75
66
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
void sa_encode (idx_t n, const float *x,
|
|
79
|
-
uint8_t *bytes) const override;
|
|
80
|
-
|
|
81
|
-
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
82
|
-
float *x) const override;
|
|
83
|
-
|
|
67
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
84
68
|
};
|
|
85
69
|
|
|
86
|
-
|
|
87
|
-
}
|
|
88
|
-
|
|
70
|
+
} // namespace faiss
|
|
89
71
|
|
|
90
72
|
#endif
|
|
@@ -7,26 +7,23 @@
|
|
|
7
7
|
|
|
8
8
|
// -*- c++ -*-
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
#include <faiss/IndexLattice.h>
|
|
12
|
-
#include <faiss/utils/hamming.h> // for the bitstring routines
|
|
13
11
|
#include <faiss/impl/FaissAssert.h>
|
|
14
12
|
#include <faiss/utils/distances.h>
|
|
13
|
+
#include <faiss/utils/hamming.h> // for the bitstring routines
|
|
15
14
|
|
|
16
15
|
namespace faiss {
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
{
|
|
26
|
-
FAISS_THROW_IF_NOT (d % nsq == 0);
|
|
17
|
+
IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
|
|
18
|
+
: Index(d),
|
|
19
|
+
nsq(nsq),
|
|
20
|
+
dsq(d / nsq),
|
|
21
|
+
zn_sphere_codec(dsq, r2),
|
|
22
|
+
scale_nbit(scale_nbit) {
|
|
23
|
+
FAISS_THROW_IF_NOT(d % nsq == 0);
|
|
27
24
|
|
|
28
25
|
lattice_nbit = 0;
|
|
29
|
-
while (!(
|
|
26
|
+
while (!(((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
|
|
30
27
|
lattice_nbit++;
|
|
31
28
|
}
|
|
32
29
|
|
|
@@ -37,12 +34,11 @@ IndexLattice::IndexLattice (idx_t d, int nsq, int scale_nbit, int r2):
|
|
|
37
34
|
is_trained = false;
|
|
38
35
|
}
|
|
39
36
|
|
|
40
|
-
void IndexLattice::train(idx_t n, const float* x)
|
|
41
|
-
{
|
|
37
|
+
void IndexLattice::train(idx_t n, const float* x) {
|
|
42
38
|
// compute ranges per sub-block
|
|
43
|
-
trained.resize
|
|
44
|
-
float
|
|
45
|
-
float
|
|
39
|
+
trained.resize(nsq * 2);
|
|
40
|
+
float* mins = trained.data();
|
|
41
|
+
float* maxs = trained.data() + nsq;
|
|
46
42
|
for (int sq = 0; sq < nsq; sq++) {
|
|
47
43
|
mins[sq] = HUGE_VAL;
|
|
48
44
|
maxs[sq] = -1;
|
|
@@ -50,45 +46,43 @@ void IndexLattice::train(idx_t n, const float* x)
|
|
|
50
46
|
|
|
51
47
|
for (idx_t i = 0; i < n; i++) {
|
|
52
48
|
for (int sq = 0; sq < nsq; sq++) {
|
|
53
|
-
float norm2 = fvec_norm_L2sqr
|
|
54
|
-
if (norm2 > maxs[sq])
|
|
55
|
-
|
|
49
|
+
float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
|
|
50
|
+
if (norm2 > maxs[sq])
|
|
51
|
+
maxs[sq] = norm2;
|
|
52
|
+
if (norm2 < mins[sq])
|
|
53
|
+
mins[sq] = norm2;
|
|
56
54
|
}
|
|
57
55
|
}
|
|
58
56
|
|
|
59
57
|
for (int sq = 0; sq < nsq; sq++) {
|
|
60
|
-
mins[sq] = sqrtf
|
|
61
|
-
maxs[sq] = sqrtf
|
|
58
|
+
mins[sq] = sqrtf(mins[sq]);
|
|
59
|
+
maxs[sq] = sqrtf(maxs[sq]);
|
|
62
60
|
}
|
|
63
61
|
|
|
64
62
|
is_trained = true;
|
|
65
63
|
}
|
|
66
64
|
|
|
67
65
|
/* The standalone codec interface */
|
|
68
|
-
size_t IndexLattice::sa_code_size
|
|
69
|
-
{
|
|
66
|
+
size_t IndexLattice::sa_code_size() const {
|
|
70
67
|
return code_size;
|
|
71
68
|
}
|
|
72
69
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
{
|
|
77
|
-
|
|
78
|
-
const float * mins = trained.data();
|
|
79
|
-
const float * maxs = mins + nsq;
|
|
70
|
+
void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
|
|
71
|
+
const float* mins = trained.data();
|
|
72
|
+
const float* maxs = mins + nsq;
|
|
80
73
|
int64_t sc = int64_t(1) << scale_nbit;
|
|
81
74
|
|
|
82
75
|
#pragma omp parallel for
|
|
83
76
|
for (idx_t i = 0; i < n; i++) {
|
|
84
77
|
BitstringWriter wr(codes + i * code_size, code_size);
|
|
85
|
-
const float
|
|
78
|
+
const float* xi = x + i * d;
|
|
86
79
|
for (int j = 0; j < nsq; j++) {
|
|
87
|
-
float nj =
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
if (nj >= sc)
|
|
80
|
+
float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
|
|
81
|
+
(maxs[j] - mins[j]);
|
|
82
|
+
if (nj < 0)
|
|
83
|
+
nj = 0;
|
|
84
|
+
if (nj >= sc)
|
|
85
|
+
nj = sc - 1;
|
|
92
86
|
wr.write((int64_t)nj, scale_nbit);
|
|
93
87
|
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
|
|
94
88
|
xi += dsq;
|
|
@@ -96,23 +90,22 @@ void IndexLattice::sa_encode (idx_t n, const float *x, uint8_t *codes) const
|
|
|
96
90
|
}
|
|
97
91
|
}
|
|
98
92
|
|
|
99
|
-
void IndexLattice::sa_decode
|
|
100
|
-
|
|
101
|
-
const float
|
|
102
|
-
const float * maxs = mins + nsq;
|
|
93
|
+
void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
|
94
|
+
const float* mins = trained.data();
|
|
95
|
+
const float* maxs = mins + nsq;
|
|
103
96
|
float sc = int64_t(1) << scale_nbit;
|
|
104
97
|
float r = sqrtf(zn_sphere_codec.r2);
|
|
105
98
|
|
|
106
99
|
#pragma omp parallel for
|
|
107
100
|
for (idx_t i = 0; i < n; i++) {
|
|
108
101
|
BitstringReader rd(codes + i * code_size, code_size);
|
|
109
|
-
float
|
|
102
|
+
float* xi = x + i * d;
|
|
110
103
|
for (int j = 0; j < nsq; j++) {
|
|
111
104
|
float norm =
|
|
112
|
-
|
|
113
|
-
|
|
105
|
+
(rd.read(scale_nbit) + 0.5) * (maxs[j] - mins[j]) / sc +
|
|
106
|
+
mins[j];
|
|
114
107
|
norm /= r;
|
|
115
|
-
zn_sphere_codec.decode
|
|
108
|
+
zn_sphere_codec.decode(rd.read(lattice_nbit), xi);
|
|
116
109
|
for (int l = 0; l < dsq; l++) {
|
|
117
110
|
xi[l] *= norm;
|
|
118
111
|
}
|
|
@@ -121,23 +114,16 @@ void IndexLattice::sa_decode (idx_t n, const uint8_t *codes, float *x) const
|
|
|
121
114
|
}
|
|
122
115
|
}
|
|
123
116
|
|
|
124
|
-
void IndexLattice::add(idx_t
|
|
125
|
-
{
|
|
117
|
+
void IndexLattice::add(idx_t, const float*) {
|
|
126
118
|
FAISS_THROW_MSG("not implemented");
|
|
127
119
|
}
|
|
128
120
|
|
|
129
|
-
|
|
130
|
-
void IndexLattice::search(idx_t , const float* , idx_t ,
|
|
131
|
-
float* , idx_t* ) const
|
|
132
|
-
{
|
|
121
|
+
void IndexLattice::search(idx_t, const float*, idx_t, float*, idx_t*) const {
|
|
133
122
|
FAISS_THROW_MSG("not implemented");
|
|
134
123
|
}
|
|
135
124
|
|
|
136
|
-
|
|
137
|
-
void IndexLattice::reset()
|
|
138
|
-
{
|
|
125
|
+
void IndexLattice::reset() {
|
|
139
126
|
FAISS_THROW_MSG("not implemented");
|
|
140
127
|
}
|
|
141
128
|
|
|
142
|
-
|
|
143
|
-
} // namespace faiss
|
|
129
|
+
} // namespace faiss
|
|
@@ -10,7 +10,6 @@
|
|
|
10
10
|
#ifndef FAISS_INDEX_LATTICE_H
|
|
11
11
|
#define FAISS_INDEX_LATTICE_H
|
|
12
12
|
|
|
13
|
-
|
|
14
13
|
#include <vector>
|
|
15
14
|
|
|
16
15
|
#include <faiss/IndexIVF.h>
|
|
@@ -18,14 +17,9 @@
|
|
|
18
17
|
|
|
19
18
|
namespace faiss {
|
|
20
19
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
20
|
/** Index that encodes a vector with a series of Zn lattice quantizers
|
|
26
21
|
*/
|
|
27
|
-
struct IndexLattice: Index {
|
|
28
|
-
|
|
22
|
+
struct IndexLattice : Index {
|
|
29
23
|
/// number of sub-vectors
|
|
30
24
|
int nsq;
|
|
31
25
|
/// dimension of sub-vectors
|
|
@@ -42,25 +36,26 @@ struct IndexLattice: Index {
|
|
|
42
36
|
/// mins and maxes of the vector norms, per subquantizer
|
|
43
37
|
std::vector<float> trained;
|
|
44
38
|
|
|
45
|
-
IndexLattice
|
|
39
|
+
IndexLattice(idx_t d, int nsq, int scale_nbit, int r2);
|
|
46
40
|
|
|
47
41
|
void train(idx_t n, const float* x) override;
|
|
48
42
|
|
|
49
43
|
/* The standalone codec interface */
|
|
50
|
-
size_t sa_code_size
|
|
44
|
+
size_t sa_code_size() const override;
|
|
51
45
|
|
|
52
|
-
void sa_encode
|
|
53
|
-
uint8_t *bytes) const override;
|
|
46
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
54
47
|
|
|
55
|
-
void sa_decode
|
|
56
|
-
float *x) const override;
|
|
48
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
57
49
|
|
|
58
50
|
/// not implemented
|
|
59
51
|
void add(idx_t n, const float* x) override;
|
|
60
|
-
void search(
|
|
61
|
-
|
|
52
|
+
void search(
|
|
53
|
+
idx_t n,
|
|
54
|
+
const float* x,
|
|
55
|
+
idx_t k,
|
|
56
|
+
float* distances,
|
|
57
|
+
idx_t* labels) const override;
|
|
62
58
|
void reset() override;
|
|
63
|
-
|
|
64
59
|
};
|
|
65
60
|
|
|
66
61
|
} // namespace faiss
|