faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,127 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_ON_DISK_INVERTED_LISTS_H
|
11
|
+
#define FAISS_ON_DISK_INVERTED_LISTS_H
|
12
|
+
|
13
|
+
#include <vector>
|
14
|
+
#include <list>
|
15
|
+
|
16
|
+
#include <faiss/IndexIVF.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
|
21
|
+
struct LockLevels;
|
22
|
+
|
23
|
+
/** On-disk storage of inverted lists.
|
24
|
+
*
|
25
|
+
* The data is stored in a mmapped chunk of memory (base ptointer ptr,
|
26
|
+
* size totsize). Each list is a range of memory that contains (object
|
27
|
+
* List) that contains:
|
28
|
+
*
|
29
|
+
* - uint8_t codes[capacity * code_size]
|
30
|
+
* - followed by idx_t ids[capacity]
|
31
|
+
*
|
32
|
+
* in each of the arrays, the size <= capacity first elements are
|
33
|
+
* used, the rest is not initialized.
|
34
|
+
*
|
35
|
+
* Addition and resize are supported by:
|
36
|
+
* - roundind up the capacity of the lists to a power of two
|
37
|
+
* - maintaining a list of empty slots, sorted by size.
|
38
|
+
* - resizing the mmapped block is adjusted as needed.
|
39
|
+
*
|
40
|
+
* An OnDiskInvertedLists is compact if the size == capacity for all
|
41
|
+
* lists and there are no available slots.
|
42
|
+
*
|
43
|
+
* Addition to the invlists is slow. For incremental add it is better
|
44
|
+
* to use a default ArrayInvertedLists object and convert it to an
|
45
|
+
* OnDisk with merge_from.
|
46
|
+
*
|
47
|
+
* When it is known that a set of lists will be accessed, it is useful
|
48
|
+
* to call prefetch_lists, that launches a set of threads to read the
|
49
|
+
* lists in parallel.
|
50
|
+
*/
|
51
|
+
struct OnDiskInvertedLists: InvertedLists {
|
52
|
+
|
53
|
+
struct List {
|
54
|
+
size_t size; // size of inverted list (entries)
|
55
|
+
size_t capacity; // allocated size (entries)
|
56
|
+
size_t offset; // offset in buffer (bytes)
|
57
|
+
List ();
|
58
|
+
};
|
59
|
+
|
60
|
+
// size nlist
|
61
|
+
std::vector<List> lists;
|
62
|
+
|
63
|
+
struct Slot {
|
64
|
+
size_t offset; // bytes
|
65
|
+
size_t capacity; // bytes
|
66
|
+
Slot (size_t offset, size_t capacity);
|
67
|
+
Slot ();
|
68
|
+
};
|
69
|
+
|
70
|
+
// size whatever space remains
|
71
|
+
std::list<Slot> slots;
|
72
|
+
|
73
|
+
std::string filename;
|
74
|
+
size_t totsize;
|
75
|
+
uint8_t *ptr; // mmap base pointer
|
76
|
+
bool read_only; /// are inverted lists mapped read-only
|
77
|
+
|
78
|
+
OnDiskInvertedLists (size_t nlist, size_t code_size,
|
79
|
+
const char *filename);
|
80
|
+
|
81
|
+
size_t list_size(size_t list_no) const override;
|
82
|
+
const uint8_t * get_codes (size_t list_no) const override;
|
83
|
+
const idx_t * get_ids (size_t list_no) const override;
|
84
|
+
|
85
|
+
size_t add_entries (
|
86
|
+
size_t list_no, size_t n_entry,
|
87
|
+
const idx_t* ids, const uint8_t *code) override;
|
88
|
+
|
89
|
+
void update_entries (size_t list_no, size_t offset, size_t n_entry,
|
90
|
+
const idx_t *ids, const uint8_t *code) override;
|
91
|
+
|
92
|
+
void resize (size_t list_no, size_t new_size) override;
|
93
|
+
|
94
|
+
// copy all inverted lists into *this, in compact form (without
|
95
|
+
// allocating slots)
|
96
|
+
size_t merge_from (const InvertedLists **ils, int n_il, bool verbose=false);
|
97
|
+
|
98
|
+
/// restrict the inverted lists to l0:l1 without touching the mmapped region
|
99
|
+
void crop_invlists(size_t l0, size_t l1);
|
100
|
+
|
101
|
+
void prefetch_lists (const idx_t *list_nos, int nlist) const override;
|
102
|
+
|
103
|
+
virtual ~OnDiskInvertedLists ();
|
104
|
+
|
105
|
+
// private
|
106
|
+
|
107
|
+
LockLevels * locks;
|
108
|
+
|
109
|
+
// encapsulates the threads that are busy prefeteching
|
110
|
+
struct OngoingPrefetch;
|
111
|
+
OngoingPrefetch *pf;
|
112
|
+
int prefetch_nthread;
|
113
|
+
|
114
|
+
void do_mmap ();
|
115
|
+
void update_totsize (size_t new_totsize);
|
116
|
+
void resize_locked (size_t list_no, size_t new_size);
|
117
|
+
size_t allocate_slot (size_t capacity);
|
118
|
+
void free_slot (size_t offset, size_t capacity);
|
119
|
+
|
120
|
+
// empty constructor for the I/O functions
|
121
|
+
OnDiskInvertedLists ();
|
122
|
+
};
|
123
|
+
|
124
|
+
|
125
|
+
} // namespace faiss
|
126
|
+
|
127
|
+
#endif
|
@@ -0,0 +1,1157 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/VectorTransform.h>
|
11
|
+
|
12
|
+
#include <cstdio>
|
13
|
+
#include <cmath>
|
14
|
+
#include <cstring>
|
15
|
+
#include <memory>
|
16
|
+
|
17
|
+
#include <faiss/utils/distances.h>
|
18
|
+
#include <faiss/utils/random.h>
|
19
|
+
#include <faiss/utils/utils.h>
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
21
|
+
#include <faiss/IndexPQ.h>
|
22
|
+
|
23
|
+
using namespace faiss;
|
24
|
+
|
25
|
+
|
26
|
+
extern "C" {
|
27
|
+
|
28
|
+
// this is to keep the clang syntax checker happy
|
29
|
+
#ifndef FINTEGER
|
30
|
+
#define FINTEGER int
|
31
|
+
#endif
|
32
|
+
|
33
|
+
|
34
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
35
|
+
|
36
|
+
int sgemm_ (
|
37
|
+
const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
38
|
+
n, FINTEGER *k, const float *alpha, const float *a,
|
39
|
+
FINTEGER *lda, const float *b,
|
40
|
+
FINTEGER *ldb, float *beta,
|
41
|
+
float *c, FINTEGER *ldc);
|
42
|
+
|
43
|
+
int dgemm_ (
|
44
|
+
const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
45
|
+
n, FINTEGER *k, const double *alpha, const double *a,
|
46
|
+
FINTEGER *lda, const double *b,
|
47
|
+
FINTEGER *ldb, double *beta,
|
48
|
+
double *c, FINTEGER *ldc);
|
49
|
+
|
50
|
+
int ssyrk_ (
|
51
|
+
const char *uplo, const char *trans, FINTEGER *n, FINTEGER *k,
|
52
|
+
float *alpha, float *a, FINTEGER *lda,
|
53
|
+
float *beta, float *c, FINTEGER *ldc);
|
54
|
+
|
55
|
+
/* Lapack functions from http://www.netlib.org/clapack/old/single/ */
|
56
|
+
|
57
|
+
int ssyev_ (
|
58
|
+
const char *jobz, const char *uplo, FINTEGER *n, float *a,
|
59
|
+
FINTEGER *lda, float *w, float *work, FINTEGER *lwork,
|
60
|
+
FINTEGER *info);
|
61
|
+
|
62
|
+
int dsyev_ (
|
63
|
+
const char *jobz, const char *uplo, FINTEGER *n, double *a,
|
64
|
+
FINTEGER *lda, double *w, double *work, FINTEGER *lwork,
|
65
|
+
FINTEGER *info);
|
66
|
+
|
67
|
+
int sgesvd_(
|
68
|
+
const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
|
69
|
+
float *a, FINTEGER *lda, float *s, float *u, FINTEGER *ldu, float *vt,
|
70
|
+
FINTEGER *ldvt, float *work, FINTEGER *lwork, FINTEGER *info);
|
71
|
+
|
72
|
+
|
73
|
+
int dgesvd_(
|
74
|
+
const char *jobu, const char *jobvt, FINTEGER *m, FINTEGER *n,
|
75
|
+
double *a, FINTEGER *lda, double *s, double *u, FINTEGER *ldu, double *vt,
|
76
|
+
FINTEGER *ldvt, double *work, FINTEGER *lwork, FINTEGER *info);
|
77
|
+
|
78
|
+
}
|
79
|
+
|
80
|
+
/*********************************************
|
81
|
+
* VectorTransform
|
82
|
+
*********************************************/
|
83
|
+
|
84
|
+
|
85
|
+
|
86
|
+
float * VectorTransform::apply (Index::idx_t n, const float * x) const
|
87
|
+
{
|
88
|
+
float * xt = new float[n * d_out];
|
89
|
+
apply_noalloc (n, x, xt);
|
90
|
+
return xt;
|
91
|
+
}
|
92
|
+
|
93
|
+
|
94
|
+
void VectorTransform::train (idx_t, const float *) {
|
95
|
+
// does nothing by default
|
96
|
+
}
|
97
|
+
|
98
|
+
|
99
|
+
void VectorTransform::reverse_transform (
|
100
|
+
idx_t , const float *,
|
101
|
+
float *) const
|
102
|
+
{
|
103
|
+
FAISS_THROW_MSG ("reverse transform not implemented");
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
|
108
|
+
|
109
|
+
/*********************************************
|
110
|
+
* LinearTransform
|
111
|
+
*********************************************/
|
112
|
+
|
113
|
+
/// both d_in > d_out and d_out < d_in are supported
|
114
|
+
LinearTransform::LinearTransform (int d_in, int d_out,
|
115
|
+
bool have_bias):
|
116
|
+
VectorTransform (d_in, d_out), have_bias (have_bias),
|
117
|
+
is_orthonormal (false), verbose (false)
|
118
|
+
{
|
119
|
+
is_trained = false; // will be trained when A and b are initialized
|
120
|
+
}
|
121
|
+
|
122
|
+
void LinearTransform::apply_noalloc (Index::idx_t n, const float * x,
|
123
|
+
float * xt) const
|
124
|
+
{
|
125
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
|
126
|
+
|
127
|
+
float c_factor;
|
128
|
+
if (have_bias) {
|
129
|
+
FAISS_THROW_IF_NOT_MSG (b.size() == d_out, "Bias not initialized");
|
130
|
+
float * xi = xt;
|
131
|
+
for (int i = 0; i < n; i++)
|
132
|
+
for(int j = 0; j < d_out; j++)
|
133
|
+
*xi++ = b[j];
|
134
|
+
c_factor = 1.0;
|
135
|
+
} else {
|
136
|
+
c_factor = 0.0;
|
137
|
+
}
|
138
|
+
|
139
|
+
FAISS_THROW_IF_NOT_MSG (A.size() == d_out * d_in,
|
140
|
+
"Transformation matrix not initialized");
|
141
|
+
|
142
|
+
float one = 1;
|
143
|
+
FINTEGER nbiti = d_out, ni = n, di = d_in;
|
144
|
+
sgemm_ ("Transposed", "Not transposed",
|
145
|
+
&nbiti, &ni, &di,
|
146
|
+
&one, A.data(), &di, x, &di, &c_factor, xt, &nbiti);
|
147
|
+
|
148
|
+
}
|
149
|
+
|
150
|
+
|
151
|
+
void LinearTransform::transform_transpose (idx_t n, const float * y,
|
152
|
+
float *x) const
|
153
|
+
{
|
154
|
+
if (have_bias) { // allocate buffer to store bias-corrected data
|
155
|
+
float *y_new = new float [n * d_out];
|
156
|
+
const float *yr = y;
|
157
|
+
float *yw = y_new;
|
158
|
+
for (idx_t i = 0; i < n; i++) {
|
159
|
+
for (int j = 0; j < d_out; j++) {
|
160
|
+
*yw++ = *yr++ - b [j];
|
161
|
+
}
|
162
|
+
}
|
163
|
+
y = y_new;
|
164
|
+
}
|
165
|
+
|
166
|
+
{
|
167
|
+
FINTEGER dii = d_in, doi = d_out, ni = n;
|
168
|
+
float one = 1.0, zero = 0.0;
|
169
|
+
sgemm_ ("Not", "Not", &dii, &ni, &doi,
|
170
|
+
&one, A.data (), &dii, y, &doi, &zero, x, &dii);
|
171
|
+
}
|
172
|
+
|
173
|
+
if (have_bias) delete [] y;
|
174
|
+
}
|
175
|
+
|
176
|
+
void LinearTransform::set_is_orthonormal ()
|
177
|
+
{
|
178
|
+
if (d_out > d_in) {
|
179
|
+
// not clear what we should do in this case
|
180
|
+
is_orthonormal = false;
|
181
|
+
return;
|
182
|
+
}
|
183
|
+
if (d_out == 0) { // borderline case, unnormalized matrix
|
184
|
+
is_orthonormal = true;
|
185
|
+
return;
|
186
|
+
}
|
187
|
+
|
188
|
+
double eps = 4e-5;
|
189
|
+
FAISS_ASSERT(A.size() >= d_out * d_in);
|
190
|
+
{
|
191
|
+
std::vector<float> ATA(d_out * d_out);
|
192
|
+
FINTEGER dii = d_in, doi = d_out;
|
193
|
+
float one = 1.0, zero = 0.0;
|
194
|
+
|
195
|
+
sgemm_ ("Transposed", "Not", &doi, &doi, &dii,
|
196
|
+
&one, A.data (), &dii,
|
197
|
+
A.data(), &dii,
|
198
|
+
&zero, ATA.data(), &doi);
|
199
|
+
|
200
|
+
is_orthonormal = true;
|
201
|
+
for (long i = 0; i < d_out; i++) {
|
202
|
+
for (long j = 0; j < d_out; j++) {
|
203
|
+
float v = ATA[i + j * d_out];
|
204
|
+
if (i == j) v-= 1;
|
205
|
+
if (fabs(v) > eps) {
|
206
|
+
is_orthonormal = false;
|
207
|
+
}
|
208
|
+
}
|
209
|
+
}
|
210
|
+
}
|
211
|
+
|
212
|
+
}
|
213
|
+
|
214
|
+
|
215
|
+
void LinearTransform::reverse_transform (idx_t n, const float * xt,
|
216
|
+
float *x) const
|
217
|
+
{
|
218
|
+
if (is_orthonormal) {
|
219
|
+
transform_transpose (n, xt, x);
|
220
|
+
} else {
|
221
|
+
FAISS_THROW_MSG ("reverse transform not implemented for non-orthonormal matrices");
|
222
|
+
}
|
223
|
+
}
|
224
|
+
|
225
|
+
|
226
|
+
void LinearTransform::print_if_verbose (
|
227
|
+
const char*name, const std::vector<double> &mat,
|
228
|
+
int n, int d) const
|
229
|
+
{
|
230
|
+
if (!verbose) return;
|
231
|
+
printf("matrix %s: %d*%d [\n", name, n, d);
|
232
|
+
FAISS_THROW_IF_NOT (mat.size() >= n * d);
|
233
|
+
for (int i = 0; i < n; i++) {
|
234
|
+
for (int j = 0; j < d; j++) {
|
235
|
+
printf("%10.5g ", mat[i * d + j]);
|
236
|
+
}
|
237
|
+
printf("\n");
|
238
|
+
}
|
239
|
+
printf("]\n");
|
240
|
+
}
|
241
|
+
|
242
|
+
/*********************************************
|
243
|
+
* RandomRotationMatrix
|
244
|
+
*********************************************/
|
245
|
+
|
246
|
+
void RandomRotationMatrix::init (int seed)
|
247
|
+
{
|
248
|
+
|
249
|
+
if(d_out <= d_in) {
|
250
|
+
A.resize (d_out * d_in);
|
251
|
+
float *q = A.data();
|
252
|
+
float_randn(q, d_out * d_in, seed);
|
253
|
+
matrix_qr(d_in, d_out, q);
|
254
|
+
} else {
|
255
|
+
// use tight-frame transformation
|
256
|
+
A.resize (d_out * d_out);
|
257
|
+
float *q = A.data();
|
258
|
+
float_randn(q, d_out * d_out, seed);
|
259
|
+
matrix_qr(d_out, d_out, q);
|
260
|
+
// remove columns
|
261
|
+
int i, j;
|
262
|
+
for (i = 0; i < d_out; i++) {
|
263
|
+
for(j = 0; j < d_in; j++) {
|
264
|
+
q[i * d_in + j] = q[i * d_out + j];
|
265
|
+
}
|
266
|
+
}
|
267
|
+
A.resize(d_in * d_out);
|
268
|
+
}
|
269
|
+
is_orthonormal = true;
|
270
|
+
is_trained = true;
|
271
|
+
}
|
272
|
+
|
273
|
+
void RandomRotationMatrix::train (Index::idx_t /*n*/, const float */*x*/)
|
274
|
+
{
|
275
|
+
// initialize with some arbitrary seed
|
276
|
+
init (12345);
|
277
|
+
}
|
278
|
+
|
279
|
+
|
280
|
+
/*********************************************
|
281
|
+
* PCAMatrix
|
282
|
+
*********************************************/
|
283
|
+
|
284
|
+
PCAMatrix::PCAMatrix (int d_in, int d_out,
|
285
|
+
float eigen_power, bool random_rotation):
|
286
|
+
LinearTransform(d_in, d_out, true),
|
287
|
+
eigen_power(eigen_power), random_rotation(random_rotation)
|
288
|
+
{
|
289
|
+
is_trained = false;
|
290
|
+
max_points_per_d = 1000;
|
291
|
+
balanced_bins = 0;
|
292
|
+
}
|
293
|
+
|
294
|
+
|
295
|
+
namespace {
|
296
|
+
|
297
|
+
/// Compute the eigenvalue decomposition of symmetric matrix cov,
|
298
|
+
/// dimensions d_in-by-d_in. Output eigenvectors in cov.
|
299
|
+
|
300
|
+
void eig(size_t d_in, double *cov, double *eigenvalues, int verbose)
|
301
|
+
{
|
302
|
+
{ // compute eigenvalues and vectors
|
303
|
+
FINTEGER info = 0, lwork = -1, di = d_in;
|
304
|
+
double workq;
|
305
|
+
|
306
|
+
dsyev_ ("Vectors as well", "Upper",
|
307
|
+
&di, cov, &di, eigenvalues, &workq, &lwork, &info);
|
308
|
+
lwork = FINTEGER(workq);
|
309
|
+
double *work = new double[lwork];
|
310
|
+
|
311
|
+
dsyev_ ("Vectors as well", "Upper",
|
312
|
+
&di, cov, &di, eigenvalues, work, &lwork, &info);
|
313
|
+
|
314
|
+
delete [] work;
|
315
|
+
|
316
|
+
if (info != 0) {
|
317
|
+
fprintf (stderr, "WARN ssyev info returns %d, "
|
318
|
+
"a very bad PCA matrix is learnt\n",
|
319
|
+
int(info));
|
320
|
+
// do not throw exception, as the matrix could still be useful
|
321
|
+
}
|
322
|
+
|
323
|
+
|
324
|
+
if(verbose && d_in <= 10) {
|
325
|
+
printf("info=%ld new eigvals=[", long(info));
|
326
|
+
for(int j = 0; j < d_in; j++) printf("%g ", eigenvalues[j]);
|
327
|
+
printf("]\n");
|
328
|
+
|
329
|
+
double *ci = cov;
|
330
|
+
printf("eigenvecs=\n");
|
331
|
+
for(int i = 0; i < d_in; i++) {
|
332
|
+
for(int j = 0; j < d_in; j++)
|
333
|
+
printf("%10.4g ", *ci++);
|
334
|
+
printf("\n");
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
}
|
339
|
+
|
340
|
+
// revert order of eigenvectors & values
|
341
|
+
|
342
|
+
for(int i = 0; i < d_in / 2; i++) {
|
343
|
+
|
344
|
+
std::swap(eigenvalues[i], eigenvalues[d_in - 1 - i]);
|
345
|
+
double *v1 = cov + i * d_in;
|
346
|
+
double *v2 = cov + (d_in - 1 - i) * d_in;
|
347
|
+
for(int j = 0; j < d_in; j++)
|
348
|
+
std::swap(v1[j], v2[j]);
|
349
|
+
}
|
350
|
+
|
351
|
+
}
|
352
|
+
|
353
|
+
|
354
|
+
}
|
355
|
+
|
356
|
+
void PCAMatrix::train (Index::idx_t n, const float *x)
|
357
|
+
{
|
358
|
+
const float * x_in = x;
|
359
|
+
|
360
|
+
x = fvecs_maybe_subsample (d_in, (size_t*)&n,
|
361
|
+
max_points_per_d * d_in, x, verbose);
|
362
|
+
|
363
|
+
ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
|
364
|
+
|
365
|
+
// compute mean
|
366
|
+
mean.clear(); mean.resize(d_in, 0.0);
|
367
|
+
if (have_bias) { // we may want to skip the bias
|
368
|
+
const float *xi = x;
|
369
|
+
for (int i = 0; i < n; i++) {
|
370
|
+
for(int j = 0; j < d_in; j++)
|
371
|
+
mean[j] += *xi++;
|
372
|
+
}
|
373
|
+
for(int j = 0; j < d_in; j++)
|
374
|
+
mean[j] /= n;
|
375
|
+
}
|
376
|
+
if(verbose) {
|
377
|
+
printf("mean=[");
|
378
|
+
for(int j = 0; j < d_in; j++) printf("%g ", mean[j]);
|
379
|
+
printf("]\n");
|
380
|
+
}
|
381
|
+
|
382
|
+
if(n >= d_in) {
|
383
|
+
// compute covariance matrix, store it in PCA matrix
|
384
|
+
PCAMat.resize(d_in * d_in);
|
385
|
+
float * cov = PCAMat.data();
|
386
|
+
{ // initialize with mean * mean^T term
|
387
|
+
float *ci = cov;
|
388
|
+
for(int i = 0; i < d_in; i++) {
|
389
|
+
for(int j = 0; j < d_in; j++)
|
390
|
+
*ci++ = - n * mean[i] * mean[j];
|
391
|
+
}
|
392
|
+
}
|
393
|
+
{
|
394
|
+
FINTEGER di = d_in, ni = n;
|
395
|
+
float one = 1.0;
|
396
|
+
ssyrk_ ("Up", "Non transposed",
|
397
|
+
&di, &ni, &one, (float*)x, &di, &one, cov, &di);
|
398
|
+
|
399
|
+
}
|
400
|
+
if(verbose && d_in <= 10) {
|
401
|
+
float *ci = cov;
|
402
|
+
printf("cov=\n");
|
403
|
+
for(int i = 0; i < d_in; i++) {
|
404
|
+
for(int j = 0; j < d_in; j++)
|
405
|
+
printf("%10g ", *ci++);
|
406
|
+
printf("\n");
|
407
|
+
}
|
408
|
+
}
|
409
|
+
|
410
|
+
std::vector<double> covd (d_in * d_in);
|
411
|
+
for (size_t i = 0; i < d_in * d_in; i++) covd [i] = cov [i];
|
412
|
+
|
413
|
+
std::vector<double> eigenvaluesd (d_in);
|
414
|
+
|
415
|
+
eig (d_in, covd.data (), eigenvaluesd.data (), verbose);
|
416
|
+
|
417
|
+
for (size_t i = 0; i < d_in * d_in; i++) PCAMat [i] = covd [i];
|
418
|
+
eigenvalues.resize (d_in);
|
419
|
+
|
420
|
+
for (size_t i = 0; i < d_in; i++)
|
421
|
+
eigenvalues [i] = eigenvaluesd [i];
|
422
|
+
|
423
|
+
|
424
|
+
} else {
|
425
|
+
|
426
|
+
std::vector<float> xc (n * d_in);
|
427
|
+
|
428
|
+
for (size_t i = 0; i < n; i++)
|
429
|
+
for(size_t j = 0; j < d_in; j++)
|
430
|
+
xc [i * d_in + j] = x [i * d_in + j] - mean[j];
|
431
|
+
|
432
|
+
// compute Gram matrix
|
433
|
+
std::vector<float> gram (n * n);
|
434
|
+
{
|
435
|
+
FINTEGER di = d_in, ni = n;
|
436
|
+
float one = 1.0, zero = 0.0;
|
437
|
+
ssyrk_ ("Up", "Transposed",
|
438
|
+
&ni, &di, &one, xc.data(), &di, &zero, gram.data(), &ni);
|
439
|
+
}
|
440
|
+
|
441
|
+
if(verbose && d_in <= 10) {
|
442
|
+
float *ci = gram.data();
|
443
|
+
printf("gram=\n");
|
444
|
+
for(int i = 0; i < n; i++) {
|
445
|
+
for(int j = 0; j < n; j++)
|
446
|
+
printf("%10g ", *ci++);
|
447
|
+
printf("\n");
|
448
|
+
}
|
449
|
+
}
|
450
|
+
|
451
|
+
std::vector<double> gramd (n * n);
|
452
|
+
for (size_t i = 0; i < n * n; i++)
|
453
|
+
gramd [i] = gram [i];
|
454
|
+
|
455
|
+
std::vector<double> eigenvaluesd (n);
|
456
|
+
|
457
|
+
// eig will fill in only the n first eigenvals
|
458
|
+
|
459
|
+
eig (n, gramd.data (), eigenvaluesd.data (), verbose);
|
460
|
+
|
461
|
+
PCAMat.resize(d_in * n);
|
462
|
+
|
463
|
+
for (size_t i = 0; i < n * n; i++)
|
464
|
+
gram [i] = gramd [i];
|
465
|
+
|
466
|
+
eigenvalues.resize (d_in);
|
467
|
+
// fill in only the n first ones
|
468
|
+
for (size_t i = 0; i < n; i++)
|
469
|
+
eigenvalues [i] = eigenvaluesd [i];
|
470
|
+
|
471
|
+
{ // compute PCAMat = x' * v
|
472
|
+
FINTEGER di = d_in, ni = n;
|
473
|
+
float one = 1.0;
|
474
|
+
|
475
|
+
sgemm_ ("Non", "Non Trans",
|
476
|
+
&di, &ni, &ni,
|
477
|
+
&one, xc.data(), &di, gram.data(), &ni,
|
478
|
+
&one, PCAMat.data(), &di);
|
479
|
+
}
|
480
|
+
|
481
|
+
if(verbose && d_in <= 10) {
|
482
|
+
float *ci = PCAMat.data();
|
483
|
+
printf("PCAMat=\n");
|
484
|
+
for(int i = 0; i < n; i++) {
|
485
|
+
for(int j = 0; j < d_in; j++)
|
486
|
+
printf("%10g ", *ci++);
|
487
|
+
printf("\n");
|
488
|
+
}
|
489
|
+
}
|
490
|
+
fvec_renorm_L2 (d_in, n, PCAMat.data());
|
491
|
+
|
492
|
+
}
|
493
|
+
|
494
|
+
prepare_Ab();
|
495
|
+
is_trained = true;
|
496
|
+
}
|
497
|
+
|
498
|
+
void PCAMatrix::copy_from (const PCAMatrix & other)
|
499
|
+
{
|
500
|
+
FAISS_THROW_IF_NOT (other.is_trained);
|
501
|
+
mean = other.mean;
|
502
|
+
eigenvalues = other.eigenvalues;
|
503
|
+
PCAMat = other.PCAMat;
|
504
|
+
prepare_Ab ();
|
505
|
+
is_trained = true;
|
506
|
+
}
|
507
|
+
|
508
|
+
void PCAMatrix::prepare_Ab ()
|
509
|
+
{
|
510
|
+
FAISS_THROW_IF_NOT_FMT (
|
511
|
+
d_out * d_in <= PCAMat.size(),
|
512
|
+
"PCA matrix cannot output %d dimensions from %d ",
|
513
|
+
d_out, d_in);
|
514
|
+
|
515
|
+
if (!random_rotation) {
|
516
|
+
A = PCAMat;
|
517
|
+
A.resize(d_out * d_in); // strip off useless dimensions
|
518
|
+
|
519
|
+
// first scale the components
|
520
|
+
if (eigen_power != 0) {
|
521
|
+
float *ai = A.data();
|
522
|
+
for (int i = 0; i < d_out; i++) {
|
523
|
+
float factor = pow(eigenvalues[i], eigen_power);
|
524
|
+
for(int j = 0; j < d_in; j++)
|
525
|
+
*ai++ *= factor;
|
526
|
+
}
|
527
|
+
}
|
528
|
+
|
529
|
+
if (balanced_bins != 0) {
|
530
|
+
FAISS_THROW_IF_NOT (d_out % balanced_bins == 0);
|
531
|
+
int dsub = d_out / balanced_bins;
|
532
|
+
std::vector <float> Ain;
|
533
|
+
std::swap(A, Ain);
|
534
|
+
A.resize(d_out * d_in);
|
535
|
+
|
536
|
+
std::vector <float> accu(balanced_bins);
|
537
|
+
std::vector <int> counter(balanced_bins);
|
538
|
+
|
539
|
+
// greedy assignment
|
540
|
+
for (int i = 0; i < d_out; i++) {
|
541
|
+
// find best bin
|
542
|
+
int best_j = -1;
|
543
|
+
float min_w = 1e30;
|
544
|
+
for (int j = 0; j < balanced_bins; j++) {
|
545
|
+
if (counter[j] < dsub && accu[j] < min_w) {
|
546
|
+
min_w = accu[j];
|
547
|
+
best_j = j;
|
548
|
+
}
|
549
|
+
}
|
550
|
+
int row_dst = best_j * dsub + counter[best_j];
|
551
|
+
accu[best_j] += eigenvalues[i];
|
552
|
+
counter[best_j] ++;
|
553
|
+
memcpy (&A[row_dst * d_in], &Ain[i * d_in],
|
554
|
+
d_in * sizeof (A[0]));
|
555
|
+
}
|
556
|
+
|
557
|
+
if (verbose) {
|
558
|
+
printf(" bin accu=[");
|
559
|
+
for (int i = 0; i < balanced_bins; i++)
|
560
|
+
printf("%g ", accu[i]);
|
561
|
+
printf("]\n");
|
562
|
+
}
|
563
|
+
}
|
564
|
+
|
565
|
+
|
566
|
+
} else {
|
567
|
+
FAISS_THROW_IF_NOT_MSG (balanced_bins == 0,
|
568
|
+
"both balancing bins and applying a random rotation "
|
569
|
+
"does not make sense");
|
570
|
+
RandomRotationMatrix rr(d_out, d_out);
|
571
|
+
|
572
|
+
rr.init(5);
|
573
|
+
|
574
|
+
// apply scaling on the rotation matrix (right multiplication)
|
575
|
+
if (eigen_power != 0) {
|
576
|
+
for (int i = 0; i < d_out; i++) {
|
577
|
+
float factor = pow(eigenvalues[i], eigen_power);
|
578
|
+
for(int j = 0; j < d_out; j++)
|
579
|
+
rr.A[j * d_out + i] *= factor;
|
580
|
+
}
|
581
|
+
}
|
582
|
+
|
583
|
+
A.resize(d_in * d_out);
|
584
|
+
{
|
585
|
+
FINTEGER dii = d_in, doo = d_out;
|
586
|
+
float one = 1.0, zero = 0.0;
|
587
|
+
|
588
|
+
sgemm_ ("Not", "Not", &dii, &doo, &doo,
|
589
|
+
&one, PCAMat.data(), &dii, rr.A.data(), &doo, &zero,
|
590
|
+
A.data(), &dii);
|
591
|
+
|
592
|
+
}
|
593
|
+
|
594
|
+
}
|
595
|
+
|
596
|
+
b.clear(); b.resize(d_out);
|
597
|
+
|
598
|
+
for (int i = 0; i < d_out; i++) {
|
599
|
+
float accu = 0;
|
600
|
+
for (int j = 0; j < d_in; j++)
|
601
|
+
accu -= mean[j] * A[j + i * d_in];
|
602
|
+
b[i] = accu;
|
603
|
+
}
|
604
|
+
|
605
|
+
is_orthonormal = eigen_power == 0;
|
606
|
+
|
607
|
+
}
|
608
|
+
|
609
|
+
/*********************************************
|
610
|
+
* ITQMatrix
|
611
|
+
*********************************************/
|
612
|
+
|
613
|
+
ITQMatrix::ITQMatrix (int d):
|
614
|
+
LinearTransform(d, d, false),
|
615
|
+
max_iter (50),
|
616
|
+
seed (123)
|
617
|
+
{
|
618
|
+
}
|
619
|
+
|
620
|
+
|
621
|
+
/** translated from fbcode/deeplearning/catalyzer/catalyzer/quantizers.py */
|
622
|
+
void ITQMatrix::train (Index::idx_t n, const float* xf)
|
623
|
+
{
|
624
|
+
size_t d = d_in;
|
625
|
+
std::vector<double> rotation (d * d);
|
626
|
+
|
627
|
+
if (init_rotation.size() == d * d) {
|
628
|
+
memcpy (rotation.data(), init_rotation.data(),
|
629
|
+
d * d * sizeof(rotation[0]));
|
630
|
+
} else {
|
631
|
+
RandomRotationMatrix rrot (d, d);
|
632
|
+
rrot.init (seed);
|
633
|
+
for (size_t i = 0; i < d * d; i++) {
|
634
|
+
rotation[i] = rrot.A[i];
|
635
|
+
}
|
636
|
+
}
|
637
|
+
|
638
|
+
std::vector<double> x (n * d);
|
639
|
+
|
640
|
+
for (size_t i = 0; i < n * d; i++) {
|
641
|
+
x[i] = xf[i];
|
642
|
+
}
|
643
|
+
|
644
|
+
std::vector<double> rotated_x (n * d), cov_mat (d * d);
|
645
|
+
std::vector<double> u (d * d), vt (d * d), singvals (d);
|
646
|
+
|
647
|
+
for (int i = 0; i < max_iter; i++) {
|
648
|
+
print_if_verbose ("rotation", rotation, d, d);
|
649
|
+
{ // rotated_data = np.dot(training_data, rotation)
|
650
|
+
FINTEGER di = d, ni = n;
|
651
|
+
double one = 1, zero = 0;
|
652
|
+
dgemm_ ("N", "N", &di, &ni, &di,
|
653
|
+
&one, rotation.data(), &di, x.data(), &di,
|
654
|
+
&zero, rotated_x.data(), &di);
|
655
|
+
}
|
656
|
+
print_if_verbose ("rotated_x", rotated_x, n, d);
|
657
|
+
// binarize
|
658
|
+
for (size_t j = 0; j < n * d; j++) {
|
659
|
+
rotated_x[j] = rotated_x[j] < 0 ? -1 : 1;
|
660
|
+
}
|
661
|
+
// covariance matrix
|
662
|
+
{ // rotated_data = np.dot(training_data, rotation)
|
663
|
+
FINTEGER di = d, ni = n;
|
664
|
+
double one = 1, zero = 0;
|
665
|
+
dgemm_ ("N", "T", &di, &di, &ni,
|
666
|
+
&one, rotated_x.data(), &di, x.data(), &di,
|
667
|
+
&zero, cov_mat.data(), &di);
|
668
|
+
}
|
669
|
+
print_if_verbose ("cov_mat", cov_mat, d, d);
|
670
|
+
// SVD
|
671
|
+
{
|
672
|
+
|
673
|
+
FINTEGER di = d;
|
674
|
+
FINTEGER lwork = -1, info;
|
675
|
+
double lwork1;
|
676
|
+
|
677
|
+
// workspace query
|
678
|
+
dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
|
679
|
+
singvals.data(), u.data(), &di,
|
680
|
+
vt.data(), &di,
|
681
|
+
&lwork1, &lwork, &info);
|
682
|
+
|
683
|
+
FAISS_THROW_IF_NOT (info == 0);
|
684
|
+
lwork = size_t (lwork1);
|
685
|
+
std::vector<double> work (lwork);
|
686
|
+
dgesvd_ ("A", "A", &di, &di, cov_mat.data(), &di,
|
687
|
+
singvals.data(), u.data(), &di,
|
688
|
+
vt.data(), &di,
|
689
|
+
work.data(), &lwork, &info);
|
690
|
+
FAISS_THROW_IF_NOT_FMT (info == 0, "sgesvd returned info=%d", info);
|
691
|
+
|
692
|
+
}
|
693
|
+
print_if_verbose ("u", u, d, d);
|
694
|
+
print_if_verbose ("vt", vt, d, d);
|
695
|
+
// update rotation
|
696
|
+
{
|
697
|
+
FINTEGER di = d;
|
698
|
+
double one = 1, zero = 0;
|
699
|
+
dgemm_ ("N", "T", &di, &di, &di,
|
700
|
+
&one, u.data(), &di, vt.data(), &di,
|
701
|
+
&zero, rotation.data(), &di);
|
702
|
+
}
|
703
|
+
print_if_verbose ("final rot", rotation, d, d);
|
704
|
+
|
705
|
+
}
|
706
|
+
A.resize (d * d);
|
707
|
+
for (size_t i = 0; i < d; i++) {
|
708
|
+
for (size_t j = 0; j < d; j++) {
|
709
|
+
A[i + d * j] = rotation[j + d * i];
|
710
|
+
}
|
711
|
+
}
|
712
|
+
is_trained = true;
|
713
|
+
|
714
|
+
}
|
715
|
+
|
716
|
+
ITQTransform::ITQTransform (int d_in, int d_out, bool do_pca):
|
717
|
+
VectorTransform (d_in, d_out),
|
718
|
+
do_pca (do_pca),
|
719
|
+
itq (d_out),
|
720
|
+
pca_then_itq (d_in, d_out, false)
|
721
|
+
{
|
722
|
+
if (!do_pca) {
|
723
|
+
FAISS_THROW_IF_NOT (d_in == d_out);
|
724
|
+
}
|
725
|
+
max_train_per_dim = 10;
|
726
|
+
is_trained = false;
|
727
|
+
}
|
728
|
+
|
729
|
+
|
730
|
+
|
731
|
+
|
732
|
+
void ITQTransform::train (idx_t n, const float *x)
|
733
|
+
{
|
734
|
+
FAISS_THROW_IF_NOT (!is_trained);
|
735
|
+
|
736
|
+
const float * x_in = x;
|
737
|
+
size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
|
738
|
+
x = fvecs_maybe_subsample (d_in, (size_t*)&n, max_train_points, x);
|
739
|
+
|
740
|
+
ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
|
741
|
+
|
742
|
+
std::unique_ptr<float []> x_norm(new float[n * d_in]);
|
743
|
+
{ // normalize
|
744
|
+
int d = d_in;
|
745
|
+
|
746
|
+
mean.resize (d, 0);
|
747
|
+
for (idx_t i = 0; i < n; i++) {
|
748
|
+
for (idx_t j = 0; j < d; j++) {
|
749
|
+
mean[j] += x[i * d + j];
|
750
|
+
}
|
751
|
+
}
|
752
|
+
for (idx_t j = 0; j < d; j++) {
|
753
|
+
mean[j] /= n;
|
754
|
+
}
|
755
|
+
for (idx_t i = 0; i < n; i++) {
|
756
|
+
for (idx_t j = 0; j < d; j++) {
|
757
|
+
x_norm[i * d + j] = x[i * d + j] - mean[j];
|
758
|
+
}
|
759
|
+
}
|
760
|
+
fvec_renorm_L2 (d_in, n, x_norm.get());
|
761
|
+
}
|
762
|
+
|
763
|
+
// train PCA
|
764
|
+
|
765
|
+
PCAMatrix pca (d_in, d_out);
|
766
|
+
float *x_pca;
|
767
|
+
std::unique_ptr<float []> x_pca_del;
|
768
|
+
if (do_pca) {
|
769
|
+
pca.have_bias = false; // for consistency with reference implem
|
770
|
+
pca.train (n, x_norm.get());
|
771
|
+
x_pca = pca.apply (n, x_norm.get());
|
772
|
+
x_pca_del.reset(x_pca);
|
773
|
+
} else {
|
774
|
+
x_pca = x_norm.get();
|
775
|
+
}
|
776
|
+
|
777
|
+
// train ITQ
|
778
|
+
itq.train (n, x_pca);
|
779
|
+
|
780
|
+
// merge PCA and ITQ
|
781
|
+
if (do_pca) {
|
782
|
+
FINTEGER di = d_out, dini = d_in;
|
783
|
+
float one = 1, zero = 0;
|
784
|
+
pca_then_itq.A.resize(d_in * d_out);
|
785
|
+
sgemm_ ("N", "N", &dini, &di, &di,
|
786
|
+
&one, pca.A.data(), &dini,
|
787
|
+
itq.A.data(), &di,
|
788
|
+
&zero, pca_then_itq.A.data(), &dini);
|
789
|
+
} else {
|
790
|
+
pca_then_itq.A = itq.A;
|
791
|
+
}
|
792
|
+
pca_then_itq.is_trained = true;
|
793
|
+
is_trained = true;
|
794
|
+
}
|
795
|
+
|
796
|
+
void ITQTransform::apply_noalloc (Index::idx_t n, const float * x,
|
797
|
+
float * xt) const
|
798
|
+
{
|
799
|
+
FAISS_THROW_IF_NOT_MSG(is_trained, "Transformation not trained yet");
|
800
|
+
|
801
|
+
std::unique_ptr<float []> x_norm(new float[n * d_in]);
|
802
|
+
{ // normalize
|
803
|
+
int d = d_in;
|
804
|
+
for (idx_t i = 0; i < n; i++) {
|
805
|
+
for (idx_t j = 0; j < d; j++) {
|
806
|
+
x_norm[i * d + j] = x[i * d + j] - mean[j];
|
807
|
+
}
|
808
|
+
}
|
809
|
+
// this is not really useful if we are going to binarize right
|
810
|
+
// afterwards but OK
|
811
|
+
fvec_renorm_L2 (d_in, n, x_norm.get());
|
812
|
+
}
|
813
|
+
|
814
|
+
pca_then_itq.apply_noalloc (n, x_norm.get(), xt);
|
815
|
+
}
|
816
|
+
|
817
|
+
/*********************************************
|
818
|
+
* OPQMatrix
|
819
|
+
*********************************************/
|
820
|
+
|
821
|
+
|
822
|
+
OPQMatrix::OPQMatrix (int d, int M, int d2):
|
823
|
+
LinearTransform (d, d2 == -1 ? d : d2, false), M(M),
|
824
|
+
niter (50),
|
825
|
+
niter_pq (4), niter_pq_0 (40),
|
826
|
+
verbose(false),
|
827
|
+
pq(nullptr)
|
828
|
+
{
|
829
|
+
is_trained = false;
|
830
|
+
// OPQ is quite expensive to train, so set this right.
|
831
|
+
max_train_points = 256 * 256;
|
832
|
+
pq = nullptr;
|
833
|
+
}
|
834
|
+
|
835
|
+
|
836
|
+
|
837
|
+
void OPQMatrix::train (Index::idx_t n, const float *x)
|
838
|
+
{
|
839
|
+
|
840
|
+
const float * x_in = x;
|
841
|
+
|
842
|
+
x = fvecs_maybe_subsample (d_in, (size_t*)&n,
|
843
|
+
max_train_points, x, verbose);
|
844
|
+
|
845
|
+
ScopeDeleter<float> del_x (x != x_in ? x : nullptr);
|
846
|
+
|
847
|
+
// To support d_out > d_in, we pad input vectors with 0s to d_out
|
848
|
+
size_t d = d_out <= d_in ? d_in : d_out;
|
849
|
+
size_t d2 = d_out;
|
850
|
+
|
851
|
+
#if 0
|
852
|
+
// what this test shows: the only way of getting bit-exact
|
853
|
+
// reproducible results with sgeqrf and sgesvd seems to be forcing
|
854
|
+
// single-threading.
|
855
|
+
{ // test repro
|
856
|
+
std::vector<float> r (d * d);
|
857
|
+
float * rotation = r.data();
|
858
|
+
float_randn (rotation, d * d, 1234);
|
859
|
+
printf("CS0: %016lx\n",
|
860
|
+
ivec_checksum (128*128, (int*)rotation));
|
861
|
+
matrix_qr (d, d, rotation);
|
862
|
+
printf("CS1: %016lx\n",
|
863
|
+
ivec_checksum (128*128, (int*)rotation));
|
864
|
+
return;
|
865
|
+
}
|
866
|
+
#endif
|
867
|
+
|
868
|
+
if (verbose) {
|
869
|
+
printf ("OPQMatrix::train: training an OPQ rotation matrix "
|
870
|
+
"for M=%d from %ld vectors in %dD -> %dD\n",
|
871
|
+
M, n, d_in, d_out);
|
872
|
+
}
|
873
|
+
|
874
|
+
std::vector<float> xtrain (n * d);
|
875
|
+
// center x
|
876
|
+
{
|
877
|
+
std::vector<float> sum (d);
|
878
|
+
const float *xi = x;
|
879
|
+
for (size_t i = 0; i < n; i++) {
|
880
|
+
for (int j = 0; j < d_in; j++)
|
881
|
+
sum [j] += *xi++;
|
882
|
+
}
|
883
|
+
for (int i = 0; i < d; i++) sum[i] /= n;
|
884
|
+
float *yi = xtrain.data();
|
885
|
+
xi = x;
|
886
|
+
for (size_t i = 0; i < n; i++) {
|
887
|
+
for (int j = 0; j < d_in; j++)
|
888
|
+
*yi++ = *xi++ - sum[j];
|
889
|
+
yi += d - d_in;
|
890
|
+
}
|
891
|
+
}
|
892
|
+
float *rotation;
|
893
|
+
|
894
|
+
if (A.size () == 0) {
|
895
|
+
A.resize (d * d);
|
896
|
+
rotation = A.data();
|
897
|
+
if (verbose)
|
898
|
+
printf(" OPQMatrix::train: making random %ld*%ld rotation\n",
|
899
|
+
d, d);
|
900
|
+
float_randn (rotation, d * d, 1234);
|
901
|
+
matrix_qr (d, d, rotation);
|
902
|
+
// we use only the d * d2 upper part of the matrix
|
903
|
+
A.resize (d * d2);
|
904
|
+
} else {
|
905
|
+
FAISS_THROW_IF_NOT (A.size() == d * d2);
|
906
|
+
rotation = A.data();
|
907
|
+
}
|
908
|
+
|
909
|
+
std::vector<float>
|
910
|
+
xproj (d2 * n), pq_recons (d2 * n), xxr (d * n),
|
911
|
+
tmp(d * d * 4);
|
912
|
+
|
913
|
+
|
914
|
+
ProductQuantizer pq_default (d2, M, 8);
|
915
|
+
ProductQuantizer &pq_regular = pq ? *pq : pq_default;
|
916
|
+
std::vector<uint8_t> codes (pq_regular.code_size * n);
|
917
|
+
|
918
|
+
double t0 = getmillisecs();
|
919
|
+
for (int iter = 0; iter < niter; iter++) {
|
920
|
+
|
921
|
+
{ // torch.mm(xtrain, rotation:t())
|
922
|
+
FINTEGER di = d, d2i = d2, ni = n;
|
923
|
+
float zero = 0, one = 1;
|
924
|
+
sgemm_ ("Transposed", "Not transposed",
|
925
|
+
&d2i, &ni, &di,
|
926
|
+
&one, rotation, &di,
|
927
|
+
xtrain.data(), &di,
|
928
|
+
&zero, xproj.data(), &d2i);
|
929
|
+
}
|
930
|
+
|
931
|
+
pq_regular.cp.max_points_per_centroid = 1000;
|
932
|
+
pq_regular.cp.niter = iter == 0 ? niter_pq_0 : niter_pq;
|
933
|
+
pq_regular.verbose = verbose;
|
934
|
+
pq_regular.train (n, xproj.data());
|
935
|
+
|
936
|
+
if (verbose) {
|
937
|
+
printf(" encode / decode\n");
|
938
|
+
}
|
939
|
+
if (pq_regular.assign_index) {
|
940
|
+
pq_regular.compute_codes_with_assign_index
|
941
|
+
(xproj.data(), codes.data(), n);
|
942
|
+
} else {
|
943
|
+
pq_regular.compute_codes (xproj.data(), codes.data(), n);
|
944
|
+
}
|
945
|
+
pq_regular.decode (codes.data(), pq_recons.data(), n);
|
946
|
+
|
947
|
+
float pq_err = fvec_L2sqr (pq_recons.data(), xproj.data(), n * d2) / n;
|
948
|
+
|
949
|
+
if (verbose)
|
950
|
+
printf (" Iteration %d (%d PQ iterations):"
|
951
|
+
"%.3f s, obj=%g\n", iter, pq_regular.cp.niter,
|
952
|
+
(getmillisecs () - t0) / 1000.0, pq_err);
|
953
|
+
|
954
|
+
{
|
955
|
+
float *u = tmp.data(), *vt = &tmp [d * d];
|
956
|
+
float *sing_val = &tmp [2 * d * d];
|
957
|
+
FINTEGER di = d, d2i = d2, ni = n;
|
958
|
+
float one = 1, zero = 0;
|
959
|
+
|
960
|
+
if (verbose) {
|
961
|
+
printf(" X * recons\n");
|
962
|
+
}
|
963
|
+
// torch.mm(xtrain:t(), pq_recons)
|
964
|
+
sgemm_ ("Not", "Transposed",
|
965
|
+
&d2i, &di, &ni,
|
966
|
+
&one, pq_recons.data(), &d2i,
|
967
|
+
xtrain.data(), &di,
|
968
|
+
&zero, xxr.data(), &d2i);
|
969
|
+
|
970
|
+
|
971
|
+
FINTEGER lwork = -1, info = -1;
|
972
|
+
float worksz;
|
973
|
+
// workspace query
|
974
|
+
sgesvd_ ("All", "All",
|
975
|
+
&d2i, &di, xxr.data(), &d2i,
|
976
|
+
sing_val,
|
977
|
+
vt, &d2i, u, &di,
|
978
|
+
&worksz, &lwork, &info);
|
979
|
+
|
980
|
+
lwork = int(worksz);
|
981
|
+
std::vector<float> work (lwork);
|
982
|
+
// u and vt swapped
|
983
|
+
sgesvd_ ("All", "All",
|
984
|
+
&d2i, &di, xxr.data(), &d2i,
|
985
|
+
sing_val,
|
986
|
+
vt, &d2i, u, &di,
|
987
|
+
work.data(), &lwork, &info);
|
988
|
+
|
989
|
+
sgemm_ ("Transposed", "Transposed",
|
990
|
+
&di, &d2i, &d2i,
|
991
|
+
&one, u, &di, vt, &d2i,
|
992
|
+
&zero, rotation, &di);
|
993
|
+
|
994
|
+
}
|
995
|
+
pq_regular.train_type = ProductQuantizer::Train_hot_start;
|
996
|
+
}
|
997
|
+
|
998
|
+
// revert A matrix
|
999
|
+
if (d > d_in) {
|
1000
|
+
for (long i = 0; i < d_out; i++)
|
1001
|
+
memmove (&A[i * d_in], &A[i * d], sizeof(A[0]) * d_in);
|
1002
|
+
A.resize (d_in * d_out);
|
1003
|
+
}
|
1004
|
+
|
1005
|
+
is_trained = true;
|
1006
|
+
is_orthonormal = true;
|
1007
|
+
}
|
1008
|
+
|
1009
|
+
|
1010
|
+
/*********************************************
|
1011
|
+
* NormalizationTransform
|
1012
|
+
*********************************************/
|
1013
|
+
|
1014
|
+
NormalizationTransform::NormalizationTransform (int d, float norm):
|
1015
|
+
VectorTransform (d, d), norm (norm)
|
1016
|
+
{
|
1017
|
+
}
|
1018
|
+
|
1019
|
+
NormalizationTransform::NormalizationTransform ():
|
1020
|
+
VectorTransform (-1, -1), norm (-1)
|
1021
|
+
{
|
1022
|
+
}
|
1023
|
+
|
1024
|
+
void NormalizationTransform::apply_noalloc
|
1025
|
+
(idx_t n, const float* x, float* xt) const
|
1026
|
+
{
|
1027
|
+
if (norm == 2.0) {
|
1028
|
+
memcpy (xt, x, sizeof (x[0]) * n * d_in);
|
1029
|
+
fvec_renorm_L2 (d_in, n, xt);
|
1030
|
+
} else {
|
1031
|
+
FAISS_THROW_MSG ("not implemented");
|
1032
|
+
}
|
1033
|
+
}
|
1034
|
+
|
1035
|
+
void NormalizationTransform::reverse_transform (idx_t n, const float* xt,
|
1036
|
+
float* x) const
|
1037
|
+
{
|
1038
|
+
memcpy (x, xt, sizeof (xt[0]) * n * d_in);
|
1039
|
+
}
|
1040
|
+
|
1041
|
+
/*********************************************
|
1042
|
+
* CenteringTransform
|
1043
|
+
*********************************************/
|
1044
|
+
|
1045
|
+
CenteringTransform::CenteringTransform (int d):
|
1046
|
+
VectorTransform (d, d)
|
1047
|
+
{
|
1048
|
+
is_trained = false;
|
1049
|
+
}
|
1050
|
+
|
1051
|
+
void CenteringTransform::train(Index::idx_t n, const float *x) {
|
1052
|
+
FAISS_THROW_IF_NOT_MSG(n > 0, "need at least one training vector");
|
1053
|
+
mean.resize (d_in, 0);
|
1054
|
+
for (idx_t i = 0; i < n; i++) {
|
1055
|
+
for (size_t j = 0; j < d_in; j++) {
|
1056
|
+
mean[j] += *x++;
|
1057
|
+
}
|
1058
|
+
}
|
1059
|
+
|
1060
|
+
for (size_t j = 0; j < d_in; j++) {
|
1061
|
+
mean[j] /= n;
|
1062
|
+
}
|
1063
|
+
is_trained = true;
|
1064
|
+
}
|
1065
|
+
|
1066
|
+
|
1067
|
+
void CenteringTransform::apply_noalloc
|
1068
|
+
(idx_t n, const float* x, float* xt) const
|
1069
|
+
{
|
1070
|
+
FAISS_THROW_IF_NOT (is_trained);
|
1071
|
+
|
1072
|
+
for (idx_t i = 0; i < n; i++) {
|
1073
|
+
for (size_t j = 0; j < d_in; j++) {
|
1074
|
+
*xt++ = *x++ - mean[j];
|
1075
|
+
}
|
1076
|
+
}
|
1077
|
+
}
|
1078
|
+
|
1079
|
+
void CenteringTransform::reverse_transform (idx_t n, const float* xt,
|
1080
|
+
float* x) const
|
1081
|
+
{
|
1082
|
+
FAISS_THROW_IF_NOT (is_trained);
|
1083
|
+
|
1084
|
+
for (idx_t i = 0; i < n; i++) {
|
1085
|
+
for (size_t j = 0; j < d_in; j++) {
|
1086
|
+
*x++ = *xt++ + mean[j];
|
1087
|
+
}
|
1088
|
+
}
|
1089
|
+
|
1090
|
+
}
|
1091
|
+
|
1092
|
+
|
1093
|
+
|
1094
|
+
|
1095
|
+
|
1096
|
+
/*********************************************
|
1097
|
+
* RemapDimensionsTransform
|
1098
|
+
*********************************************/
|
1099
|
+
|
1100
|
+
|
1101
|
+
RemapDimensionsTransform::RemapDimensionsTransform (
|
1102
|
+
int d_in, int d_out, const int *map_in):
|
1103
|
+
VectorTransform (d_in, d_out)
|
1104
|
+
{
|
1105
|
+
map.resize (d_out);
|
1106
|
+
for (int i = 0; i < d_out; i++) {
|
1107
|
+
map[i] = map_in[i];
|
1108
|
+
FAISS_THROW_IF_NOT (map[i] == -1 || (map[i] >= 0 && map[i] < d_in));
|
1109
|
+
}
|
1110
|
+
}
|
1111
|
+
|
1112
|
+
RemapDimensionsTransform::RemapDimensionsTransform (
|
1113
|
+
int d_in, int d_out, bool uniform): VectorTransform (d_in, d_out)
|
1114
|
+
{
|
1115
|
+
map.resize (d_out, -1);
|
1116
|
+
|
1117
|
+
if (uniform) {
|
1118
|
+
if (d_in < d_out) {
|
1119
|
+
for (int i = 0; i < d_in; i++) {
|
1120
|
+
map [i * d_out / d_in] = i;
|
1121
|
+
}
|
1122
|
+
} else {
|
1123
|
+
for (int i = 0; i < d_out; i++) {
|
1124
|
+
map [i] = i * d_in / d_out;
|
1125
|
+
}
|
1126
|
+
}
|
1127
|
+
} else {
|
1128
|
+
for (int i = 0; i < d_in && i < d_out; i++)
|
1129
|
+
map [i] = i;
|
1130
|
+
}
|
1131
|
+
}
|
1132
|
+
|
1133
|
+
|
1134
|
+
void RemapDimensionsTransform::apply_noalloc (idx_t n, const float * x,
|
1135
|
+
float *xt) const
|
1136
|
+
{
|
1137
|
+
for (idx_t i = 0; i < n; i++) {
|
1138
|
+
for (int j = 0; j < d_out; j++) {
|
1139
|
+
xt[j] = map[j] < 0 ? 0 : x[map[j]];
|
1140
|
+
}
|
1141
|
+
x += d_in;
|
1142
|
+
xt += d_out;
|
1143
|
+
}
|
1144
|
+
}
|
1145
|
+
|
1146
|
+
void RemapDimensionsTransform::reverse_transform (idx_t n, const float * xt,
|
1147
|
+
float *x) const
|
1148
|
+
{
|
1149
|
+
memset (x, 0, sizeof (*x) * n * d_in);
|
1150
|
+
for (idx_t i = 0; i < n; i++) {
|
1151
|
+
for (int j = 0; j < d_out; j++) {
|
1152
|
+
if (map[j] >= 0) x[map[j]] = xt[j];
|
1153
|
+
}
|
1154
|
+
x += d_in;
|
1155
|
+
xt += d_out;
|
1156
|
+
}
|
1157
|
+
}
|