faiss 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +20 -2
|
@@ -10,11 +10,11 @@
|
|
|
10
10
|
#ifndef META_INDEXES_H
|
|
11
11
|
#define META_INDEXES_H
|
|
12
12
|
|
|
13
|
-
#include <vector>
|
|
14
|
-
#include <unordered_map>
|
|
15
13
|
#include <faiss/Index.h>
|
|
16
|
-
#include <faiss/IndexShards.h>
|
|
17
14
|
#include <faiss/IndexReplicas.h>
|
|
15
|
+
#include <faiss/IndexShards.h>
|
|
16
|
+
#include <unordered_map>
|
|
17
|
+
#include <vector>
|
|
18
18
|
|
|
19
19
|
namespace faiss {
|
|
20
20
|
|
|
@@ -25,22 +25,25 @@ struct IndexIDMapTemplate : IndexT {
|
|
|
25
25
|
using component_t = typename IndexT::component_t;
|
|
26
26
|
using distance_t = typename IndexT::distance_t;
|
|
27
27
|
|
|
28
|
-
IndexT
|
|
29
|
-
bool own_fields;
|
|
28
|
+
IndexT* index; ///! the sub-index
|
|
29
|
+
bool own_fields; ///! whether pointers are deleted in destructo
|
|
30
30
|
std::vector<idx_t> id_map;
|
|
31
31
|
|
|
32
|
-
explicit IndexIDMapTemplate
|
|
32
|
+
explicit IndexIDMapTemplate(IndexT* index);
|
|
33
33
|
|
|
34
34
|
/// @param xids if non-null, ids to store for the vectors (size n)
|
|
35
|
-
void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
|
|
35
|
+
void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
|
|
36
|
+
override;
|
|
36
37
|
|
|
37
38
|
/// this will fail. Use add_with_ids
|
|
38
39
|
void add(idx_t n, const component_t* x) override;
|
|
39
40
|
|
|
40
41
|
void search(
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
42
|
+
idx_t n,
|
|
43
|
+
const component_t* x,
|
|
44
|
+
idx_t k,
|
|
45
|
+
distance_t* distances,
|
|
46
|
+
idx_t* labels) const override;
|
|
44
47
|
|
|
45
48
|
void train(idx_t n, const component_t* x) override;
|
|
46
49
|
|
|
@@ -49,17 +52,22 @@ struct IndexIDMapTemplate : IndexT {
|
|
|
49
52
|
/// remove ids adapted to IndexFlat
|
|
50
53
|
size_t remove_ids(const IDSelector& sel) override;
|
|
51
54
|
|
|
52
|
-
void range_search
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
55
|
+
void range_search(
|
|
56
|
+
idx_t n,
|
|
57
|
+
const component_t* x,
|
|
58
|
+
distance_t radius,
|
|
59
|
+
RangeSearchResult* result) const override;
|
|
60
|
+
|
|
61
|
+
~IndexIDMapTemplate() override;
|
|
62
|
+
IndexIDMapTemplate() {
|
|
63
|
+
own_fields = false;
|
|
64
|
+
index = nullptr;
|
|
65
|
+
}
|
|
57
66
|
};
|
|
58
67
|
|
|
59
68
|
using IndexIDMap = IndexIDMapTemplate<Index>;
|
|
60
69
|
using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
|
|
61
70
|
|
|
62
|
-
|
|
63
71
|
/** same as IndexIDMap but also provides an efficient reconstruction
|
|
64
72
|
* implementation via a 2-way index */
|
|
65
73
|
template <typename IndexT>
|
|
@@ -70,47 +78,47 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
|
|
|
70
78
|
|
|
71
79
|
std::unordered_map<idx_t, idx_t> rev_map;
|
|
72
80
|
|
|
73
|
-
explicit IndexIDMap2Template
|
|
81
|
+
explicit IndexIDMap2Template(IndexT* index);
|
|
74
82
|
|
|
75
83
|
/// make the rev_map from scratch
|
|
76
|
-
void construct_rev_map
|
|
84
|
+
void construct_rev_map();
|
|
77
85
|
|
|
78
|
-
void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
|
|
86
|
+
void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
|
|
87
|
+
override;
|
|
79
88
|
|
|
80
89
|
size_t remove_ids(const IDSelector& sel) override;
|
|
81
90
|
|
|
82
|
-
void reconstruct
|
|
91
|
+
void reconstruct(idx_t key, component_t* recons) const override;
|
|
83
92
|
|
|
84
93
|
~IndexIDMap2Template() override {}
|
|
85
|
-
IndexIDMap2Template
|
|
94
|
+
IndexIDMap2Template() {}
|
|
86
95
|
};
|
|
87
96
|
|
|
88
97
|
using IndexIDMap2 = IndexIDMap2Template<Index>;
|
|
89
98
|
using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
|
|
90
99
|
|
|
91
|
-
|
|
92
100
|
/** splits input vectors in segments and assigns each segment to a sub-index
|
|
93
101
|
* used to distribute a MultiIndexQuantizer
|
|
94
102
|
*/
|
|
95
|
-
struct IndexSplitVectors: Index {
|
|
103
|
+
struct IndexSplitVectors : Index {
|
|
96
104
|
bool own_fields;
|
|
97
105
|
bool threaded;
|
|
98
106
|
std::vector<Index*> sub_indexes;
|
|
99
|
-
idx_t sum_d;
|
|
107
|
+
idx_t sum_d; /// sum of dimensions seen so far
|
|
100
108
|
|
|
101
|
-
explicit IndexSplitVectors
|
|
109
|
+
explicit IndexSplitVectors(idx_t d, bool threaded = false);
|
|
102
110
|
|
|
103
|
-
void add_sub_index
|
|
104
|
-
void sync_with_sub_indexes
|
|
111
|
+
void add_sub_index(Index*);
|
|
112
|
+
void sync_with_sub_indexes();
|
|
105
113
|
|
|
106
114
|
void add(idx_t n, const float* x) override;
|
|
107
115
|
|
|
108
116
|
void search(
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
117
|
+
idx_t n,
|
|
118
|
+
const float* x,
|
|
119
|
+
idx_t k,
|
|
120
|
+
float* distances,
|
|
121
|
+
idx_t* labels) const override;
|
|
114
122
|
|
|
115
123
|
void train(idx_t n, const float* x) override;
|
|
116
124
|
|
|
@@ -119,8 +127,6 @@ struct IndexSplitVectors: Index {
|
|
|
119
127
|
~IndexSplitVectors() override;
|
|
120
128
|
};
|
|
121
129
|
|
|
122
|
-
|
|
123
130
|
} // namespace faiss
|
|
124
131
|
|
|
125
|
-
|
|
126
132
|
#endif
|
|
@@ -18,12 +18,12 @@ namespace faiss {
|
|
|
18
18
|
/// (brute-force) indices supporting additional metric types for vector
|
|
19
19
|
/// comparison.
|
|
20
20
|
enum MetricType {
|
|
21
|
-
METRIC_INNER_PRODUCT = 0,
|
|
22
|
-
METRIC_L2 = 1,
|
|
23
|
-
METRIC_L1,
|
|
24
|
-
METRIC_Linf,
|
|
25
|
-
METRIC_Lp,
|
|
26
|
-
|
|
21
|
+
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
|
|
22
|
+
METRIC_L2 = 1, ///< squared L2 search
|
|
23
|
+
METRIC_L1, ///< L1 (aka cityblock)
|
|
24
|
+
METRIC_Linf, ///< infinity distance
|
|
25
|
+
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
|
|
26
|
+
/// metric_arg
|
|
27
27
|
|
|
28
28
|
/// some additional metrics defined in scipy.spatial.distance
|
|
29
29
|
METRIC_Canberra = 20,
|
|
@@ -31,6 +31,6 @@ enum MetricType {
|
|
|
31
31
|
METRIC_JensenShannon,
|
|
32
32
|
};
|
|
33
33
|
|
|
34
|
-
}
|
|
34
|
+
} // namespace faiss
|
|
35
35
|
|
|
36
36
|
#endif
|
|
@@ -10,20 +10,19 @@
|
|
|
10
10
|
#include <faiss/VectorTransform.h>
|
|
11
11
|
|
|
12
12
|
#include <cinttypes>
|
|
13
|
-
#include <cstdio>
|
|
14
13
|
#include <cmath>
|
|
14
|
+
#include <cstdio>
|
|
15
15
|
#include <cstring>
|
|
16
16
|
#include <memory>
|
|
17
17
|
|
|
18
|
+
#include <faiss/IndexPQ.h>
|
|
19
|
+
#include <faiss/impl/FaissAssert.h>
|
|
18
20
|
#include <faiss/utils/distances.h>
|
|
19
21
|
#include <faiss/utils/random.h>
|
|
20
22
|
#include <faiss/utils/utils.h>
|
|
21
|
-
#include <faiss/impl/FaissAssert.h>
|
|
22
|
-
#include <faiss/IndexPQ.h>
|
|
23
23
|
|
|
24
24
|
using namespace faiss;
|
|
25
25
|
|
|
26
|
-
|
|
27
26
|
extern "C" {
|
|
28
27
|
|
|
29
28
|
// this is to keep the clang syntax checker happy
|
|
@@ -31,134 +30,183 @@ extern "C" {
|
|
|
31
30
|
#define FINTEGER int
|
|
32
31
|
#endif
|
|
33
32
|
|
|
34
|
-
|
|
35
33
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
36
34
|
|
|
37
|
-
int sgemm_
|
|
38
|
-
const char
|
|
39
|
-
|
|
40
|
-
FINTEGER
|
|
41
|
-
FINTEGER
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
FINTEGER
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
35
|
+
int sgemm_(
|
|
36
|
+
const char* transa,
|
|
37
|
+
const char* transb,
|
|
38
|
+
FINTEGER* m,
|
|
39
|
+
FINTEGER* n,
|
|
40
|
+
FINTEGER* k,
|
|
41
|
+
const float* alpha,
|
|
42
|
+
const float* a,
|
|
43
|
+
FINTEGER* lda,
|
|
44
|
+
const float* b,
|
|
45
|
+
FINTEGER* ldb,
|
|
46
|
+
float* beta,
|
|
47
|
+
float* c,
|
|
48
|
+
FINTEGER* ldc);
|
|
49
|
+
|
|
50
|
+
int dgemm_(
|
|
51
|
+
const char* transa,
|
|
52
|
+
const char* transb,
|
|
53
|
+
FINTEGER* m,
|
|
54
|
+
FINTEGER* n,
|
|
55
|
+
FINTEGER* k,
|
|
56
|
+
const double* alpha,
|
|
57
|
+
const double* a,
|
|
58
|
+
FINTEGER* lda,
|
|
59
|
+
const double* b,
|
|
60
|
+
FINTEGER* ldb,
|
|
61
|
+
double* beta,
|
|
62
|
+
double* c,
|
|
63
|
+
FINTEGER* ldc);
|
|
64
|
+
|
|
65
|
+
int ssyrk_(
|
|
66
|
+
const char* uplo,
|
|
67
|
+
const char* trans,
|
|
68
|
+
FINTEGER* n,
|
|
69
|
+
FINTEGER* k,
|
|
70
|
+
float* alpha,
|
|
71
|
+
float* a,
|
|
72
|
+
FINTEGER* lda,
|
|
73
|
+
float* beta,
|
|
74
|
+
float* c,
|
|
75
|
+
FINTEGER* ldc);
|
|
55
76
|
|
|
56
77
|
/* Lapack functions from http://www.netlib.org/clapack/old/single/ */
|
|
57
78
|
|
|
58
|
-
int ssyev_
|
|
59
|
-
const char
|
|
60
|
-
|
|
61
|
-
FINTEGER
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
FINTEGER
|
|
79
|
+
int ssyev_(
|
|
80
|
+
const char* jobz,
|
|
81
|
+
const char* uplo,
|
|
82
|
+
FINTEGER* n,
|
|
83
|
+
float* a,
|
|
84
|
+
FINTEGER* lda,
|
|
85
|
+
float* w,
|
|
86
|
+
float* work,
|
|
87
|
+
FINTEGER* lwork,
|
|
88
|
+
FINTEGER* info);
|
|
89
|
+
|
|
90
|
+
int dsyev_(
|
|
91
|
+
const char* jobz,
|
|
92
|
+
const char* uplo,
|
|
93
|
+
FINTEGER* n,
|
|
94
|
+
double* a,
|
|
95
|
+
FINTEGER* lda,
|
|
96
|
+
double* w,
|
|
97
|
+
double* work,
|
|
98
|
+
FINTEGER* lwork,
|
|
99
|
+
FINTEGER* info);
|
|
67
100
|
|
|
68
101
|
int sgesvd_(
|
|
69
|
-
const char
|
|
70
|
-
|
|
71
|
-
FINTEGER
|
|
72
|
-
|
|
102
|
+
const char* jobu,
|
|
103
|
+
const char* jobvt,
|
|
104
|
+
FINTEGER* m,
|
|
105
|
+
FINTEGER* n,
|
|
106
|
+
float* a,
|
|
107
|
+
FINTEGER* lda,
|
|
108
|
+
float* s,
|
|
109
|
+
float* u,
|
|
110
|
+
FINTEGER* ldu,
|
|
111
|
+
float* vt,
|
|
112
|
+
FINTEGER* ldvt,
|
|
113
|
+
float* work,
|
|
114
|
+
FINTEGER* lwork,
|
|
115
|
+
FINTEGER* info);
|
|
73
116
|
|
|
74
117
|
int dgesvd_(
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
118
|
+
const char* jobu,
|
|
119
|
+
const char* jobvt,
|
|
120
|
+
FINTEGER* m,
|
|
121
|
+
FINTEGER* n,
|
|
122
|
+
double* a,
|
|
123
|
+
FINTEGER* lda,
|
|
124
|
+
double* s,
|
|
125
|
+
double* u,
|
|
126
|
+
FINTEGER* ldu,
|
|
127
|
+
double* vt,
|
|
128
|
+
FINTEGER* ldvt,
|
|
129
|
+
double* work,
|
|
130
|
+
FINTEGER* lwork,
|
|
131
|
+
FINTEGER* info);
|
|
79
132
|
}
|
|
80
133
|
|
|
81
134
|
/*********************************************
|
|
82
135
|
* VectorTransform
|
|
83
136
|
*********************************************/
|
|
84
137
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
{
|
|
89
|
-
float * xt = new float[n * d_out];
|
|
90
|
-
apply_noalloc (n, x, xt);
|
|
138
|
+
float* VectorTransform::apply(Index::idx_t n, const float* x) const {
|
|
139
|
+
float* xt = new float[n * d_out];
|
|
140
|
+
apply_noalloc(n, x, xt);
|
|
91
141
|
return xt;
|
|
92
142
|
}
|
|
93
143
|
|
|
94
|
-
|
|
95
|
-
void VectorTransform::train (idx_t, const float *) {
|
|
144
|
+
void VectorTransform::train(idx_t, const float*) {
|
|
96
145
|
// does nothing by default
|
|
97
146
|
}
|
|
98
147
|
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
idx_t , const float *,
|
|
102
|
-
float *) const
|
|
103
|
-
{
|
|
104
|
-
FAISS_THROW_MSG ("reverse transform not implemented");
|
|
148
|
+
void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
|
|
149
|
+
FAISS_THROW_MSG("reverse transform not implemented");
|
|
105
150
|
}
|
|
106
151
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
152
|
/*********************************************
|
|
111
153
|
* LinearTransform
|
|
112
154
|
*********************************************/
|
|
113
155
|
|
|
114
156
|
/// both d_in > d_out and d_out < d_in are supported
|
|
115
|
-
LinearTransform::LinearTransform
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
{
|
|
157
|
+
LinearTransform::LinearTransform(int d_in, int d_out, bool have_bias)
|
|
158
|
+
: VectorTransform(d_in, d_out),
|
|
159
|
+
have_bias(have_bias),
|
|
160
|
+
is_orthonormal(false),
|
|
161
|
+
verbose(false) {
|
|
120
162
|
is_trained = false; // will be trained when A and b are initialized
|
|
121
163
|
}
|
|
122
164
|
|
|
123
|
-
void LinearTransform::apply_noalloc
|
|
124
|
-
|
|
125
|
-
{
|
|
165
|
+
void LinearTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
|
|
166
|
+
const {
|
|
126
167
|
FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
|
|
127
168
|
|
|
128
169
|
float c_factor;
|
|
129
170
|
if (have_bias) {
|
|
130
|
-
FAISS_THROW_IF_NOT_MSG
|
|
131
|
-
float
|
|
171
|
+
FAISS_THROW_IF_NOT_MSG(b.size() == d_out, "Bias not initialized");
|
|
172
|
+
float* xi = xt;
|
|
132
173
|
for (int i = 0; i < n; i++)
|
|
133
|
-
for(int j = 0; j < d_out; j++)
|
|
174
|
+
for (int j = 0; j < d_out; j++)
|
|
134
175
|
*xi++ = b[j];
|
|
135
176
|
c_factor = 1.0;
|
|
136
177
|
} else {
|
|
137
178
|
c_factor = 0.0;
|
|
138
179
|
}
|
|
139
180
|
|
|
140
|
-
FAISS_THROW_IF_NOT_MSG
|
|
141
|
-
|
|
181
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
182
|
+
A.size() == d_out * d_in, "Transformation matrix not initialized");
|
|
142
183
|
|
|
143
184
|
float one = 1;
|
|
144
185
|
FINTEGER nbiti = d_out, ni = n, di = d_in;
|
|
145
|
-
sgemm_
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
186
|
+
sgemm_("Transposed",
|
|
187
|
+
"Not transposed",
|
|
188
|
+
&nbiti,
|
|
189
|
+
&ni,
|
|
190
|
+
&di,
|
|
191
|
+
&one,
|
|
192
|
+
A.data(),
|
|
193
|
+
&di,
|
|
194
|
+
x,
|
|
195
|
+
&di,
|
|
196
|
+
&c_factor,
|
|
197
|
+
xt,
|
|
198
|
+
&nbiti);
|
|
149
199
|
}
|
|
150
200
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
float *x) const
|
|
154
|
-
{
|
|
201
|
+
void LinearTransform::transform_transpose(idx_t n, const float* y, float* x)
|
|
202
|
+
const {
|
|
155
203
|
if (have_bias) { // allocate buffer to store bias-corrected data
|
|
156
|
-
float
|
|
157
|
-
const float
|
|
158
|
-
float
|
|
204
|
+
float* y_new = new float[n * d_out];
|
|
205
|
+
const float* yr = y;
|
|
206
|
+
float* yw = y_new;
|
|
159
207
|
for (idx_t i = 0; i < n; i++) {
|
|
160
208
|
for (int j = 0; j < d_out; j++) {
|
|
161
|
-
*yw++ = *yr++ - b
|
|
209
|
+
*yw++ = *yr++ - b[j];
|
|
162
210
|
}
|
|
163
211
|
}
|
|
164
212
|
y = y_new;
|
|
@@ -167,15 +215,26 @@ void LinearTransform::transform_transpose (idx_t n, const float * y,
|
|
|
167
215
|
{
|
|
168
216
|
FINTEGER dii = d_in, doi = d_out, ni = n;
|
|
169
217
|
float one = 1.0, zero = 0.0;
|
|
170
|
-
sgemm_
|
|
171
|
-
|
|
218
|
+
sgemm_("Not",
|
|
219
|
+
"Not",
|
|
220
|
+
&dii,
|
|
221
|
+
&ni,
|
|
222
|
+
&doi,
|
|
223
|
+
&one,
|
|
224
|
+
A.data(),
|
|
225
|
+
&dii,
|
|
226
|
+
y,
|
|
227
|
+
&doi,
|
|
228
|
+
&zero,
|
|
229
|
+
x,
|
|
230
|
+
&dii);
|
|
172
231
|
}
|
|
173
232
|
|
|
174
|
-
if (have_bias)
|
|
233
|
+
if (have_bias)
|
|
234
|
+
delete[] y;
|
|
175
235
|
}
|
|
176
236
|
|
|
177
|
-
void LinearTransform::set_is_orthonormal
|
|
178
|
-
{
|
|
237
|
+
void LinearTransform::set_is_orthonormal() {
|
|
179
238
|
if (d_out > d_in) {
|
|
180
239
|
// not clear what we should do in this case
|
|
181
240
|
is_orthonormal = false;
|
|
@@ -193,44 +252,53 @@ void LinearTransform::set_is_orthonormal ()
|
|
|
193
252
|
FINTEGER dii = d_in, doi = d_out;
|
|
194
253
|
float one = 1.0, zero = 0.0;
|
|
195
254
|
|
|
196
|
-
sgemm_
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
255
|
+
sgemm_("Transposed",
|
|
256
|
+
"Not",
|
|
257
|
+
&doi,
|
|
258
|
+
&doi,
|
|
259
|
+
&dii,
|
|
260
|
+
&one,
|
|
261
|
+
A.data(),
|
|
262
|
+
&dii,
|
|
263
|
+
A.data(),
|
|
264
|
+
&dii,
|
|
265
|
+
&zero,
|
|
266
|
+
ATA.data(),
|
|
267
|
+
&doi);
|
|
200
268
|
|
|
201
269
|
is_orthonormal = true;
|
|
202
270
|
for (long i = 0; i < d_out; i++) {
|
|
203
271
|
for (long j = 0; j < d_out; j++) {
|
|
204
272
|
float v = ATA[i + j * d_out];
|
|
205
|
-
if (i == j)
|
|
273
|
+
if (i == j)
|
|
274
|
+
v -= 1;
|
|
206
275
|
if (fabs(v) > eps) {
|
|
207
276
|
is_orthonormal = false;
|
|
208
277
|
}
|
|
209
278
|
}
|
|
210
279
|
}
|
|
211
280
|
}
|
|
212
|
-
|
|
213
281
|
}
|
|
214
282
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
float *x) const
|
|
218
|
-
{
|
|
283
|
+
void LinearTransform::reverse_transform(idx_t n, const float* xt, float* x)
|
|
284
|
+
const {
|
|
219
285
|
if (is_orthonormal) {
|
|
220
|
-
transform_transpose
|
|
286
|
+
transform_transpose(n, xt, x);
|
|
221
287
|
} else {
|
|
222
|
-
FAISS_THROW_MSG
|
|
288
|
+
FAISS_THROW_MSG(
|
|
289
|
+
"reverse transform not implemented for non-orthonormal matrices");
|
|
223
290
|
}
|
|
224
291
|
}
|
|
225
292
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
{
|
|
231
|
-
if (!verbose)
|
|
293
|
+
void LinearTransform::print_if_verbose(
|
|
294
|
+
const char* name,
|
|
295
|
+
const std::vector<double>& mat,
|
|
296
|
+
int n,
|
|
297
|
+
int d) const {
|
|
298
|
+
if (!verbose)
|
|
299
|
+
return;
|
|
232
300
|
printf("matrix %s: %d*%d [\n", name, n, d);
|
|
233
|
-
FAISS_THROW_IF_NOT
|
|
301
|
+
FAISS_THROW_IF_NOT(mat.size() >= n * d);
|
|
234
302
|
for (int i = 0; i < n; i++) {
|
|
235
303
|
for (int j = 0; j < d; j++) {
|
|
236
304
|
printf("%10.5g ", mat[i * d + j]);
|
|
@@ -244,24 +312,22 @@ void LinearTransform::print_if_verbose (
|
|
|
244
312
|
* RandomRotationMatrix
|
|
245
313
|
*********************************************/
|
|
246
314
|
|
|
247
|
-
void RandomRotationMatrix::init
|
|
248
|
-
{
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
A.resize (d_out * d_in);
|
|
252
|
-
float *q = A.data();
|
|
315
|
+
void RandomRotationMatrix::init(int seed) {
|
|
316
|
+
if (d_out <= d_in) {
|
|
317
|
+
A.resize(d_out * d_in);
|
|
318
|
+
float* q = A.data();
|
|
253
319
|
float_randn(q, d_out * d_in, seed);
|
|
254
320
|
matrix_qr(d_in, d_out, q);
|
|
255
321
|
} else {
|
|
256
322
|
// use tight-frame transformation
|
|
257
|
-
A.resize
|
|
258
|
-
float
|
|
323
|
+
A.resize(d_out * d_out);
|
|
324
|
+
float* q = A.data();
|
|
259
325
|
float_randn(q, d_out * d_out, seed);
|
|
260
326
|
matrix_qr(d_out, d_out, q);
|
|
261
327
|
// remove columns
|
|
262
328
|
int i, j;
|
|
263
329
|
for (i = 0; i < d_out; i++) {
|
|
264
|
-
for(j = 0; j < d_in; j++) {
|
|
330
|
+
for (j = 0; j < d_in; j++) {
|
|
265
331
|
q[i * d_in + j] = q[i * d_out + j];
|
|
266
332
|
}
|
|
267
333
|
}
|
|
@@ -271,247 +337,280 @@ void RandomRotationMatrix::init (int seed)
|
|
|
271
337
|
is_trained = true;
|
|
272
338
|
}
|
|
273
339
|
|
|
274
|
-
void RandomRotationMatrix::train
|
|
275
|
-
{
|
|
340
|
+
void RandomRotationMatrix::train(Index::idx_t /*n*/, const float* /*x*/) {
|
|
276
341
|
// initialize with some arbitrary seed
|
|
277
|
-
init
|
|
342
|
+
init(12345);
|
|
278
343
|
}
|
|
279
344
|
|
|
280
|
-
|
|
281
345
|
/*********************************************
|
|
282
346
|
* PCAMatrix
|
|
283
347
|
*********************************************/
|
|
284
348
|
|
|
285
|
-
PCAMatrix::PCAMatrix
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
349
|
+
PCAMatrix::PCAMatrix(
|
|
350
|
+
int d_in,
|
|
351
|
+
int d_out,
|
|
352
|
+
float eigen_power,
|
|
353
|
+
bool random_rotation)
|
|
354
|
+
: LinearTransform(d_in, d_out, true),
|
|
355
|
+
eigen_power(eigen_power),
|
|
356
|
+
random_rotation(random_rotation) {
|
|
290
357
|
is_trained = false;
|
|
291
358
|
max_points_per_d = 1000;
|
|
292
359
|
balanced_bins = 0;
|
|
293
360
|
}
|
|
294
361
|
|
|
295
|
-
|
|
296
362
|
namespace {
|
|
297
363
|
|
|
298
364
|
/// Compute the eigenvalue decomposition of symmetric matrix cov,
|
|
299
365
|
/// dimensions d_in-by-d_in. Output eigenvectors in cov.
|
|
300
366
|
|
|
301
|
-
void eig(size_t d_in, double
|
|
302
|
-
{
|
|
367
|
+
void eig(size_t d_in, double* cov, double* eigenvalues, int verbose) {
|
|
303
368
|
{ // compute eigenvalues and vectors
|
|
304
369
|
FINTEGER info = 0, lwork = -1, di = d_in;
|
|
305
370
|
double workq;
|
|
306
371
|
|
|
307
|
-
dsyev_
|
|
308
|
-
|
|
372
|
+
dsyev_("Vectors as well",
|
|
373
|
+
"Upper",
|
|
374
|
+
&di,
|
|
375
|
+
cov,
|
|
376
|
+
&di,
|
|
377
|
+
eigenvalues,
|
|
378
|
+
&workq,
|
|
379
|
+
&lwork,
|
|
380
|
+
&info);
|
|
309
381
|
lwork = FINTEGER(workq);
|
|
310
|
-
double
|
|
382
|
+
double* work = new double[lwork];
|
|
311
383
|
|
|
312
|
-
dsyev_
|
|
313
|
-
|
|
384
|
+
dsyev_("Vectors as well",
|
|
385
|
+
"Upper",
|
|
386
|
+
&di,
|
|
387
|
+
cov,
|
|
388
|
+
&di,
|
|
389
|
+
eigenvalues,
|
|
390
|
+
work,
|
|
391
|
+
&lwork,
|
|
392
|
+
&info);
|
|
314
393
|
|
|
315
|
-
delete
|
|
394
|
+
delete[] work;
|
|
316
395
|
|
|
317
396
|
if (info != 0) {
|
|
318
|
-
fprintf
|
|
319
|
-
|
|
320
|
-
|
|
397
|
+
fprintf(stderr,
|
|
398
|
+
"WARN ssyev info returns %d, "
|
|
399
|
+
"a very bad PCA matrix is learnt\n",
|
|
400
|
+
int(info));
|
|
321
401
|
// do not throw exception, as the matrix could still be useful
|
|
322
402
|
}
|
|
323
403
|
|
|
324
|
-
|
|
325
|
-
if(verbose && d_in <= 10) {
|
|
404
|
+
if (verbose && d_in <= 10) {
|
|
326
405
|
printf("info=%ld new eigvals=[", long(info));
|
|
327
|
-
for(int j = 0; j < d_in; j++)
|
|
406
|
+
for (int j = 0; j < d_in; j++)
|
|
407
|
+
printf("%g ", eigenvalues[j]);
|
|
328
408
|
printf("]\n");
|
|
329
409
|
|
|
330
|
-
double
|
|
410
|
+
double* ci = cov;
|
|
331
411
|
printf("eigenvecs=\n");
|
|
332
|
-
for(int i = 0; i < d_in; i++) {
|
|
333
|
-
for(int j = 0; j < d_in; j++)
|
|
412
|
+
for (int i = 0; i < d_in; i++) {
|
|
413
|
+
for (int j = 0; j < d_in; j++)
|
|
334
414
|
printf("%10.4g ", *ci++);
|
|
335
415
|
printf("\n");
|
|
336
416
|
}
|
|
337
417
|
}
|
|
338
|
-
|
|
339
418
|
}
|
|
340
419
|
|
|
341
420
|
// revert order of eigenvectors & values
|
|
342
421
|
|
|
343
|
-
for(int i = 0; i < d_in / 2; i++) {
|
|
344
|
-
|
|
422
|
+
for (int i = 0; i < d_in / 2; i++) {
|
|
345
423
|
std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
|
|
346
|
-
double
|
|
347
|
-
double
|
|
348
|
-
for(int j = 0; j < d_in; j++)
|
|
424
|
+
double* v1 = cov + i * d_in;
|
|
425
|
+
double* v2 = cov + (d_in - 1 - i) * d_in;
|
|
426
|
+
for (int j = 0; j < d_in; j++)
|
|
349
427
|
std::swap(v1[j], v2[j]);
|
|
350
428
|
}
|
|
351
|
-
|
|
352
429
|
}
|
|
353
430
|
|
|
431
|
+
} // namespace
|
|
354
432
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
void PCAMatrix::train (Index::idx_t n, const float *x)
|
|
358
|
-
{
|
|
359
|
-
const float * x_in = x;
|
|
433
|
+
void PCAMatrix::train(Index::idx_t n, const float* x) {
|
|
434
|
+
const float* x_in = x;
|
|
360
435
|
|
|
361
|
-
x = fvecs_maybe_subsample
|
|
362
|
-
|
|
436
|
+
x = fvecs_maybe_subsample(
|
|
437
|
+
d_in, (size_t*)&n, max_points_per_d * d_in, x, verbose);
|
|
363
438
|
|
|
364
|
-
ScopeDeleter<float> del_x
|
|
439
|
+
ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
|
|
365
440
|
|
|
366
441
|
// compute mean
|
|
367
|
-
mean.clear();
|
|
442
|
+
mean.clear();
|
|
443
|
+
mean.resize(d_in, 0.0);
|
|
368
444
|
if (have_bias) { // we may want to skip the bias
|
|
369
|
-
const float
|
|
445
|
+
const float* xi = x;
|
|
370
446
|
for (int i = 0; i < n; i++) {
|
|
371
|
-
for(int j = 0; j < d_in; j++)
|
|
447
|
+
for (int j = 0; j < d_in; j++)
|
|
372
448
|
mean[j] += *xi++;
|
|
373
449
|
}
|
|
374
|
-
for(int j = 0; j < d_in; j++)
|
|
450
|
+
for (int j = 0; j < d_in; j++)
|
|
375
451
|
mean[j] /= n;
|
|
376
452
|
}
|
|
377
|
-
if(verbose) {
|
|
453
|
+
if (verbose) {
|
|
378
454
|
printf("mean=[");
|
|
379
|
-
for(int j = 0; j < d_in; j++)
|
|
455
|
+
for (int j = 0; j < d_in; j++)
|
|
456
|
+
printf("%g ", mean[j]);
|
|
380
457
|
printf("]\n");
|
|
381
458
|
}
|
|
382
459
|
|
|
383
|
-
if(n >= d_in) {
|
|
460
|
+
if (n >= d_in) {
|
|
384
461
|
// compute covariance matrix, store it in PCA matrix
|
|
385
462
|
PCAMat.resize(d_in * d_in);
|
|
386
|
-
float
|
|
463
|
+
float* cov = PCAMat.data();
|
|
387
464
|
{ // initialize with mean * mean^T term
|
|
388
|
-
float
|
|
389
|
-
for(int i = 0; i < d_in; i++) {
|
|
390
|
-
for(int j = 0; j < d_in; j++)
|
|
391
|
-
*ci++ = -
|
|
465
|
+
float* ci = cov;
|
|
466
|
+
for (int i = 0; i < d_in; i++) {
|
|
467
|
+
for (int j = 0; j < d_in; j++)
|
|
468
|
+
*ci++ = -n * mean[i] * mean[j];
|
|
392
469
|
}
|
|
393
470
|
}
|
|
394
471
|
{
|
|
395
472
|
FINTEGER di = d_in, ni = n;
|
|
396
473
|
float one = 1.0;
|
|
397
|
-
ssyrk_
|
|
398
|
-
|
|
399
|
-
|
|
474
|
+
ssyrk_("Up",
|
|
475
|
+
"Non transposed",
|
|
476
|
+
&di,
|
|
477
|
+
&ni,
|
|
478
|
+
&one,
|
|
479
|
+
(float*)x,
|
|
480
|
+
&di,
|
|
481
|
+
&one,
|
|
482
|
+
cov,
|
|
483
|
+
&di);
|
|
400
484
|
}
|
|
401
|
-
if(verbose && d_in <= 10) {
|
|
402
|
-
float
|
|
485
|
+
if (verbose && d_in <= 10) {
|
|
486
|
+
float* ci = cov;
|
|
403
487
|
printf("cov=\n");
|
|
404
|
-
for(int i = 0; i < d_in; i++) {
|
|
405
|
-
for(int j = 0; j < d_in; j++)
|
|
488
|
+
for (int i = 0; i < d_in; i++) {
|
|
489
|
+
for (int j = 0; j < d_in; j++)
|
|
406
490
|
printf("%10g ", *ci++);
|
|
407
491
|
printf("\n");
|
|
408
492
|
}
|
|
409
493
|
}
|
|
410
494
|
|
|
411
|
-
std::vector<double> covd
|
|
412
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
495
|
+
std::vector<double> covd(d_in * d_in);
|
|
496
|
+
for (size_t i = 0; i < d_in * d_in; i++)
|
|
497
|
+
covd[i] = cov[i];
|
|
413
498
|
|
|
414
|
-
std::vector<double> eigenvaluesd
|
|
499
|
+
std::vector<double> eigenvaluesd(d_in);
|
|
415
500
|
|
|
416
|
-
eig
|
|
501
|
+
eig(d_in, covd.data(), eigenvaluesd.data(), verbose);
|
|
417
502
|
|
|
418
|
-
for (size_t i = 0; i < d_in * d_in; i++)
|
|
419
|
-
|
|
503
|
+
for (size_t i = 0; i < d_in * d_in; i++)
|
|
504
|
+
PCAMat[i] = covd[i];
|
|
505
|
+
eigenvalues.resize(d_in);
|
|
420
506
|
|
|
421
507
|
for (size_t i = 0; i < d_in; i++)
|
|
422
|
-
eigenvalues
|
|
423
|
-
|
|
508
|
+
eigenvalues[i] = eigenvaluesd[i];
|
|
424
509
|
|
|
425
510
|
} else {
|
|
426
|
-
|
|
427
|
-
std::vector<float> xc (n * d_in);
|
|
511
|
+
std::vector<float> xc(n * d_in);
|
|
428
512
|
|
|
429
513
|
for (size_t i = 0; i < n; i++)
|
|
430
|
-
for(size_t j = 0; j < d_in; j++)
|
|
431
|
-
xc
|
|
514
|
+
for (size_t j = 0; j < d_in; j++)
|
|
515
|
+
xc[i * d_in + j] = x[i * d_in + j] - mean[j];
|
|
432
516
|
|
|
433
517
|
// compute Gram matrix
|
|
434
|
-
std::vector<float> gram
|
|
518
|
+
std::vector<float> gram(n * n);
|
|
435
519
|
{
|
|
436
520
|
FINTEGER di = d_in, ni = n;
|
|
437
521
|
float one = 1.0, zero = 0.0;
|
|
438
|
-
ssyrk_
|
|
439
|
-
|
|
522
|
+
ssyrk_("Up",
|
|
523
|
+
"Transposed",
|
|
524
|
+
&ni,
|
|
525
|
+
&di,
|
|
526
|
+
&one,
|
|
527
|
+
xc.data(),
|
|
528
|
+
&di,
|
|
529
|
+
&zero,
|
|
530
|
+
gram.data(),
|
|
531
|
+
&ni);
|
|
440
532
|
}
|
|
441
533
|
|
|
442
|
-
if(verbose && d_in <= 10) {
|
|
443
|
-
float
|
|
534
|
+
if (verbose && d_in <= 10) {
|
|
535
|
+
float* ci = gram.data();
|
|
444
536
|
printf("gram=\n");
|
|
445
|
-
for(int i = 0; i < n; i++) {
|
|
446
|
-
for(int j = 0; j < n; j++)
|
|
537
|
+
for (int i = 0; i < n; i++) {
|
|
538
|
+
for (int j = 0; j < n; j++)
|
|
447
539
|
printf("%10g ", *ci++);
|
|
448
540
|
printf("\n");
|
|
449
541
|
}
|
|
450
542
|
}
|
|
451
543
|
|
|
452
|
-
std::vector<double> gramd
|
|
544
|
+
std::vector<double> gramd(n * n);
|
|
453
545
|
for (size_t i = 0; i < n * n; i++)
|
|
454
|
-
gramd
|
|
546
|
+
gramd[i] = gram[i];
|
|
455
547
|
|
|
456
|
-
std::vector<double> eigenvaluesd
|
|
548
|
+
std::vector<double> eigenvaluesd(n);
|
|
457
549
|
|
|
458
550
|
// eig will fill in only the n first eigenvals
|
|
459
551
|
|
|
460
|
-
eig
|
|
552
|
+
eig(n, gramd.data(), eigenvaluesd.data(), verbose);
|
|
461
553
|
|
|
462
554
|
PCAMat.resize(d_in * n);
|
|
463
555
|
|
|
464
556
|
for (size_t i = 0; i < n * n; i++)
|
|
465
|
-
gram
|
|
557
|
+
gram[i] = gramd[i];
|
|
466
558
|
|
|
467
|
-
eigenvalues.resize
|
|
559
|
+
eigenvalues.resize(d_in);
|
|
468
560
|
// fill in only the n first ones
|
|
469
561
|
for (size_t i = 0; i < n; i++)
|
|
470
|
-
eigenvalues
|
|
562
|
+
eigenvalues[i] = eigenvaluesd[i];
|
|
471
563
|
|
|
472
564
|
{ // compute PCAMat = x' * v
|
|
473
565
|
FINTEGER di = d_in, ni = n;
|
|
474
566
|
float one = 1.0;
|
|
475
567
|
|
|
476
|
-
sgemm_
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
568
|
+
sgemm_("Non",
|
|
569
|
+
"Non Trans",
|
|
570
|
+
&di,
|
|
571
|
+
&ni,
|
|
572
|
+
&ni,
|
|
573
|
+
&one,
|
|
574
|
+
xc.data(),
|
|
575
|
+
&di,
|
|
576
|
+
gram.data(),
|
|
577
|
+
&ni,
|
|
578
|
+
&one,
|
|
579
|
+
PCAMat.data(),
|
|
580
|
+
&di);
|
|
480
581
|
}
|
|
481
582
|
|
|
482
|
-
if(verbose && d_in <= 10) {
|
|
483
|
-
float
|
|
583
|
+
if (verbose && d_in <= 10) {
|
|
584
|
+
float* ci = PCAMat.data();
|
|
484
585
|
printf("PCAMat=\n");
|
|
485
|
-
for(int i = 0; i < n; i++) {
|
|
486
|
-
for(int j = 0; j < d_in; j++)
|
|
586
|
+
for (int i = 0; i < n; i++) {
|
|
587
|
+
for (int j = 0; j < d_in; j++)
|
|
487
588
|
printf("%10g ", *ci++);
|
|
488
589
|
printf("\n");
|
|
489
590
|
}
|
|
490
591
|
}
|
|
491
|
-
fvec_renorm_L2
|
|
492
|
-
|
|
592
|
+
fvec_renorm_L2(d_in, n, PCAMat.data());
|
|
493
593
|
}
|
|
494
594
|
|
|
495
595
|
prepare_Ab();
|
|
496
596
|
is_trained = true;
|
|
497
597
|
}
|
|
498
598
|
|
|
499
|
-
void PCAMatrix::copy_from
|
|
500
|
-
|
|
501
|
-
FAISS_THROW_IF_NOT (other.is_trained);
|
|
599
|
+
void PCAMatrix::copy_from(const PCAMatrix& other) {
|
|
600
|
+
FAISS_THROW_IF_NOT(other.is_trained);
|
|
502
601
|
mean = other.mean;
|
|
503
602
|
eigenvalues = other.eigenvalues;
|
|
504
603
|
PCAMat = other.PCAMat;
|
|
505
|
-
prepare_Ab
|
|
604
|
+
prepare_Ab();
|
|
506
605
|
is_trained = true;
|
|
507
606
|
}
|
|
508
607
|
|
|
509
|
-
void PCAMatrix::prepare_Ab
|
|
510
|
-
|
|
511
|
-
FAISS_THROW_IF_NOT_FMT (
|
|
608
|
+
void PCAMatrix::prepare_Ab() {
|
|
609
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
512
610
|
d_out * d_in <= PCAMat.size(),
|
|
513
611
|
"PCA matrix cannot output %d dimensions from %d ",
|
|
514
|
-
d_out,
|
|
612
|
+
d_out,
|
|
613
|
+
d_in);
|
|
515
614
|
|
|
516
615
|
if (!random_rotation) {
|
|
517
616
|
A = PCAMat;
|
|
@@ -519,23 +618,23 @@ void PCAMatrix::prepare_Ab ()
|
|
|
519
618
|
|
|
520
619
|
// first scale the components
|
|
521
620
|
if (eigen_power != 0) {
|
|
522
|
-
float
|
|
621
|
+
float* ai = A.data();
|
|
523
622
|
for (int i = 0; i < d_out; i++) {
|
|
524
623
|
float factor = pow(eigenvalues[i], eigen_power);
|
|
525
|
-
for(int j = 0; j < d_in; j++)
|
|
624
|
+
for (int j = 0; j < d_in; j++)
|
|
526
625
|
*ai++ *= factor;
|
|
527
626
|
}
|
|
528
627
|
}
|
|
529
628
|
|
|
530
629
|
if (balanced_bins != 0) {
|
|
531
|
-
FAISS_THROW_IF_NOT
|
|
630
|
+
FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
|
|
532
631
|
int dsub = d_out / balanced_bins;
|
|
533
|
-
std::vector
|
|
632
|
+
std::vector<float> Ain;
|
|
534
633
|
std::swap(A, Ain);
|
|
535
634
|
A.resize(d_out * d_in);
|
|
536
635
|
|
|
537
|
-
std::vector
|
|
538
|
-
std::vector
|
|
636
|
+
std::vector<float> accu(balanced_bins);
|
|
637
|
+
std::vector<int> counter(balanced_bins);
|
|
539
638
|
|
|
540
639
|
// greedy assignment
|
|
541
640
|
for (int i = 0; i < d_out; i++) {
|
|
@@ -550,9 +649,8 @@ void PCAMatrix::prepare_Ab ()
|
|
|
550
649
|
}
|
|
551
650
|
int row_dst = best_j * dsub + counter[best_j];
|
|
552
651
|
accu[best_j] += eigenvalues[i];
|
|
553
|
-
counter[best_j]
|
|
554
|
-
memcpy
|
|
555
|
-
d_in * sizeof (A[0]));
|
|
652
|
+
counter[best_j]++;
|
|
653
|
+
memcpy(&A[row_dst * d_in], &Ain[i * d_in], d_in * sizeof(A[0]));
|
|
556
654
|
}
|
|
557
655
|
|
|
558
656
|
if (verbose) {
|
|
@@ -563,11 +661,11 @@ void PCAMatrix::prepare_Ab ()
|
|
|
563
661
|
}
|
|
564
662
|
}
|
|
565
663
|
|
|
566
|
-
|
|
567
664
|
} else {
|
|
568
|
-
FAISS_THROW_IF_NOT_MSG
|
|
569
|
-
|
|
570
|
-
|
|
665
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
666
|
+
balanced_bins == 0,
|
|
667
|
+
"both balancing bins and applying a random rotation "
|
|
668
|
+
"does not make sense");
|
|
571
669
|
RandomRotationMatrix rr(d_out, d_out);
|
|
572
670
|
|
|
573
671
|
rr.init(5);
|
|
@@ -576,8 +674,8 @@ void PCAMatrix::prepare_Ab ()
|
|
|
576
674
|
if (eigen_power != 0) {
|
|
577
675
|
for (int i = 0; i < d_out; i++) {
|
|
578
676
|
float factor = pow(eigenvalues[i], eigen_power);
|
|
579
|
-
for(int j = 0; j < d_out; j++)
|
|
580
|
-
|
|
677
|
+
for (int j = 0; j < d_out; j++)
|
|
678
|
+
rr.A[j * d_out + i] *= factor;
|
|
581
679
|
}
|
|
582
680
|
}
|
|
583
681
|
|
|
@@ -586,15 +684,24 @@ void PCAMatrix::prepare_Ab ()
|
|
|
586
684
|
FINTEGER dii = d_in, doo = d_out;
|
|
587
685
|
float one = 1.0, zero = 0.0;
|
|
588
686
|
|
|
589
|
-
sgemm_
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
687
|
+
sgemm_("Not",
|
|
688
|
+
"Not",
|
|
689
|
+
&dii,
|
|
690
|
+
&doo,
|
|
691
|
+
&doo,
|
|
692
|
+
&one,
|
|
693
|
+
PCAMat.data(),
|
|
694
|
+
&dii,
|
|
695
|
+
rr.A.data(),
|
|
696
|
+
&doo,
|
|
697
|
+
&zero,
|
|
698
|
+
A.data(),
|
|
699
|
+
&dii);
|
|
593
700
|
}
|
|
594
|
-
|
|
595
701
|
}
|
|
596
702
|
|
|
597
|
-
b.clear();
|
|
703
|
+
b.clear();
|
|
704
|
+
b.resize(d_out);
|
|
598
705
|
|
|
599
706
|
for (int i = 0; i < d_out; i++) {
|
|
600
707
|
float accu = 0;
|
|
@@ -604,57 +711,61 @@ void PCAMatrix::prepare_Ab ()
|
|
|
604
711
|
}
|
|
605
712
|
|
|
606
713
|
is_orthonormal = eigen_power == 0;
|
|
607
|
-
|
|
608
714
|
}
|
|
609
715
|
|
|
610
716
|
/*********************************************
|
|
611
717
|
* ITQMatrix
|
|
612
718
|
*********************************************/
|
|
613
719
|
|
|
614
|
-
ITQMatrix::ITQMatrix
|
|
615
|
-
|
|
616
|
-
max_iter (50),
|
|
617
|
-
seed (123)
|
|
618
|
-
{
|
|
619
|
-
}
|
|
620
|
-
|
|
720
|
+
ITQMatrix::ITQMatrix(int d)
|
|
721
|
+
: LinearTransform(d, d, false), max_iter(50), seed(123) {}
|
|
621
722
|
|
|
622
723
|
/** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
|
|
623
|
-
void ITQMatrix::train
|
|
624
|
-
{
|
|
724
|
+
void ITQMatrix::train(Index::idx_t n, const float* xf) {
|
|
625
725
|
size_t d = d_in;
|
|
626
|
-
std::vector<double> rotation
|
|
726
|
+
std::vector<double> rotation(d * d);
|
|
627
727
|
|
|
628
728
|
if (init_rotation.size() == d * d) {
|
|
629
|
-
memcpy
|
|
630
|
-
|
|
729
|
+
memcpy(rotation.data(),
|
|
730
|
+
init_rotation.data(),
|
|
731
|
+
d * d * sizeof(rotation[0]));
|
|
631
732
|
} else {
|
|
632
|
-
RandomRotationMatrix rrot
|
|
633
|
-
rrot.init
|
|
733
|
+
RandomRotationMatrix rrot(d, d);
|
|
734
|
+
rrot.init(seed);
|
|
634
735
|
for (size_t i = 0; i < d * d; i++) {
|
|
635
736
|
rotation[i] = rrot.A[i];
|
|
636
737
|
}
|
|
637
738
|
}
|
|
638
739
|
|
|
639
|
-
std::vector<double> x
|
|
740
|
+
std::vector<double> x(n * d);
|
|
640
741
|
|
|
641
742
|
for (size_t i = 0; i < n * d; i++) {
|
|
642
743
|
x[i] = xf[i];
|
|
643
744
|
}
|
|
644
745
|
|
|
645
|
-
std::vector<double> rotated_x
|
|
646
|
-
std::vector<double> u
|
|
746
|
+
std::vector<double> rotated_x(n * d), cov_mat(d * d);
|
|
747
|
+
std::vector<double> u(d * d), vt(d * d), singvals(d);
|
|
647
748
|
|
|
648
749
|
for (int i = 0; i < max_iter; i++) {
|
|
649
|
-
print_if_verbose
|
|
750
|
+
print_if_verbose("rotation", rotation, d, d);
|
|
650
751
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
651
752
|
FINTEGER di = d, ni = n;
|
|
652
753
|
double one = 1, zero = 0;
|
|
653
|
-
dgemm_
|
|
654
|
-
|
|
655
|
-
|
|
754
|
+
dgemm_("N",
|
|
755
|
+
"N",
|
|
756
|
+
&di,
|
|
757
|
+
&ni,
|
|
758
|
+
&di,
|
|
759
|
+
&one,
|
|
760
|
+
rotation.data(),
|
|
761
|
+
&di,
|
|
762
|
+
x.data(),
|
|
763
|
+
&di,
|
|
764
|
+
&zero,
|
|
765
|
+
rotated_x.data(),
|
|
766
|
+
&di);
|
|
656
767
|
}
|
|
657
|
-
print_if_verbose
|
|
768
|
+
print_if_verbose("rotated_x", rotated_x, n, d);
|
|
658
769
|
// binarize
|
|
659
770
|
for (size_t j = 0; j < n * d; j++) {
|
|
660
771
|
rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
|
|
@@ -663,88 +774,119 @@ void ITQMatrix::train (Index::idx_t n, const float* xf)
|
|
|
663
774
|
{ // rotated_data = np.dot(training_data, rotation)
|
|
664
775
|
FINTEGER di = d, ni = n;
|
|
665
776
|
double one = 1, zero = 0;
|
|
666
|
-
dgemm_
|
|
667
|
-
|
|
668
|
-
|
|
777
|
+
dgemm_("N",
|
|
778
|
+
"T",
|
|
779
|
+
&di,
|
|
780
|
+
&di,
|
|
781
|
+
&ni,
|
|
782
|
+
&one,
|
|
783
|
+
rotated_x.data(),
|
|
784
|
+
&di,
|
|
785
|
+
x.data(),
|
|
786
|
+
&di,
|
|
787
|
+
&zero,
|
|
788
|
+
cov_mat.data(),
|
|
789
|
+
&di);
|
|
669
790
|
}
|
|
670
|
-
print_if_verbose
|
|
791
|
+
print_if_verbose("cov_mat", cov_mat, d, d);
|
|
671
792
|
// SVD
|
|
672
793
|
{
|
|
673
|
-
|
|
674
794
|
FINTEGER di = d;
|
|
675
795
|
FINTEGER lwork = -1, info;
|
|
676
796
|
double lwork1;
|
|
677
797
|
|
|
678
798
|
// workspace query
|
|
679
|
-
dgesvd_
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
799
|
+
dgesvd_("A",
|
|
800
|
+
"A",
|
|
801
|
+
&di,
|
|
802
|
+
&di,
|
|
803
|
+
cov_mat.data(),
|
|
804
|
+
&di,
|
|
805
|
+
singvals.data(),
|
|
806
|
+
u.data(),
|
|
807
|
+
&di,
|
|
808
|
+
vt.data(),
|
|
809
|
+
&di,
|
|
810
|
+
&lwork1,
|
|
811
|
+
&lwork,
|
|
812
|
+
&info);
|
|
813
|
+
|
|
814
|
+
FAISS_THROW_IF_NOT(info == 0);
|
|
815
|
+
lwork = size_t(lwork1);
|
|
816
|
+
std::vector<double> work(lwork);
|
|
817
|
+
dgesvd_("A",
|
|
818
|
+
"A",
|
|
819
|
+
&di,
|
|
820
|
+
&di,
|
|
821
|
+
cov_mat.data(),
|
|
822
|
+
&di,
|
|
823
|
+
singvals.data(),
|
|
824
|
+
u.data(),
|
|
825
|
+
&di,
|
|
826
|
+
vt.data(),
|
|
827
|
+
&di,
|
|
828
|
+
work.data(),
|
|
829
|
+
&lwork,
|
|
830
|
+
&info);
|
|
831
|
+
FAISS_THROW_IF_NOT_FMT(info == 0, "sgesvd returned info=%d", info);
|
|
693
832
|
}
|
|
694
|
-
print_if_verbose
|
|
695
|
-
print_if_verbose
|
|
833
|
+
print_if_verbose("u", u, d, d);
|
|
834
|
+
print_if_verbose("vt", vt, d, d);
|
|
696
835
|
// update rotation
|
|
697
836
|
{
|
|
698
837
|
FINTEGER di = d;
|
|
699
838
|
double one = 1, zero = 0;
|
|
700
|
-
dgemm_
|
|
701
|
-
|
|
702
|
-
|
|
839
|
+
dgemm_("N",
|
|
840
|
+
"T",
|
|
841
|
+
&di,
|
|
842
|
+
&di,
|
|
843
|
+
&di,
|
|
844
|
+
&one,
|
|
845
|
+
u.data(),
|
|
846
|
+
&di,
|
|
847
|
+
vt.data(),
|
|
848
|
+
&di,
|
|
849
|
+
&zero,
|
|
850
|
+
rotation.data(),
|
|
851
|
+
&di);
|
|
703
852
|
}
|
|
704
|
-
print_if_verbose
|
|
705
|
-
|
|
853
|
+
print_if_verbose("final rot", rotation, d, d);
|
|
706
854
|
}
|
|
707
|
-
A.resize
|
|
855
|
+
A.resize(d * d);
|
|
708
856
|
for (size_t i = 0; i < d; i++) {
|
|
709
857
|
for (size_t j = 0; j < d; j++) {
|
|
710
858
|
A[i + d * j] = rotation[j + d * i];
|
|
711
859
|
}
|
|
712
860
|
}
|
|
713
861
|
is_trained = true;
|
|
714
|
-
|
|
715
862
|
}
|
|
716
863
|
|
|
717
|
-
ITQTransform::ITQTransform
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
{
|
|
864
|
+
ITQTransform::ITQTransform(int d_in, int d_out, bool do_pca)
|
|
865
|
+
: VectorTransform(d_in, d_out),
|
|
866
|
+
do_pca(do_pca),
|
|
867
|
+
itq(d_out),
|
|
868
|
+
pca_then_itq(d_in, d_out, false) {
|
|
723
869
|
if (!do_pca) {
|
|
724
|
-
FAISS_THROW_IF_NOT
|
|
870
|
+
FAISS_THROW_IF_NOT(d_in == d_out);
|
|
725
871
|
}
|
|
726
872
|
max_train_per_dim = 10;
|
|
727
873
|
is_trained = false;
|
|
728
874
|
}
|
|
729
875
|
|
|
876
|
+
void ITQTransform::train(idx_t n, const float* x) {
|
|
877
|
+
FAISS_THROW_IF_NOT(!is_trained);
|
|
730
878
|
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
void ITQTransform::train (idx_t n, const float *x)
|
|
734
|
-
{
|
|
735
|
-
FAISS_THROW_IF_NOT (!is_trained);
|
|
736
|
-
|
|
737
|
-
const float * x_in = x;
|
|
879
|
+
const float* x_in = x;
|
|
738
880
|
size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
|
|
739
|
-
x = fvecs_maybe_subsample
|
|
881
|
+
x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x);
|
|
740
882
|
|
|
741
|
-
ScopeDeleter<float> del_x
|
|
883
|
+
ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
|
|
742
884
|
|
|
743
|
-
std::unique_ptr<float
|
|
885
|
+
std::unique_ptr<float[]> x_norm(new float[n * d_in]);
|
|
744
886
|
{ // normalize
|
|
745
887
|
int d = d_in;
|
|
746
888
|
|
|
747
|
-
mean.resize
|
|
889
|
+
mean.resize(d, 0);
|
|
748
890
|
for (idx_t i = 0; i < n; i++) {
|
|
749
891
|
for (idx_t j = 0; j < d; j++) {
|
|
750
892
|
mean[j] += x[i * d + j];
|
|
@@ -755,38 +897,47 @@ void ITQTransform::train (idx_t n, const float *x)
|
|
|
755
897
|
}
|
|
756
898
|
for (idx_t i = 0; i < n; i++) {
|
|
757
899
|
for (idx_t j = 0; j < d; j++) {
|
|
758
|
-
|
|
900
|
+
x_norm[i * d + j] = x[i * d + j] - mean[j];
|
|
759
901
|
}
|
|
760
902
|
}
|
|
761
|
-
fvec_renorm_L2
|
|
903
|
+
fvec_renorm_L2(d_in, n, x_norm.get());
|
|
762
904
|
}
|
|
763
905
|
|
|
764
906
|
// train PCA
|
|
765
907
|
|
|
766
|
-
PCAMatrix pca
|
|
767
|
-
float
|
|
768
|
-
std::unique_ptr<float
|
|
908
|
+
PCAMatrix pca(d_in, d_out);
|
|
909
|
+
float* x_pca;
|
|
910
|
+
std::unique_ptr<float[]> x_pca_del;
|
|
769
911
|
if (do_pca) {
|
|
770
|
-
pca.have_bias = false;
|
|
771
|
-
pca.train
|
|
772
|
-
x_pca = pca.apply
|
|
912
|
+
pca.have_bias = false; // for consistency with reference implem
|
|
913
|
+
pca.train(n, x_norm.get());
|
|
914
|
+
x_pca = pca.apply(n, x_norm.get());
|
|
773
915
|
x_pca_del.reset(x_pca);
|
|
774
916
|
} else {
|
|
775
917
|
x_pca = x_norm.get();
|
|
776
918
|
}
|
|
777
919
|
|
|
778
920
|
// train ITQ
|
|
779
|
-
itq.train
|
|
921
|
+
itq.train(n, x_pca);
|
|
780
922
|
|
|
781
923
|
// merge PCA and ITQ
|
|
782
924
|
if (do_pca) {
|
|
783
925
|
FINTEGER di = d_out, dini = d_in;
|
|
784
926
|
float one = 1, zero = 0;
|
|
785
927
|
pca_then_itq.A.resize(d_in * d_out);
|
|
786
|
-
sgemm_
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
928
|
+
sgemm_("N",
|
|
929
|
+
"N",
|
|
930
|
+
&dini,
|
|
931
|
+
&di,
|
|
932
|
+
&di,
|
|
933
|
+
&one,
|
|
934
|
+
pca.A.data(),
|
|
935
|
+
&dini,
|
|
936
|
+
itq.A.data(),
|
|
937
|
+
&di,
|
|
938
|
+
&zero,
|
|
939
|
+
pca_then_itq.A.data(),
|
|
940
|
+
&dini);
|
|
790
941
|
} else {
|
|
791
942
|
pca_then_itq.A = itq.A;
|
|
792
943
|
}
|
|
@@ -794,12 +945,11 @@ void ITQTransform::train (idx_t n, const float *x)
|
|
|
794
945
|
is_trained = true;
|
|
795
946
|
}
|
|
796
947
|
|
|
797
|
-
void ITQTransform::apply_noalloc
|
|
798
|
-
|
|
799
|
-
{
|
|
948
|
+
void ITQTransform::apply_noalloc(Index::idx_t n, const float* x, float* xt)
|
|
949
|
+
const {
|
|
800
950
|
FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
|
|
801
951
|
|
|
802
|
-
std::unique_ptr<float
|
|
952
|
+
std::unique_ptr<float[]> x_norm(new float[n * d_in]);
|
|
803
953
|
{ // normalize
|
|
804
954
|
int d = d_in;
|
|
805
955
|
for (idx_t i = 0; i < n; i++) {
|
|
@@ -809,41 +959,36 @@ void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
|
|
|
809
959
|
}
|
|
810
960
|
// this is not really useful if we are going to binarize right
|
|
811
961
|
// afterwards but OK
|
|
812
|
-
fvec_renorm_L2
|
|
962
|
+
fvec_renorm_L2(d_in, n, x_norm.get());
|
|
813
963
|
}
|
|
814
964
|
|
|
815
|
-
pca_then_itq.apply_noalloc
|
|
965
|
+
pca_then_itq.apply_noalloc(n, x_norm.get(), xt);
|
|
816
966
|
}
|
|
817
967
|
|
|
818
968
|
/*********************************************
|
|
819
969
|
* OPQMatrix
|
|
820
970
|
*********************************************/
|
|
821
971
|
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
{
|
|
972
|
+
OPQMatrix::OPQMatrix(int d, int M, int d2)
|
|
973
|
+
: LinearTransform(d, d2 == -1 ? d : d2, false),
|
|
974
|
+
M(M),
|
|
975
|
+
niter(50),
|
|
976
|
+
niter_pq(4),
|
|
977
|
+
niter_pq_0(40),
|
|
978
|
+
verbose(false),
|
|
979
|
+
pq(nullptr) {
|
|
830
980
|
is_trained = false;
|
|
831
981
|
// OPQ is quite expensive to train, so set this right.
|
|
832
982
|
max_train_points = 256 * 256;
|
|
833
983
|
pq = nullptr;
|
|
834
984
|
}
|
|
835
985
|
|
|
986
|
+
void OPQMatrix::train(Index::idx_t n, const float* x) {
|
|
987
|
+
const float* x_in = x;
|
|
836
988
|
|
|
989
|
+
x = fvecs_maybe_subsample(d_in, (size_t*)&n, max_train_points, x, verbose);
|
|
837
990
|
|
|
838
|
-
|
|
839
|
-
{
|
|
840
|
-
|
|
841
|
-
const float * x_in = x;
|
|
842
|
-
|
|
843
|
-
x = fvecs_maybe_subsample (d_in, (size_t*)&n,
|
|
844
|
-
max_train_points, x, verbose);
|
|
845
|
-
|
|
846
|
-
ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
|
|
991
|
+
ScopeDeleter<float> del_x(x != x_in ? x : nullptr);
|
|
847
992
|
|
|
848
993
|
// To support d_out > d_in, we pad input vectors with 0s to d_out
|
|
849
994
|
size_t d = d_out <= d_in ? d_in : d_out;
|
|
@@ -867,22 +1012,26 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
|
|
|
867
1012
|
#endif
|
|
868
1013
|
|
|
869
1014
|
if (verbose) {
|
|
870
|
-
printf
|
|
871
|
-
|
|
872
|
-
|
|
1015
|
+
printf("OPQMatrix::train: training an OPQ rotation matrix "
|
|
1016
|
+
"for M=%d from %" PRId64 " vectors in %dD -> %dD\n",
|
|
1017
|
+
M,
|
|
1018
|
+
n,
|
|
1019
|
+
d_in,
|
|
1020
|
+
d_out);
|
|
873
1021
|
}
|
|
874
1022
|
|
|
875
|
-
std::vector<float> xtrain
|
|
1023
|
+
std::vector<float> xtrain(n * d);
|
|
876
1024
|
// center x
|
|
877
1025
|
{
|
|
878
|
-
std::vector<float> sum
|
|
879
|
-
const float
|
|
1026
|
+
std::vector<float> sum(d);
|
|
1027
|
+
const float* xi = x;
|
|
880
1028
|
for (size_t i = 0; i < n; i++) {
|
|
881
1029
|
for (int j = 0; j < d_in; j++)
|
|
882
|
-
sum
|
|
1030
|
+
sum[j] += *xi++;
|
|
883
1031
|
}
|
|
884
|
-
for (int i = 0; i < d; i++)
|
|
885
|
-
|
|
1032
|
+
for (int i = 0; i < d; i++)
|
|
1033
|
+
sum[i] /= n;
|
|
1034
|
+
float* yi = xtrain.data();
|
|
886
1035
|
xi = x;
|
|
887
1036
|
for (size_t i = 0; i < n; i++) {
|
|
888
1037
|
for (int j = 0; j < d_in; j++)
|
|
@@ -890,71 +1039,80 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
|
|
|
890
1039
|
yi += d - d_in;
|
|
891
1040
|
}
|
|
892
1041
|
}
|
|
893
|
-
float
|
|
1042
|
+
float* rotation;
|
|
894
1043
|
|
|
895
|
-
if (A.size
|
|
896
|
-
A.resize
|
|
1044
|
+
if (A.size() == 0) {
|
|
1045
|
+
A.resize(d * d);
|
|
897
1046
|
rotation = A.data();
|
|
898
1047
|
if (verbose)
|
|
899
1048
|
printf(" OPQMatrix::train: making random %zd*%zd rotation\n",
|
|
900
|
-
d,
|
|
901
|
-
|
|
902
|
-
|
|
1049
|
+
d,
|
|
1050
|
+
d);
|
|
1051
|
+
float_randn(rotation, d * d, 1234);
|
|
1052
|
+
matrix_qr(d, d, rotation);
|
|
903
1053
|
// we use only the d * d2 upper part of the matrix
|
|
904
|
-
A.resize
|
|
1054
|
+
A.resize(d * d2);
|
|
905
1055
|
} else {
|
|
906
|
-
FAISS_THROW_IF_NOT
|
|
1056
|
+
FAISS_THROW_IF_NOT(A.size() == d * d2);
|
|
907
1057
|
rotation = A.data();
|
|
908
1058
|
}
|
|
909
1059
|
|
|
910
|
-
std::vector<float>
|
|
911
|
-
|
|
912
|
-
tmp(d * d * 4);
|
|
913
|
-
|
|
1060
|
+
std::vector<float> xproj(d2 * n), pq_recons(d2 * n), xxr(d * n),
|
|
1061
|
+
tmp(d * d * 4);
|
|
914
1062
|
|
|
915
|
-
ProductQuantizer pq_default
|
|
916
|
-
ProductQuantizer
|
|
917
|
-
std::vector<uint8_t> codes
|
|
1063
|
+
ProductQuantizer pq_default(d2, M, 8);
|
|
1064
|
+
ProductQuantizer& pq_regular = pq ? *pq : pq_default;
|
|
1065
|
+
std::vector<uint8_t> codes(pq_regular.code_size * n);
|
|
918
1066
|
|
|
919
1067
|
double t0 = getmillisecs();
|
|
920
1068
|
for (int iter = 0; iter < niter; iter++) {
|
|
921
|
-
|
|
922
1069
|
{ // torch.mm(xtrain, rotation:t())
|
|
923
1070
|
FINTEGER di = d, d2i = d2, ni = n;
|
|
924
1071
|
float zero = 0, one = 1;
|
|
925
|
-
sgemm_
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
1072
|
+
sgemm_("Transposed",
|
|
1073
|
+
"Not transposed",
|
|
1074
|
+
&d2i,
|
|
1075
|
+
&ni,
|
|
1076
|
+
&di,
|
|
1077
|
+
&one,
|
|
1078
|
+
rotation,
|
|
1079
|
+
&di,
|
|
1080
|
+
xtrain.data(),
|
|
1081
|
+
&di,
|
|
1082
|
+
&zero,
|
|
1083
|
+
xproj.data(),
|
|
1084
|
+
&d2i);
|
|
930
1085
|
}
|
|
931
1086
|
|
|
932
1087
|
pq_regular.cp.max_points_per_centroid = 1000;
|
|
933
1088
|
pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
|
|
934
1089
|
pq_regular.verbose = verbose;
|
|
935
|
-
pq_regular.train
|
|
1090
|
+
pq_regular.train(n, xproj.data());
|
|
936
1091
|
|
|
937
1092
|
if (verbose) {
|
|
938
1093
|
printf(" encode / decode\n");
|
|
939
1094
|
}
|
|
940
1095
|
if (pq_regular.assign_index) {
|
|
941
|
-
pq_regular.compute_codes_with_assign_index
|
|
942
|
-
|
|
1096
|
+
pq_regular.compute_codes_with_assign_index(
|
|
1097
|
+
xproj.data(), codes.data(), n);
|
|
943
1098
|
} else {
|
|
944
|
-
pq_regular.compute_codes
|
|
1099
|
+
pq_regular.compute_codes(xproj.data(), codes.data(), n);
|
|
945
1100
|
}
|
|
946
|
-
pq_regular.decode
|
|
1101
|
+
pq_regular.decode(codes.data(), pq_recons.data(), n);
|
|
947
1102
|
|
|
948
|
-
float pq_err = fvec_L2sqr
|
|
1103
|
+
float pq_err = fvec_L2sqr(pq_recons.data(), xproj.data(), n * d2) / n;
|
|
949
1104
|
|
|
950
1105
|
if (verbose)
|
|
951
|
-
printf
|
|
952
|
-
|
|
953
|
-
|
|
1106
|
+
printf(" Iteration %d (%d PQ iterations):"
|
|
1107
|
+
"%.3f s, obj=%g\n",
|
|
1108
|
+
iter,
|
|
1109
|
+
pq_regular.cp.niter,
|
|
1110
|
+
(getmillisecs() - t0) / 1000.0,
|
|
1111
|
+
pq_err);
|
|
954
1112
|
|
|
955
1113
|
{
|
|
956
|
-
float *u = tmp.data(), *vt = &tmp
|
|
957
|
-
float
|
|
1114
|
+
float *u = tmp.data(), *vt = &tmp[d * d];
|
|
1115
|
+
float* sing_val = &tmp[2 * d * d];
|
|
958
1116
|
FINTEGER di = d, d2i = d2, ni = n;
|
|
959
1117
|
float one = 1, zero = 0;
|
|
960
1118
|
|
|
@@ -962,36 +1120,69 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
|
|
|
962
1120
|
printf(" X * recons\n");
|
|
963
1121
|
}
|
|
964
1122
|
// torch.mm(xtrain:t(), pq_recons)
|
|
965
|
-
sgemm_
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1123
|
+
sgemm_("Not",
|
|
1124
|
+
"Transposed",
|
|
1125
|
+
&d2i,
|
|
1126
|
+
&di,
|
|
1127
|
+
&ni,
|
|
1128
|
+
&one,
|
|
1129
|
+
pq_recons.data(),
|
|
1130
|
+
&d2i,
|
|
1131
|
+
xtrain.data(),
|
|
1132
|
+
&di,
|
|
1133
|
+
&zero,
|
|
1134
|
+
xxr.data(),
|
|
1135
|
+
&d2i);
|
|
971
1136
|
|
|
972
1137
|
FINTEGER lwork = -1, info = -1;
|
|
973
1138
|
float worksz;
|
|
974
1139
|
// workspace query
|
|
975
|
-
sgesvd_
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
1140
|
+
sgesvd_("All",
|
|
1141
|
+
"All",
|
|
1142
|
+
&d2i,
|
|
1143
|
+
&di,
|
|
1144
|
+
xxr.data(),
|
|
1145
|
+
&d2i,
|
|
1146
|
+
sing_val,
|
|
1147
|
+
vt,
|
|
1148
|
+
&d2i,
|
|
1149
|
+
u,
|
|
1150
|
+
&di,
|
|
1151
|
+
&worksz,
|
|
1152
|
+
&lwork,
|
|
1153
|
+
&info);
|
|
980
1154
|
|
|
981
1155
|
lwork = int(worksz);
|
|
982
|
-
std::vector<float> work
|
|
1156
|
+
std::vector<float> work(lwork);
|
|
983
1157
|
// u and vt swapped
|
|
984
|
-
sgesvd_
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
&
|
|
993
|
-
|
|
994
|
-
|
|
1158
|
+
sgesvd_("All",
|
|
1159
|
+
"All",
|
|
1160
|
+
&d2i,
|
|
1161
|
+
&di,
|
|
1162
|
+
xxr.data(),
|
|
1163
|
+
&d2i,
|
|
1164
|
+
sing_val,
|
|
1165
|
+
vt,
|
|
1166
|
+
&d2i,
|
|
1167
|
+
u,
|
|
1168
|
+
&di,
|
|
1169
|
+
work.data(),
|
|
1170
|
+
&lwork,
|
|
1171
|
+
&info);
|
|
1172
|
+
|
|
1173
|
+
sgemm_("Transposed",
|
|
1174
|
+
"Transposed",
|
|
1175
|
+
&di,
|
|
1176
|
+
&d2i,
|
|
1177
|
+
&d2i,
|
|
1178
|
+
&one,
|
|
1179
|
+
u,
|
|
1180
|
+
&di,
|
|
1181
|
+
vt,
|
|
1182
|
+
&d2i,
|
|
1183
|
+
&zero,
|
|
1184
|
+
rotation,
|
|
1185
|
+
&di);
|
|
995
1186
|
}
|
|
996
1187
|
pq_regular.train_type = ProductQuantizer::Train_hot_start;
|
|
997
1188
|
}
|
|
@@ -999,59 +1190,52 @@ void OPQMatrix::train (Index::idx_t n, const float *x)
|
|
|
999
1190
|
// revert A matrix
|
|
1000
1191
|
if (d > d_in) {
|
|
1001
1192
|
for (long i = 0; i < d_out; i++)
|
|
1002
|
-
memmove
|
|
1003
|
-
A.resize
|
|
1193
|
+
memmove(&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
|
|
1194
|
+
A.resize(d_in * d_out);
|
|
1004
1195
|
}
|
|
1005
1196
|
|
|
1006
1197
|
is_trained = true;
|
|
1007
1198
|
is_orthonormal = true;
|
|
1008
1199
|
}
|
|
1009
1200
|
|
|
1010
|
-
|
|
1011
1201
|
/*********************************************
|
|
1012
1202
|
* NormalizationTransform
|
|
1013
1203
|
*********************************************/
|
|
1014
1204
|
|
|
1015
|
-
NormalizationTransform::NormalizationTransform
|
|
1016
|
-
|
|
1017
|
-
{
|
|
1018
|
-
}
|
|
1205
|
+
NormalizationTransform::NormalizationTransform(int d, float norm)
|
|
1206
|
+
: VectorTransform(d, d), norm(norm) {}
|
|
1019
1207
|
|
|
1020
|
-
NormalizationTransform::NormalizationTransform
|
|
1021
|
-
|
|
1022
|
-
{
|
|
1023
|
-
}
|
|
1208
|
+
NormalizationTransform::NormalizationTransform()
|
|
1209
|
+
: VectorTransform(-1, -1), norm(-1) {}
|
|
1024
1210
|
|
|
1025
|
-
void NormalizationTransform::apply_noalloc
|
|
1026
|
-
|
|
1027
|
-
{
|
|
1211
|
+
void NormalizationTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
1212
|
+
const {
|
|
1028
1213
|
if (norm == 2.0) {
|
|
1029
|
-
memcpy
|
|
1030
|
-
fvec_renorm_L2
|
|
1214
|
+
memcpy(xt, x, sizeof(x[0]) * n * d_in);
|
|
1215
|
+
fvec_renorm_L2(d_in, n, xt);
|
|
1031
1216
|
} else {
|
|
1032
|
-
FAISS_THROW_MSG
|
|
1217
|
+
FAISS_THROW_MSG("not implemented");
|
|
1033
1218
|
}
|
|
1034
1219
|
}
|
|
1035
1220
|
|
|
1036
|
-
void NormalizationTransform::reverse_transform
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1221
|
+
void NormalizationTransform::reverse_transform(
|
|
1222
|
+
idx_t n,
|
|
1223
|
+
const float* xt,
|
|
1224
|
+
float* x) const {
|
|
1225
|
+
memcpy(x, xt, sizeof(xt[0]) * n * d_in);
|
|
1040
1226
|
}
|
|
1041
1227
|
|
|
1042
1228
|
/*********************************************
|
|
1043
1229
|
* CenteringTransform
|
|
1044
1230
|
*********************************************/
|
|
1045
1231
|
|
|
1046
|
-
CenteringTransform::CenteringTransform
|
|
1047
|
-
VectorTransform (d, d)
|
|
1048
|
-
{
|
|
1232
|
+
CenteringTransform::CenteringTransform(int d) : VectorTransform(d, d) {
|
|
1049
1233
|
is_trained = false;
|
|
1050
1234
|
}
|
|
1051
1235
|
|
|
1052
|
-
void CenteringTransform::train(Index::idx_t n, const float
|
|
1236
|
+
void CenteringTransform::train(Index::idx_t n, const float* x) {
|
|
1053
1237
|
FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
|
|
1054
|
-
mean.resize
|
|
1238
|
+
mean.resize(d_in, 0);
|
|
1055
1239
|
for (idx_t i = 0; i < n; i++) {
|
|
1056
1240
|
for (size_t j = 0; j < d_in; j++) {
|
|
1057
1241
|
mean[j] += *x++;
|
|
@@ -1064,11 +1248,9 @@ void CenteringTransform::train(Index::idx_t n, const float *x) {
|
|
|
1064
1248
|
is_trained = true;
|
|
1065
1249
|
}
|
|
1066
1250
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
{
|
|
1071
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
1251
|
+
void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
1252
|
+
const {
|
|
1253
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
1072
1254
|
|
|
1073
1255
|
for (idx_t i = 0; i < n; i++) {
|
|
1074
1256
|
for (size_t j = 0; j < d_in; j++) {
|
|
@@ -1077,64 +1259,58 @@ void CenteringTransform::apply_noalloc
|
|
|
1077
1259
|
}
|
|
1078
1260
|
}
|
|
1079
1261
|
|
|
1080
|
-
void CenteringTransform::reverse_transform
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
FAISS_THROW_IF_NOT (is_trained);
|
|
1262
|
+
void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
|
|
1263
|
+
const {
|
|
1264
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
1084
1265
|
|
|
1085
1266
|
for (idx_t i = 0; i < n; i++) {
|
|
1086
1267
|
for (size_t j = 0; j < d_in; j++) {
|
|
1087
1268
|
*x++ = *xt++ + mean[j];
|
|
1088
1269
|
}
|
|
1089
1270
|
}
|
|
1090
|
-
|
|
1091
1271
|
}
|
|
1092
1272
|
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
1273
|
/*********************************************
|
|
1098
1274
|
* RemapDimensionsTransform
|
|
1099
1275
|
*********************************************/
|
|
1100
1276
|
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
int
|
|
1104
|
-
|
|
1105
|
-
{
|
|
1106
|
-
map.resize
|
|
1277
|
+
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1278
|
+
int d_in,
|
|
1279
|
+
int d_out,
|
|
1280
|
+
const int* map_in)
|
|
1281
|
+
: VectorTransform(d_in, d_out) {
|
|
1282
|
+
map.resize(d_out);
|
|
1107
1283
|
for (int i = 0; i < d_out; i++) {
|
|
1108
1284
|
map[i] = map_in[i];
|
|
1109
|
-
FAISS_THROW_IF_NOT
|
|
1285
|
+
FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
|
|
1110
1286
|
}
|
|
1111
1287
|
}
|
|
1112
1288
|
|
|
1113
|
-
RemapDimensionsTransform::RemapDimensionsTransform
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1289
|
+
RemapDimensionsTransform::RemapDimensionsTransform(
|
|
1290
|
+
int d_in,
|
|
1291
|
+
int d_out,
|
|
1292
|
+
bool uniform)
|
|
1293
|
+
: VectorTransform(d_in, d_out) {
|
|
1294
|
+
map.resize(d_out, -1);
|
|
1117
1295
|
|
|
1118
1296
|
if (uniform) {
|
|
1119
1297
|
if (d_in < d_out) {
|
|
1120
1298
|
for (int i = 0; i < d_in; i++) {
|
|
1121
|
-
map
|
|
1122
|
-
|
|
1299
|
+
map[i * d_out / d_in] = i;
|
|
1300
|
+
}
|
|
1123
1301
|
} else {
|
|
1124
1302
|
for (int i = 0; i < d_out; i++) {
|
|
1125
|
-
map
|
|
1303
|
+
map[i] = i * d_in / d_out;
|
|
1126
1304
|
}
|
|
1127
1305
|
}
|
|
1128
1306
|
} else {
|
|
1129
1307
|
for (int i = 0; i < d_in && i < d_out; i++)
|
|
1130
|
-
map
|
|
1308
|
+
map[i] = i;
|
|
1131
1309
|
}
|
|
1132
1310
|
}
|
|
1133
1311
|
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
float *xt) const
|
|
1137
|
-
{
|
|
1312
|
+
void RemapDimensionsTransform::apply_noalloc(idx_t n, const float* x, float* xt)
|
|
1313
|
+
const {
|
|
1138
1314
|
for (idx_t i = 0; i < n; i++) {
|
|
1139
1315
|
for (int j = 0; j < d_out; j++) {
|
|
1140
1316
|
xt[j] = map[j] < 0 ? 0 : x[map[j]];
|
|
@@ -1144,13 +1320,15 @@ void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
|
|
|
1144
1320
|
}
|
|
1145
1321
|
}
|
|
1146
1322
|
|
|
1147
|
-
void RemapDimensionsTransform::reverse_transform
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1323
|
+
void RemapDimensionsTransform::reverse_transform(
|
|
1324
|
+
idx_t n,
|
|
1325
|
+
const float* xt,
|
|
1326
|
+
float* x) const {
|
|
1327
|
+
memset(x, 0, sizeof(*x) * n * d_in);
|
|
1151
1328
|
for (idx_t i = 0; i < n; i++) {
|
|
1152
1329
|
for (int j = 0; j < d_out; j++) {
|
|
1153
|
-
if (map[j] >= 0)
|
|
1330
|
+
if (map[j] >= 0)
|
|
1331
|
+
x[map[j]] = xt[j];
|
|
1154
1332
|
}
|
|
1155
1333
|
x += d_in;
|
|
1156
1334
|
xt += d_out;
|