faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,142 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <cstring>
|
11
|
+
#include <cassert>
|
12
|
+
|
13
|
+
#include <faiss/impl/io.h>
|
14
|
+
#include <faiss/impl/FaissAssert.h>
|
15
|
+
|
16
|
+
|
17
|
+
namespace faiss {
|
18
|
+
|
19
|
+
|
20
|
+
/***********************************************************************
|
21
|
+
* IO functions
|
22
|
+
***********************************************************************/
|
23
|
+
|
24
|
+
|
25
|
+
int IOReader::fileno ()
|
26
|
+
{
|
27
|
+
FAISS_THROW_MSG ("IOReader does not support memory mapping");
|
28
|
+
}
|
29
|
+
|
30
|
+
int IOWriter::fileno ()
|
31
|
+
{
|
32
|
+
FAISS_THROW_MSG ("IOWriter does not support memory mapping");
|
33
|
+
}
|
34
|
+
|
35
|
+
/***********************************************************************
|
36
|
+
* IO Vector
|
37
|
+
***********************************************************************/
|
38
|
+
|
39
|
+
|
40
|
+
|
41
|
+
size_t VectorIOWriter::operator()(
|
42
|
+
const void *ptr, size_t size, size_t nitems)
|
43
|
+
{
|
44
|
+
size_t bytes = size * nitems;
|
45
|
+
if (bytes > 0) {
|
46
|
+
size_t o = data.size();
|
47
|
+
data.resize(o + bytes);
|
48
|
+
memcpy (&data[o], ptr, size * nitems);
|
49
|
+
}
|
50
|
+
return nitems;
|
51
|
+
}
|
52
|
+
|
53
|
+
size_t VectorIOReader::operator()(
|
54
|
+
void *ptr, size_t size, size_t nitems)
|
55
|
+
{
|
56
|
+
if (rp >= data.size()) return 0;
|
57
|
+
size_t nremain = (data.size() - rp) / size;
|
58
|
+
if (nremain < nitems) nitems = nremain;
|
59
|
+
if (size * nitems > 0) {
|
60
|
+
memcpy (ptr, &data[rp], size * nitems);
|
61
|
+
rp += size * nitems;
|
62
|
+
}
|
63
|
+
return nitems;
|
64
|
+
}
|
65
|
+
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
/***********************************************************************
|
70
|
+
* IO File
|
71
|
+
***********************************************************************/
|
72
|
+
|
73
|
+
|
74
|
+
|
75
|
+
FileIOReader::FileIOReader(FILE *rf): f(rf) {}
|
76
|
+
|
77
|
+
FileIOReader::FileIOReader(const char * fname)
|
78
|
+
{
|
79
|
+
name = fname;
|
80
|
+
f = fopen(fname, "rb");
|
81
|
+
FAISS_THROW_IF_NOT_FMT (f, "could not open %s for reading: %s",
|
82
|
+
fname, strerror(errno));
|
83
|
+
need_close = true;
|
84
|
+
}
|
85
|
+
|
86
|
+
FileIOReader::~FileIOReader() {
|
87
|
+
if (need_close) {
|
88
|
+
int ret = fclose(f);
|
89
|
+
if (ret != 0) {// we cannot raise and exception in the destructor
|
90
|
+
fprintf(stderr, "file %s close error: %s",
|
91
|
+
name.c_str(), strerror(errno));
|
92
|
+
}
|
93
|
+
}
|
94
|
+
}
|
95
|
+
|
96
|
+
size_t FileIOReader::operator()(void *ptr, size_t size, size_t nitems) {
|
97
|
+
return fread(ptr, size, nitems, f);
|
98
|
+
}
|
99
|
+
|
100
|
+
int FileIOReader::fileno() {
|
101
|
+
return ::fileno (f);
|
102
|
+
}
|
103
|
+
|
104
|
+
|
105
|
+
FileIOWriter::FileIOWriter(FILE *wf): f(wf) {}
|
106
|
+
|
107
|
+
FileIOWriter::FileIOWriter(const char * fname)
|
108
|
+
{
|
109
|
+
name = fname;
|
110
|
+
f = fopen(fname, "wb");
|
111
|
+
FAISS_THROW_IF_NOT_FMT (f, "could not open %s for writing: %s",
|
112
|
+
fname, strerror(errno));
|
113
|
+
need_close = true;
|
114
|
+
}
|
115
|
+
|
116
|
+
FileIOWriter::~FileIOWriter() {
|
117
|
+
if (need_close) {
|
118
|
+
int ret = fclose(f);
|
119
|
+
if (ret != 0) {
|
120
|
+
// we cannot raise and exception in the destructor
|
121
|
+
fprintf(stderr, "file %s close error: %s",
|
122
|
+
name.c_str(), strerror(errno));
|
123
|
+
}
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
size_t FileIOWriter::operator()(const void *ptr, size_t size, size_t nitems) {
|
128
|
+
return fwrite(ptr, size, nitems, f);
|
129
|
+
}
|
130
|
+
|
131
|
+
int FileIOWriter::fileno() {
|
132
|
+
return ::fileno (f);
|
133
|
+
}
|
134
|
+
|
135
|
+
uint32_t fourcc (const char sx[4]) {
|
136
|
+
assert(4 == strlen(sx));
|
137
|
+
const unsigned char *x = (unsigned char*)sx;
|
138
|
+
return x[0] | x[1] << 8 | x[2] << 16 | x[3] << 24;
|
139
|
+
}
|
140
|
+
|
141
|
+
|
142
|
+
} // namespace faiss
|
@@ -0,0 +1,98 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
/***********************************************************
|
11
|
+
* Abstract I/O objects
|
12
|
+
***********************************************************/
|
13
|
+
|
14
|
+
#pragma once
|
15
|
+
|
16
|
+
#include <string>
|
17
|
+
#include <cstdio>
|
18
|
+
#include <vector>
|
19
|
+
|
20
|
+
#include <faiss/Index.h>
|
21
|
+
|
22
|
+
namespace faiss {
|
23
|
+
|
24
|
+
|
25
|
+
struct IOReader {
|
26
|
+
// name that can be used in error messages
|
27
|
+
std::string name;
|
28
|
+
|
29
|
+
// fread
|
30
|
+
virtual size_t operator()(
|
31
|
+
void *ptr, size_t size, size_t nitems) = 0;
|
32
|
+
|
33
|
+
// return a file number that can be memory-mapped
|
34
|
+
virtual int fileno ();
|
35
|
+
|
36
|
+
virtual ~IOReader() {}
|
37
|
+
};
|
38
|
+
|
39
|
+
struct IOWriter {
|
40
|
+
// name that can be used in error messages
|
41
|
+
std::string name;
|
42
|
+
|
43
|
+
// fwrite
|
44
|
+
virtual size_t operator()(
|
45
|
+
const void *ptr, size_t size, size_t nitems) = 0;
|
46
|
+
|
47
|
+
// return a file number that can be memory-mapped
|
48
|
+
virtual int fileno ();
|
49
|
+
|
50
|
+
virtual ~IOWriter() {}
|
51
|
+
};
|
52
|
+
|
53
|
+
|
54
|
+
struct VectorIOReader:IOReader {
|
55
|
+
std::vector<uint8_t> data;
|
56
|
+
size_t rp = 0;
|
57
|
+
size_t operator()(void *ptr, size_t size, size_t nitems) override;
|
58
|
+
};
|
59
|
+
|
60
|
+
struct VectorIOWriter:IOWriter {
|
61
|
+
std::vector<uint8_t> data;
|
62
|
+
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
|
63
|
+
};
|
64
|
+
|
65
|
+
struct FileIOReader: IOReader {
|
66
|
+
FILE *f = nullptr;
|
67
|
+
bool need_close = false;
|
68
|
+
|
69
|
+
FileIOReader(FILE *rf);
|
70
|
+
|
71
|
+
FileIOReader(const char * fname);
|
72
|
+
|
73
|
+
~FileIOReader() override;
|
74
|
+
|
75
|
+
size_t operator()(void *ptr, size_t size, size_t nitems) override;
|
76
|
+
|
77
|
+
int fileno() override;
|
78
|
+
};
|
79
|
+
|
80
|
+
struct FileIOWriter: IOWriter {
|
81
|
+
FILE *f = nullptr;
|
82
|
+
bool need_close = false;
|
83
|
+
|
84
|
+
FileIOWriter(FILE *wf);
|
85
|
+
|
86
|
+
FileIOWriter(const char * fname);
|
87
|
+
|
88
|
+
~FileIOWriter() override;
|
89
|
+
|
90
|
+
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
|
91
|
+
|
92
|
+
int fileno() override;
|
93
|
+
};
|
94
|
+
|
95
|
+
/// cast a 4-character string to a uint32_t that can be written and read easily
|
96
|
+
uint32_t fourcc (const char sx[4]);
|
97
|
+
|
98
|
+
} // namespace faiss
|
@@ -0,0 +1,712 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/impl/lattice_Zn.h>
|
11
|
+
|
12
|
+
#include <cstdlib>
|
13
|
+
#include <cmath>
|
14
|
+
#include <cstring>
|
15
|
+
#include <cassert>
|
16
|
+
|
17
|
+
#include <queue>
|
18
|
+
#include <unordered_set>
|
19
|
+
#include <unordered_map>
|
20
|
+
#include <algorithm>
|
21
|
+
|
22
|
+
#include <faiss/utils/distances.h>
|
23
|
+
|
24
|
+
namespace faiss {
|
25
|
+
|
26
|
+
/********************************************
|
27
|
+
* small utility functions
|
28
|
+
********************************************/
|
29
|
+
|
30
|
+
namespace {
|
31
|
+
|
32
|
+
inline float sqr(float x) {
|
33
|
+
return x * x;
|
34
|
+
}
|
35
|
+
|
36
|
+
|
37
|
+
typedef std::vector<float> point_list_t;
|
38
|
+
|
39
|
+
struct Comb {
|
40
|
+
std::vector<uint64_t> tab; // Pascal's triangle
|
41
|
+
int nmax;
|
42
|
+
|
43
|
+
explicit Comb(int nmax): nmax(nmax) {
|
44
|
+
tab.resize(nmax * nmax, 0);
|
45
|
+
tab[0] = 1;
|
46
|
+
for(int i = 1; i < nmax; i++) {
|
47
|
+
tab[i * nmax] = 1;
|
48
|
+
for(int j = 1; j <= i; j++) {
|
49
|
+
tab[i * nmax + j] =
|
50
|
+
tab[(i - 1) * nmax + j] +
|
51
|
+
tab[(i - 1) * nmax + (j - 1)];
|
52
|
+
}
|
53
|
+
|
54
|
+
}
|
55
|
+
}
|
56
|
+
|
57
|
+
uint64_t operator()(int n, int p) const {
|
58
|
+
assert (n < nmax && p < nmax);
|
59
|
+
if (p > n) return 0;
|
60
|
+
return tab[n * nmax + p];
|
61
|
+
}
|
62
|
+
};
|
63
|
+
|
64
|
+
Comb comb(100);
|
65
|
+
|
66
|
+
|
67
|
+
|
68
|
+
// compute combinations of n integer values <= v that sum up to total (squared)
|
69
|
+
point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
70
|
+
if (total < 0) {
|
71
|
+
return point_list_t();
|
72
|
+
} else if (n == 1) {
|
73
|
+
while (sqr(v + add) > total) v--;
|
74
|
+
if (sqr(v + add) == total) {
|
75
|
+
return point_list_t(1, v + add);
|
76
|
+
} else {
|
77
|
+
return point_list_t();
|
78
|
+
}
|
79
|
+
} else {
|
80
|
+
point_list_t res;
|
81
|
+
while (v >= 0) {
|
82
|
+
point_list_t sub_points =
|
83
|
+
sum_of_sq (total - sqr(v + add), v, n - 1, add);
|
84
|
+
for (size_t i = 0; i < sub_points.size(); i += n - 1) {
|
85
|
+
res.push_back (v + add);
|
86
|
+
for (int j = 0; j < n - 1; j++) {
|
87
|
+
res.push_back(sub_points[i + j]);
|
88
|
+
}
|
89
|
+
}
|
90
|
+
v--;
|
91
|
+
}
|
92
|
+
return res;
|
93
|
+
}
|
94
|
+
}
|
95
|
+
|
96
|
+
int decode_comb_1 (uint64_t *n, int k1, int r) {
|
97
|
+
while (comb(r, k1) > *n) {
|
98
|
+
r--;
|
99
|
+
}
|
100
|
+
*n -= comb(r, k1);
|
101
|
+
return r;
|
102
|
+
}
|
103
|
+
|
104
|
+
// optimized version for < 64 bits
|
105
|
+
long repeats_encode_64 (
|
106
|
+
const std::vector<Repeat> & repeats,
|
107
|
+
int dim, const float *c)
|
108
|
+
{
|
109
|
+
uint64_t coded = 0;
|
110
|
+
int nfree = dim;
|
111
|
+
uint64_t code = 0, shift = 1;
|
112
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
113
|
+
int rank = 0, occ = 0;
|
114
|
+
uint64_t code_comb = 0;
|
115
|
+
uint64_t tosee = ~coded;
|
116
|
+
for(;;) {
|
117
|
+
// directly jump to next available slot.
|
118
|
+
int i = __builtin_ctzl(tosee);
|
119
|
+
tosee &= ~(1UL << i) ;
|
120
|
+
if (c[i] == r->val) {
|
121
|
+
code_comb += comb(rank, occ + 1);
|
122
|
+
occ++;
|
123
|
+
coded |= 1UL << i;
|
124
|
+
if (occ == r->n) break;
|
125
|
+
}
|
126
|
+
rank++;
|
127
|
+
}
|
128
|
+
uint64_t max_comb = comb(nfree, r->n);
|
129
|
+
code += shift * code_comb;
|
130
|
+
shift *= max_comb;
|
131
|
+
nfree -= r->n;
|
132
|
+
}
|
133
|
+
return code;
|
134
|
+
}
|
135
|
+
|
136
|
+
|
137
|
+
void repeats_decode_64(
|
138
|
+
const std::vector<Repeat> & repeats,
|
139
|
+
int dim, uint64_t code, float *c)
|
140
|
+
{
|
141
|
+
uint64_t decoded = 0;
|
142
|
+
int nfree = dim;
|
143
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
144
|
+
uint64_t max_comb = comb(nfree, r->n);
|
145
|
+
uint64_t code_comb = code % max_comb;
|
146
|
+
code /= max_comb;
|
147
|
+
|
148
|
+
int occ = 0;
|
149
|
+
int rank = nfree;
|
150
|
+
int next_rank = decode_comb_1 (&code_comb, r->n, rank);
|
151
|
+
uint64_t tosee = ((1UL << dim) - 1) ^ decoded;
|
152
|
+
for(;;) {
|
153
|
+
int i = 63 - __builtin_clzl(tosee);
|
154
|
+
tosee &= ~(1UL << i);
|
155
|
+
rank--;
|
156
|
+
if (rank == next_rank) {
|
157
|
+
decoded |= 1UL << i;
|
158
|
+
c[i] = r->val;
|
159
|
+
occ++;
|
160
|
+
if (occ == r->n) break;
|
161
|
+
next_rank = decode_comb_1 (
|
162
|
+
&code_comb, r->n - occ, next_rank);
|
163
|
+
}
|
164
|
+
}
|
165
|
+
nfree -= r->n;
|
166
|
+
}
|
167
|
+
|
168
|
+
}
|
169
|
+
|
170
|
+
|
171
|
+
|
172
|
+
} // anonymous namespace
|
173
|
+
|
174
|
+
Repeats::Repeats (int dim, const float *c): dim(dim)
|
175
|
+
{
|
176
|
+
for(int i = 0; i < dim; i++) {
|
177
|
+
int j = 0;
|
178
|
+
for(;;) {
|
179
|
+
if (j == repeats.size()) {
|
180
|
+
repeats.push_back(Repeat{c[i], 1});
|
181
|
+
break;
|
182
|
+
}
|
183
|
+
if (repeats[j].val == c[i]) {
|
184
|
+
repeats[j].n++;
|
185
|
+
break;
|
186
|
+
}
|
187
|
+
j++;
|
188
|
+
}
|
189
|
+
}
|
190
|
+
}
|
191
|
+
|
192
|
+
|
193
|
+
long Repeats::count () const
|
194
|
+
{
|
195
|
+
long accu = 1;
|
196
|
+
int remain = dim;
|
197
|
+
for (int i = 0; i < repeats.size(); i++) {
|
198
|
+
accu *= comb(remain, repeats[i].n);
|
199
|
+
remain -= repeats[i].n;
|
200
|
+
}
|
201
|
+
return accu;
|
202
|
+
}
|
203
|
+
|
204
|
+
|
205
|
+
|
206
|
+
// version with a bool vector that works for > 64 dim
|
207
|
+
long Repeats::encode(const float *c) const
|
208
|
+
{
|
209
|
+
if (dim < 64) {
|
210
|
+
return repeats_encode_64 (repeats, dim, c);
|
211
|
+
}
|
212
|
+
std::vector<bool> coded(dim, false);
|
213
|
+
int nfree = dim;
|
214
|
+
uint64_t code = 0, shift = 1;
|
215
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
216
|
+
int rank = 0, occ = 0;
|
217
|
+
uint64_t code_comb = 0;
|
218
|
+
for (int i = 0; i < dim; i++) {
|
219
|
+
if (!coded[i]) {
|
220
|
+
if (c[i] == r->val) {
|
221
|
+
code_comb += comb(rank, occ + 1);
|
222
|
+
occ++;
|
223
|
+
coded[i] = true;
|
224
|
+
if (occ == r->n) break;
|
225
|
+
}
|
226
|
+
rank++;
|
227
|
+
}
|
228
|
+
}
|
229
|
+
uint64_t max_comb = comb(nfree, r->n);
|
230
|
+
code += shift * code_comb;
|
231
|
+
shift *= max_comb;
|
232
|
+
nfree -= r->n;
|
233
|
+
}
|
234
|
+
return code;
|
235
|
+
}
|
236
|
+
|
237
|
+
|
238
|
+
|
239
|
+
void Repeats::decode(uint64_t code, float *c) const
|
240
|
+
{
|
241
|
+
if (dim < 64) {
|
242
|
+
repeats_decode_64 (repeats, dim, code, c);
|
243
|
+
return;
|
244
|
+
}
|
245
|
+
|
246
|
+
std::vector<bool> decoded(dim, false);
|
247
|
+
int nfree = dim;
|
248
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
249
|
+
uint64_t max_comb = comb(nfree, r->n);
|
250
|
+
uint64_t code_comb = code % max_comb;
|
251
|
+
code /= max_comb;
|
252
|
+
|
253
|
+
int occ = 0;
|
254
|
+
int rank = nfree;
|
255
|
+
int next_rank = decode_comb_1 (&code_comb, r->n, rank);
|
256
|
+
for (int i = dim - 1; i >= 0; i--) {
|
257
|
+
if (!decoded[i]) {
|
258
|
+
rank--;
|
259
|
+
if (rank == next_rank) {
|
260
|
+
decoded[i] = true;
|
261
|
+
c[i] = r->val;
|
262
|
+
occ++;
|
263
|
+
if (occ == r->n) break;
|
264
|
+
next_rank = decode_comb_1 (
|
265
|
+
&code_comb, r->n - occ, next_rank);
|
266
|
+
}
|
267
|
+
}
|
268
|
+
}
|
269
|
+
nfree -= r->n;
|
270
|
+
}
|
271
|
+
|
272
|
+
}
|
273
|
+
|
274
|
+
|
275
|
+
|
276
|
+
/********************************************
|
277
|
+
* EnumeratedVectors functions
|
278
|
+
********************************************/
|
279
|
+
|
280
|
+
|
281
|
+
void EnumeratedVectors::encode_multi(size_t n, const float *c,
|
282
|
+
uint64_t * codes) const
|
283
|
+
{
|
284
|
+
#pragma omp parallel if (n > 1000)
|
285
|
+
{
|
286
|
+
#pragma omp for
|
287
|
+
for(int i = 0; i < n; i++) {
|
288
|
+
codes[i] = encode(c + i * dim);
|
289
|
+
}
|
290
|
+
}
|
291
|
+
}
|
292
|
+
|
293
|
+
|
294
|
+
void EnumeratedVectors::decode_multi(size_t n, const uint64_t * codes,
|
295
|
+
float *c) const
|
296
|
+
{
|
297
|
+
#pragma omp parallel if (n > 1000)
|
298
|
+
{
|
299
|
+
#pragma omp for
|
300
|
+
for(int i = 0; i < n; i++) {
|
301
|
+
decode(codes[i], c + i * dim);
|
302
|
+
}
|
303
|
+
}
|
304
|
+
}
|
305
|
+
|
306
|
+
void EnumeratedVectors::find_nn (
|
307
|
+
size_t nc, const uint64_t * codes,
|
308
|
+
size_t nq, const float *xq,
|
309
|
+
long *labels, float *distances)
|
310
|
+
{
|
311
|
+
for (long i = 0; i < nq; i++) {
|
312
|
+
distances[i] = -1e20;
|
313
|
+
labels[i] = -1;
|
314
|
+
}
|
315
|
+
|
316
|
+
float c[dim];
|
317
|
+
for(long i = 0; i < nc; i++) {
|
318
|
+
uint64_t code = codes[nc];
|
319
|
+
decode(code, c);
|
320
|
+
for (long j = 0; j < nq; j++) {
|
321
|
+
const float *x = xq + j * dim;
|
322
|
+
float dis = fvec_inner_product(x, c, dim);
|
323
|
+
if (dis > distances[j]) {
|
324
|
+
distances[j] = dis;
|
325
|
+
labels[j] = i;
|
326
|
+
}
|
327
|
+
}
|
328
|
+
}
|
329
|
+
|
330
|
+
}
|
331
|
+
|
332
|
+
|
333
|
+
/**********************************************************
|
334
|
+
* ZnSphereSearch
|
335
|
+
**********************************************************/
|
336
|
+
|
337
|
+
|
338
|
+
ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
|
339
|
+
voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
|
340
|
+
natom = voc.size() / dim;
|
341
|
+
}
|
342
|
+
|
343
|
+
float ZnSphereSearch::search(const float *x, float *c) const {
|
344
|
+
float tmp[dimS * 2];
|
345
|
+
int tmp_int[dimS];
|
346
|
+
return search(x, c, tmp, tmp_int);
|
347
|
+
}
|
348
|
+
|
349
|
+
float ZnSphereSearch::search(const float *x, float *c,
|
350
|
+
float *tmp, // size 2 *dim
|
351
|
+
int *tmp_int, // size dim
|
352
|
+
int *ibest_out
|
353
|
+
) const {
|
354
|
+
int dim = dimS;
|
355
|
+
assert (natom > 0);
|
356
|
+
int *o = tmp_int;
|
357
|
+
float *xabs = tmp;
|
358
|
+
float *xperm = tmp + dim;
|
359
|
+
|
360
|
+
// argsort
|
361
|
+
for (int i = 0; i < dim; i++) {
|
362
|
+
o[i] = i;
|
363
|
+
xabs[i] = fabsf(x[i]);
|
364
|
+
}
|
365
|
+
std::sort(o, o + dim, [xabs](int a, int b) {
|
366
|
+
return xabs[a] > xabs[b];
|
367
|
+
});
|
368
|
+
for (int i = 0; i < dim; i++) {
|
369
|
+
xperm[i] = xabs[o[i]];
|
370
|
+
}
|
371
|
+
// find best
|
372
|
+
int ibest = -1;
|
373
|
+
float dpbest = -100;
|
374
|
+
for (int i = 0; i < natom; i++) {
|
375
|
+
float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim);
|
376
|
+
if (dp > dpbest) {
|
377
|
+
dpbest = dp;
|
378
|
+
ibest = i;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
// revert sort
|
382
|
+
const float *cin = voc.data() + ibest * dim;
|
383
|
+
for (int i = 0; i < dim; i++) {
|
384
|
+
c[o[i]] = copysignf (cin[i], x[o[i]]);
|
385
|
+
}
|
386
|
+
if (ibest_out) {
|
387
|
+
*ibest_out = ibest;
|
388
|
+
}
|
389
|
+
return dpbest;
|
390
|
+
}
|
391
|
+
|
392
|
+
void ZnSphereSearch::search_multi(int n, const float *x,
|
393
|
+
float *c_out,
|
394
|
+
float *dp_out) {
|
395
|
+
#pragma omp parallel if (n > 1000)
|
396
|
+
{
|
397
|
+
#pragma omp for
|
398
|
+
for(int i = 0; i < n; i++) {
|
399
|
+
dp_out[i] = search(x + i * dimS, c_out + i * dimS);
|
400
|
+
}
|
401
|
+
}
|
402
|
+
}
|
403
|
+
|
404
|
+
|
405
|
+
/**********************************************************
|
406
|
+
* ZnSphereCodec
|
407
|
+
**********************************************************/
|
408
|
+
|
409
|
+
ZnSphereCodec::ZnSphereCodec(int dim, int r2):
|
410
|
+
ZnSphereSearch(dim, r2),
|
411
|
+
EnumeratedVectors(dim)
|
412
|
+
{
|
413
|
+
nv = 0;
|
414
|
+
for (int i = 0; i < natom; i++) {
|
415
|
+
Repeats repeats(dim, &voc[i * dim]);
|
416
|
+
CodeSegment cs(repeats);
|
417
|
+
cs.c0 = nv;
|
418
|
+
Repeat &br = repeats.repeats.back();
|
419
|
+
cs.signbits = br.val == 0 ? dim - br.n : dim;
|
420
|
+
code_segments.push_back(cs);
|
421
|
+
nv += repeats.count() << cs.signbits;
|
422
|
+
}
|
423
|
+
|
424
|
+
uint64_t nvx = nv;
|
425
|
+
code_size = 0;
|
426
|
+
while (nvx > 0) {
|
427
|
+
nvx >>= 8;
|
428
|
+
code_size++;
|
429
|
+
}
|
430
|
+
}
|
431
|
+
|
432
|
+
uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
|
433
|
+
float tmp[dim * 2];
|
434
|
+
int tmp_int[dim];
|
435
|
+
int ano; // atom number
|
436
|
+
float c[dim];
|
437
|
+
search(x, c, tmp, tmp_int, &ano);
|
438
|
+
uint64_t signs = 0;
|
439
|
+
float cabs[dim];
|
440
|
+
int nnz = 0;
|
441
|
+
for (int i = 0; i < dim; i++) {
|
442
|
+
cabs[i] = fabs(c[i]);
|
443
|
+
if (c[i] != 0) {
|
444
|
+
if (c[i] < 0) {
|
445
|
+
signs |= 1UL << nnz;
|
446
|
+
}
|
447
|
+
nnz ++;
|
448
|
+
}
|
449
|
+
}
|
450
|
+
const CodeSegment &cs = code_segments[ano];
|
451
|
+
assert(nnz == cs.signbits);
|
452
|
+
uint64_t code = cs.c0 + signs;
|
453
|
+
code += cs.encode(cabs) << cs.signbits;
|
454
|
+
return code;
|
455
|
+
}
|
456
|
+
|
457
|
+
uint64_t ZnSphereCodec::encode(const float *x) const
|
458
|
+
{
|
459
|
+
return search_and_encode(x);
|
460
|
+
}
|
461
|
+
|
462
|
+
|
463
|
+
void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
464
|
+
int i0 = 0, i1 = natom;
|
465
|
+
while (i0 + 1 < i1) {
|
466
|
+
int imed = (i0 + i1) / 2;
|
467
|
+
if (code_segments[imed].c0 <= code) i0 = imed;
|
468
|
+
else i1 = imed;
|
469
|
+
}
|
470
|
+
const CodeSegment &cs = code_segments[i0];
|
471
|
+
code -= cs.c0;
|
472
|
+
uint64_t signs = code;
|
473
|
+
code >>= cs.signbits;
|
474
|
+
cs.decode(code, c);
|
475
|
+
|
476
|
+
int nnz = 0;
|
477
|
+
for (int i = 0; i < dim; i++) {
|
478
|
+
if (c[i] != 0) {
|
479
|
+
if (signs & (1UL << nnz)) {
|
480
|
+
c[i] = -c[i];
|
481
|
+
}
|
482
|
+
nnz ++;
|
483
|
+
}
|
484
|
+
}
|
485
|
+
}
|
486
|
+
|
487
|
+
|
488
|
+
/**************************************************************
|
489
|
+
* ZnSphereCodecRec
|
490
|
+
**************************************************************/
|
491
|
+
|
492
|
+
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
|
493
|
+
{
|
494
|
+
return all_nv[ld * (r2 + 1) + r2a];
|
495
|
+
}
|
496
|
+
|
497
|
+
|
498
|
+
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
|
499
|
+
{
|
500
|
+
return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
|
501
|
+
}
|
502
|
+
|
503
|
+
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
|
504
|
+
{
|
505
|
+
all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
|
506
|
+
}
|
507
|
+
|
508
|
+
|
509
|
+
ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
510
|
+
EnumeratedVectors(dim), r2(r2)
|
511
|
+
{
|
512
|
+
log2_dim = 0;
|
513
|
+
while (dim > (1 << log2_dim)) {
|
514
|
+
log2_dim++;
|
515
|
+
}
|
516
|
+
assert(dim == (1 << log2_dim) ||
|
517
|
+
!"dimension must be a power of 2");
|
518
|
+
|
519
|
+
all_nv.resize((log2_dim + 1) * (r2 + 1));
|
520
|
+
all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
|
521
|
+
|
522
|
+
for (int r2a = 0; r2a <= r2; r2a++) {
|
523
|
+
int r = int(sqrt(r2a));
|
524
|
+
if (r * r == r2a) {
|
525
|
+
all_nv[r2a] = r == 0 ? 1 : 2;
|
526
|
+
} else {
|
527
|
+
all_nv[r2a] = 0;
|
528
|
+
}
|
529
|
+
}
|
530
|
+
|
531
|
+
for (int ld = 1; ld <= log2_dim; ld++) {
|
532
|
+
|
533
|
+
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
534
|
+
uint64_t nv = 0;
|
535
|
+
for (int r2a = 0; r2a <= r2sub; r2a++) {
|
536
|
+
int r2b = r2sub - r2a;
|
537
|
+
set_nv_cum(ld, r2sub, r2a, nv);
|
538
|
+
nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b);
|
539
|
+
}
|
540
|
+
all_nv[ld * (r2 + 1) + r2sub] = nv;
|
541
|
+
}
|
542
|
+
}
|
543
|
+
nv = get_nv(log2_dim, r2);
|
544
|
+
|
545
|
+
uint64_t nvx = nv;
|
546
|
+
code_size = 0;
|
547
|
+
while (nvx > 0) {
|
548
|
+
nvx >>= 8;
|
549
|
+
code_size++;
|
550
|
+
}
|
551
|
+
|
552
|
+
int cache_level = std::min(3, log2_dim - 1);
|
553
|
+
decode_cache_ld = 0;
|
554
|
+
assert(cache_level <= log2_dim);
|
555
|
+
decode_cache.resize((r2 + 1));
|
556
|
+
|
557
|
+
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
558
|
+
int ld = cache_level;
|
559
|
+
uint64_t nvi = get_nv(ld, r2sub);
|
560
|
+
std::vector<float> &cache = decode_cache[r2sub];
|
561
|
+
int dimsub = (1 << cache_level);
|
562
|
+
cache.resize (nvi * dimsub);
|
563
|
+
float c[dim];
|
564
|
+
uint64_t code0 = get_nv_cum(cache_level + 1, r2,
|
565
|
+
r2 - r2sub);
|
566
|
+
for (int i = 0; i < nvi; i++) {
|
567
|
+
decode(i + code0, c);
|
568
|
+
memcpy(&cache[i * dimsub], c + dim - dimsub,
|
569
|
+
dimsub * sizeof(*c));
|
570
|
+
}
|
571
|
+
}
|
572
|
+
decode_cache_ld = cache_level;
|
573
|
+
}
|
574
|
+
|
575
|
+
uint64_t ZnSphereCodecRec::encode(const float *c) const
|
576
|
+
{
|
577
|
+
return encode_centroid(c);
|
578
|
+
}
|
579
|
+
|
580
|
+
|
581
|
+
|
582
|
+
uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
583
|
+
{
|
584
|
+
uint64_t codes[dim];
|
585
|
+
int norm2s[dim];
|
586
|
+
for(int i = 0; i < dim; i++) {
|
587
|
+
if (c[i] == 0) {
|
588
|
+
codes[i] = 0;
|
589
|
+
norm2s[i] = 0;
|
590
|
+
} else {
|
591
|
+
int r2i = int(c[i] * c[i]);
|
592
|
+
norm2s[i] = r2i;
|
593
|
+
codes[i] = c[i] >= 0 ? 0 : 1;
|
594
|
+
}
|
595
|
+
}
|
596
|
+
int dim2 = dim / 2;
|
597
|
+
for(int ld = 1; ld <= log2_dim; ld++) {
|
598
|
+
for (int i = 0; i < dim2; i++) {
|
599
|
+
int r2a = norm2s[2 * i];
|
600
|
+
int r2b = norm2s[2 * i + 1];
|
601
|
+
|
602
|
+
uint64_t code_a = codes[2 * i];
|
603
|
+
uint64_t code_b = codes[2 * i + 1];
|
604
|
+
|
605
|
+
codes[i] =
|
606
|
+
get_nv_cum(ld, r2a + r2b, r2a) +
|
607
|
+
code_a * get_nv(ld - 1, r2b) +
|
608
|
+
code_b;
|
609
|
+
norm2s[i] = r2a + r2b;
|
610
|
+
}
|
611
|
+
dim2 /= 2;
|
612
|
+
}
|
613
|
+
return codes[0];
|
614
|
+
}
|
615
|
+
|
616
|
+
|
617
|
+
|
618
|
+
void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
619
|
+
{
|
620
|
+
uint64_t codes[dim];
|
621
|
+
int norm2s[dim];
|
622
|
+
codes[0] = code;
|
623
|
+
norm2s[0] = r2;
|
624
|
+
|
625
|
+
int dim2 = 1;
|
626
|
+
for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
627
|
+
for (int i = dim2 - 1; i >= 0; i--) {
|
628
|
+
int r2sub = norm2s[i];
|
629
|
+
int i0 = 0, i1 = r2sub + 1;
|
630
|
+
uint64_t codei = codes[i];
|
631
|
+
const uint64_t *cum =
|
632
|
+
&all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
|
633
|
+
while (i1 > i0 + 1) {
|
634
|
+
int imed = (i0 + i1) / 2;
|
635
|
+
if (cum[imed] <= codei)
|
636
|
+
i0 = imed;
|
637
|
+
else
|
638
|
+
i1 = imed;
|
639
|
+
}
|
640
|
+
int r2a = i0, r2b = r2sub - i0;
|
641
|
+
codei -= cum[r2a];
|
642
|
+
norm2s[2 * i] = r2a;
|
643
|
+
norm2s[2 * i + 1] = r2b;
|
644
|
+
|
645
|
+
uint64_t code_a = codei / get_nv(ld - 1, r2b);
|
646
|
+
uint64_t code_b = codei % get_nv(ld - 1, r2b);
|
647
|
+
|
648
|
+
codes[2 * i] = code_a;
|
649
|
+
codes[2 * i + 1] = code_b;
|
650
|
+
|
651
|
+
}
|
652
|
+
dim2 *= 2;
|
653
|
+
}
|
654
|
+
|
655
|
+
if (decode_cache_ld == 0) {
|
656
|
+
for(int i = 0; i < dim; i++) {
|
657
|
+
if (norm2s[i] == 0) {
|
658
|
+
c[i] = 0;
|
659
|
+
} else {
|
660
|
+
float r = sqrt(norm2s[i]);
|
661
|
+
assert(r * r == norm2s[i]);
|
662
|
+
c[i] = codes[i] == 0 ? r : -r;
|
663
|
+
}
|
664
|
+
}
|
665
|
+
} else {
|
666
|
+
int subdim = 1 << decode_cache_ld;
|
667
|
+
assert ((dim2 * subdim) == dim);
|
668
|
+
|
669
|
+
for(int i = 0; i < dim2; i++) {
|
670
|
+
|
671
|
+
const std::vector<float> & cache =
|
672
|
+
decode_cache[norm2s[i]];
|
673
|
+
assert(codes[i] < cache.size());
|
674
|
+
memcpy(c + i * subdim,
|
675
|
+
&cache[codes[i] * subdim],
|
676
|
+
sizeof(*c)* subdim);
|
677
|
+
}
|
678
|
+
}
|
679
|
+
}
|
680
|
+
|
681
|
+
// if not use_rec, instanciate an arbitrary harmless znc_rec
|
682
|
+
ZnSphereCodecAlt::ZnSphereCodecAlt (int dim, int r2):
|
683
|
+
ZnSphereCodec (dim, r2),
|
684
|
+
use_rec ((dim & (dim - 1)) == 0),
|
685
|
+
znc_rec (use_rec ? dim : 8,
|
686
|
+
use_rec ? r2 : 14)
|
687
|
+
{}
|
688
|
+
|
689
|
+
uint64_t ZnSphereCodecAlt::encode(const float *x) const
|
690
|
+
{
|
691
|
+
if (!use_rec) {
|
692
|
+
// it's ok if the vector is not normalized
|
693
|
+
return ZnSphereCodec::encode(x);
|
694
|
+
} else {
|
695
|
+
// find nearest centroid
|
696
|
+
std::vector<float> centroid(dim);
|
697
|
+
search (x, centroid.data());
|
698
|
+
return znc_rec.encode(centroid.data());
|
699
|
+
}
|
700
|
+
}
|
701
|
+
|
702
|
+
void ZnSphereCodecAlt::decode(uint64_t code, float *c) const
|
703
|
+
{
|
704
|
+
if (!use_rec) {
|
705
|
+
ZnSphereCodec::decode (code, c);
|
706
|
+
} else {
|
707
|
+
znc_rec.decode (code, c);
|
708
|
+
}
|
709
|
+
}
|
710
|
+
|
711
|
+
|
712
|
+
} // namespace faiss
|