faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,19 +9,18 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/lattice_Zn.h>
|
|
11
11
|
|
|
12
|
-
#include <
|
|
12
|
+
#include <cassert>
|
|
13
13
|
#include <cmath>
|
|
14
|
+
#include <cstdlib>
|
|
14
15
|
#include <cstring>
|
|
15
|
-
#include <cassert>
|
|
16
16
|
|
|
17
|
+
#include <algorithm>
|
|
17
18
|
#include <queue>
|
|
18
|
-
#include <unordered_set>
|
|
19
19
|
#include <unordered_map>
|
|
20
|
-
#include <
|
|
20
|
+
#include <unordered_set>
|
|
21
21
|
|
|
22
|
-
#include <faiss/utils/distances.h>
|
|
23
22
|
#include <faiss/impl/platform_macros.h>
|
|
24
|
-
|
|
23
|
+
#include <faiss/utils/distances.h>
|
|
25
24
|
|
|
26
25
|
namespace faiss {
|
|
27
26
|
|
|
@@ -35,44 +34,41 @@ inline float sqr(float x) {
|
|
|
35
34
|
return x * x;
|
|
36
35
|
}
|
|
37
36
|
|
|
38
|
-
|
|
39
37
|
typedef std::vector<float> point_list_t;
|
|
40
38
|
|
|
41
39
|
struct Comb {
|
|
42
40
|
std::vector<uint64_t> tab; // Pascal's triangle
|
|
43
41
|
int nmax;
|
|
44
42
|
|
|
45
|
-
explicit Comb(int nmax): nmax(nmax) {
|
|
43
|
+
explicit Comb(int nmax) : nmax(nmax) {
|
|
46
44
|
tab.resize(nmax * nmax, 0);
|
|
47
45
|
tab[0] = 1;
|
|
48
|
-
for(int i = 1; i < nmax; i++) {
|
|
46
|
+
for (int i = 1; i < nmax; i++) {
|
|
49
47
|
tab[i * nmax] = 1;
|
|
50
|
-
for(int j = 1; j <= i; j++) {
|
|
48
|
+
for (int j = 1; j <= i; j++) {
|
|
51
49
|
tab[i * nmax + j] =
|
|
52
|
-
|
|
53
|
-
tab[(i - 1) * nmax + (j - 1)];
|
|
50
|
+
tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
|
|
54
51
|
}
|
|
55
|
-
|
|
56
52
|
}
|
|
57
53
|
}
|
|
58
54
|
|
|
59
55
|
uint64_t operator()(int n, int p) const {
|
|
60
|
-
assert
|
|
61
|
-
if (p > n)
|
|
56
|
+
assert(n < nmax && p < nmax);
|
|
57
|
+
if (p > n)
|
|
58
|
+
return 0;
|
|
62
59
|
return tab[n * nmax + p];
|
|
63
60
|
}
|
|
64
61
|
};
|
|
65
62
|
|
|
66
63
|
Comb comb(100);
|
|
67
64
|
|
|
68
|
-
|
|
69
|
-
|
|
70
65
|
// compute combinations of n integer values <= v that sum up to total (squared)
|
|
71
|
-
point_list_t sum_of_sq
|
|
66
|
+
point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
|
|
72
67
|
if (total < 0) {
|
|
73
68
|
return point_list_t();
|
|
74
69
|
} else if (n == 1) {
|
|
75
|
-
while (sqr(v + add) > total)
|
|
70
|
+
while (sqr(v + add) > total)
|
|
71
|
+
v--;
|
|
76
72
|
if (sqr(v + add) == total) {
|
|
77
73
|
return point_list_t(1, v + add);
|
|
78
74
|
} else {
|
|
@@ -82,9 +78,9 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
|
|
82
78
|
point_list_t res;
|
|
83
79
|
while (v >= 0) {
|
|
84
80
|
point_list_t sub_points =
|
|
85
|
-
|
|
81
|
+
sum_of_sq(total - sqr(v + add), v, n - 1, add);
|
|
86
82
|
for (size_t i = 0; i < sub_points.size(); i += n - 1) {
|
|
87
|
-
res.push_back
|
|
83
|
+
res.push_back(v + add);
|
|
88
84
|
for (int j = 0; j < n - 1; j++) {
|
|
89
85
|
res.push_back(sub_points[i + j]);
|
|
90
86
|
}
|
|
@@ -95,7 +91,7 @@ point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
|
|
95
91
|
}
|
|
96
92
|
}
|
|
97
93
|
|
|
98
|
-
int decode_comb_1
|
|
94
|
+
int decode_comb_1(uint64_t* n, int k1, int r) {
|
|
99
95
|
while (comb(r, k1) > *n) {
|
|
100
96
|
r--;
|
|
101
97
|
}
|
|
@@ -104,10 +100,10 @@ int decode_comb_1 (uint64_t *n, int k1, int r) {
|
|
|
104
100
|
}
|
|
105
101
|
|
|
106
102
|
// optimized version for < 64 bits
|
|
107
|
-
uint64_t repeats_encode_64
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
{
|
|
103
|
+
uint64_t repeats_encode_64(
|
|
104
|
+
const std::vector<Repeat>& repeats,
|
|
105
|
+
int dim,
|
|
106
|
+
const float* c) {
|
|
111
107
|
uint64_t coded = 0;
|
|
112
108
|
int nfree = dim;
|
|
113
109
|
uint64_t code = 0, shift = 1;
|
|
@@ -115,15 +111,16 @@ uint64_t repeats_encode_64 (
|
|
|
115
111
|
int rank = 0, occ = 0;
|
|
116
112
|
uint64_t code_comb = 0;
|
|
117
113
|
uint64_t tosee = ~coded;
|
|
118
|
-
for(;;) {
|
|
114
|
+
for (;;) {
|
|
119
115
|
// directly jump to next available slot.
|
|
120
116
|
int i = __builtin_ctzll(tosee);
|
|
121
|
-
tosee &= ~(uint64_t{1} << i)
|
|
117
|
+
tosee &= ~(uint64_t{1} << i);
|
|
122
118
|
if (c[i] == r->val) {
|
|
123
119
|
code_comb += comb(rank, occ + 1);
|
|
124
120
|
occ++;
|
|
125
121
|
coded |= uint64_t{1} << i;
|
|
126
|
-
if (occ == r->n)
|
|
122
|
+
if (occ == r->n)
|
|
123
|
+
break;
|
|
127
124
|
}
|
|
128
125
|
rank++;
|
|
129
126
|
}
|
|
@@ -135,11 +132,11 @@ uint64_t repeats_encode_64 (
|
|
|
135
132
|
return code;
|
|
136
133
|
}
|
|
137
134
|
|
|
138
|
-
|
|
139
135
|
void repeats_decode_64(
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
136
|
+
const std::vector<Repeat>& repeats,
|
|
137
|
+
int dim,
|
|
138
|
+
uint64_t code,
|
|
139
|
+
float* c) {
|
|
143
140
|
uint64_t decoded = 0;
|
|
144
141
|
int nfree = dim;
|
|
145
142
|
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
|
@@ -149,9 +146,9 @@ void repeats_decode_64(
|
|
|
149
146
|
|
|
150
147
|
int occ = 0;
|
|
151
148
|
int rank = nfree;
|
|
152
|
-
int next_rank = decode_comb_1
|
|
149
|
+
int next_rank = decode_comb_1(&code_comb, r->n, rank);
|
|
153
150
|
uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
|
|
154
|
-
for(;;) {
|
|
151
|
+
for (;;) {
|
|
155
152
|
int i = 63 - __builtin_clzll(tosee);
|
|
156
153
|
tosee &= ~(uint64_t{1} << i);
|
|
157
154
|
rank--;
|
|
@@ -159,25 +156,21 @@ void repeats_decode_64(
|
|
|
159
156
|
decoded |= uint64_t{1} << i;
|
|
160
157
|
c[i] = r->val;
|
|
161
158
|
occ++;
|
|
162
|
-
if (occ == r->n)
|
|
163
|
-
|
|
164
|
-
|
|
159
|
+
if (occ == r->n)
|
|
160
|
+
break;
|
|
161
|
+
next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
|
|
165
162
|
}
|
|
166
163
|
}
|
|
167
164
|
nfree -= r->n;
|
|
168
165
|
}
|
|
169
|
-
|
|
170
166
|
}
|
|
171
167
|
|
|
172
|
-
|
|
173
|
-
|
|
174
168
|
} // anonymous namespace
|
|
175
169
|
|
|
176
|
-
Repeats::Repeats
|
|
177
|
-
{
|
|
178
|
-
for(int i = 0; i < dim; i++) {
|
|
170
|
+
Repeats::Repeats(int dim, const float* c) : dim(dim) {
|
|
171
|
+
for (int i = 0; i < dim; i++) {
|
|
179
172
|
int j = 0;
|
|
180
|
-
for(;;) {
|
|
173
|
+
for (;;) {
|
|
181
174
|
if (j == repeats.size()) {
|
|
182
175
|
repeats.push_back(Repeat{c[i], 1});
|
|
183
176
|
break;
|
|
@@ -191,9 +184,7 @@ Repeats::Repeats (int dim, const float *c): dim(dim)
|
|
|
191
184
|
}
|
|
192
185
|
}
|
|
193
186
|
|
|
194
|
-
|
|
195
|
-
uint64_t Repeats::count () const
|
|
196
|
-
{
|
|
187
|
+
uint64_t Repeats::count() const {
|
|
197
188
|
uint64_t accu = 1;
|
|
198
189
|
int remain = dim;
|
|
199
190
|
for (int i = 0; i < repeats.size(); i++) {
|
|
@@ -203,13 +194,10 @@ uint64_t Repeats::count () const
|
|
|
203
194
|
return accu;
|
|
204
195
|
}
|
|
205
196
|
|
|
206
|
-
|
|
207
|
-
|
|
208
197
|
// version with a bool vector that works for > 64 dim
|
|
209
|
-
uint64_t Repeats::encode(const float
|
|
210
|
-
{
|
|
198
|
+
uint64_t Repeats::encode(const float* c) const {
|
|
211
199
|
if (dim < 64) {
|
|
212
|
-
return repeats_encode_64
|
|
200
|
+
return repeats_encode_64(repeats, dim, c);
|
|
213
201
|
}
|
|
214
202
|
std::vector<bool> coded(dim, false);
|
|
215
203
|
int nfree = dim;
|
|
@@ -223,7 +211,8 @@ uint64_t Repeats::encode(const float *c) const
|
|
|
223
211
|
code_comb += comb(rank, occ + 1);
|
|
224
212
|
occ++;
|
|
225
213
|
coded[i] = true;
|
|
226
|
-
if (occ == r->n)
|
|
214
|
+
if (occ == r->n)
|
|
215
|
+
break;
|
|
227
216
|
}
|
|
228
217
|
rank++;
|
|
229
218
|
}
|
|
@@ -236,12 +225,9 @@ uint64_t Repeats::encode(const float *c) const
|
|
|
236
225
|
return code;
|
|
237
226
|
}
|
|
238
227
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
void Repeats::decode(uint64_t code, float *c) const
|
|
242
|
-
{
|
|
228
|
+
void Repeats::decode(uint64_t code, float* c) const {
|
|
243
229
|
if (dim < 64) {
|
|
244
|
-
repeats_decode_64
|
|
230
|
+
repeats_decode_64(repeats, dim, code, c);
|
|
245
231
|
return;
|
|
246
232
|
}
|
|
247
233
|
|
|
@@ -254,7 +240,7 @@ void Repeats::decode(uint64_t code, float *c) const
|
|
|
254
240
|
|
|
255
241
|
int occ = 0;
|
|
256
242
|
int rank = nfree;
|
|
257
|
-
int next_rank = decode_comb_1
|
|
243
|
+
int next_rank = decode_comb_1(&code_comb, r->n, rank);
|
|
258
244
|
for (int i = dim - 1; i >= 0; i--) {
|
|
259
245
|
if (!decoded[i]) {
|
|
260
246
|
rank--;
|
|
@@ -262,65 +248,61 @@ void Repeats::decode(uint64_t code, float *c) const
|
|
|
262
248
|
decoded[i] = true;
|
|
263
249
|
c[i] = r->val;
|
|
264
250
|
occ++;
|
|
265
|
-
if (occ == r->n)
|
|
266
|
-
|
|
267
|
-
|
|
251
|
+
if (occ == r->n)
|
|
252
|
+
break;
|
|
253
|
+
next_rank =
|
|
254
|
+
decode_comb_1(&code_comb, r->n - occ, next_rank);
|
|
268
255
|
}
|
|
269
256
|
}
|
|
270
257
|
}
|
|
271
258
|
nfree -= r->n;
|
|
272
259
|
}
|
|
273
|
-
|
|
274
260
|
}
|
|
275
261
|
|
|
276
|
-
|
|
277
|
-
|
|
278
262
|
/********************************************
|
|
279
263
|
* EnumeratedVectors functions
|
|
280
264
|
********************************************/
|
|
281
265
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
uint64_t * codes) const
|
|
285
|
-
{
|
|
266
|
+
void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
|
|
267
|
+
const {
|
|
286
268
|
#pragma omp parallel if (n > 1000)
|
|
287
269
|
{
|
|
288
270
|
#pragma omp for
|
|
289
|
-
for(int i = 0; i < n; i++) {
|
|
271
|
+
for (int i = 0; i < n; i++) {
|
|
290
272
|
codes[i] = encode(c + i * dim);
|
|
291
273
|
}
|
|
292
274
|
}
|
|
293
275
|
}
|
|
294
276
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
float *c) const
|
|
298
|
-
{
|
|
277
|
+
void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
|
|
278
|
+
const {
|
|
299
279
|
#pragma omp parallel if (n > 1000)
|
|
300
280
|
{
|
|
301
281
|
#pragma omp for
|
|
302
|
-
for(int i = 0; i < n; i++) {
|
|
282
|
+
for (int i = 0; i < n; i++) {
|
|
303
283
|
decode(codes[i], c + i * dim);
|
|
304
284
|
}
|
|
305
285
|
}
|
|
306
286
|
}
|
|
307
287
|
|
|
308
|
-
void EnumeratedVectors::find_nn
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
288
|
+
void EnumeratedVectors::find_nn(
|
|
289
|
+
size_t nc,
|
|
290
|
+
const uint64_t* codes,
|
|
291
|
+
size_t nq,
|
|
292
|
+
const float* xq,
|
|
293
|
+
int64_t* labels,
|
|
294
|
+
float* distances) {
|
|
313
295
|
for (size_t i = 0; i < nq; i++) {
|
|
314
296
|
distances[i] = -1e20;
|
|
315
297
|
labels[i] = -1;
|
|
316
298
|
}
|
|
317
299
|
|
|
318
300
|
std::vector<float> c(dim);
|
|
319
|
-
for(size_t i = 0; i < nc; i++) {
|
|
301
|
+
for (size_t i = 0; i < nc; i++) {
|
|
320
302
|
uint64_t code = codes[nc];
|
|
321
303
|
decode(code, c.data());
|
|
322
304
|
for (size_t j = 0; j < nq; j++) {
|
|
323
|
-
const float
|
|
305
|
+
const float* x = xq + j * dim;
|
|
324
306
|
float dis = fvec_inner_product(x, c.data(), dim);
|
|
325
307
|
if (dis > distances[j]) {
|
|
326
308
|
distances[j] = dis;
|
|
@@ -328,45 +310,41 @@ void EnumeratedVectors::find_nn (
|
|
|
328
310
|
}
|
|
329
311
|
}
|
|
330
312
|
}
|
|
331
|
-
|
|
332
313
|
}
|
|
333
314
|
|
|
334
|
-
|
|
335
315
|
/**********************************************************
|
|
336
316
|
* ZnSphereSearch
|
|
337
317
|
**********************************************************/
|
|
338
318
|
|
|
339
|
-
|
|
340
|
-
ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
|
|
319
|
+
ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
|
|
341
320
|
voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
|
|
342
321
|
natom = voc.size() / dim;
|
|
343
322
|
}
|
|
344
323
|
|
|
345
|
-
float ZnSphereSearch::search(const float
|
|
324
|
+
float ZnSphereSearch::search(const float* x, float* c) const {
|
|
346
325
|
std::vector<float> tmp(dimS * 2);
|
|
347
326
|
std::vector<int> tmp_int(dimS);
|
|
348
327
|
return search(x, c, tmp.data(), tmp_int.data());
|
|
349
328
|
}
|
|
350
329
|
|
|
351
|
-
float ZnSphereSearch::search(
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
330
|
+
float ZnSphereSearch::search(
|
|
331
|
+
const float* x,
|
|
332
|
+
float* c,
|
|
333
|
+
float* tmp, // size 2 *dim
|
|
334
|
+
int* tmp_int, // size dim
|
|
335
|
+
int* ibest_out) const {
|
|
356
336
|
int dim = dimS;
|
|
357
|
-
assert
|
|
358
|
-
int
|
|
359
|
-
float
|
|
360
|
-
float
|
|
337
|
+
assert(natom > 0);
|
|
338
|
+
int* o = tmp_int;
|
|
339
|
+
float* xabs = tmp;
|
|
340
|
+
float* xperm = tmp + dim;
|
|
361
341
|
|
|
362
342
|
// argsort
|
|
363
343
|
for (int i = 0; i < dim; i++) {
|
|
364
344
|
o[i] = i;
|
|
365
345
|
xabs[i] = fabsf(x[i]);
|
|
366
346
|
}
|
|
367
|
-
std::sort(o, o + dim, [xabs](int a, int b) {
|
|
368
|
-
return xabs[a] > xabs[b];
|
|
369
|
-
});
|
|
347
|
+
std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
|
|
370
348
|
for (int i = 0; i < dim; i++) {
|
|
371
349
|
xperm[i] = xabs[o[i]];
|
|
372
350
|
}
|
|
@@ -374,16 +352,16 @@ float ZnSphereSearch::search(const float *x, float *c,
|
|
|
374
352
|
int ibest = -1;
|
|
375
353
|
float dpbest = -100;
|
|
376
354
|
for (int i = 0; i < natom; i++) {
|
|
377
|
-
float dp = fvec_inner_product
|
|
355
|
+
float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
|
|
378
356
|
if (dp > dpbest) {
|
|
379
357
|
dpbest = dp;
|
|
380
358
|
ibest = i;
|
|
381
359
|
}
|
|
382
360
|
}
|
|
383
361
|
// revert sort
|
|
384
|
-
const float
|
|
362
|
+
const float* cin = voc.data() + ibest * dim;
|
|
385
363
|
for (int i = 0; i < dim; i++) {
|
|
386
|
-
c[o[i]] = copysignf
|
|
364
|
+
c[o[i]] = copysignf(cin[i], x[o[i]]);
|
|
387
365
|
}
|
|
388
366
|
if (ibest_out) {
|
|
389
367
|
*ibest_out = ibest;
|
|
@@ -391,33 +369,32 @@ float ZnSphereSearch::search(const float *x, float *c,
|
|
|
391
369
|
return dpbest;
|
|
392
370
|
}
|
|
393
371
|
|
|
394
|
-
void ZnSphereSearch::search_multi(
|
|
395
|
-
|
|
396
|
-
|
|
372
|
+
void ZnSphereSearch::search_multi(
|
|
373
|
+
int n,
|
|
374
|
+
const float* x,
|
|
375
|
+
float* c_out,
|
|
376
|
+
float* dp_out) {
|
|
397
377
|
#pragma omp parallel if (n > 1000)
|
|
398
378
|
{
|
|
399
379
|
#pragma omp for
|
|
400
|
-
for(int i = 0; i < n; i++) {
|
|
380
|
+
for (int i = 0; i < n; i++) {
|
|
401
381
|
dp_out[i] = search(x + i * dimS, c_out + i * dimS);
|
|
402
382
|
}
|
|
403
383
|
}
|
|
404
384
|
}
|
|
405
385
|
|
|
406
|
-
|
|
407
386
|
/**********************************************************
|
|
408
387
|
* ZnSphereCodec
|
|
409
388
|
**********************************************************/
|
|
410
389
|
|
|
411
|
-
ZnSphereCodec::ZnSphereCodec(int dim, int r2)
|
|
412
|
-
|
|
413
|
-
EnumeratedVectors(dim)
|
|
414
|
-
{
|
|
390
|
+
ZnSphereCodec::ZnSphereCodec(int dim, int r2)
|
|
391
|
+
: ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
|
|
415
392
|
nv = 0;
|
|
416
393
|
for (int i = 0; i < natom; i++) {
|
|
417
394
|
Repeats repeats(dim, &voc[i * dim]);
|
|
418
395
|
CodeSegment cs(repeats);
|
|
419
396
|
cs.c0 = nv;
|
|
420
|
-
Repeat
|
|
397
|
+
Repeat& br = repeats.repeats.back();
|
|
421
398
|
cs.signbits = br.val == 0 ? dim - br.n : dim;
|
|
422
399
|
code_segments.push_back(cs);
|
|
423
400
|
nv += repeats.count() << cs.signbits;
|
|
@@ -431,7 +408,7 @@ ZnSphereCodec::ZnSphereCodec(int dim, int r2):
|
|
|
431
408
|
}
|
|
432
409
|
}
|
|
433
410
|
|
|
434
|
-
uint64_t ZnSphereCodec::search_and_encode(const float
|
|
411
|
+
uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
|
|
435
412
|
std::vector<float> tmp(dim * 2);
|
|
436
413
|
std::vector<int> tmp_int(dim);
|
|
437
414
|
int ano; // atom number
|
|
@@ -446,30 +423,30 @@ uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
|
|
|
446
423
|
if (c[i] < 0) {
|
|
447
424
|
signs |= uint64_t{1} << nnz;
|
|
448
425
|
}
|
|
449
|
-
nnz
|
|
426
|
+
nnz++;
|
|
450
427
|
}
|
|
451
428
|
}
|
|
452
|
-
const CodeSegment
|
|
429
|
+
const CodeSegment& cs = code_segments[ano];
|
|
453
430
|
assert(nnz == cs.signbits);
|
|
454
431
|
uint64_t code = cs.c0 + signs;
|
|
455
432
|
code += cs.encode(cabs.data()) << cs.signbits;
|
|
456
433
|
return code;
|
|
457
434
|
}
|
|
458
435
|
|
|
459
|
-
uint64_t ZnSphereCodec::encode(const float
|
|
460
|
-
{
|
|
436
|
+
uint64_t ZnSphereCodec::encode(const float* x) const {
|
|
461
437
|
return search_and_encode(x);
|
|
462
438
|
}
|
|
463
439
|
|
|
464
|
-
|
|
465
|
-
void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
|
440
|
+
void ZnSphereCodec::decode(uint64_t code, float* c) const {
|
|
466
441
|
int i0 = 0, i1 = natom;
|
|
467
442
|
while (i0 + 1 < i1) {
|
|
468
443
|
int imed = (i0 + i1) / 2;
|
|
469
|
-
if (code_segments[imed].c0 <= code)
|
|
470
|
-
|
|
444
|
+
if (code_segments[imed].c0 <= code)
|
|
445
|
+
i0 = imed;
|
|
446
|
+
else
|
|
447
|
+
i1 = imed;
|
|
471
448
|
}
|
|
472
|
-
const CodeSegment
|
|
449
|
+
const CodeSegment& cs = code_segments[i0];
|
|
473
450
|
code -= cs.c0;
|
|
474
451
|
uint64_t signs = code;
|
|
475
452
|
code >>= cs.signbits;
|
|
@@ -481,42 +458,34 @@ void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
|
|
481
458
|
if (signs & (1UL << nnz)) {
|
|
482
459
|
c[i] = -c[i];
|
|
483
460
|
}
|
|
484
|
-
nnz
|
|
461
|
+
nnz++;
|
|
485
462
|
}
|
|
486
463
|
}
|
|
487
464
|
}
|
|
488
465
|
|
|
489
|
-
|
|
490
466
|
/**************************************************************
|
|
491
467
|
* ZnSphereCodecRec
|
|
492
468
|
**************************************************************/
|
|
493
469
|
|
|
494
|
-
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
|
|
495
|
-
{
|
|
470
|
+
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
|
|
496
471
|
return all_nv[ld * (r2 + 1) + r2a];
|
|
497
472
|
}
|
|
498
473
|
|
|
499
|
-
|
|
500
|
-
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
|
|
501
|
-
{
|
|
474
|
+
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
|
|
502
475
|
return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
|
|
503
476
|
}
|
|
504
477
|
|
|
505
|
-
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
|
|
506
|
-
{
|
|
478
|
+
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
|
|
507
479
|
all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
|
|
508
480
|
}
|
|
509
481
|
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
EnumeratedVectors(dim), r2(r2)
|
|
513
|
-
{
|
|
482
|
+
ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
|
|
483
|
+
: EnumeratedVectors(dim), r2(r2) {
|
|
514
484
|
log2_dim = 0;
|
|
515
485
|
while (dim > (1 << log2_dim)) {
|
|
516
486
|
log2_dim++;
|
|
517
487
|
}
|
|
518
|
-
assert(dim == (1 << log2_dim) ||
|
|
519
|
-
!"dimension must be a power of 2");
|
|
488
|
+
assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
|
|
520
489
|
|
|
521
490
|
all_nv.resize((log2_dim + 1) * (r2 + 1));
|
|
522
491
|
all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
|
|
@@ -531,7 +500,6 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
|
|
531
500
|
}
|
|
532
501
|
|
|
533
502
|
for (int ld = 1; ld <= log2_dim; ld++) {
|
|
534
|
-
|
|
535
503
|
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
|
536
504
|
uint64_t nv = 0;
|
|
537
505
|
for (int r2a = 0; r2a <= r2sub; r2a++) {
|
|
@@ -559,33 +527,29 @@ ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
|
|
559
527
|
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
|
560
528
|
int ld = cache_level;
|
|
561
529
|
uint64_t nvi = get_nv(ld, r2sub);
|
|
562
|
-
std::vector<float
|
|
530
|
+
std::vector<float>& cache = decode_cache[r2sub];
|
|
563
531
|
int dimsub = (1 << cache_level);
|
|
564
|
-
cache.resize
|
|
532
|
+
cache.resize(nvi * dimsub);
|
|
565
533
|
std::vector<float> c(dim);
|
|
566
|
-
uint64_t code0 = get_nv_cum(cache_level + 1, r2,
|
|
567
|
-
r2 - r2sub);
|
|
534
|
+
uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
|
|
568
535
|
for (int i = 0; i < nvi; i++) {
|
|
569
536
|
decode(i + code0, c.data());
|
|
570
|
-
memcpy(&cache[i * dimsub],
|
|
537
|
+
memcpy(&cache[i * dimsub],
|
|
538
|
+
c.data() + dim - dimsub,
|
|
571
539
|
dimsub * sizeof(*c.data()));
|
|
572
540
|
}
|
|
573
541
|
}
|
|
574
542
|
decode_cache_ld = cache_level;
|
|
575
543
|
}
|
|
576
544
|
|
|
577
|
-
uint64_t ZnSphereCodecRec::encode(const float
|
|
578
|
-
{
|
|
545
|
+
uint64_t ZnSphereCodecRec::encode(const float* c) const {
|
|
579
546
|
return encode_centroid(c);
|
|
580
547
|
}
|
|
581
548
|
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
585
|
-
{
|
|
549
|
+
uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
|
|
586
550
|
std::vector<uint64_t> codes(dim);
|
|
587
551
|
std::vector<int> norm2s(dim);
|
|
588
|
-
for(int i = 0; i < dim; i++) {
|
|
552
|
+
for (int i = 0; i < dim; i++) {
|
|
589
553
|
if (c[i] == 0) {
|
|
590
554
|
codes[i] = 0;
|
|
591
555
|
norm2s[i] = 0;
|
|
@@ -596,7 +560,7 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
|
596
560
|
}
|
|
597
561
|
}
|
|
598
562
|
int dim2 = dim / 2;
|
|
599
|
-
for(int ld = 1; ld <= log2_dim; ld++) {
|
|
563
|
+
for (int ld = 1; ld <= log2_dim; ld++) {
|
|
600
564
|
for (int i = 0; i < dim2; i++) {
|
|
601
565
|
int r2a = norm2s[2 * i];
|
|
602
566
|
int r2b = norm2s[2 * i + 1];
|
|
@@ -604,10 +568,8 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
|
604
568
|
uint64_t code_a = codes[2 * i];
|
|
605
569
|
uint64_t code_b = codes[2 * i + 1];
|
|
606
570
|
|
|
607
|
-
codes[i] =
|
|
608
|
-
|
|
609
|
-
code_a * get_nv(ld - 1, r2b) +
|
|
610
|
-
code_b;
|
|
571
|
+
codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
|
|
572
|
+
code_a * get_nv(ld - 1, r2b) + code_b;
|
|
611
573
|
norm2s[i] = r2a + r2b;
|
|
612
574
|
}
|
|
613
575
|
dim2 /= 2;
|
|
@@ -615,23 +577,20 @@ uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
|
615
577
|
return codes[0];
|
|
616
578
|
}
|
|
617
579
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
621
|
-
{
|
|
580
|
+
void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
|
|
622
581
|
std::vector<uint64_t> codes(dim);
|
|
623
582
|
std::vector<int> norm2s(dim);
|
|
624
583
|
codes[0] = code;
|
|
625
584
|
norm2s[0] = r2;
|
|
626
585
|
|
|
627
586
|
int dim2 = 1;
|
|
628
|
-
for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
|
587
|
+
for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
|
629
588
|
for (int i = dim2 - 1; i >= 0; i--) {
|
|
630
589
|
int r2sub = norm2s[i];
|
|
631
590
|
int i0 = 0, i1 = r2sub + 1;
|
|
632
591
|
uint64_t codei = codes[i];
|
|
633
|
-
const uint64_t
|
|
634
|
-
|
|
592
|
+
const uint64_t* cum =
|
|
593
|
+
&all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
|
|
635
594
|
while (i1 > i0 + 1) {
|
|
636
595
|
int imed = (i0 + i1) / 2;
|
|
637
596
|
if (cum[imed] <= codei)
|
|
@@ -649,13 +608,12 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
|
649
608
|
|
|
650
609
|
codes[2 * i] = code_a;
|
|
651
610
|
codes[2 * i + 1] = code_b;
|
|
652
|
-
|
|
653
611
|
}
|
|
654
612
|
dim2 *= 2;
|
|
655
613
|
}
|
|
656
614
|
|
|
657
615
|
if (decode_cache_ld == 0) {
|
|
658
|
-
for(int i = 0; i < dim; i++) {
|
|
616
|
+
for (int i = 0; i < dim; i++) {
|
|
659
617
|
if (norm2s[i] == 0) {
|
|
660
618
|
c[i] = 0;
|
|
661
619
|
} else {
|
|
@@ -666,49 +624,42 @@ void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
|
666
624
|
}
|
|
667
625
|
} else {
|
|
668
626
|
int subdim = 1 << decode_cache_ld;
|
|
669
|
-
assert
|
|
670
|
-
|
|
671
|
-
for(int i = 0; i < dim2; i++) {
|
|
627
|
+
assert((dim2 * subdim) == dim);
|
|
672
628
|
|
|
673
|
-
|
|
674
|
-
|
|
629
|
+
for (int i = 0; i < dim2; i++) {
|
|
630
|
+
const std::vector<float>& cache = decode_cache[norm2s[i]];
|
|
675
631
|
assert(codes[i] < cache.size());
|
|
676
632
|
memcpy(c + i * subdim,
|
|
677
633
|
&cache[codes[i] * subdim],
|
|
678
|
-
sizeof(*c)* subdim);
|
|
634
|
+
sizeof(*c) * subdim);
|
|
679
635
|
}
|
|
680
636
|
}
|
|
681
637
|
}
|
|
682
638
|
|
|
683
639
|
// if not use_rec, instanciate an arbitrary harmless znc_rec
|
|
684
|
-
ZnSphereCodecAlt::ZnSphereCodecAlt
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
{
|
|
690
|
-
|
|
691
|
-
uint64_t ZnSphereCodecAlt::encode(const float *x) const
|
|
692
|
-
{
|
|
640
|
+
ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
|
|
641
|
+
: ZnSphereCodec(dim, r2),
|
|
642
|
+
use_rec((dim & (dim - 1)) == 0),
|
|
643
|
+
znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
|
|
644
|
+
|
|
645
|
+
uint64_t ZnSphereCodecAlt::encode(const float* x) const {
|
|
693
646
|
if (!use_rec) {
|
|
694
647
|
// it's ok if the vector is not normalized
|
|
695
648
|
return ZnSphereCodec::encode(x);
|
|
696
649
|
} else {
|
|
697
650
|
// find nearest centroid
|
|
698
651
|
std::vector<float> centroid(dim);
|
|
699
|
-
search
|
|
652
|
+
search(x, centroid.data());
|
|
700
653
|
return znc_rec.encode(centroid.data());
|
|
701
654
|
}
|
|
702
655
|
}
|
|
703
656
|
|
|
704
|
-
void ZnSphereCodecAlt::decode(uint64_t code, float
|
|
705
|
-
{
|
|
657
|
+
void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
|
|
706
658
|
if (!use_rec) {
|
|
707
|
-
ZnSphereCodec::decode
|
|
659
|
+
ZnSphereCodec::decode(code, c);
|
|
708
660
|
} else {
|
|
709
|
-
znc_rec.decode
|
|
661
|
+
znc_rec.decode(code, c);
|
|
710
662
|
}
|
|
711
663
|
}
|
|
712
664
|
|
|
713
|
-
|
|
714
665
|
} // namespace faiss
|