faiss 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +20 -2
@@ -15,43 +15,54 @@
|
|
15
15
|
* always called f and thus is not passed in as a macro parameter.
|
16
16
|
**************************************************************/
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
size_t ret = (*f)(ptr, sizeof(*(ptr)), n);
|
21
|
-
FAISS_THROW_IF_NOT_FMT(
|
22
|
-
|
23
|
-
|
18
|
+
#define READANDCHECK(ptr, n) \
|
19
|
+
{ \
|
20
|
+
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
|
21
|
+
FAISS_THROW_IF_NOT_FMT( \
|
22
|
+
ret == (n), \
|
23
|
+
"read error in %s: %zd != %zd (%s)", \
|
24
|
+
f->name.c_str(), \
|
25
|
+
ret, \
|
26
|
+
size_t(n), \
|
27
|
+
strerror(errno)); \
|
24
28
|
}
|
25
29
|
|
26
|
-
#define READ1(x)
|
30
|
+
#define READ1(x) READANDCHECK(&(x), 1)
|
27
31
|
|
28
32
|
// will fail if we write 256G of data at once...
|
29
|
-
#define READVECTOR(vec)
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
#define READSTRING(s)
|
39
|
-
|
40
|
-
|
41
|
-
WRITEANDCHECK
|
33
|
+
#define READVECTOR(vec) \
|
34
|
+
{ \
|
35
|
+
size_t size; \
|
36
|
+
READANDCHECK(&size, 1); \
|
37
|
+
FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \
|
38
|
+
(vec).resize(size); \
|
39
|
+
READANDCHECK((vec).data(), size); \
|
40
|
+
}
|
41
|
+
|
42
|
+
#define READSTRING(s) \
|
43
|
+
{ \
|
44
|
+
size_t size = (s).size(); \
|
45
|
+
WRITEANDCHECK(&size, 1); \
|
46
|
+
WRITEANDCHECK((s).c_str(), size); \
|
42
47
|
}
|
43
48
|
|
44
|
-
#define WRITEANDCHECK(ptr, n)
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
+
#define WRITEANDCHECK(ptr, n) \
|
50
|
+
{ \
|
51
|
+
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
|
52
|
+
FAISS_THROW_IF_NOT_FMT( \
|
53
|
+
ret == (n), \
|
54
|
+
"write error in %s: %zd != %zd (%s)", \
|
55
|
+
f->name.c_str(), \
|
56
|
+
ret, \
|
57
|
+
size_t(n), \
|
58
|
+
strerror(errno)); \
|
49
59
|
}
|
50
60
|
|
51
61
|
#define WRITE1(x) WRITEANDCHECK(&(x), 1)
|
52
62
|
|
53
|
-
#define WRITEVECTOR(vec)
|
54
|
-
|
55
|
-
|
56
|
-
WRITEANDCHECK
|
63
|
+
#define WRITEVECTOR(vec) \
|
64
|
+
{ \
|
65
|
+
size_t size = (vec).size(); \
|
66
|
+
WRITEANDCHECK(&size, 1); \
|
67
|
+
WRITEANDCHECK((vec).data(), size); \
|
57
68
|
}
|
@@ -9,19 +9,18 @@
|
|
9
9
|
|
10
10
|
#include <faiss/impl/lattice_Zn.h>
|
11
11
|
|
12
|
-
#include <
|
12
|
+
#include <cassert>
|
13
13
|
#include <cmath>
|
14
|
+
#include <cstdlib>
|
14
15
|
#include <cstring>
|
15
|
-
#include <cassert>
|
16
16
|
|
17
|
+
#include <algorithm>
|
17
18
|
#include <queue>
|
18
|
-
#include <unordered_set>
|
19
19
|
#include <unordered_map>
|
20
|
-
#include <
|
20
|
+
#include <unordered_set>
|
21
21
|
|
22
|
-
#include <faiss/utils/distances.h>
|
23
22
|
#include <faiss/impl/platform_macros.h>
|
24
|
-
|
23
|
+
#include <faiss/utils/distances.h>
|
25
24
|
|
26
25
|
namespace faiss {
|
27
26
|
|
@@ -35,44 +34,41 @@ inline float sqr(float x) {
|
|
35
34
|
return x * x;
|
36
35
|
}
|
37
36
|
|
38
|
-
|
39
37
|
typedef std::vector<float> point_list_t;
|
40
38
|
|
41
39
|
struct Comb {
|
42
40
|
std::vector<uint64_t> tab; // Pascal's triangle
|
43
41
|
int nmax;
|
44
42
|
|
45
|
-
explicit Comb(int nmax): nmax(nmax) {
|
43
|
+
explicit Comb(int nmax) : nmax(nmax) {
|
46
44
|
tab.resize(nmax * nmax, 0);
|
47
45
|
tab[0] = 1;
|
48
|
-
for(int i = 1; i < nmax; i++) {
|
46
|
+
for (int i = 1; i < nmax; i++) {
|
49
47
|
tab[i * nmax] = 1;
|
50
|
-
for(int j = 1; j <= i; j++) {
|
48
|
+
for (int j = 1; j <= i; j++) {
|
51
49
|
tab[i * nmax + j] =
|
52
|
-
|
53
|
-
tab[(i - 1) * nmax + (j - 1)];
|
50
|
+
tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
|
54
51
|
}
|
55
|
-
|
56
52
|
}
|
57
53
|
}
|
58
54
|
|
59
55
|
uint64_t operator()(int n, int p) const {
|
60
|
-
assert
|
61
|
-
if (p > n)
|
56
|
+
assert(n < nmax && p < nmax);
|
57
|
+
if (p > n)
|
58
|
+
return 0;
|
62
59
|
return tab[n * nmax + p];
|
63
60
|
}
|
64
61
|
};
|
65
62
|
|
66
63
|
Comb comb(100);
|
67
64
|
|
68
|
-
|
69
|
-
|
70
65
|
// compute combinations of n integer values <= v that sum up to total (squared)
|
71
|
-
point_list_t sum_of_sq
|
66
|
+
point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
|
72
67
|
if (total < 0) {
|
73
68
|
return point_list_t();
|
74
69
|
} else if (n == 1) {
|
75
|
-
while (sqr(v + add) > total)
|
70
|
+
while (sqr(v + add) > total)
|
71
|
+
v--;
|
76
72
|
if (sqr(v + add) == total) {
|
77
73
|
return point_list_t(1, v + add);
|
78
74
|
} else {
|
@@ -82,9 +78,9 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
|
82
78
|
point_list_t res;
|
83
79
|
while (v >= 0) {
|
84
80
|
point_list_t sub_points =
|
85
|
-
|
81
|
+
sum_of_sq(total - sqr(v + add), v, n - 1, add);
|
86
82
|
for (size_t i = 0; i < sub_points.size(); i += n - 1) {
|
87
|
-
res.push_back
|
83
|
+
res.push_back(v + add);
|
88
84
|
for (int j = 0; j < n - 1; j++) {
|
89
85
|
res.push_back(sub_points[i + j]);
|
90
86
|
}
|
@@ -95,7 +91,7 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
|
95
91
|
}
|
96
92
|
}
|
97
93
|
|
98
|
-
int decode_comb_1
|
94
|
+
int decode_comb_1(uint64_t* n, int k1, int r) {
|
99
95
|
while (comb(r, k1) > *n) {
|
100
96
|
r--;
|
101
97
|
}
|
@@ -104,10 +100,10 @@ int decode_comb_1 (uint64_t *n, int k1, int r) {
|
|
104
100
|
}
|
105
101
|
|
106
102
|
// optimized version for < 64 bits
|
107
|
-
uint64_t repeats_encode_64
|
108
|
-
|
109
|
-
|
110
|
-
{
|
103
|
+
uint64_t repeats_encode_64(
|
104
|
+
const std::vector<Repeat>& repeats,
|
105
|
+
int dim,
|
106
|
+
const float* c) {
|
111
107
|
uint64_t coded = 0;
|
112
108
|
int nfree = dim;
|
113
109
|
uint64_t code = 0, shift = 1;
|
@@ -115,15 +111,16 @@ uint64_t repeats_encode_64 (
|
|
115
111
|
int rank = 0, occ = 0;
|
116
112
|
uint64_t code_comb = 0;
|
117
113
|
uint64_t tosee = ~coded;
|
118
|
-
for(;;) {
|
114
|
+
for (;;) {
|
119
115
|
// directly jump to next available slot.
|
120
116
|
int i = __builtin_ctzll(tosee);
|
121
|
-
tosee &= ~(uint64_t{1} << i)
|
117
|
+
tosee &= ~(uint64_t{1} << i);
|
122
118
|
if (c[i] == r->val) {
|
123
119
|
code_comb += comb(rank, occ + 1);
|
124
120
|
occ++;
|
125
121
|
coded |= uint64_t{1} << i;
|
126
|
-
if (occ == r->n)
|
122
|
+
if (occ == r->n)
|
123
|
+
break;
|
127
124
|
}
|
128
125
|
rank++;
|
129
126
|
}
|
@@ -135,11 +132,11 @@ uint64_t repeats_encode_64 (
|
|
135
132
|
return code;
|
136
133
|
}
|
137
134
|
|
138
|
-
|
139
135
|
void repeats_decode_64(
|
140
|
-
|
141
|
-
|
142
|
-
|
136
|
+
const std::vector<Repeat>& repeats,
|
137
|
+
int dim,
|
138
|
+
uint64_t code,
|
139
|
+
float* c) {
|
143
140
|
uint64_t decoded = 0;
|
144
141
|
int nfree = dim;
|
145
142
|
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
@@ -149,9 +146,9 @@ void repeats_decode_64(
|
|
149
146
|
|
150
147
|
int occ = 0;
|
151
148
|
int rank = nfree;
|
152
|
-
int next_rank = decode_comb_1
|
149
|
+
int next_rank = decode_comb_1(&code_comb, r->n, rank);
|
153
150
|
uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
|
154
|
-
for(;;) {
|
151
|
+
for (;;) {
|
155
152
|
int i = 63 - __builtin_clzll(tosee);
|
156
153
|
tosee &= ~(uint64_t{1} << i);
|
157
154
|
rank--;
|
@@ -159,25 +156,21 @@ void repeats_decode_64(
|
|
159
156
|
decoded |= uint64_t{1} << i;
|
160
157
|
c[i] = r->val;
|
161
158
|
occ++;
|
162
|
-
if (occ == r->n)
|
163
|
-
|
164
|
-
|
159
|
+
if (occ == r->n)
|
160
|
+
break;
|
161
|
+
next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
|
165
162
|
}
|
166
163
|
}
|
167
164
|
nfree -= r->n;
|
168
165
|
}
|
169
|
-
|
170
166
|
}
|
171
167
|
|
172
|
-
|
173
|
-
|
174
168
|
} // anonymous namespace
|
175
169
|
|
176
|
-
Repeats::Repeats
|
177
|
-
{
|
178
|
-
for(int i = 0; i < dim; i++) {
|
170
|
+
Repeats::Repeats(int dim, const float* c) : dim(dim) {
|
171
|
+
for (int i = 0; i < dim; i++) {
|
179
172
|
int j = 0;
|
180
|
-
for(;;) {
|
173
|
+
for (;;) {
|
181
174
|
if (j == repeats.size()) {
|
182
175
|
repeats.push_back(Repeat{c[i], 1});
|
183
176
|
break;
|
@@ -191,9 +184,7 @@ Repeats::Repeats (int dim, const float *c): dim(dim)
|
|
191
184
|
}
|
192
185
|
}
|
193
186
|
|
194
|
-
|
195
|
-
uint64_t Repeats::count () const
|
196
|
-
{
|
187
|
+
uint64_t Repeats::count() const {
|
197
188
|
uint64_t accu = 1;
|
198
189
|
int remain = dim;
|
199
190
|
for (int i = 0; i < repeats.size(); i++) {
|
@@ -203,13 +194,10 @@ uint64_t Repeats::count () const
|
|
203
194
|
return accu;
|
204
195
|
}
|
205
196
|
|
206
|
-
|
207
|
-
|
208
197
|
// version with a bool vector that works for > 64 dim
|
209
|
-
uint64_t Repeats::encode(const float
|
210
|
-
{
|
198
|
+
uint64_t Repeats::encode(const float* c) const {
|
211
199
|
if (dim < 64) {
|
212
|
-
return repeats_encode_64
|
200
|
+
return repeats_encode_64(repeats, dim, c);
|
213
201
|
}
|
214
202
|
std::vector<bool> coded(dim, false);
|
215
203
|
int nfree = dim;
|
@@ -223,7 +211,8 @@ uint64_t Repeats::encode(const float *c) const
|
|
223
211
|
code_comb += comb(rank, occ + 1);
|
224
212
|
occ++;
|
225
213
|
coded[i] = true;
|
226
|
-
if (occ == r->n)
|
214
|
+
if (occ == r->n)
|
215
|
+
break;
|
227
216
|
}
|
228
217
|
rank++;
|
229
218
|
}
|
@@ -236,12 +225,9 @@ uint64_t Repeats::encode(const float *c) const
|
|
236
225
|
return code;
|
237
226
|
}
|
238
227
|
|
239
|
-
|
240
|
-
|
241
|
-
void Repeats::decode(uint64_t code, float *c) const
|
242
|
-
{
|
228
|
+
void Repeats::decode(uint64_t code, float* c) const {
|
243
229
|
if (dim < 64) {
|
244
|
-
repeats_decode_64
|
230
|
+
repeats_decode_64(repeats, dim, code, c);
|
245
231
|
return;
|
246
232
|
}
|
247
233
|
|
@@ -254,7 +240,7 @@ void Repeats::decode(uint64_t code, float *c) const
|
|
254
240
|
|
255
241
|
int occ = 0;
|
256
242
|
int rank = nfree;
|
257
|
-
int next_rank = decode_comb_1
|
243
|
+
int next_rank = decode_comb_1(&code_comb, r->n, rank);
|
258
244
|
for (int i = dim - 1; i >= 0; i--) {
|
259
245
|
if (!decoded[i]) {
|
260
246
|
rank--;
|
@@ -262,65 +248,61 @@ void Repeats::decode(uint64_t code, float *c) const
|
|
262
248
|
decoded[i] = true;
|
263
249
|
c[i] = r->val;
|
264
250
|
occ++;
|
265
|
-
if (occ == r->n)
|
266
|
-
|
267
|
-
|
251
|
+
if (occ == r->n)
|
252
|
+
break;
|
253
|
+
next_rank =
|
254
|
+
decode_comb_1(&code_comb, r->n - occ, next_rank);
|
268
255
|
}
|
269
256
|
}
|
270
257
|
}
|
271
258
|
nfree -= r->n;
|
272
259
|
}
|
273
|
-
|
274
260
|
}
|
275
261
|
|
276
|
-
|
277
|
-
|
278
262
|
/********************************************
|
279
263
|
* EnumeratedVectors functions
|
280
264
|
********************************************/
|
281
265
|
|
282
|
-
|
283
|
-
|
284
|
-
uint64_t * codes) const
|
285
|
-
{
|
266
|
+
void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
|
267
|
+
const {
|
286
268
|
#pragma omp parallel if (n > 1000)
|
287
269
|
{
|
288
270
|
#pragma omp for
|
289
|
-
for(int i = 0; i < n; i++) {
|
271
|
+
for (int i = 0; i < n; i++) {
|
290
272
|
codes[i] = encode(c + i * dim);
|
291
273
|
}
|
292
274
|
}
|
293
275
|
}
|
294
276
|
|
295
|
-
|
296
|
-
|
297
|
-
float *c) const
|
298
|
-
{
|
277
|
+
void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
|
278
|
+
const {
|
299
279
|
#pragma omp parallel if (n > 1000)
|
300
280
|
{
|
301
281
|
#pragma omp for
|
302
|
-
for(int i = 0; i < n; i++) {
|
282
|
+
for (int i = 0; i < n; i++) {
|
303
283
|
decode(codes[i], c + i * dim);
|
304
284
|
}
|
305
285
|
}
|
306
286
|
}
|
307
287
|
|
308
|
-
void EnumeratedVectors::find_nn
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
288
|
+
void EnumeratedVectors::find_nn(
|
289
|
+
size_t nc,
|
290
|
+
const uint64_t* codes,
|
291
|
+
size_t nq,
|
292
|
+
const float* xq,
|
293
|
+
int64_t* labels,
|
294
|
+
float* distances) {
|
313
295
|
for (size_t i = 0; i < nq; i++) {
|
314
296
|
distances[i] = -1e20;
|
315
297
|
labels[i] = -1;
|
316
298
|
}
|
317
299
|
|
318
300
|
std::vector<float> c(dim);
|
319
|
-
for(size_t i = 0; i < nc; i++) {
|
301
|
+
for (size_t i = 0; i < nc; i++) {
|
320
302
|
uint64_t code = codes[nc];
|
321
303
|
decode(code, c.data());
|
322
304
|
for (size_t j = 0; j < nq; j++) {
|
323
|
-
const float
|
305
|
+
const float* x = xq + j * dim;
|
324
306
|
float dis = fvec_inner_product(x, c.data(), dim);
|
325
307
|
if (dis > distances[j]) {
|
326
308
|
distances[j] = dis;
|
@@ -328,45 +310,41 @@ void EnumeratedVectors::find_nn (
|
|
328
310
|
}
|
329
311
|
}
|
330
312
|
}
|
331
|
-
|
332
313
|
}
|
333
314
|
|
334
|
-
|
335
315
|
/**********************************************************
|
336
316
|
* ZnSphereSearch
|
337
317
|
**********************************************************/
|
338
318
|
|
339
|
-
|
340
|
-
ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
|
319
|
+
ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
|
341
320
|
voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
|
342
321
|
natom = voc.size() / dim;
|
343
322
|
}
|
344
323
|
|
345
|
-
float ZnSphereSearch::search(const float
|
324
|
+
float ZnSphereSearch::search(const float* x, float* c) const {
|
346
325
|
std::vector<float> tmp(dimS * 2);
|
347
326
|
std::vector<int> tmp_int(dimS);
|
348
327
|
return search(x, c, tmp.data(), tmp_int.data());
|
349
328
|
}
|
350
329
|
|
351
|
-
float ZnSphereSearch::search(
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
330
|
+
float ZnSphereSearch::search(
|
331
|
+
const float* x,
|
332
|
+
float* c,
|
333
|
+
float* tmp, // size 2 *dim
|
334
|
+
int* tmp_int, // size dim
|
335
|
+
int* ibest_out) const {
|
356
336
|
int dim = dimS;
|
357
|
-
assert
|
358
|
-
int
|
359
|
-
float
|
360
|
-
float
|
337
|
+
assert(natom > 0);
|
338
|
+
int* o = tmp_int;
|
339
|
+
float* xabs = tmp;
|
340
|
+
float* xperm = tmp + dim;
|
361
341
|
|
362
342
|
// argsort
|
363
343
|
for (int i = 0; i < dim; i++) {
|
364
344
|
o[i] = i;
|
365
345
|
xabs[i] = fabsf(x[i]);
|
366
346
|
}
|
367
|
-
std::sort(o, o + dim, [xabs](int a, int b) {
|
368
|
-
return xabs[a] > xabs[b];
|
369
|
-
});
|
347
|
+
std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
|
370
348
|
for (int i = 0; i < dim; i++) {
|
371
349
|
xperm[i] = xabs[o[i]];
|
372
350
|
}
|
@@ -374,16 +352,16 @@ float ZnSphereSearch::search(const float *x, float *c,
|
|
374
352
|
int ibest = -1;
|
375
353
|
float dpbest = -100;
|
376
354
|
for (int i = 0; i < natom; i++) {
|
377
|
-
float dp = fvec_inner_product
|
355
|
+
float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
|
378
356
|
if (dp > dpbest) {
|
379
357
|
dpbest = dp;
|
380
358
|
ibest = i;
|
381
359
|
}
|
382
360
|
}
|
383
361
|
// revert sort
|
384
|
-
const float
|
362
|
+
const float* cin = voc.data() + ibest * dim;
|
385
363
|
for (int i = 0; i < dim; i++) {
|
386
|
-
c[o[i]] = copysignf
|
364
|
+
c[o[i]] = copysignf(cin[i], x[o[i]]);
|
387
365
|
}
|
388
366
|
if (ibest_out) {
|
389
367
|
*ibest_out = ibest;
|
@@ -391,33 +369,32 @@ float ZnSphereSearch::search(const float *x, float *c,
|
|
391
369
|
return dpbest;
|
392
370
|
}
|
393
371
|
|
394
|
-
void ZnSphereSearch::search_multi(
|
395
|
-
|
396
|
-
|
372
|
+
void ZnSphereSearch::search_multi(
|
373
|
+
int n,
|
374
|
+
const float* x,
|
375
|
+
float* c_out,
|
376
|
+
float* dp_out) {
|
397
377
|
#pragma omp parallel if (n > 1000)
|
398
378
|
{
|
399
379
|
#pragma omp for
|
400
|
-
for(int i = 0; i < n; i++) {
|
380
|
+
for (int i = 0; i < n; i++) {
|
401
381
|
dp_out[i] = search(x + i * dimS, c_out + i * dimS);
|
402
382
|
}
|
403
383
|
}
|
404
384
|
}
|
405
385
|
|
406
|
-
|
407
386
|
/**********************************************************
|
408
387
|
* ZnSphereCodec
|
409
388
|
**********************************************************/
|
410
389
|
|
411
|
-
ZnSphereCodec::ZnSphereCodec(int dim, int r2)
|
412
|
-
|
413
|
-
EnumeratedVectors(dim)
|
414
|
-
{
|
390
|
+
ZnSphereCodec::ZnSphereCodec(int dim, int r2)
|
391
|
+
: ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
|
415
392
|
nv = 0;
|
416
393
|
for (int i = 0; i < natom; i++) {
|
417
394
|
Repeats repeats(dim, &voc[i * dim]);
|
418
395
|
CodeSegment cs(repeats);
|
419
396
|
cs.c0 = nv;
|
420
|
-
Repeat
|
397
|
+
Repeat& br = repeats.repeats.back();
|
421
398
|
cs.signbits = br.val == 0 ? dim - br.n : dim;
|
422
399
|
code_segments.push_back(cs);
|
423
400
|
nv += repeats.count() << cs.signbits;
|
@@ -431,7 +408,7 @@ ZnSphereCodec::ZnSphereCodec(int dim, int r2):
|
|
431
408
|
}
|
432
409
|
}
|
433
410
|
|
434
|
-
uint64_t ZnSphereCodec::search_and_encode(const float
|
411
|
+
uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
|
435
412
|
std::vector<float> tmp(dim * 2);
|
436
413
|
std::vector<int> tmp_int(dim);
|
437
414
|
int ano; // atom number
|
@@ -446,30 +423,30 @@ uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
|
|
446
423
|
if (c[i] < 0) {
|
447
424
|
signs |= uint64_t{1} << nnz;
|
448
425
|
}
|
449
|
-
nnz
|
426
|
+
nnz++;
|
450
427
|
}
|
451
428
|
}
|
452
|
-
const CodeSegment
|
429
|
+
const CodeSegment& cs = code_segments[ano];
|
453
430
|
assert(nnz == cs.signbits);
|
454
431
|
uint64_t code = cs.c0 + signs;
|
455
432
|
code += cs.encode(cabs.data()) << cs.signbits;
|
456
433
|
return code;
|
457
434
|
}
|
458
435
|
|
459
|
-
uint64_t ZnSphereCodec::encode(const float
|
460
|
-
{
|
436
|
+
uint64_t ZnSphereCodec::encode(const float* x) const {
|
461
437
|
return search_and_encode(x);
|
462
438
|
}
|
463
439
|
|
464
|
-
|
465
|
-
void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
440
|
+
void ZnSphereCodec::decode(uint64_t code, float* c) const {
|
466
441
|
int i0 = 0, i1 = natom;
|
467
442
|
while (i0 + 1 < i1) {
|
468
443
|
int imed = (i0 + i1) / 2;
|
469
|
-
if (code_segments[imed].c0 <= code)
|
470
|
-
|
444
|
+
if (code_segments[imed].c0 <= code)
|
445
|
+
i0 = imed;
|
446
|
+
else
|
447
|
+
i1 = imed;
|
471
448
|
}
|
472
|
-
const CodeSegment
|
449
|
+
const CodeSegment& cs = code_segments[i0];
|
473
450
|
code -= cs.c0;
|
474
451
|
uint64_t signs = code;
|
475
452
|
code >>= cs.signbits;
|
@@ -481,42 +458,34 @@ void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
|
481
458
|
if (signs & (1UL << nnz)) {
|
482
459
|
c[i] = -c[i];
|
483
460
|
}
|
484
|
-
nnz
|
461
|
+
nnz++;
|
485
462
|
}
|
486
463
|
}
|
487
464
|
}
|
488
465
|
|
489
|
-
|
490
466
|
/**************************************************************
|
491
467
|
* ZnSphereCodecRec
|
492
468
|
**************************************************************/
|
493
469
|
|
494
|
-
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
|
495
|
-
{
|
470
|
+
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
|
496
471
|
return all_nv[ld * (r2 + 1) + r2a];
|
497
472
|
}
|
498
473
|
|
499
|
-
|
500
|
-
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
|
501
|
-
{
|
474
|
+
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
|
502
475
|
return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
|
503
476
|
}
|
504
477
|
|
505
|
-
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
|
506
|
-
{
|
478
|
+
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
|
507
479
|
all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
|
508
480
|
}
|
509
481
|
|
510
|
-
|
511
|
-
|
512
|
-
EnumeratedVectors(dim), r2(r2)
|
513
|
-
{
|
482
|
+
ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
|
483
|
+
: EnumeratedVectors(dim), r2(r2) {
|
514
484
|
log2_dim = 0;
|
515
485
|
while (dim > (1 << log2_dim)) {
|
516
486
|
log2_dim++;
|
517
487
|
}
|
518
|
-
assert(dim == (1 << log2_dim) ||
|
519
|
-
!"dimension must be a power of 2");
|
488
|
+
assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
|
520
489
|
|
521
490
|
all_nv.resize((log2_dim + 1) * (r2 + 1));
|
522
491
|
all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
|
@@ -531,7 +500,6 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
|
531
500
|
}
|
532
501
|
|
533
502
|
for (int ld = 1; ld <= log2_dim; ld++) {
|
534
|
-
|
535
503
|
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
536
504
|
uint64_t nv = 0;
|
537
505
|
for (int r2a = 0; r2a <= r2sub; r2a++) {
|
@@ -559,33 +527,29 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
|
559
527
|
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
560
528
|
int ld = cache_level;
|
561
529
|
uint64_t nvi = get_nv(ld, r2sub);
|
562
|
-
std::vector<float
|
530
|
+
std::vector<float>& cache = decode_cache[r2sub];
|
563
531
|
int dimsub = (1 << cache_level);
|
564
|
-
cache.resize
|
532
|
+
cache.resize(nvi * dimsub);
|
565
533
|
std::vector<float> c(dim);
|
566
|
-
uint64_t code0 = get_nv_cum(cache_level + 1, r2,
|
567
|
-
r2 - r2sub);
|
534
|
+
uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
|
568
535
|
for (int i = 0; i < nvi; i++) {
|
569
536
|
decode(i + code0, c.data());
|
570
|
-
memcpy(&cache[i * dimsub],
|
537
|
+
memcpy(&cache[i * dimsub],
|
538
|
+
c.data() + dim - dimsub,
|
571
539
|
dimsub * sizeof(*c.data()));
|
572
540
|
}
|
573
541
|
}
|
574
542
|
decode_cache_ld = cache_level;
|
575
543
|
}
|
576
544
|
|
577
|
-
uint64_t ZnSphereCodecRec::encode(const float
|
578
|
-
{
|
545
|
+
uint64_t ZnSphereCodecRec::encode(const float* c) const {
|
579
546
|
return encode_centroid(c);
|
580
547
|
}
|
581
548
|
|
582
|
-
|
583
|
-
|
584
|
-
uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
585
|
-
{
|
549
|
+
uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
|
586
550
|
std::vector<uint64_t> codes(dim);
|
587
551
|
std::vector<int> norm2s(dim);
|
588
|
-
for(int i = 0; i < dim; i++) {
|
552
|
+
for (int i = 0; i < dim; i++) {
|
589
553
|
if (c[i] == 0) {
|
590
554
|
codes[i] = 0;
|
591
555
|
norm2s[i] = 0;
|
@@ -596,7 +560,7 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
596
560
|
}
|
597
561
|
}
|
598
562
|
int dim2 = dim / 2;
|
599
|
-
for(int ld = 1; ld <= log2_dim; ld++) {
|
563
|
+
for (int ld = 1; ld <= log2_dim; ld++) {
|
600
564
|
for (int i = 0; i < dim2; i++) {
|
601
565
|
int r2a = norm2s[2 * i];
|
602
566
|
int r2b = norm2s[2 * i + 1];
|
@@ -604,10 +568,8 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
604
568
|
uint64_t code_a = codes[2 * i];
|
605
569
|
uint64_t code_b = codes[2 * i + 1];
|
606
570
|
|
607
|
-
codes[i] =
|
608
|
-
|
609
|
-
code_a * get_nv(ld - 1, r2b) +
|
610
|
-
code_b;
|
571
|
+
codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
|
572
|
+
code_a * get_nv(ld - 1, r2b) + code_b;
|
611
573
|
norm2s[i] = r2a + r2b;
|
612
574
|
}
|
613
575
|
dim2 /= 2;
|
@@ -615,23 +577,20 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
615
577
|
return codes[0];
|
616
578
|
}
|
617
579
|
|
618
|
-
|
619
|
-
|
620
|
-
void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
621
|
-
{
|
580
|
+
void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
|
622
581
|
std::vector<uint64_t> codes(dim);
|
623
582
|
std::vector<int> norm2s(dim);
|
624
583
|
codes[0] = code;
|
625
584
|
norm2s[0] = r2;
|
626
585
|
|
627
586
|
int dim2 = 1;
|
628
|
-
for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
587
|
+
for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
629
588
|
for (int i = dim2 - 1; i >= 0; i--) {
|
630
589
|
int r2sub = norm2s[i];
|
631
590
|
int i0 = 0, i1 = r2sub + 1;
|
632
591
|
uint64_t codei = codes[i];
|
633
|
-
const uint64_t
|
634
|
-
|
592
|
+
const uint64_t* cum =
|
593
|
+
&all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
|
635
594
|
while (i1 > i0 + 1) {
|
636
595
|
int imed = (i0 + i1) / 2;
|
637
596
|
if (cum[imed] <= codei)
|
@@ -649,13 +608,12 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
649
608
|
|
650
609
|
codes[2 * i] = code_a;
|
651
610
|
codes[2 * i + 1] = code_b;
|
652
|
-
|
653
611
|
}
|
654
612
|
dim2 *= 2;
|
655
613
|
}
|
656
614
|
|
657
615
|
if (decode_cache_ld == 0) {
|
658
|
-
for(int i = 0; i < dim; i++) {
|
616
|
+
for (int i = 0; i < dim; i++) {
|
659
617
|
if (norm2s[i] == 0) {
|
660
618
|
c[i] = 0;
|
661
619
|
} else {
|
@@ -666,49 +624,42 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
666
624
|
}
|
667
625
|
} else {
|
668
626
|
int subdim = 1 << decode_cache_ld;
|
669
|
-
assert
|
670
|
-
|
671
|
-
for(int i = 0; i < dim2; i++) {
|
627
|
+
assert((dim2 * subdim) == dim);
|
672
628
|
|
673
|
-
|
674
|
-
|
629
|
+
for (int i = 0; i < dim2; i++) {
|
630
|
+
const std::vector<float>& cache = decode_cache[norm2s[i]];
|
675
631
|
assert(codes[i] < cache.size());
|
676
632
|
memcpy(c + i * subdim,
|
677
633
|
&cache[codes[i] * subdim],
|
678
|
-
sizeof(*c)* subdim);
|
634
|
+
sizeof(*c) * subdim);
|
679
635
|
}
|
680
636
|
}
|
681
637
|
}
|
682
638
|
|
683
639
|
// if not use_rec, instanciate an arbitrary harmless znc_rec
|
684
|
-
ZnSphereCodecAlt::ZnSphereCodecAlt
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
{
|
690
|
-
|
691
|
-
uint64_t ZnSphereCodecAlt::encode(const float *x) const
|
692
|
-
{
|
640
|
+
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
|
641
|
+
: ZnSphereCodec(dim, r2),
|
642
|
+
use_rec((dim & (dim - 1)) == 0),
|
643
|
+
znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
|
644
|
+
|
645
|
+
uint64_t ZnSphereCodecAlt::encode(const float* x) const {
|
693
646
|
if (!use_rec) {
|
694
647
|
// it's ok if the vector is not normalized
|
695
648
|
return ZnSphereCodec::encode(x);
|
696
649
|
} else {
|
697
650
|
// find nearest centroid
|
698
651
|
std::vector<float> centroid(dim);
|
699
|
-
search
|
652
|
+
search(x, centroid.data());
|
700
653
|
return znc_rec.encode(centroid.data());
|
701
654
|
}
|
702
655
|
}
|
703
656
|
|
704
|
-
void ZnSphereCodecAlt::decode(uint64_t code, float
|
705
|
-
{
|
657
|
+
void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
|
706
658
|
if (!use_rec) {
|
707
|
-
ZnSphereCodec::decode
|
659
|
+
ZnSphereCodec::decode(code, c);
|
708
660
|
} else {
|
709
|
-
znc_rec.decode
|
661
|
+
znc_rec.decode(code, c);
|
710
662
|
}
|
711
663
|
}
|
712
664
|
|
713
|
-
|
714
665
|
} // namespace faiss
|