faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,75 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_INDEX_IVFSH_H
|
11
|
+
#define FAISS_INDEX_IVFSH_H
|
12
|
+
|
13
|
+
|
14
|
+
#include <vector>
|
15
|
+
|
16
|
+
#include <faiss/IndexIVF.h>
|
17
|
+
|
18
|
+
|
19
|
+
namespace faiss {
|
20
|
+
|
21
|
+
struct VectorTransform;
|
22
|
+
|
23
|
+
/** Inverted list that stores binary codes of size nbit. Before the
|
24
|
+
* binary conversion, the dimension of the vectors is transformed from
|
25
|
+
* dim d into dim nbit by vt (a random rotation by default).
|
26
|
+
*
|
27
|
+
* Each coordinate is subtracted from a value determined by
|
28
|
+
* threshold_type, and split into intervals of size period. Half of
|
29
|
+
* the interval is a 0 bit, the other half a 1.
|
30
|
+
*/
|
31
|
+
struct IndexIVFSpectralHash: IndexIVF {
|
32
|
+
|
33
|
+
VectorTransform *vt; // transformation from d to nbit dim
|
34
|
+
bool own_fields;
|
35
|
+
|
36
|
+
int nbit;
|
37
|
+
float period;
|
38
|
+
|
39
|
+
enum ThresholdType {
|
40
|
+
Thresh_global,
|
41
|
+
Thresh_centroid,
|
42
|
+
Thresh_centroid_half,
|
43
|
+
Thresh_median
|
44
|
+
};
|
45
|
+
ThresholdType threshold_type;
|
46
|
+
|
47
|
+
// size nlist * nbit or 0 if Thresh_global
|
48
|
+
std::vector<float> trained;
|
49
|
+
|
50
|
+
IndexIVFSpectralHash (Index * quantizer, size_t d, size_t nlist,
|
51
|
+
int nbit, float period);
|
52
|
+
|
53
|
+
IndexIVFSpectralHash ();
|
54
|
+
|
55
|
+
void train_residual(idx_t n, const float* x) override;
|
56
|
+
|
57
|
+
void encode_vectors(idx_t n, const float* x,
|
58
|
+
const idx_t *list_nos,
|
59
|
+
uint8_t * codes,
|
60
|
+
bool include_listnos = false) const override;
|
61
|
+
|
62
|
+
InvertedListScanner *get_InvertedListScanner (bool store_pairs)
|
63
|
+
const override;
|
64
|
+
|
65
|
+
~IndexIVFSpectralHash () override;
|
66
|
+
|
67
|
+
};
|
68
|
+
|
69
|
+
|
70
|
+
|
71
|
+
|
72
|
+
}; // namespace faiss
|
73
|
+
|
74
|
+
|
75
|
+
#endif
|
@@ -0,0 +1,225 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/IndexLSH.h>
|
11
|
+
|
12
|
+
#include <cstdio>
|
13
|
+
#include <cstring>
|
14
|
+
|
15
|
+
#include <algorithm>
|
16
|
+
|
17
|
+
#include <faiss/utils/utils.h>
|
18
|
+
#include <faiss/utils/hamming.h>
|
19
|
+
#include <faiss/impl/FaissAssert.h>
|
20
|
+
|
21
|
+
|
22
|
+
namespace faiss {
|
23
|
+
|
24
|
+
/***************************************************************
|
25
|
+
* IndexLSH
|
26
|
+
***************************************************************/
|
27
|
+
|
28
|
+
|
29
|
+
IndexLSH::IndexLSH (idx_t d, int nbits, bool rotate_data, bool train_thresholds):
|
30
|
+
Index(d), nbits(nbits), rotate_data(rotate_data),
|
31
|
+
train_thresholds (train_thresholds), rrot(d, nbits)
|
32
|
+
{
|
33
|
+
is_trained = !train_thresholds;
|
34
|
+
|
35
|
+
bytes_per_vec = (nbits + 7) / 8;
|
36
|
+
|
37
|
+
if (rotate_data) {
|
38
|
+
rrot.init(5);
|
39
|
+
} else {
|
40
|
+
FAISS_THROW_IF_NOT (d >= nbits);
|
41
|
+
}
|
42
|
+
}
|
43
|
+
|
44
|
+
IndexLSH::IndexLSH ():
|
45
|
+
nbits (0), bytes_per_vec(0), rotate_data (false), train_thresholds (false)
|
46
|
+
{
|
47
|
+
}
|
48
|
+
|
49
|
+
|
50
|
+
const float * IndexLSH::apply_preprocess (idx_t n, const float *x) const
|
51
|
+
{
|
52
|
+
|
53
|
+
float *xt = nullptr;
|
54
|
+
if (rotate_data) {
|
55
|
+
// also applies bias if exists
|
56
|
+
xt = rrot.apply (n, x);
|
57
|
+
} else if (d != nbits) {
|
58
|
+
assert (nbits < d);
|
59
|
+
xt = new float [nbits * n];
|
60
|
+
float *xp = xt;
|
61
|
+
for (idx_t i = 0; i < n; i++) {
|
62
|
+
const float *xl = x + i * d;
|
63
|
+
for (int j = 0; j < nbits; j++)
|
64
|
+
*xp++ = xl [j];
|
65
|
+
}
|
66
|
+
}
|
67
|
+
|
68
|
+
if (train_thresholds) {
|
69
|
+
|
70
|
+
if (xt == NULL) {
|
71
|
+
xt = new float [nbits * n];
|
72
|
+
memcpy (xt, x, sizeof(*x) * n * nbits);
|
73
|
+
}
|
74
|
+
|
75
|
+
float *xp = xt;
|
76
|
+
for (idx_t i = 0; i < n; i++)
|
77
|
+
for (int j = 0; j < nbits; j++)
|
78
|
+
*xp++ -= thresholds [j];
|
79
|
+
}
|
80
|
+
|
81
|
+
return xt ? xt : x;
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
|
86
|
+
void IndexLSH::train (idx_t n, const float *x)
|
87
|
+
{
|
88
|
+
if (train_thresholds) {
|
89
|
+
thresholds.resize (nbits);
|
90
|
+
train_thresholds = false;
|
91
|
+
const float *xt = apply_preprocess (n, x);
|
92
|
+
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
93
|
+
train_thresholds = true;
|
94
|
+
|
95
|
+
float * transposed_x = new float [n * nbits];
|
96
|
+
ScopeDeleter<float> del2 (transposed_x);
|
97
|
+
|
98
|
+
for (idx_t i = 0; i < n; i++)
|
99
|
+
for (idx_t j = 0; j < nbits; j++)
|
100
|
+
transposed_x [j * n + i] = xt [i * nbits + j];
|
101
|
+
|
102
|
+
for (idx_t i = 0; i < nbits; i++) {
|
103
|
+
float *xi = transposed_x + i * n;
|
104
|
+
// std::nth_element
|
105
|
+
std::sort (xi, xi + n);
|
106
|
+
if (n % 2 == 1)
|
107
|
+
thresholds [i] = xi [n / 2];
|
108
|
+
else
|
109
|
+
thresholds [i] = (xi [n / 2 - 1] + xi [n / 2]) / 2;
|
110
|
+
|
111
|
+
}
|
112
|
+
}
|
113
|
+
is_trained = true;
|
114
|
+
}
|
115
|
+
|
116
|
+
|
117
|
+
void IndexLSH::add (idx_t n, const float *x)
|
118
|
+
{
|
119
|
+
FAISS_THROW_IF_NOT (is_trained);
|
120
|
+
codes.resize ((ntotal + n) * bytes_per_vec);
|
121
|
+
|
122
|
+
sa_encode (n, x, &codes[ntotal * bytes_per_vec]);
|
123
|
+
|
124
|
+
ntotal += n;
|
125
|
+
}
|
126
|
+
|
127
|
+
|
128
|
+
void IndexLSH::search (
|
129
|
+
idx_t n,
|
130
|
+
const float *x,
|
131
|
+
idx_t k,
|
132
|
+
float *distances,
|
133
|
+
idx_t *labels) const
|
134
|
+
{
|
135
|
+
FAISS_THROW_IF_NOT (is_trained);
|
136
|
+
const float *xt = apply_preprocess (n, x);
|
137
|
+
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
138
|
+
|
139
|
+
uint8_t * qcodes = new uint8_t [n * bytes_per_vec];
|
140
|
+
ScopeDeleter<uint8_t> del2 (qcodes);
|
141
|
+
|
142
|
+
fvecs2bitvecs (xt, qcodes, nbits, n);
|
143
|
+
|
144
|
+
int * idistances = new int [n * k];
|
145
|
+
ScopeDeleter<int> del3 (idistances);
|
146
|
+
|
147
|
+
int_maxheap_array_t res = { size_t(n), size_t(k), labels, idistances};
|
148
|
+
|
149
|
+
hammings_knn_hc (&res, qcodes, codes.data(),
|
150
|
+
ntotal, bytes_per_vec, true);
|
151
|
+
|
152
|
+
|
153
|
+
// convert distances to floats
|
154
|
+
for (int i = 0; i < k * n; i++)
|
155
|
+
distances[i] = idistances[i];
|
156
|
+
|
157
|
+
}
|
158
|
+
|
159
|
+
|
160
|
+
void IndexLSH::transfer_thresholds (LinearTransform *vt) {
|
161
|
+
if (!train_thresholds) return;
|
162
|
+
FAISS_THROW_IF_NOT (nbits == vt->d_out);
|
163
|
+
if (!vt->have_bias) {
|
164
|
+
vt->b.resize (nbits, 0);
|
165
|
+
vt->have_bias = true;
|
166
|
+
}
|
167
|
+
for (int i = 0; i < nbits; i++)
|
168
|
+
vt->b[i] -= thresholds[i];
|
169
|
+
train_thresholds = false;
|
170
|
+
thresholds.clear();
|
171
|
+
}
|
172
|
+
|
173
|
+
void IndexLSH::reset() {
|
174
|
+
codes.clear();
|
175
|
+
ntotal = 0;
|
176
|
+
}
|
177
|
+
|
178
|
+
|
179
|
+
size_t IndexLSH::sa_code_size () const
|
180
|
+
{
|
181
|
+
return bytes_per_vec;
|
182
|
+
}
|
183
|
+
|
184
|
+
void IndexLSH::sa_encode (idx_t n, const float *x,
|
185
|
+
uint8_t *bytes) const
|
186
|
+
{
|
187
|
+
FAISS_THROW_IF_NOT (is_trained);
|
188
|
+
const float *xt = apply_preprocess (n, x);
|
189
|
+
ScopeDeleter<float> del (xt == x ? nullptr : xt);
|
190
|
+
fvecs2bitvecs (xt, bytes, nbits, n);
|
191
|
+
}
|
192
|
+
|
193
|
+
void IndexLSH::sa_decode (idx_t n, const uint8_t *bytes,
|
194
|
+
float *x) const
|
195
|
+
{
|
196
|
+
float *xt = x;
|
197
|
+
ScopeDeleter<float> del;
|
198
|
+
if (rotate_data || nbits != d) {
|
199
|
+
xt = new float [n * nbits];
|
200
|
+
del.set(xt);
|
201
|
+
}
|
202
|
+
bitvecs2fvecs (bytes, xt, nbits, n);
|
203
|
+
|
204
|
+
if (train_thresholds) {
|
205
|
+
float *xp = xt;
|
206
|
+
for (idx_t i = 0; i < n; i++) {
|
207
|
+
for (int j = 0; j < nbits; j++) {
|
208
|
+
*xp++ += thresholds [j];
|
209
|
+
}
|
210
|
+
}
|
211
|
+
}
|
212
|
+
|
213
|
+
if (rotate_data) {
|
214
|
+
rrot.reverse_transform (n, xt, x);
|
215
|
+
} else if (nbits != d) {
|
216
|
+
for (idx_t i = 0; i < n; i++) {
|
217
|
+
memcpy (x + i * d, xt + i * nbits,
|
218
|
+
nbits * sizeof(xt[0]));
|
219
|
+
}
|
220
|
+
}
|
221
|
+
}
|
222
|
+
|
223
|
+
|
224
|
+
|
225
|
+
} // namespace faiss
|
@@ -0,0 +1,87 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef INDEX_LSH_H
|
11
|
+
#define INDEX_LSH_H
|
12
|
+
|
13
|
+
#include <vector>
|
14
|
+
|
15
|
+
#include <faiss/Index.h>
|
16
|
+
#include <faiss/VectorTransform.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
|
21
|
+
/** The sign of each vector component is put in a binary signature */
|
22
|
+
struct IndexLSH:Index {
|
23
|
+
typedef unsigned char uint8_t;
|
24
|
+
|
25
|
+
int nbits; ///< nb of bits per vector
|
26
|
+
int bytes_per_vec; ///< nb of 8-bits per encoded vector
|
27
|
+
bool rotate_data; ///< whether to apply a random rotation to input
|
28
|
+
bool train_thresholds; ///< whether we train thresholds or use 0
|
29
|
+
|
30
|
+
RandomRotationMatrix rrot; ///< optional random rotation
|
31
|
+
|
32
|
+
std::vector <float> thresholds; ///< thresholds to compare with
|
33
|
+
|
34
|
+
/// encoded dataset
|
35
|
+
std::vector<uint8_t> codes;
|
36
|
+
|
37
|
+
IndexLSH (
|
38
|
+
idx_t d, int nbits,
|
39
|
+
bool rotate_data = true,
|
40
|
+
bool train_thresholds = false);
|
41
|
+
|
42
|
+
/** Preprocesses and resizes the input to the size required to
|
43
|
+
* binarize the data
|
44
|
+
*
|
45
|
+
* @param x input vectors, size n * d
|
46
|
+
* @return output vectors, size n * bits. May be the same pointer
|
47
|
+
* as x, otherwise it should be deleted by the caller
|
48
|
+
*/
|
49
|
+
const float *apply_preprocess (idx_t n, const float *x) const;
|
50
|
+
|
51
|
+
void train(idx_t n, const float* x) override;
|
52
|
+
|
53
|
+
void add(idx_t n, const float* x) override;
|
54
|
+
|
55
|
+
void search(
|
56
|
+
idx_t n,
|
57
|
+
const float* x,
|
58
|
+
idx_t k,
|
59
|
+
float* distances,
|
60
|
+
idx_t* labels) const override;
|
61
|
+
|
62
|
+
void reset() override;
|
63
|
+
|
64
|
+
/// transfer the thresholds to a pre-processing stage (and unset
|
65
|
+
/// train_thresholds)
|
66
|
+
void transfer_thresholds (LinearTransform * vt);
|
67
|
+
|
68
|
+
~IndexLSH() override {}
|
69
|
+
|
70
|
+
IndexLSH ();
|
71
|
+
|
72
|
+
/* standalone codec interface */
|
73
|
+
size_t sa_code_size () const override;
|
74
|
+
|
75
|
+
void sa_encode (idx_t n, const float *x,
|
76
|
+
uint8_t *bytes) const override;
|
77
|
+
|
78
|
+
void sa_decode (idx_t n, const uint8_t *bytes,
|
79
|
+
float *x) const override;
|
80
|
+
|
81
|
+
};
|
82
|
+
|
83
|
+
|
84
|
+
}
|
85
|
+
|
86
|
+
|
87
|
+
#endif
|
@@ -0,0 +1,143 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
|
11
|
+
#include <faiss/IndexLattice.h>
|
12
|
+
#include <faiss/utils/hamming.h> // for the bitstring routines
|
13
|
+
#include <faiss/impl/FaissAssert.h>
|
14
|
+
#include <faiss/utils/distances.h>
|
15
|
+
|
16
|
+
namespace faiss {
|
17
|
+
|
18
|
+
|
19
|
+
IndexLattice::IndexLattice (idx_t d, int nsq, int scale_nbit, int r2):
|
20
|
+
Index (d),
|
21
|
+
nsq (nsq),
|
22
|
+
dsq (d / nsq),
|
23
|
+
zn_sphere_codec (dsq, r2),
|
24
|
+
scale_nbit (scale_nbit)
|
25
|
+
{
|
26
|
+
FAISS_THROW_IF_NOT (d % nsq == 0);
|
27
|
+
|
28
|
+
lattice_nbit = 0;
|
29
|
+
while (!( ((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
|
30
|
+
lattice_nbit++;
|
31
|
+
}
|
32
|
+
|
33
|
+
int total_nbit = (lattice_nbit + scale_nbit) * nsq;
|
34
|
+
|
35
|
+
code_size = (total_nbit + 7) / 8;
|
36
|
+
|
37
|
+
is_trained = false;
|
38
|
+
}
|
39
|
+
|
40
|
+
void IndexLattice::train(idx_t n, const float* x)
|
41
|
+
{
|
42
|
+
// compute ranges per sub-block
|
43
|
+
trained.resize (nsq * 2);
|
44
|
+
float * mins = trained.data();
|
45
|
+
float * maxs = trained.data() + nsq;
|
46
|
+
for (int sq = 0; sq < nsq; sq++) {
|
47
|
+
mins[sq] = HUGE_VAL;
|
48
|
+
maxs[sq] = -1;
|
49
|
+
}
|
50
|
+
|
51
|
+
for (idx_t i = 0; i < n; i++) {
|
52
|
+
for (int sq = 0; sq < nsq; sq++) {
|
53
|
+
float norm2 = fvec_norm_L2sqr (x + i * d + sq * dsq, dsq);
|
54
|
+
if (norm2 > maxs[sq]) maxs[sq] = norm2;
|
55
|
+
if (norm2 < mins[sq]) mins[sq] = norm2;
|
56
|
+
}
|
57
|
+
}
|
58
|
+
|
59
|
+
for (int sq = 0; sq < nsq; sq++) {
|
60
|
+
mins[sq] = sqrtf (mins[sq]);
|
61
|
+
maxs[sq] = sqrtf (maxs[sq]);
|
62
|
+
}
|
63
|
+
|
64
|
+
is_trained = true;
|
65
|
+
}
|
66
|
+
|
67
|
+
/* The standalone codec interface */
|
68
|
+
size_t IndexLattice::sa_code_size () const
|
69
|
+
{
|
70
|
+
return code_size;
|
71
|
+
}
|
72
|
+
|
73
|
+
|
74
|
+
|
75
|
+
void IndexLattice::sa_encode (idx_t n, const float *x, uint8_t *codes) const
|
76
|
+
{
|
77
|
+
|
78
|
+
const float * mins = trained.data();
|
79
|
+
const float * maxs = mins + nsq;
|
80
|
+
int64_t sc = int64_t(1) << scale_nbit;
|
81
|
+
|
82
|
+
#pragma omp parallel for
|
83
|
+
for (idx_t i = 0; i < n; i++) {
|
84
|
+
BitstringWriter wr(codes + i * code_size, code_size);
|
85
|
+
const float *xi = x + i * d;
|
86
|
+
for (int j = 0; j < nsq; j++) {
|
87
|
+
float nj =
|
88
|
+
(sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j])
|
89
|
+
* sc / (maxs[j] - mins[j]);
|
90
|
+
if (nj < 0) nj = 0;
|
91
|
+
if (nj >= sc) nj = sc - 1;
|
92
|
+
wr.write((int64_t)nj, scale_nbit);
|
93
|
+
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
|
94
|
+
xi += dsq;
|
95
|
+
}
|
96
|
+
}
|
97
|
+
}
|
98
|
+
|
99
|
+
void IndexLattice::sa_decode (idx_t n, const uint8_t *codes, float *x) const
|
100
|
+
{
|
101
|
+
const float * mins = trained.data();
|
102
|
+
const float * maxs = mins + nsq;
|
103
|
+
float sc = int64_t(1) << scale_nbit;
|
104
|
+
float r = sqrtf(zn_sphere_codec.r2);
|
105
|
+
|
106
|
+
#pragma omp parallel for
|
107
|
+
for (idx_t i = 0; i < n; i++) {
|
108
|
+
BitstringReader rd(codes + i * code_size, code_size);
|
109
|
+
float *xi = x + i * d;
|
110
|
+
for (int j = 0; j < nsq; j++) {
|
111
|
+
float norm =
|
112
|
+
(rd.read (scale_nbit) + 0.5) *
|
113
|
+
(maxs[j] - mins[j]) / sc + mins[j];
|
114
|
+
norm /= r;
|
115
|
+
zn_sphere_codec.decode (rd.read (lattice_nbit), xi);
|
116
|
+
for (int l = 0; l < dsq; l++) {
|
117
|
+
xi[l] *= norm;
|
118
|
+
}
|
119
|
+
xi += dsq;
|
120
|
+
}
|
121
|
+
}
|
122
|
+
}
|
123
|
+
|
124
|
+
void IndexLattice::add(idx_t , const float* )
|
125
|
+
{
|
126
|
+
FAISS_THROW_MSG("not implemented");
|
127
|
+
}
|
128
|
+
|
129
|
+
|
130
|
+
void IndexLattice::search(idx_t , const float* , idx_t ,
|
131
|
+
float* , idx_t* ) const
|
132
|
+
{
|
133
|
+
FAISS_THROW_MSG("not implemented");
|
134
|
+
}
|
135
|
+
|
136
|
+
|
137
|
+
void IndexLattice::reset()
|
138
|
+
{
|
139
|
+
FAISS_THROW_MSG("not implemented");
|
140
|
+
}
|
141
|
+
|
142
|
+
|
143
|
+
} // namespace faiss
|