faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_INDEX_PQ_H
|
|
11
|
+
#define FAISS_INDEX_PQ_H
|
|
12
|
+
|
|
13
|
+
#include <stdint.h>
|
|
14
|
+
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
#include <faiss/Index.h>
|
|
18
|
+
#include <faiss/impl/ProductQuantizer.h>
|
|
19
|
+
#include <faiss/impl/PolysemousTraining.h>
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
/** Index based on a product quantizer. Stored vectors are
|
|
25
|
+
* approximated by PQ codes. */
|
|
26
|
+
struct IndexPQ: Index {
|
|
27
|
+
|
|
28
|
+
/// The product quantizer used to encode the vectors
|
|
29
|
+
ProductQuantizer pq;
|
|
30
|
+
|
|
31
|
+
/// Codes. Size ntotal * pq.code_size
|
|
32
|
+
std::vector<uint8_t> codes;
|
|
33
|
+
|
|
34
|
+
/** Constructor.
|
|
35
|
+
*
|
|
36
|
+
* @param d dimensionality of the input vectors
|
|
37
|
+
* @param M number of subquantizers
|
|
38
|
+
* @param nbits number of bit per subvector index
|
|
39
|
+
*/
|
|
40
|
+
IndexPQ (int d, ///< dimensionality of the input vectors
|
|
41
|
+
size_t M, ///< number of subquantizers
|
|
42
|
+
size_t nbits, ///< number of bit per subvector index
|
|
43
|
+
MetricType metric = METRIC_L2);
|
|
44
|
+
|
|
45
|
+
IndexPQ ();
|
|
46
|
+
|
|
47
|
+
void train(idx_t n, const float* x) override;
|
|
48
|
+
|
|
49
|
+
void add(idx_t n, const float* x) override;
|
|
50
|
+
|
|
51
|
+
void search(
|
|
52
|
+
idx_t n,
|
|
53
|
+
const float* x,
|
|
54
|
+
idx_t k,
|
|
55
|
+
float* distances,
|
|
56
|
+
idx_t* labels) const override;
|
|
57
|
+
|
|
58
|
+
void reset() override;
|
|
59
|
+
|
|
60
|
+
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
|
61
|
+
|
|
62
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
63
|
+
|
|
64
|
+
size_t remove_ids(const IDSelector& sel) override;
|
|
65
|
+
|
|
66
|
+
/* The standalone codec interface */
|
|
67
|
+
size_t sa_code_size () const override;
|
|
68
|
+
|
|
69
|
+
void sa_encode (idx_t n, const float *x,
|
|
70
|
+
uint8_t *bytes) const override;
|
|
71
|
+
|
|
72
|
+
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
73
|
+
float *x) const override;
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
DistanceComputer * get_distance_computer() const override;
|
|
77
|
+
|
|
78
|
+
/******************************************************
|
|
79
|
+
* Polysemous codes implementation
|
|
80
|
+
******************************************************/
|
|
81
|
+
bool do_polysemous_training; ///< false = standard PQ
|
|
82
|
+
|
|
83
|
+
/// parameters used for the polysemous training
|
|
84
|
+
PolysemousTraining polysemous_training;
|
|
85
|
+
|
|
86
|
+
/// how to perform the search in search_core
|
|
87
|
+
enum Search_type_t {
|
|
88
|
+
ST_PQ, ///< asymmetric product quantizer (default)
|
|
89
|
+
ST_HE, ///< Hamming distance on codes
|
|
90
|
+
ST_generalized_HE, ///< nb of same codes
|
|
91
|
+
ST_SDC, ///< symmetric product quantizer (SDC)
|
|
92
|
+
ST_polysemous, ///< HE filter (using ht) + PQ combination
|
|
93
|
+
ST_polysemous_generalize, ///< Filter on generalized Hamming
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
Search_type_t search_type;
|
|
97
|
+
|
|
98
|
+
// just encode the sign of the components, instead of using the PQ encoder
|
|
99
|
+
// used only for the queries
|
|
100
|
+
bool encode_signs;
|
|
101
|
+
|
|
102
|
+
/// Hamming threshold used for polysemy
|
|
103
|
+
int polysemous_ht;
|
|
104
|
+
|
|
105
|
+
// actual polysemous search
|
|
106
|
+
void search_core_polysemous (idx_t n, const float *x, idx_t k,
|
|
107
|
+
float *distances, idx_t *labels) const;
|
|
108
|
+
|
|
109
|
+
/// prepare query for a polysemous search, but instead of
|
|
110
|
+
/// computing the result, just get the histogram of Hamming
|
|
111
|
+
/// distances. May be computed on a provided dataset if xb != NULL
|
|
112
|
+
/// @param dist_histogram (M * nbits + 1)
|
|
113
|
+
void hamming_distance_histogram (idx_t n, const float *x,
|
|
114
|
+
idx_t nb, const float *xb,
|
|
115
|
+
int64_t *dist_histogram);
|
|
116
|
+
|
|
117
|
+
/** compute pairwise distances between queries and database
|
|
118
|
+
*
|
|
119
|
+
* @param n nb of query vectors
|
|
120
|
+
* @param x query vector, size n * d
|
|
121
|
+
* @param dis output distances, size n * ntotal
|
|
122
|
+
*/
|
|
123
|
+
void hamming_distance_table (idx_t n, const float *x,
|
|
124
|
+
int32_t *dis) const;
|
|
125
|
+
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
/// statistics are robust to internal threading, but not if
|
|
130
|
+
/// IndexPQ::search is called by multiple threads
|
|
131
|
+
struct IndexPQStats {
|
|
132
|
+
size_t nq; // nb of queries run
|
|
133
|
+
size_t ncode; // nb of codes visited
|
|
134
|
+
|
|
135
|
+
size_t n_hamming_pass; // nb of passed Hamming distance tests (for polysemy)
|
|
136
|
+
|
|
137
|
+
IndexPQStats () {reset (); }
|
|
138
|
+
void reset ();
|
|
139
|
+
};
|
|
140
|
+
|
|
141
|
+
extern IndexPQStats indexPQ_stats;
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
/** Quantizer where centroids are virtual: they are the Cartesian
|
|
146
|
+
* product of sub-centroids. */
|
|
147
|
+
struct MultiIndexQuantizer: Index {
|
|
148
|
+
ProductQuantizer pq;
|
|
149
|
+
|
|
150
|
+
MultiIndexQuantizer (int d, ///< dimension of the input vectors
|
|
151
|
+
size_t M, ///< number of subquantizers
|
|
152
|
+
size_t nbits); ///< number of bit per subvector index
|
|
153
|
+
|
|
154
|
+
void train(idx_t n, const float* x) override;
|
|
155
|
+
|
|
156
|
+
void search(
|
|
157
|
+
idx_t n, const float* x, idx_t k,
|
|
158
|
+
float* distances, idx_t* labels) const override;
|
|
159
|
+
|
|
160
|
+
/// add and reset will crash at runtime
|
|
161
|
+
void add(idx_t n, const float* x) override;
|
|
162
|
+
void reset() override;
|
|
163
|
+
|
|
164
|
+
MultiIndexQuantizer () {}
|
|
165
|
+
|
|
166
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
167
|
+
};
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
/** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
|
|
171
|
+
*/
|
|
172
|
+
struct MultiIndexQuantizer2: MultiIndexQuantizer {
|
|
173
|
+
|
|
174
|
+
/// M Indexes on d / M dimensions
|
|
175
|
+
std::vector<Index*> assign_indexes;
|
|
176
|
+
bool own_fields;
|
|
177
|
+
|
|
178
|
+
MultiIndexQuantizer2 (
|
|
179
|
+
int d, size_t M, size_t nbits,
|
|
180
|
+
Index **indexes);
|
|
181
|
+
|
|
182
|
+
MultiIndexQuantizer2 (
|
|
183
|
+
int d, size_t nbits,
|
|
184
|
+
Index *assign_index_0,
|
|
185
|
+
Index *assign_index_1);
|
|
186
|
+
|
|
187
|
+
void train(idx_t n, const float* x) override;
|
|
188
|
+
|
|
189
|
+
void search(
|
|
190
|
+
idx_t n, const float* x, idx_t k,
|
|
191
|
+
float* distances, idx_t* labels) const override;
|
|
192
|
+
|
|
193
|
+
};
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
} // namespace faiss
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
#endif
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/IndexPreTransform.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdio>
|
|
13
|
+
#include <cmath>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <memory>
|
|
16
|
+
|
|
17
|
+
#include <faiss/utils/utils.h>
|
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
|
19
|
+
|
|
20
|
+
namespace faiss {
|
|
21
|
+
|
|
22
|
+
/*********************************************
|
|
23
|
+
* IndexPreTransform
|
|
24
|
+
*********************************************/
|
|
25
|
+
|
|
26
|
+
IndexPreTransform::IndexPreTransform ():
|
|
27
|
+
index(nullptr), own_fields (false)
|
|
28
|
+
{
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
IndexPreTransform::IndexPreTransform (
|
|
33
|
+
Index * index):
|
|
34
|
+
Index (index->d, index->metric_type),
|
|
35
|
+
index (index), own_fields (false)
|
|
36
|
+
{
|
|
37
|
+
is_trained = index->is_trained;
|
|
38
|
+
ntotal = index->ntotal;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
IndexPreTransform::IndexPreTransform (
|
|
43
|
+
VectorTransform * ltrans,
|
|
44
|
+
Index * index):
|
|
45
|
+
Index (index->d, index->metric_type),
|
|
46
|
+
index (index), own_fields (false)
|
|
47
|
+
{
|
|
48
|
+
is_trained = index->is_trained;
|
|
49
|
+
ntotal = index->ntotal;
|
|
50
|
+
prepend_transform (ltrans);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
void IndexPreTransform::prepend_transform (VectorTransform *ltrans)
|
|
54
|
+
{
|
|
55
|
+
FAISS_THROW_IF_NOT (ltrans->d_out == d);
|
|
56
|
+
is_trained = is_trained && ltrans->is_trained;
|
|
57
|
+
chain.insert (chain.begin(), ltrans);
|
|
58
|
+
d = ltrans->d_in;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
IndexPreTransform::~IndexPreTransform ()
|
|
63
|
+
{
|
|
64
|
+
if (own_fields) {
|
|
65
|
+
for (int i = 0; i < chain.size(); i++)
|
|
66
|
+
delete chain[i];
|
|
67
|
+
delete index;
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
void IndexPreTransform::train (idx_t n, const float *x)
|
|
75
|
+
{
|
|
76
|
+
int last_untrained = 0;
|
|
77
|
+
if (!index->is_trained) {
|
|
78
|
+
last_untrained = chain.size();
|
|
79
|
+
} else {
|
|
80
|
+
for (int i = chain.size() - 1; i >= 0; i--) {
|
|
81
|
+
if (!chain[i]->is_trained) {
|
|
82
|
+
last_untrained = i;
|
|
83
|
+
break;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
const float *prev_x = x;
|
|
88
|
+
ScopeDeleter<float> del;
|
|
89
|
+
|
|
90
|
+
if (verbose) {
|
|
91
|
+
printf("IndexPreTransform::train: training chain 0 to %d\n",
|
|
92
|
+
last_untrained);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
for (int i = 0; i <= last_untrained; i++) {
|
|
96
|
+
|
|
97
|
+
if (i < chain.size()) {
|
|
98
|
+
VectorTransform *ltrans = chain [i];
|
|
99
|
+
if (!ltrans->is_trained) {
|
|
100
|
+
if (verbose) {
|
|
101
|
+
printf(" Training chain component %d/%zd\n",
|
|
102
|
+
i, chain.size());
|
|
103
|
+
if (OPQMatrix *opqm = dynamic_cast<OPQMatrix*>(ltrans)) {
|
|
104
|
+
opqm->verbose = true;
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
ltrans->train (n, prev_x);
|
|
108
|
+
}
|
|
109
|
+
} else {
|
|
110
|
+
if (verbose) {
|
|
111
|
+
printf(" Training sub-index\n");
|
|
112
|
+
}
|
|
113
|
+
index->train (n, prev_x);
|
|
114
|
+
}
|
|
115
|
+
if (i == last_untrained) break;
|
|
116
|
+
if (verbose) {
|
|
117
|
+
printf(" Applying transform %d/%zd\n",
|
|
118
|
+
i, chain.size());
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
float * xt = chain[i]->apply (n, prev_x);
|
|
122
|
+
|
|
123
|
+
if (prev_x != x) delete [] prev_x;
|
|
124
|
+
prev_x = xt;
|
|
125
|
+
del.set(xt);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
is_trained = true;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
const float *IndexPreTransform::apply_chain (idx_t n, const float *x) const
|
|
133
|
+
{
|
|
134
|
+
const float *prev_x = x;
|
|
135
|
+
ScopeDeleter<float> del;
|
|
136
|
+
|
|
137
|
+
for (int i = 0; i < chain.size(); i++) {
|
|
138
|
+
float * xt = chain[i]->apply (n, prev_x);
|
|
139
|
+
ScopeDeleter<float> del2 (xt);
|
|
140
|
+
del2.swap (del);
|
|
141
|
+
prev_x = xt;
|
|
142
|
+
}
|
|
143
|
+
del.release ();
|
|
144
|
+
return prev_x;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
void IndexPreTransform::reverse_chain (idx_t n, const float* xt, float* x) const
|
|
148
|
+
{
|
|
149
|
+
const float* next_x = xt;
|
|
150
|
+
ScopeDeleter<float> del;
|
|
151
|
+
|
|
152
|
+
for (int i = chain.size() - 1; i >= 0; i--) {
|
|
153
|
+
float* prev_x = (i == 0) ? x : new float [n * chain[i]->d_in];
|
|
154
|
+
ScopeDeleter<float> del2 ((prev_x == x) ? nullptr : prev_x);
|
|
155
|
+
chain [i]->reverse_transform (n, next_x, prev_x);
|
|
156
|
+
del2.swap (del);
|
|
157
|
+
next_x = prev_x;
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
void IndexPreTransform::add (idx_t n, const float *x)
|
|
162
|
+
{
|
|
163
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
164
|
+
const float *xt = apply_chain (n, x);
|
|
165
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
166
|
+
index->add (n, xt);
|
|
167
|
+
ntotal = index->ntotal;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
void IndexPreTransform::add_with_ids (idx_t n, const float * x,
|
|
171
|
+
const idx_t *xids)
|
|
172
|
+
{
|
|
173
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
174
|
+
const float *xt = apply_chain (n, x);
|
|
175
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
176
|
+
index->add_with_ids (n, xt, xids);
|
|
177
|
+
ntotal = index->ntotal;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
|
|
184
|
+
float *distances, idx_t *labels) const
|
|
185
|
+
{
|
|
186
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
187
|
+
const float *xt = apply_chain (n, x);
|
|
188
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
189
|
+
index->search (n, xt, k, distances, labels);
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
|
|
193
|
+
RangeSearchResult* result) const
|
|
194
|
+
{
|
|
195
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
196
|
+
const float *xt = apply_chain (n, x);
|
|
197
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
198
|
+
index->range_search (n, xt, radius, result);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
void IndexPreTransform::reset () {
|
|
204
|
+
index->reset();
|
|
205
|
+
ntotal = 0;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
size_t IndexPreTransform::remove_ids (const IDSelector & sel) {
|
|
209
|
+
size_t nremove = index->remove_ids (sel);
|
|
210
|
+
ntotal = index->ntotal;
|
|
211
|
+
return nremove;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
void IndexPreTransform::reconstruct (idx_t key, float * recons) const
|
|
216
|
+
{
|
|
217
|
+
float *x = chain.empty() ? recons : new float [index->d];
|
|
218
|
+
ScopeDeleter<float> del (recons == x ? nullptr : x);
|
|
219
|
+
// Initial reconstruction
|
|
220
|
+
index->reconstruct (key, x);
|
|
221
|
+
|
|
222
|
+
// Revert transformations from last to first
|
|
223
|
+
reverse_chain (1, x, recons);
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
void IndexPreTransform::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
|
|
228
|
+
{
|
|
229
|
+
float *x = chain.empty() ? recons : new float [ni * index->d];
|
|
230
|
+
ScopeDeleter<float> del (recons == x ? nullptr : x);
|
|
231
|
+
// Initial reconstruction
|
|
232
|
+
index->reconstruct_n (i0, ni, x);
|
|
233
|
+
|
|
234
|
+
// Revert transformations from last to first
|
|
235
|
+
reverse_chain (ni, x, recons);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
void IndexPreTransform::search_and_reconstruct (
|
|
240
|
+
idx_t n, const float *x, idx_t k,
|
|
241
|
+
float *distances, idx_t *labels, float* recons) const
|
|
242
|
+
{
|
|
243
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
244
|
+
|
|
245
|
+
const float* xt = apply_chain (n, x);
|
|
246
|
+
ScopeDeleter<float> del ((xt == x) ? nullptr : xt);
|
|
247
|
+
|
|
248
|
+
float* recons_temp = chain.empty() ? recons : new float [n * k * index->d];
|
|
249
|
+
ScopeDeleter<float> del2 ((recons_temp == recons) ? nullptr : recons_temp);
|
|
250
|
+
index->search_and_reconstruct (n, xt, k, distances, labels, recons_temp);
|
|
251
|
+
|
|
252
|
+
// Revert transformations from last to first
|
|
253
|
+
reverse_chain (n * k, recons_temp, recons);
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
size_t IndexPreTransform::sa_code_size () const
|
|
257
|
+
{
|
|
258
|
+
return index->sa_code_size ();
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
void IndexPreTransform::sa_encode (idx_t n, const float *x,
|
|
262
|
+
uint8_t *bytes) const
|
|
263
|
+
{
|
|
264
|
+
if (chain.empty()) {
|
|
265
|
+
index->sa_encode (n, x, bytes);
|
|
266
|
+
} else {
|
|
267
|
+
const float *xt = apply_chain (n, x);
|
|
268
|
+
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
|
269
|
+
index->sa_encode (n, xt, bytes);
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
void IndexPreTransform::sa_decode (idx_t n, const uint8_t *bytes,
|
|
274
|
+
float *x) const
|
|
275
|
+
{
|
|
276
|
+
if (chain.empty()) {
|
|
277
|
+
index->sa_decode (n, bytes, x);
|
|
278
|
+
} else {
|
|
279
|
+
std::unique_ptr<float []> x1 (new float [index->d * n]);
|
|
280
|
+
index->sa_decode (n, bytes, x1.get());
|
|
281
|
+
// Revert transformations from last to first
|
|
282
|
+
reverse_chain (n, x1.get(), x);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
} // namespace faiss
|