faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <cstring>
|
|
11
|
+
#include <cassert>
|
|
12
|
+
|
|
13
|
+
#include <faiss/impl/io.h>
|
|
14
|
+
#include <faiss/impl/FaissAssert.h>
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
namespace faiss {
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
/***********************************************************************
|
|
21
|
+
* IO functions
|
|
22
|
+
***********************************************************************/
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
int IOReader::fileno ()
|
|
26
|
+
{
|
|
27
|
+
FAISS_THROW_MSG ("IOReader does not support memory mapping");
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
int IOWriter::fileno ()
|
|
31
|
+
{
|
|
32
|
+
FAISS_THROW_MSG ("IOWriter does not support memory mapping");
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/***********************************************************************
|
|
36
|
+
* IO Vector
|
|
37
|
+
***********************************************************************/
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
size_t VectorIOWriter::operator()(
|
|
42
|
+
const void *ptr, size_t size, size_t nitems)
|
|
43
|
+
{
|
|
44
|
+
size_t bytes = size * nitems;
|
|
45
|
+
if (bytes > 0) {
|
|
46
|
+
size_t o = data.size();
|
|
47
|
+
data.resize(o + bytes);
|
|
48
|
+
memcpy (&data[o], ptr, size * nitems);
|
|
49
|
+
}
|
|
50
|
+
return nitems;
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
size_t VectorIOReader::operator()(
|
|
54
|
+
void *ptr, size_t size, size_t nitems)
|
|
55
|
+
{
|
|
56
|
+
if (rp >= data.size()) return 0;
|
|
57
|
+
size_t nremain = (data.size() - rp) / size;
|
|
58
|
+
if (nremain < nitems) nitems = nremain;
|
|
59
|
+
if (size * nitems > 0) {
|
|
60
|
+
memcpy (ptr, &data[rp], size * nitems);
|
|
61
|
+
rp += size * nitems;
|
|
62
|
+
}
|
|
63
|
+
return nitems;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
/***********************************************************************
|
|
70
|
+
* IO File
|
|
71
|
+
***********************************************************************/
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
FileIOReader::FileIOReader(FILE *rf): f(rf) {}
|
|
76
|
+
|
|
77
|
+
FileIOReader::FileIOReader(const char * fname)
|
|
78
|
+
{
|
|
79
|
+
name = fname;
|
|
80
|
+
f = fopen(fname, "rb");
|
|
81
|
+
FAISS_THROW_IF_NOT_FMT (f, "could not open %s for reading: %s",
|
|
82
|
+
fname, strerror(errno));
|
|
83
|
+
need_close = true;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
FileIOReader::~FileIOReader() {
|
|
87
|
+
if (need_close) {
|
|
88
|
+
int ret = fclose(f);
|
|
89
|
+
if (ret != 0) {// we cannot raise and exception in the destructor
|
|
90
|
+
fprintf(stderr, "file %s close error: %s",
|
|
91
|
+
name.c_str(), strerror(errno));
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
size_t FileIOReader::operator()(void *ptr, size_t size, size_t nitems) {
|
|
97
|
+
return fread(ptr, size, nitems, f);
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
int FileIOReader::fileno() {
|
|
101
|
+
return ::fileno (f);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
FileIOWriter::FileIOWriter(FILE *wf): f(wf) {}
|
|
106
|
+
|
|
107
|
+
FileIOWriter::FileIOWriter(const char * fname)
|
|
108
|
+
{
|
|
109
|
+
name = fname;
|
|
110
|
+
f = fopen(fname, "wb");
|
|
111
|
+
FAISS_THROW_IF_NOT_FMT (f, "could not open %s for writing: %s",
|
|
112
|
+
fname, strerror(errno));
|
|
113
|
+
need_close = true;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
FileIOWriter::~FileIOWriter() {
|
|
117
|
+
if (need_close) {
|
|
118
|
+
int ret = fclose(f);
|
|
119
|
+
if (ret != 0) {
|
|
120
|
+
// we cannot raise and exception in the destructor
|
|
121
|
+
fprintf(stderr, "file %s close error: %s",
|
|
122
|
+
name.c_str(), strerror(errno));
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
size_t FileIOWriter::operator()(const void *ptr, size_t size, size_t nitems) {
|
|
128
|
+
return fwrite(ptr, size, nitems, f);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
int FileIOWriter::fileno() {
|
|
132
|
+
return ::fileno (f);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
uint32_t fourcc (const char sx[4]) {
|
|
136
|
+
assert(4 == strlen(sx));
|
|
137
|
+
const unsigned char *x = (unsigned char*)sx;
|
|
138
|
+
return x[0] | x[1] << 8 | x[2] << 16 | x[3] << 24;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
} // namespace faiss
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
/***********************************************************
|
|
11
|
+
* Abstract I/O objects
|
|
12
|
+
***********************************************************/
|
|
13
|
+
|
|
14
|
+
#pragma once
|
|
15
|
+
|
|
16
|
+
#include <string>
|
|
17
|
+
#include <cstdio>
|
|
18
|
+
#include <vector>
|
|
19
|
+
|
|
20
|
+
#include <faiss/Index.h>
|
|
21
|
+
|
|
22
|
+
namespace faiss {
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
struct IOReader {
|
|
26
|
+
// name that can be used in error messages
|
|
27
|
+
std::string name;
|
|
28
|
+
|
|
29
|
+
// fread
|
|
30
|
+
virtual size_t operator()(
|
|
31
|
+
void *ptr, size_t size, size_t nitems) = 0;
|
|
32
|
+
|
|
33
|
+
// return a file number that can be memory-mapped
|
|
34
|
+
virtual int fileno ();
|
|
35
|
+
|
|
36
|
+
virtual ~IOReader() {}
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
struct IOWriter {
|
|
40
|
+
// name that can be used in error messages
|
|
41
|
+
std::string name;
|
|
42
|
+
|
|
43
|
+
// fwrite
|
|
44
|
+
virtual size_t operator()(
|
|
45
|
+
const void *ptr, size_t size, size_t nitems) = 0;
|
|
46
|
+
|
|
47
|
+
// return a file number that can be memory-mapped
|
|
48
|
+
virtual int fileno ();
|
|
49
|
+
|
|
50
|
+
virtual ~IOWriter() {}
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
struct VectorIOReader:IOReader {
|
|
55
|
+
std::vector<uint8_t> data;
|
|
56
|
+
size_t rp = 0;
|
|
57
|
+
size_t operator()(void *ptr, size_t size, size_t nitems) override;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
struct VectorIOWriter:IOWriter {
|
|
61
|
+
std::vector<uint8_t> data;
|
|
62
|
+
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
struct FileIOReader: IOReader {
|
|
66
|
+
FILE *f = nullptr;
|
|
67
|
+
bool need_close = false;
|
|
68
|
+
|
|
69
|
+
FileIOReader(FILE *rf);
|
|
70
|
+
|
|
71
|
+
FileIOReader(const char * fname);
|
|
72
|
+
|
|
73
|
+
~FileIOReader() override;
|
|
74
|
+
|
|
75
|
+
size_t operator()(void *ptr, size_t size, size_t nitems) override;
|
|
76
|
+
|
|
77
|
+
int fileno() override;
|
|
78
|
+
};
|
|
79
|
+
|
|
80
|
+
struct FileIOWriter: IOWriter {
|
|
81
|
+
FILE *f = nullptr;
|
|
82
|
+
bool need_close = false;
|
|
83
|
+
|
|
84
|
+
FileIOWriter(FILE *wf);
|
|
85
|
+
|
|
86
|
+
FileIOWriter(const char * fname);
|
|
87
|
+
|
|
88
|
+
~FileIOWriter() override;
|
|
89
|
+
|
|
90
|
+
size_t operator()(const void *ptr, size_t size, size_t nitems) override;
|
|
91
|
+
|
|
92
|
+
int fileno() override;
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
/// cast a 4-character string to a uint32_t that can be written and read easily
|
|
96
|
+
uint32_t fourcc (const char sx[4]);
|
|
97
|
+
|
|
98
|
+
} // namespace faiss
|
|
@@ -0,0 +1,712 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/lattice_Zn.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdlib>
|
|
13
|
+
#include <cmath>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <cassert>
|
|
16
|
+
|
|
17
|
+
#include <queue>
|
|
18
|
+
#include <unordered_set>
|
|
19
|
+
#include <unordered_map>
|
|
20
|
+
#include <algorithm>
|
|
21
|
+
|
|
22
|
+
#include <faiss/utils/distances.h>
|
|
23
|
+
|
|
24
|
+
namespace faiss {
|
|
25
|
+
|
|
26
|
+
/********************************************
|
|
27
|
+
* small utility functions
|
|
28
|
+
********************************************/
|
|
29
|
+
|
|
30
|
+
namespace {
|
|
31
|
+
|
|
32
|
+
inline float sqr(float x) {
|
|
33
|
+
return x * x;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
typedef std::vector<float> point_list_t;
|
|
38
|
+
|
|
39
|
+
struct Comb {
|
|
40
|
+
std::vector<uint64_t> tab; // Pascal's triangle
|
|
41
|
+
int nmax;
|
|
42
|
+
|
|
43
|
+
explicit Comb(int nmax): nmax(nmax) {
|
|
44
|
+
tab.resize(nmax * nmax, 0);
|
|
45
|
+
tab[0] = 1;
|
|
46
|
+
for(int i = 1; i < nmax; i++) {
|
|
47
|
+
tab[i * nmax] = 1;
|
|
48
|
+
for(int j = 1; j <= i; j++) {
|
|
49
|
+
tab[i * nmax + j] =
|
|
50
|
+
tab[(i - 1) * nmax + j] +
|
|
51
|
+
tab[(i - 1) * nmax + (j - 1)];
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
uint64_t operator()(int n, int p) const {
|
|
58
|
+
assert (n < nmax && p < nmax);
|
|
59
|
+
if (p > n) return 0;
|
|
60
|
+
return tab[n * nmax + p];
|
|
61
|
+
}
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
Comb comb(100);
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
// compute combinations of n integer values <= v that sum up to total (squared)
|
|
69
|
+
point_list_t sum_of_sq (float total, int v, int n, float add = 0) {
|
|
70
|
+
if (total < 0) {
|
|
71
|
+
return point_list_t();
|
|
72
|
+
} else if (n == 1) {
|
|
73
|
+
while (sqr(v + add) > total) v--;
|
|
74
|
+
if (sqr(v + add) == total) {
|
|
75
|
+
return point_list_t(1, v + add);
|
|
76
|
+
} else {
|
|
77
|
+
return point_list_t();
|
|
78
|
+
}
|
|
79
|
+
} else {
|
|
80
|
+
point_list_t res;
|
|
81
|
+
while (v >= 0) {
|
|
82
|
+
point_list_t sub_points =
|
|
83
|
+
sum_of_sq (total - sqr(v + add), v, n - 1, add);
|
|
84
|
+
for (size_t i = 0; i < sub_points.size(); i += n - 1) {
|
|
85
|
+
res.push_back (v + add);
|
|
86
|
+
for (int j = 0; j < n - 1; j++) {
|
|
87
|
+
res.push_back(sub_points[i + j]);
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
v--;
|
|
91
|
+
}
|
|
92
|
+
return res;
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
int decode_comb_1 (uint64_t *n, int k1, int r) {
|
|
97
|
+
while (comb(r, k1) > *n) {
|
|
98
|
+
r--;
|
|
99
|
+
}
|
|
100
|
+
*n -= comb(r, k1);
|
|
101
|
+
return r;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// optimized version for < 64 bits
|
|
105
|
+
long repeats_encode_64 (
|
|
106
|
+
const std::vector<Repeat> & repeats,
|
|
107
|
+
int dim, const float *c)
|
|
108
|
+
{
|
|
109
|
+
uint64_t coded = 0;
|
|
110
|
+
int nfree = dim;
|
|
111
|
+
uint64_t code = 0, shift = 1;
|
|
112
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
|
113
|
+
int rank = 0, occ = 0;
|
|
114
|
+
uint64_t code_comb = 0;
|
|
115
|
+
uint64_t tosee = ~coded;
|
|
116
|
+
for(;;) {
|
|
117
|
+
// directly jump to next available slot.
|
|
118
|
+
int i = __builtin_ctzl(tosee);
|
|
119
|
+
tosee &= ~(1UL << i) ;
|
|
120
|
+
if (c[i] == r->val) {
|
|
121
|
+
code_comb += comb(rank, occ + 1);
|
|
122
|
+
occ++;
|
|
123
|
+
coded |= 1UL << i;
|
|
124
|
+
if (occ == r->n) break;
|
|
125
|
+
}
|
|
126
|
+
rank++;
|
|
127
|
+
}
|
|
128
|
+
uint64_t max_comb = comb(nfree, r->n);
|
|
129
|
+
code += shift * code_comb;
|
|
130
|
+
shift *= max_comb;
|
|
131
|
+
nfree -= r->n;
|
|
132
|
+
}
|
|
133
|
+
return code;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
void repeats_decode_64(
|
|
138
|
+
const std::vector<Repeat> & repeats,
|
|
139
|
+
int dim, uint64_t code, float *c)
|
|
140
|
+
{
|
|
141
|
+
uint64_t decoded = 0;
|
|
142
|
+
int nfree = dim;
|
|
143
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
|
144
|
+
uint64_t max_comb = comb(nfree, r->n);
|
|
145
|
+
uint64_t code_comb = code % max_comb;
|
|
146
|
+
code /= max_comb;
|
|
147
|
+
|
|
148
|
+
int occ = 0;
|
|
149
|
+
int rank = nfree;
|
|
150
|
+
int next_rank = decode_comb_1 (&code_comb, r->n, rank);
|
|
151
|
+
uint64_t tosee = ((1UL << dim) - 1) ^ decoded;
|
|
152
|
+
for(;;) {
|
|
153
|
+
int i = 63 - __builtin_clzl(tosee);
|
|
154
|
+
tosee &= ~(1UL << i);
|
|
155
|
+
rank--;
|
|
156
|
+
if (rank == next_rank) {
|
|
157
|
+
decoded |= 1UL << i;
|
|
158
|
+
c[i] = r->val;
|
|
159
|
+
occ++;
|
|
160
|
+
if (occ == r->n) break;
|
|
161
|
+
next_rank = decode_comb_1 (
|
|
162
|
+
&code_comb, r->n - occ, next_rank);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
nfree -= r->n;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
} // anonymous namespace
|
|
173
|
+
|
|
174
|
+
Repeats::Repeats (int dim, const float *c): dim(dim)
|
|
175
|
+
{
|
|
176
|
+
for(int i = 0; i < dim; i++) {
|
|
177
|
+
int j = 0;
|
|
178
|
+
for(;;) {
|
|
179
|
+
if (j == repeats.size()) {
|
|
180
|
+
repeats.push_back(Repeat{c[i], 1});
|
|
181
|
+
break;
|
|
182
|
+
}
|
|
183
|
+
if (repeats[j].val == c[i]) {
|
|
184
|
+
repeats[j].n++;
|
|
185
|
+
break;
|
|
186
|
+
}
|
|
187
|
+
j++;
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
long Repeats::count () const
|
|
194
|
+
{
|
|
195
|
+
long accu = 1;
|
|
196
|
+
int remain = dim;
|
|
197
|
+
for (int i = 0; i < repeats.size(); i++) {
|
|
198
|
+
accu *= comb(remain, repeats[i].n);
|
|
199
|
+
remain -= repeats[i].n;
|
|
200
|
+
}
|
|
201
|
+
return accu;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
// version with a bool vector that works for > 64 dim
|
|
207
|
+
long Repeats::encode(const float *c) const
|
|
208
|
+
{
|
|
209
|
+
if (dim < 64) {
|
|
210
|
+
return repeats_encode_64 (repeats, dim, c);
|
|
211
|
+
}
|
|
212
|
+
std::vector<bool> coded(dim, false);
|
|
213
|
+
int nfree = dim;
|
|
214
|
+
uint64_t code = 0, shift = 1;
|
|
215
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
|
216
|
+
int rank = 0, occ = 0;
|
|
217
|
+
uint64_t code_comb = 0;
|
|
218
|
+
for (int i = 0; i < dim; i++) {
|
|
219
|
+
if (!coded[i]) {
|
|
220
|
+
if (c[i] == r->val) {
|
|
221
|
+
code_comb += comb(rank, occ + 1);
|
|
222
|
+
occ++;
|
|
223
|
+
coded[i] = true;
|
|
224
|
+
if (occ == r->n) break;
|
|
225
|
+
}
|
|
226
|
+
rank++;
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
uint64_t max_comb = comb(nfree, r->n);
|
|
230
|
+
code += shift * code_comb;
|
|
231
|
+
shift *= max_comb;
|
|
232
|
+
nfree -= r->n;
|
|
233
|
+
}
|
|
234
|
+
return code;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
void Repeats::decode(uint64_t code, float *c) const
|
|
240
|
+
{
|
|
241
|
+
if (dim < 64) {
|
|
242
|
+
repeats_decode_64 (repeats, dim, code, c);
|
|
243
|
+
return;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
std::vector<bool> decoded(dim, false);
|
|
247
|
+
int nfree = dim;
|
|
248
|
+
for (auto r = repeats.begin(); r != repeats.end(); ++r) {
|
|
249
|
+
uint64_t max_comb = comb(nfree, r->n);
|
|
250
|
+
uint64_t code_comb = code % max_comb;
|
|
251
|
+
code /= max_comb;
|
|
252
|
+
|
|
253
|
+
int occ = 0;
|
|
254
|
+
int rank = nfree;
|
|
255
|
+
int next_rank = decode_comb_1 (&code_comb, r->n, rank);
|
|
256
|
+
for (int i = dim - 1; i >= 0; i--) {
|
|
257
|
+
if (!decoded[i]) {
|
|
258
|
+
rank--;
|
|
259
|
+
if (rank == next_rank) {
|
|
260
|
+
decoded[i] = true;
|
|
261
|
+
c[i] = r->val;
|
|
262
|
+
occ++;
|
|
263
|
+
if (occ == r->n) break;
|
|
264
|
+
next_rank = decode_comb_1 (
|
|
265
|
+
&code_comb, r->n - occ, next_rank);
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
nfree -= r->n;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
/********************************************
|
|
277
|
+
* EnumeratedVectors functions
|
|
278
|
+
********************************************/
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
void EnumeratedVectors::encode_multi(size_t n, const float *c,
|
|
282
|
+
uint64_t * codes) const
|
|
283
|
+
{
|
|
284
|
+
#pragma omp parallel if (n > 1000)
|
|
285
|
+
{
|
|
286
|
+
#pragma omp for
|
|
287
|
+
for(int i = 0; i < n; i++) {
|
|
288
|
+
codes[i] = encode(c + i * dim);
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
void EnumeratedVectors::decode_multi(size_t n, const uint64_t * codes,
|
|
295
|
+
float *c) const
|
|
296
|
+
{
|
|
297
|
+
#pragma omp parallel if (n > 1000)
|
|
298
|
+
{
|
|
299
|
+
#pragma omp for
|
|
300
|
+
for(int i = 0; i < n; i++) {
|
|
301
|
+
decode(codes[i], c + i * dim);
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
void EnumeratedVectors::find_nn (
|
|
307
|
+
size_t nc, const uint64_t * codes,
|
|
308
|
+
size_t nq, const float *xq,
|
|
309
|
+
long *labels, float *distances)
|
|
310
|
+
{
|
|
311
|
+
for (long i = 0; i < nq; i++) {
|
|
312
|
+
distances[i] = -1e20;
|
|
313
|
+
labels[i] = -1;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
float c[dim];
|
|
317
|
+
for(long i = 0; i < nc; i++) {
|
|
318
|
+
uint64_t code = codes[nc];
|
|
319
|
+
decode(code, c);
|
|
320
|
+
for (long j = 0; j < nq; j++) {
|
|
321
|
+
const float *x = xq + j * dim;
|
|
322
|
+
float dis = fvec_inner_product(x, c, dim);
|
|
323
|
+
if (dis > distances[j]) {
|
|
324
|
+
distances[j] = dis;
|
|
325
|
+
labels[j] = i;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
/**********************************************************
|
|
334
|
+
* ZnSphereSearch
|
|
335
|
+
**********************************************************/
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
ZnSphereSearch::ZnSphereSearch(int dim, int r2): dimS(dim), r2(r2) {
|
|
339
|
+
voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
|
|
340
|
+
natom = voc.size() / dim;
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
float ZnSphereSearch::search(const float *x, float *c) const {
|
|
344
|
+
float tmp[dimS * 2];
|
|
345
|
+
int tmp_int[dimS];
|
|
346
|
+
return search(x, c, tmp, tmp_int);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
float ZnSphereSearch::search(const float *x, float *c,
|
|
350
|
+
float *tmp, // size 2 *dim
|
|
351
|
+
int *tmp_int, // size dim
|
|
352
|
+
int *ibest_out
|
|
353
|
+
) const {
|
|
354
|
+
int dim = dimS;
|
|
355
|
+
assert (natom > 0);
|
|
356
|
+
int *o = tmp_int;
|
|
357
|
+
float *xabs = tmp;
|
|
358
|
+
float *xperm = tmp + dim;
|
|
359
|
+
|
|
360
|
+
// argsort
|
|
361
|
+
for (int i = 0; i < dim; i++) {
|
|
362
|
+
o[i] = i;
|
|
363
|
+
xabs[i] = fabsf(x[i]);
|
|
364
|
+
}
|
|
365
|
+
std::sort(o, o + dim, [xabs](int a, int b) {
|
|
366
|
+
return xabs[a] > xabs[b];
|
|
367
|
+
});
|
|
368
|
+
for (int i = 0; i < dim; i++) {
|
|
369
|
+
xperm[i] = xabs[o[i]];
|
|
370
|
+
}
|
|
371
|
+
// find best
|
|
372
|
+
int ibest = -1;
|
|
373
|
+
float dpbest = -100;
|
|
374
|
+
for (int i = 0; i < natom; i++) {
|
|
375
|
+
float dp = fvec_inner_product (voc.data() + i * dim, xperm, dim);
|
|
376
|
+
if (dp > dpbest) {
|
|
377
|
+
dpbest = dp;
|
|
378
|
+
ibest = i;
|
|
379
|
+
}
|
|
380
|
+
}
|
|
381
|
+
// revert sort
|
|
382
|
+
const float *cin = voc.data() + ibest * dim;
|
|
383
|
+
for (int i = 0; i < dim; i++) {
|
|
384
|
+
c[o[i]] = copysignf (cin[i], x[o[i]]);
|
|
385
|
+
}
|
|
386
|
+
if (ibest_out) {
|
|
387
|
+
*ibest_out = ibest;
|
|
388
|
+
}
|
|
389
|
+
return dpbest;
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
void ZnSphereSearch::search_multi(int n, const float *x,
|
|
393
|
+
float *c_out,
|
|
394
|
+
float *dp_out) {
|
|
395
|
+
#pragma omp parallel if (n > 1000)
|
|
396
|
+
{
|
|
397
|
+
#pragma omp for
|
|
398
|
+
for(int i = 0; i < n; i++) {
|
|
399
|
+
dp_out[i] = search(x + i * dimS, c_out + i * dimS);
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
/**********************************************************
|
|
406
|
+
* ZnSphereCodec
|
|
407
|
+
**********************************************************/
|
|
408
|
+
|
|
409
|
+
ZnSphereCodec::ZnSphereCodec(int dim, int r2):
|
|
410
|
+
ZnSphereSearch(dim, r2),
|
|
411
|
+
EnumeratedVectors(dim)
|
|
412
|
+
{
|
|
413
|
+
nv = 0;
|
|
414
|
+
for (int i = 0; i < natom; i++) {
|
|
415
|
+
Repeats repeats(dim, &voc[i * dim]);
|
|
416
|
+
CodeSegment cs(repeats);
|
|
417
|
+
cs.c0 = nv;
|
|
418
|
+
Repeat &br = repeats.repeats.back();
|
|
419
|
+
cs.signbits = br.val == 0 ? dim - br.n : dim;
|
|
420
|
+
code_segments.push_back(cs);
|
|
421
|
+
nv += repeats.count() << cs.signbits;
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
uint64_t nvx = nv;
|
|
425
|
+
code_size = 0;
|
|
426
|
+
while (nvx > 0) {
|
|
427
|
+
nvx >>= 8;
|
|
428
|
+
code_size++;
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
uint64_t ZnSphereCodec::search_and_encode(const float *x) const {
|
|
433
|
+
float tmp[dim * 2];
|
|
434
|
+
int tmp_int[dim];
|
|
435
|
+
int ano; // atom number
|
|
436
|
+
float c[dim];
|
|
437
|
+
search(x, c, tmp, tmp_int, &ano);
|
|
438
|
+
uint64_t signs = 0;
|
|
439
|
+
float cabs[dim];
|
|
440
|
+
int nnz = 0;
|
|
441
|
+
for (int i = 0; i < dim; i++) {
|
|
442
|
+
cabs[i] = fabs(c[i]);
|
|
443
|
+
if (c[i] != 0) {
|
|
444
|
+
if (c[i] < 0) {
|
|
445
|
+
signs |= 1UL << nnz;
|
|
446
|
+
}
|
|
447
|
+
nnz ++;
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
const CodeSegment &cs = code_segments[ano];
|
|
451
|
+
assert(nnz == cs.signbits);
|
|
452
|
+
uint64_t code = cs.c0 + signs;
|
|
453
|
+
code += cs.encode(cabs) << cs.signbits;
|
|
454
|
+
return code;
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
uint64_t ZnSphereCodec::encode(const float *x) const
|
|
458
|
+
{
|
|
459
|
+
return search_and_encode(x);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
void ZnSphereCodec::decode(uint64_t code, float *c) const {
|
|
464
|
+
int i0 = 0, i1 = natom;
|
|
465
|
+
while (i0 + 1 < i1) {
|
|
466
|
+
int imed = (i0 + i1) / 2;
|
|
467
|
+
if (code_segments[imed].c0 <= code) i0 = imed;
|
|
468
|
+
else i1 = imed;
|
|
469
|
+
}
|
|
470
|
+
const CodeSegment &cs = code_segments[i0];
|
|
471
|
+
code -= cs.c0;
|
|
472
|
+
uint64_t signs = code;
|
|
473
|
+
code >>= cs.signbits;
|
|
474
|
+
cs.decode(code, c);
|
|
475
|
+
|
|
476
|
+
int nnz = 0;
|
|
477
|
+
for (int i = 0; i < dim; i++) {
|
|
478
|
+
if (c[i] != 0) {
|
|
479
|
+
if (signs & (1UL << nnz)) {
|
|
480
|
+
c[i] = -c[i];
|
|
481
|
+
}
|
|
482
|
+
nnz ++;
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
/**************************************************************
|
|
489
|
+
* ZnSphereCodecRec
|
|
490
|
+
**************************************************************/
|
|
491
|
+
|
|
492
|
+
uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const
|
|
493
|
+
{
|
|
494
|
+
return all_nv[ld * (r2 + 1) + r2a];
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const
|
|
499
|
+
{
|
|
500
|
+
return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum)
|
|
504
|
+
{
|
|
505
|
+
all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2):
|
|
510
|
+
EnumeratedVectors(dim), r2(r2)
|
|
511
|
+
{
|
|
512
|
+
log2_dim = 0;
|
|
513
|
+
while (dim > (1 << log2_dim)) {
|
|
514
|
+
log2_dim++;
|
|
515
|
+
}
|
|
516
|
+
assert(dim == (1 << log2_dim) ||
|
|
517
|
+
!"dimension must be a power of 2");
|
|
518
|
+
|
|
519
|
+
all_nv.resize((log2_dim + 1) * (r2 + 1));
|
|
520
|
+
all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
|
|
521
|
+
|
|
522
|
+
for (int r2a = 0; r2a <= r2; r2a++) {
|
|
523
|
+
int r = int(sqrt(r2a));
|
|
524
|
+
if (r * r == r2a) {
|
|
525
|
+
all_nv[r2a] = r == 0 ? 1 : 2;
|
|
526
|
+
} else {
|
|
527
|
+
all_nv[r2a] = 0;
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
for (int ld = 1; ld <= log2_dim; ld++) {
|
|
532
|
+
|
|
533
|
+
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
|
534
|
+
uint64_t nv = 0;
|
|
535
|
+
for (int r2a = 0; r2a <= r2sub; r2a++) {
|
|
536
|
+
int r2b = r2sub - r2a;
|
|
537
|
+
set_nv_cum(ld, r2sub, r2a, nv);
|
|
538
|
+
nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b);
|
|
539
|
+
}
|
|
540
|
+
all_nv[ld * (r2 + 1) + r2sub] = nv;
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
nv = get_nv(log2_dim, r2);
|
|
544
|
+
|
|
545
|
+
uint64_t nvx = nv;
|
|
546
|
+
code_size = 0;
|
|
547
|
+
while (nvx > 0) {
|
|
548
|
+
nvx >>= 8;
|
|
549
|
+
code_size++;
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
int cache_level = std::min(3, log2_dim - 1);
|
|
553
|
+
decode_cache_ld = 0;
|
|
554
|
+
assert(cache_level <= log2_dim);
|
|
555
|
+
decode_cache.resize((r2 + 1));
|
|
556
|
+
|
|
557
|
+
for (int r2sub = 0; r2sub <= r2; r2sub++) {
|
|
558
|
+
int ld = cache_level;
|
|
559
|
+
uint64_t nvi = get_nv(ld, r2sub);
|
|
560
|
+
std::vector<float> &cache = decode_cache[r2sub];
|
|
561
|
+
int dimsub = (1 << cache_level);
|
|
562
|
+
cache.resize (nvi * dimsub);
|
|
563
|
+
float c[dim];
|
|
564
|
+
uint64_t code0 = get_nv_cum(cache_level + 1, r2,
|
|
565
|
+
r2 - r2sub);
|
|
566
|
+
for (int i = 0; i < nvi; i++) {
|
|
567
|
+
decode(i + code0, c);
|
|
568
|
+
memcpy(&cache[i * dimsub], c + dim - dimsub,
|
|
569
|
+
dimsub * sizeof(*c));
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
decode_cache_ld = cache_level;
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
uint64_t ZnSphereCodecRec::encode(const float *c) const
|
|
576
|
+
{
|
|
577
|
+
return encode_centroid(c);
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
uint64_t ZnSphereCodecRec::encode_centroid(const float *c) const
|
|
583
|
+
{
|
|
584
|
+
uint64_t codes[dim];
|
|
585
|
+
int norm2s[dim];
|
|
586
|
+
for(int i = 0; i < dim; i++) {
|
|
587
|
+
if (c[i] == 0) {
|
|
588
|
+
codes[i] = 0;
|
|
589
|
+
norm2s[i] = 0;
|
|
590
|
+
} else {
|
|
591
|
+
int r2i = int(c[i] * c[i]);
|
|
592
|
+
norm2s[i] = r2i;
|
|
593
|
+
codes[i] = c[i] >= 0 ? 0 : 1;
|
|
594
|
+
}
|
|
595
|
+
}
|
|
596
|
+
int dim2 = dim / 2;
|
|
597
|
+
for(int ld = 1; ld <= log2_dim; ld++) {
|
|
598
|
+
for (int i = 0; i < dim2; i++) {
|
|
599
|
+
int r2a = norm2s[2 * i];
|
|
600
|
+
int r2b = norm2s[2 * i + 1];
|
|
601
|
+
|
|
602
|
+
uint64_t code_a = codes[2 * i];
|
|
603
|
+
uint64_t code_b = codes[2 * i + 1];
|
|
604
|
+
|
|
605
|
+
codes[i] =
|
|
606
|
+
get_nv_cum(ld, r2a + r2b, r2a) +
|
|
607
|
+
code_a * get_nv(ld - 1, r2b) +
|
|
608
|
+
code_b;
|
|
609
|
+
norm2s[i] = r2a + r2b;
|
|
610
|
+
}
|
|
611
|
+
dim2 /= 2;
|
|
612
|
+
}
|
|
613
|
+
return codes[0];
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
void ZnSphereCodecRec::decode(uint64_t code, float *c) const
|
|
619
|
+
{
|
|
620
|
+
uint64_t codes[dim];
|
|
621
|
+
int norm2s[dim];
|
|
622
|
+
codes[0] = code;
|
|
623
|
+
norm2s[0] = r2;
|
|
624
|
+
|
|
625
|
+
int dim2 = 1;
|
|
626
|
+
for(int ld = log2_dim; ld > decode_cache_ld; ld--) {
|
|
627
|
+
for (int i = dim2 - 1; i >= 0; i--) {
|
|
628
|
+
int r2sub = norm2s[i];
|
|
629
|
+
int i0 = 0, i1 = r2sub + 1;
|
|
630
|
+
uint64_t codei = codes[i];
|
|
631
|
+
const uint64_t *cum =
|
|
632
|
+
&all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
|
|
633
|
+
while (i1 > i0 + 1) {
|
|
634
|
+
int imed = (i0 + i1) / 2;
|
|
635
|
+
if (cum[imed] <= codei)
|
|
636
|
+
i0 = imed;
|
|
637
|
+
else
|
|
638
|
+
i1 = imed;
|
|
639
|
+
}
|
|
640
|
+
int r2a = i0, r2b = r2sub - i0;
|
|
641
|
+
codei -= cum[r2a];
|
|
642
|
+
norm2s[2 * i] = r2a;
|
|
643
|
+
norm2s[2 * i + 1] = r2b;
|
|
644
|
+
|
|
645
|
+
uint64_t code_a = codei / get_nv(ld - 1, r2b);
|
|
646
|
+
uint64_t code_b = codei % get_nv(ld - 1, r2b);
|
|
647
|
+
|
|
648
|
+
codes[2 * i] = code_a;
|
|
649
|
+
codes[2 * i + 1] = code_b;
|
|
650
|
+
|
|
651
|
+
}
|
|
652
|
+
dim2 *= 2;
|
|
653
|
+
}
|
|
654
|
+
|
|
655
|
+
if (decode_cache_ld == 0) {
|
|
656
|
+
for(int i = 0; i < dim; i++) {
|
|
657
|
+
if (norm2s[i] == 0) {
|
|
658
|
+
c[i] = 0;
|
|
659
|
+
} else {
|
|
660
|
+
float r = sqrt(norm2s[i]);
|
|
661
|
+
assert(r * r == norm2s[i]);
|
|
662
|
+
c[i] = codes[i] == 0 ? r : -r;
|
|
663
|
+
}
|
|
664
|
+
}
|
|
665
|
+
} else {
|
|
666
|
+
int subdim = 1 << decode_cache_ld;
|
|
667
|
+
assert ((dim2 * subdim) == dim);
|
|
668
|
+
|
|
669
|
+
for(int i = 0; i < dim2; i++) {
|
|
670
|
+
|
|
671
|
+
const std::vector<float> & cache =
|
|
672
|
+
decode_cache[norm2s[i]];
|
|
673
|
+
assert(codes[i] < cache.size());
|
|
674
|
+
memcpy(c + i * subdim,
|
|
675
|
+
&cache[codes[i] * subdim],
|
|
676
|
+
sizeof(*c)* subdim);
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
// if not use_rec, instanciate an arbitrary harmless znc_rec
|
|
682
|
+
ZnSphereCodecAlt::ZnSphereCodecAlt (int dim, int r2):
|
|
683
|
+
ZnSphereCodec (dim, r2),
|
|
684
|
+
use_rec ((dim & (dim - 1)) == 0),
|
|
685
|
+
znc_rec (use_rec ? dim : 8,
|
|
686
|
+
use_rec ? r2 : 14)
|
|
687
|
+
{}
|
|
688
|
+
|
|
689
|
+
uint64_t ZnSphereCodecAlt::encode(const float *x) const
|
|
690
|
+
{
|
|
691
|
+
if (!use_rec) {
|
|
692
|
+
// it's ok if the vector is not normalized
|
|
693
|
+
return ZnSphereCodec::encode(x);
|
|
694
|
+
} else {
|
|
695
|
+
// find nearest centroid
|
|
696
|
+
std::vector<float> centroid(dim);
|
|
697
|
+
search (x, centroid.data());
|
|
698
|
+
return znc_rec.encode(centroid.data());
|
|
699
|
+
}
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
void ZnSphereCodecAlt::decode(uint64_t code, float *c) const
|
|
703
|
+
{
|
|
704
|
+
if (!use_rec) {
|
|
705
|
+
ZnSphereCodec::decode (code, c);
|
|
706
|
+
} else {
|
|
707
|
+
znc_rec.decode (code, c);
|
|
708
|
+
}
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
} // namespace faiss
|