faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
|
|
11
|
+
#define FAISS_POLYSEMOUS_TRAINING_INCLUDED
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#include <faiss/impl/ProductQuantizer.h>
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
namespace faiss {
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
/// parameters used for the simulated annealing method
|
|
21
|
+
struct SimulatedAnnealingParameters {
|
|
22
|
+
|
|
23
|
+
// optimization parameters
|
|
24
|
+
double init_temperature; // init probaility of accepting a bad swap
|
|
25
|
+
double temperature_decay; // at each iteration the temp is multiplied by this
|
|
26
|
+
int n_iter; // nb of iterations
|
|
27
|
+
int n_redo; // nb of runs of the simulation
|
|
28
|
+
int seed; // random seed
|
|
29
|
+
int verbose;
|
|
30
|
+
bool only_bit_flips; // restrict permutation changes to bit flips
|
|
31
|
+
bool init_random; // intialize with a random permutation (not identity)
|
|
32
|
+
|
|
33
|
+
// set reasonable defaults
|
|
34
|
+
SimulatedAnnealingParameters ();
|
|
35
|
+
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
/// abstract class for the loss function
|
|
40
|
+
struct PermutationObjective {
|
|
41
|
+
|
|
42
|
+
int n;
|
|
43
|
+
|
|
44
|
+
virtual double compute_cost (const int *perm) const = 0;
|
|
45
|
+
|
|
46
|
+
// what would the cost update be if iw and jw were swapped?
|
|
47
|
+
// default implementation just computes both and computes the difference
|
|
48
|
+
virtual double cost_update (const int *perm, int iw, int jw) const;
|
|
49
|
+
|
|
50
|
+
virtual ~PermutationObjective () {}
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
struct ReproduceDistancesObjective : PermutationObjective {
|
|
55
|
+
|
|
56
|
+
double dis_weight_factor;
|
|
57
|
+
|
|
58
|
+
static double sqr (double x) { return x * x; }
|
|
59
|
+
|
|
60
|
+
// weihgting of distances: it is more important to reproduce small
|
|
61
|
+
// distances well
|
|
62
|
+
double dis_weight (double x) const;
|
|
63
|
+
|
|
64
|
+
std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
|
|
65
|
+
const double * target_dis; ///< wanted distances (size n^2)
|
|
66
|
+
std::vector<double> weights; ///< weights for each distance (size n^2)
|
|
67
|
+
|
|
68
|
+
double get_source_dis (int i, int j) const;
|
|
69
|
+
|
|
70
|
+
// cost = quadratic difference between actual distance and Hamming distance
|
|
71
|
+
double compute_cost(const int* perm) const override;
|
|
72
|
+
|
|
73
|
+
// what would the cost update be if iw and jw were swapped?
|
|
74
|
+
// computed in O(n) instead of O(n^2) for the full re-computation
|
|
75
|
+
double cost_update(const int* perm, int iw, int jw) const override;
|
|
76
|
+
|
|
77
|
+
ReproduceDistancesObjective (
|
|
78
|
+
int n,
|
|
79
|
+
const double *source_dis_in,
|
|
80
|
+
const double *target_dis_in,
|
|
81
|
+
double dis_weight_factor);
|
|
82
|
+
|
|
83
|
+
static void compute_mean_stdev (const double *tab, size_t n2,
|
|
84
|
+
double *mean_out, double *stddev_out);
|
|
85
|
+
|
|
86
|
+
void set_affine_target_dis (const double *source_dis_in);
|
|
87
|
+
|
|
88
|
+
~ReproduceDistancesObjective() override {}
|
|
89
|
+
};
|
|
90
|
+
|
|
91
|
+
struct RandomGenerator;
|
|
92
|
+
|
|
93
|
+
/// Simulated annealing optimization algorithm for permutations.
|
|
94
|
+
struct SimulatedAnnealingOptimizer: SimulatedAnnealingParameters {
|
|
95
|
+
|
|
96
|
+
PermutationObjective *obj;
|
|
97
|
+
int n; ///< size of the permutation
|
|
98
|
+
FILE *logfile; /// logs values of the cost function
|
|
99
|
+
|
|
100
|
+
SimulatedAnnealingOptimizer (PermutationObjective *obj,
|
|
101
|
+
const SimulatedAnnealingParameters &p);
|
|
102
|
+
RandomGenerator *rnd;
|
|
103
|
+
|
|
104
|
+
/// remember intial cost of optimization
|
|
105
|
+
double init_cost;
|
|
106
|
+
|
|
107
|
+
// main entry point. Perform the optimization loop, starting from
|
|
108
|
+
// and modifying permutation in-place
|
|
109
|
+
double optimize (int *perm);
|
|
110
|
+
|
|
111
|
+
// run the optimization and return the best result in best_perm
|
|
112
|
+
double run_optimization (int * best_perm);
|
|
113
|
+
|
|
114
|
+
virtual ~SimulatedAnnealingOptimizer ();
|
|
115
|
+
};
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
/// optimizes the order of indices in a ProductQuantizer
|
|
121
|
+
struct PolysemousTraining: SimulatedAnnealingParameters {
|
|
122
|
+
|
|
123
|
+
enum Optimization_type_t {
|
|
124
|
+
OT_None,
|
|
125
|
+
OT_ReproduceDistances_affine, ///< default
|
|
126
|
+
OT_Ranking_weighted_diff /// same as _2, but use rank of y+ - rank of y-
|
|
127
|
+
};
|
|
128
|
+
Optimization_type_t optimization_type;
|
|
129
|
+
|
|
130
|
+
// use 1/4 of the training points for the optimization, with
|
|
131
|
+
// max. ntrain_permutation. If ntrain_permutation == 0: train on
|
|
132
|
+
// centroids
|
|
133
|
+
int ntrain_permutation;
|
|
134
|
+
double dis_weight_factor; // decay of exp that weights distance loss
|
|
135
|
+
|
|
136
|
+
// filename pattern for the logging of iterations
|
|
137
|
+
std::string log_pattern;
|
|
138
|
+
|
|
139
|
+
// sets default values
|
|
140
|
+
PolysemousTraining ();
|
|
141
|
+
|
|
142
|
+
/// reorder the centroids so that the Hamming distace becomes a
|
|
143
|
+
/// good approximation of the SDC distance (called by train)
|
|
144
|
+
void optimize_pq_for_hamming (ProductQuantizer & pq,
|
|
145
|
+
size_t n, const float *x) const;
|
|
146
|
+
|
|
147
|
+
/// called by optimize_pq_for_hamming
|
|
148
|
+
void optimize_ranking (ProductQuantizer &pq, size_t n, const float *x) const;
|
|
149
|
+
/// called by optimize_pq_for_hamming
|
|
150
|
+
void optimize_reproduce_distances (ProductQuantizer &pq) const;
|
|
151
|
+
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
} // namespace faiss
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
#endif
|
|
@@ -0,0 +1,876 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/ProductQuantizer.h>
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
#include <cstddef>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <memory>
|
|
17
|
+
|
|
18
|
+
#include <algorithm>
|
|
19
|
+
|
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
|
21
|
+
#include <faiss/VectorTransform.h>
|
|
22
|
+
#include <faiss/IndexFlat.h>
|
|
23
|
+
#include <faiss/utils/distances.h>
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
extern "C" {
|
|
27
|
+
|
|
28
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
29
|
+
|
|
30
|
+
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
31
|
+
n, FINTEGER *k, const float *alpha, const float *a,
|
|
32
|
+
FINTEGER *lda, const float *b, FINTEGER *
|
|
33
|
+
ldb, float *beta, float *c, FINTEGER *ldc);
|
|
34
|
+
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
namespace faiss {
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
/* compute an estimator using look-up tables for typical values of M */
|
|
42
|
+
template <typename CT, class C>
|
|
43
|
+
void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
|
|
44
|
+
size_t ncodes,
|
|
45
|
+
const float * __restrict dis_table,
|
|
46
|
+
size_t ksub,
|
|
47
|
+
size_t k,
|
|
48
|
+
float * heap_dis,
|
|
49
|
+
int64_t * heap_ids)
|
|
50
|
+
{
|
|
51
|
+
|
|
52
|
+
for (size_t j = 0; j < ncodes; j++) {
|
|
53
|
+
float dis = 0;
|
|
54
|
+
const float *dt = dis_table;
|
|
55
|
+
|
|
56
|
+
for (size_t m = 0; m < M; m+=4) {
|
|
57
|
+
float dism = 0;
|
|
58
|
+
dism = dt[*codes++]; dt += ksub;
|
|
59
|
+
dism += dt[*codes++]; dt += ksub;
|
|
60
|
+
dism += dt[*codes++]; dt += ksub;
|
|
61
|
+
dism += dt[*codes++]; dt += ksub;
|
|
62
|
+
dis += dism;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if (C::cmp (heap_dis[0], dis)) {
|
|
66
|
+
heap_pop<C> (k, heap_dis, heap_ids);
|
|
67
|
+
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
template <typename CT, class C>
|
|
74
|
+
void pq_estimators_from_tables_M4 (const CT * codes,
|
|
75
|
+
size_t ncodes,
|
|
76
|
+
const float * __restrict dis_table,
|
|
77
|
+
size_t ksub,
|
|
78
|
+
size_t k,
|
|
79
|
+
float * heap_dis,
|
|
80
|
+
int64_t * heap_ids)
|
|
81
|
+
{
|
|
82
|
+
|
|
83
|
+
for (size_t j = 0; j < ncodes; j++) {
|
|
84
|
+
float dis = 0;
|
|
85
|
+
const float *dt = dis_table;
|
|
86
|
+
dis = dt[*codes++]; dt += ksub;
|
|
87
|
+
dis += dt[*codes++]; dt += ksub;
|
|
88
|
+
dis += dt[*codes++]; dt += ksub;
|
|
89
|
+
dis += dt[*codes++];
|
|
90
|
+
|
|
91
|
+
if (C::cmp (heap_dis[0], dis)) {
|
|
92
|
+
heap_pop<C> (k, heap_dis, heap_ids);
|
|
93
|
+
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
template <typename CT, class C>
|
|
100
|
+
static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
|
|
101
|
+
const CT * codes,
|
|
102
|
+
size_t ncodes,
|
|
103
|
+
const float * dis_table,
|
|
104
|
+
size_t k,
|
|
105
|
+
float * heap_dis,
|
|
106
|
+
int64_t * heap_ids)
|
|
107
|
+
{
|
|
108
|
+
|
|
109
|
+
if (pq.M == 4) {
|
|
110
|
+
|
|
111
|
+
pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
|
|
112
|
+
dis_table, pq.ksub, k,
|
|
113
|
+
heap_dis, heap_ids);
|
|
114
|
+
return;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
if (pq.M % 4 == 0) {
|
|
118
|
+
pq_estimators_from_tables_Mmul4<CT, C> (pq.M, codes, ncodes,
|
|
119
|
+
dis_table, pq.ksub, k,
|
|
120
|
+
heap_dis, heap_ids);
|
|
121
|
+
return;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
/* Default is relatively slow */
|
|
125
|
+
const size_t M = pq.M;
|
|
126
|
+
const size_t ksub = pq.ksub;
|
|
127
|
+
for (size_t j = 0; j < ncodes; j++) {
|
|
128
|
+
float dis = 0;
|
|
129
|
+
const float * __restrict dt = dis_table;
|
|
130
|
+
for (int m = 0; m < M; m++) {
|
|
131
|
+
dis += dt[*codes++];
|
|
132
|
+
dt += ksub;
|
|
133
|
+
}
|
|
134
|
+
if (C::cmp (heap_dis[0], dis)) {
|
|
135
|
+
heap_pop<C> (k, heap_dis, heap_ids);
|
|
136
|
+
heap_push<C> (k, heap_dis, heap_ids, dis, j);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
template <class C>
|
|
142
|
+
static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
|
|
143
|
+
size_t nbits,
|
|
144
|
+
const uint8_t *codes,
|
|
145
|
+
size_t ncodes,
|
|
146
|
+
const float *dis_table,
|
|
147
|
+
size_t k,
|
|
148
|
+
float *heap_dis,
|
|
149
|
+
int64_t *heap_ids)
|
|
150
|
+
{
|
|
151
|
+
const size_t M = pq.M;
|
|
152
|
+
const size_t ksub = pq.ksub;
|
|
153
|
+
for (size_t j = 0; j < ncodes; ++j) {
|
|
154
|
+
faiss::ProductQuantizer::PQDecoderGeneric decoder(
|
|
155
|
+
codes + j * pq.code_size, nbits
|
|
156
|
+
);
|
|
157
|
+
float dis = 0;
|
|
158
|
+
const float * __restrict dt = dis_table;
|
|
159
|
+
for (size_t m = 0; m < M; m++) {
|
|
160
|
+
uint64_t c = decoder.decode();
|
|
161
|
+
dis += dt[c];
|
|
162
|
+
dt += ksub;
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
166
|
+
heap_pop<C>(k, heap_dis, heap_ids);
|
|
167
|
+
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
/*********************************************
|
|
173
|
+
* PQ implementation
|
|
174
|
+
*********************************************/
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
ProductQuantizer::ProductQuantizer (size_t d, size_t M, size_t nbits):
|
|
179
|
+
d(d), M(M), nbits(nbits), assign_index(nullptr)
|
|
180
|
+
{
|
|
181
|
+
set_derived_values ();
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
ProductQuantizer::ProductQuantizer ()
|
|
185
|
+
: ProductQuantizer(0, 1, 0) {}
|
|
186
|
+
|
|
187
|
+
void ProductQuantizer::set_derived_values () {
|
|
188
|
+
// quite a few derived values
|
|
189
|
+
FAISS_THROW_IF_NOT (d % M == 0);
|
|
190
|
+
dsub = d / M;
|
|
191
|
+
code_size = (nbits * M + 7) / 8;
|
|
192
|
+
ksub = 1 << nbits;
|
|
193
|
+
centroids.resize (d * ksub);
|
|
194
|
+
verbose = false;
|
|
195
|
+
train_type = Train_default;
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
void ProductQuantizer::set_params (const float * centroids_, int m)
|
|
199
|
+
{
|
|
200
|
+
memcpy (get_centroids(m, 0), centroids_,
|
|
201
|
+
ksub * dsub * sizeof (centroids_[0]));
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
static void init_hypercube (int d, int nbits,
|
|
206
|
+
int n, const float * x,
|
|
207
|
+
float *centroids)
|
|
208
|
+
{
|
|
209
|
+
|
|
210
|
+
std::vector<float> mean (d);
|
|
211
|
+
for (int i = 0; i < n; i++)
|
|
212
|
+
for (int j = 0; j < d; j++)
|
|
213
|
+
mean [j] += x[i * d + j];
|
|
214
|
+
|
|
215
|
+
float maxm = 0;
|
|
216
|
+
for (int j = 0; j < d; j++) {
|
|
217
|
+
mean [j] /= n;
|
|
218
|
+
if (fabs(mean[j]) > maxm) maxm = fabs(mean[j]);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
for (int i = 0; i < (1 << nbits); i++) {
|
|
222
|
+
float * cent = centroids + i * d;
|
|
223
|
+
for (int j = 0; j < nbits; j++)
|
|
224
|
+
cent[j] = mean [j] + (((i >> j) & 1) ? 1 : -1) * maxm;
|
|
225
|
+
for (int j = nbits; j < d; j++)
|
|
226
|
+
cent[j] = mean [j];
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
static void init_hypercube_pca (int d, int nbits,
|
|
233
|
+
int n, const float * x,
|
|
234
|
+
float *centroids)
|
|
235
|
+
{
|
|
236
|
+
PCAMatrix pca (d, nbits);
|
|
237
|
+
pca.train (n, x);
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
for (int i = 0; i < (1 << nbits); i++) {
|
|
241
|
+
float * cent = centroids + i * d;
|
|
242
|
+
for (int j = 0; j < d; j++) {
|
|
243
|
+
cent[j] = pca.mean[j];
|
|
244
|
+
float f = 1.0;
|
|
245
|
+
for (int k = 0; k < nbits; k++)
|
|
246
|
+
cent[j] += f *
|
|
247
|
+
sqrt (pca.eigenvalues [k]) *
|
|
248
|
+
(((i >> k) & 1) ? 1 : -1) *
|
|
249
|
+
pca.PCAMat [j + k * d];
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
void ProductQuantizer::train (int n, const float * x)
|
|
256
|
+
{
|
|
257
|
+
if (train_type != Train_shared) {
|
|
258
|
+
train_type_t final_train_type;
|
|
259
|
+
final_train_type = train_type;
|
|
260
|
+
if (train_type == Train_hypercube ||
|
|
261
|
+
train_type == Train_hypercube_pca) {
|
|
262
|
+
if (dsub < nbits) {
|
|
263
|
+
final_train_type = Train_default;
|
|
264
|
+
printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
|
|
265
|
+
nbits, dsub);
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
float * xslice = new float[n * dsub];
|
|
270
|
+
ScopeDeleter<float> del (xslice);
|
|
271
|
+
for (int m = 0; m < M; m++) {
|
|
272
|
+
for (int j = 0; j < n; j++)
|
|
273
|
+
memcpy (xslice + j * dsub,
|
|
274
|
+
x + j * d + m * dsub,
|
|
275
|
+
dsub * sizeof(float));
|
|
276
|
+
|
|
277
|
+
Clustering clus (dsub, ksub, cp);
|
|
278
|
+
|
|
279
|
+
// we have some initialization for the centroids
|
|
280
|
+
if (final_train_type != Train_default) {
|
|
281
|
+
clus.centroids.resize (dsub * ksub);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
switch (final_train_type) {
|
|
285
|
+
case Train_hypercube:
|
|
286
|
+
init_hypercube (dsub, nbits, n, xslice,
|
|
287
|
+
clus.centroids.data ());
|
|
288
|
+
break;
|
|
289
|
+
case Train_hypercube_pca:
|
|
290
|
+
init_hypercube_pca (dsub, nbits, n, xslice,
|
|
291
|
+
clus.centroids.data ());
|
|
292
|
+
break;
|
|
293
|
+
case Train_hot_start:
|
|
294
|
+
memcpy (clus.centroids.data(),
|
|
295
|
+
get_centroids (m, 0),
|
|
296
|
+
dsub * ksub * sizeof (float));
|
|
297
|
+
break;
|
|
298
|
+
default: ;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
if(verbose) {
|
|
302
|
+
clus.verbose = true;
|
|
303
|
+
printf ("Training PQ slice %d/%zd\n", m, M);
|
|
304
|
+
}
|
|
305
|
+
IndexFlatL2 index (dsub);
|
|
306
|
+
clus.train (n, xslice, assign_index ? *assign_index : index);
|
|
307
|
+
set_params (clus.centroids.data(), m);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
} else {
|
|
312
|
+
|
|
313
|
+
Clustering clus (dsub, ksub, cp);
|
|
314
|
+
|
|
315
|
+
if(verbose) {
|
|
316
|
+
clus.verbose = true;
|
|
317
|
+
printf ("Training all PQ slices at once\n");
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
IndexFlatL2 index (dsub);
|
|
321
|
+
|
|
322
|
+
clus.train (n * M, x, assign_index ? *assign_index : index);
|
|
323
|
+
for (int m = 0; m < M; m++) {
|
|
324
|
+
set_params (clus.centroids.data(), m);
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
template<class PQEncoder>
|
|
331
|
+
void compute_code(const ProductQuantizer& pq, const float *x, uint8_t *code) {
|
|
332
|
+
float distances [pq.ksub];
|
|
333
|
+
PQEncoder encoder(code, pq.nbits);
|
|
334
|
+
for (size_t m = 0; m < pq.M; m++) {
|
|
335
|
+
float mindis = 1e20;
|
|
336
|
+
uint64_t idxm = 0;
|
|
337
|
+
const float * xsub = x + m * pq.dsub;
|
|
338
|
+
|
|
339
|
+
fvec_L2sqr_ny(distances, xsub, pq.get_centroids(m, 0), pq.dsub, pq.ksub);
|
|
340
|
+
|
|
341
|
+
/* Find best centroid */
|
|
342
|
+
for (size_t i = 0; i < pq.ksub; i++) {
|
|
343
|
+
float dis = distances[i];
|
|
344
|
+
if (dis < mindis) {
|
|
345
|
+
mindis = dis;
|
|
346
|
+
idxm = i;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
encoder.encode(idxm);
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
void ProductQuantizer::compute_code(const float * x, uint8_t * code) const {
|
|
355
|
+
switch (nbits) {
|
|
356
|
+
case 8:
|
|
357
|
+
faiss::compute_code<PQEncoder8>(*this, x, code);
|
|
358
|
+
break;
|
|
359
|
+
|
|
360
|
+
case 16:
|
|
361
|
+
faiss::compute_code<PQEncoder16>(*this, x, code);
|
|
362
|
+
break;
|
|
363
|
+
|
|
364
|
+
default:
|
|
365
|
+
faiss::compute_code<PQEncoderGeneric>(*this, x, code);
|
|
366
|
+
break;
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
template<class PQDecoder>
|
|
371
|
+
void decode(const ProductQuantizer& pq, const uint8_t *code, float *x)
|
|
372
|
+
{
|
|
373
|
+
PQDecoder decoder(code, pq.nbits);
|
|
374
|
+
for (size_t m = 0; m < pq.M; m++) {
|
|
375
|
+
uint64_t c = decoder.decode();
|
|
376
|
+
memcpy(x + m * pq.dsub, pq.get_centroids(m, c), sizeof(float) * pq.dsub);
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
void ProductQuantizer::decode (const uint8_t *code, float *x) const
|
|
381
|
+
{
|
|
382
|
+
switch (nbits) {
|
|
383
|
+
case 8:
|
|
384
|
+
faiss::decode<PQDecoder8>(*this, code, x);
|
|
385
|
+
break;
|
|
386
|
+
|
|
387
|
+
case 16:
|
|
388
|
+
faiss::decode<PQDecoder16>(*this, code, x);
|
|
389
|
+
break;
|
|
390
|
+
|
|
391
|
+
default:
|
|
392
|
+
faiss::decode<PQDecoderGeneric>(*this, code, x);
|
|
393
|
+
break;
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
void ProductQuantizer::decode (const uint8_t *code, float *x, size_t n) const
|
|
399
|
+
{
|
|
400
|
+
for (size_t i = 0; i < n; i++) {
|
|
401
|
+
this->decode (code + code_size * i, x + d * i);
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
void ProductQuantizer::compute_code_from_distance_table (const float *tab,
|
|
407
|
+
uint8_t *code) const
|
|
408
|
+
{
|
|
409
|
+
PQEncoderGeneric encoder(code, nbits);
|
|
410
|
+
for (size_t m = 0; m < M; m++) {
|
|
411
|
+
float mindis = 1e20;
|
|
412
|
+
uint64_t idxm = 0;
|
|
413
|
+
|
|
414
|
+
/* Find best centroid */
|
|
415
|
+
for (size_t j = 0; j < ksub; j++) {
|
|
416
|
+
float dis = *tab++;
|
|
417
|
+
if (dis < mindis) {
|
|
418
|
+
mindis = dis;
|
|
419
|
+
idxm = j;
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
encoder.encode(idxm);
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
void ProductQuantizer::compute_codes_with_assign_index (
|
|
428
|
+
const float * x,
|
|
429
|
+
uint8_t * codes,
|
|
430
|
+
size_t n)
|
|
431
|
+
{
|
|
432
|
+
FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
|
|
433
|
+
|
|
434
|
+
for (size_t m = 0; m < M; m++) {
|
|
435
|
+
assign_index->reset ();
|
|
436
|
+
assign_index->add (ksub, get_centroids (m, 0));
|
|
437
|
+
size_t bs = 65536;
|
|
438
|
+
float * xslice = new float[bs * dsub];
|
|
439
|
+
ScopeDeleter<float> del (xslice);
|
|
440
|
+
idx_t *assign = new idx_t[bs];
|
|
441
|
+
ScopeDeleter<idx_t> del2 (assign);
|
|
442
|
+
|
|
443
|
+
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
444
|
+
size_t i1 = std::min(i0 + bs, n);
|
|
445
|
+
|
|
446
|
+
for (size_t i = i0; i < i1; i++) {
|
|
447
|
+
memcpy (xslice + (i - i0) * dsub,
|
|
448
|
+
x + i * d + m * dsub,
|
|
449
|
+
dsub * sizeof(float));
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
assign_index->assign (i1 - i0, xslice, assign);
|
|
453
|
+
|
|
454
|
+
if (nbits == 8) {
|
|
455
|
+
uint8_t *c = codes + code_size * i0 + m;
|
|
456
|
+
for (size_t i = i0; i < i1; i++) {
|
|
457
|
+
*c = assign[i - i0];
|
|
458
|
+
c += M;
|
|
459
|
+
}
|
|
460
|
+
} else if (nbits == 16) {
|
|
461
|
+
uint16_t *c = (uint16_t*)(codes + code_size * i0 + m * 2);
|
|
462
|
+
for (size_t i = i0; i < i1; i++) {
|
|
463
|
+
*c = assign[i - i0];
|
|
464
|
+
c += M;
|
|
465
|
+
}
|
|
466
|
+
} else {
|
|
467
|
+
for (size_t i = i0; i < i1; ++i) {
|
|
468
|
+
uint8_t *c = codes + code_size * i + ((m * nbits) / 8);
|
|
469
|
+
uint8_t offset = (m * nbits) % 8;
|
|
470
|
+
uint64_t ass = assign[i - i0];
|
|
471
|
+
|
|
472
|
+
PQEncoderGeneric encoder(c, nbits, offset);
|
|
473
|
+
encoder.encode(ass);
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
void ProductQuantizer::compute_codes (const float * x,
|
|
483
|
+
uint8_t * codes,
|
|
484
|
+
size_t n) const
|
|
485
|
+
{
|
|
486
|
+
// process by blocks to avoid using too much RAM
|
|
487
|
+
size_t bs = 256 * 1024;
|
|
488
|
+
if (n > bs) {
|
|
489
|
+
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
|
490
|
+
size_t i1 = std::min(i0 + bs, n);
|
|
491
|
+
compute_codes (x + d * i0, codes + code_size * i0, i1 - i0);
|
|
492
|
+
}
|
|
493
|
+
return;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
if (dsub < 16) { // simple direct computation
|
|
497
|
+
|
|
498
|
+
#pragma omp parallel for
|
|
499
|
+
for (size_t i = 0; i < n; i++)
|
|
500
|
+
compute_code (x + i * d, codes + i * code_size);
|
|
501
|
+
|
|
502
|
+
} else { // worthwile to use BLAS
|
|
503
|
+
float *dis_tables = new float [n * ksub * M];
|
|
504
|
+
ScopeDeleter<float> del (dis_tables);
|
|
505
|
+
compute_distance_tables (n, x, dis_tables);
|
|
506
|
+
|
|
507
|
+
#pragma omp parallel for
|
|
508
|
+
for (size_t i = 0; i < n; i++) {
|
|
509
|
+
uint8_t * code = codes + i * code_size;
|
|
510
|
+
const float * tab = dis_tables + i * ksub * M;
|
|
511
|
+
compute_code_from_distance_table (tab, code);
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
void ProductQuantizer::compute_distance_table (const float * x,
|
|
518
|
+
float * dis_table) const
|
|
519
|
+
{
|
|
520
|
+
size_t m;
|
|
521
|
+
|
|
522
|
+
for (m = 0; m < M; m++) {
|
|
523
|
+
fvec_L2sqr_ny (dis_table + m * ksub,
|
|
524
|
+
x + m * dsub,
|
|
525
|
+
get_centroids(m, 0),
|
|
526
|
+
dsub,
|
|
527
|
+
ksub);
|
|
528
|
+
}
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
void ProductQuantizer::compute_inner_prod_table (const float * x,
|
|
532
|
+
float * dis_table) const
|
|
533
|
+
{
|
|
534
|
+
size_t m;
|
|
535
|
+
|
|
536
|
+
for (m = 0; m < M; m++) {
|
|
537
|
+
fvec_inner_products_ny (dis_table + m * ksub,
|
|
538
|
+
x + m * dsub,
|
|
539
|
+
get_centroids(m, 0),
|
|
540
|
+
dsub,
|
|
541
|
+
ksub);
|
|
542
|
+
}
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
void ProductQuantizer::compute_distance_tables (
|
|
547
|
+
size_t nx,
|
|
548
|
+
const float * x,
|
|
549
|
+
float * dis_tables) const
|
|
550
|
+
{
|
|
551
|
+
|
|
552
|
+
if (dsub < 16) {
|
|
553
|
+
|
|
554
|
+
#pragma omp parallel for
|
|
555
|
+
for (size_t i = 0; i < nx; i++) {
|
|
556
|
+
compute_distance_table (x + i * d, dis_tables + i * ksub * M);
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
} else { // use BLAS
|
|
560
|
+
|
|
561
|
+
for (int m = 0; m < M; m++) {
|
|
562
|
+
pairwise_L2sqr (dsub,
|
|
563
|
+
nx, x + dsub * m,
|
|
564
|
+
ksub, centroids.data() + m * dsub * ksub,
|
|
565
|
+
dis_tables + ksub * m,
|
|
566
|
+
d, dsub, ksub * M);
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
void ProductQuantizer::compute_inner_prod_tables (
|
|
572
|
+
size_t nx,
|
|
573
|
+
const float * x,
|
|
574
|
+
float * dis_tables) const
|
|
575
|
+
{
|
|
576
|
+
|
|
577
|
+
if (dsub < 16) {
|
|
578
|
+
|
|
579
|
+
#pragma omp parallel for
|
|
580
|
+
for (size_t i = 0; i < nx; i++) {
|
|
581
|
+
compute_inner_prod_table (x + i * d, dis_tables + i * ksub * M);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
} else { // use BLAS
|
|
585
|
+
|
|
586
|
+
// compute distance tables
|
|
587
|
+
for (int m = 0; m < M; m++) {
|
|
588
|
+
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
|
|
589
|
+
dsubi = dsub, di = d;
|
|
590
|
+
float one = 1.0, zero = 0;
|
|
591
|
+
|
|
592
|
+
sgemm_ ("Transposed", "Not transposed",
|
|
593
|
+
&ksubi, &nxi, &dsubi,
|
|
594
|
+
&one, ¢roids [m * dsub * ksub], &dsubi,
|
|
595
|
+
x + dsub * m, &di,
|
|
596
|
+
&zero, dis_tables + ksub * m, &ldc);
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
template <class C>
|
|
603
|
+
static void pq_knn_search_with_tables (
|
|
604
|
+
const ProductQuantizer& pq,
|
|
605
|
+
size_t nbits,
|
|
606
|
+
const float *dis_tables,
|
|
607
|
+
const uint8_t * codes,
|
|
608
|
+
const size_t ncodes,
|
|
609
|
+
HeapArray<C> * res,
|
|
610
|
+
bool init_finalize_heap)
|
|
611
|
+
{
|
|
612
|
+
size_t k = res->k, nx = res->nh;
|
|
613
|
+
size_t ksub = pq.ksub, M = pq.M;
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
#pragma omp parallel for
|
|
617
|
+
for (size_t i = 0; i < nx; i++) {
|
|
618
|
+
/* query preparation for asymmetric search: compute look-up tables */
|
|
619
|
+
const float* dis_table = dis_tables + i * ksub * M;
|
|
620
|
+
|
|
621
|
+
/* Compute distances and keep smallest values */
|
|
622
|
+
int64_t * __restrict heap_ids = res->ids + i * k;
|
|
623
|
+
float * __restrict heap_dis = res->val + i * k;
|
|
624
|
+
|
|
625
|
+
if (init_finalize_heap) {
|
|
626
|
+
heap_heapify<C> (k, heap_dis, heap_ids);
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
switch (nbits) {
|
|
630
|
+
case 8:
|
|
631
|
+
pq_estimators_from_tables<uint8_t, C> (pq,
|
|
632
|
+
codes, ncodes,
|
|
633
|
+
dis_table,
|
|
634
|
+
k, heap_dis, heap_ids);
|
|
635
|
+
break;
|
|
636
|
+
|
|
637
|
+
case 16:
|
|
638
|
+
pq_estimators_from_tables<uint16_t, C> (pq,
|
|
639
|
+
(uint16_t*)codes, ncodes,
|
|
640
|
+
dis_table,
|
|
641
|
+
k, heap_dis, heap_ids);
|
|
642
|
+
break;
|
|
643
|
+
|
|
644
|
+
default:
|
|
645
|
+
pq_estimators_from_tables_generic<C> (pq,
|
|
646
|
+
nbits,
|
|
647
|
+
codes, ncodes,
|
|
648
|
+
dis_table,
|
|
649
|
+
k, heap_dis, heap_ids);
|
|
650
|
+
break;
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
if (init_finalize_heap) {
|
|
654
|
+
heap_reorder<C> (k, heap_dis, heap_ids);
|
|
655
|
+
}
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
void ProductQuantizer::search (const float * __restrict x,
|
|
660
|
+
size_t nx,
|
|
661
|
+
const uint8_t * codes,
|
|
662
|
+
const size_t ncodes,
|
|
663
|
+
float_maxheap_array_t * res,
|
|
664
|
+
bool init_finalize_heap) const
|
|
665
|
+
{
|
|
666
|
+
FAISS_THROW_IF_NOT (nx == res->nh);
|
|
667
|
+
std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
|
|
668
|
+
compute_distance_tables (nx, x, dis_tables.get());
|
|
669
|
+
|
|
670
|
+
pq_knn_search_with_tables<CMax<float, int64_t>> (
|
|
671
|
+
*this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
void ProductQuantizer::search_ip (const float * __restrict x,
|
|
675
|
+
size_t nx,
|
|
676
|
+
const uint8_t * codes,
|
|
677
|
+
const size_t ncodes,
|
|
678
|
+
float_minheap_array_t * res,
|
|
679
|
+
bool init_finalize_heap) const
|
|
680
|
+
{
|
|
681
|
+
FAISS_THROW_IF_NOT (nx == res->nh);
|
|
682
|
+
std::unique_ptr<float[]> dis_tables(new float [nx * ksub * M]);
|
|
683
|
+
compute_inner_prod_tables (nx, x, dis_tables.get());
|
|
684
|
+
|
|
685
|
+
pq_knn_search_with_tables<CMin<float, int64_t> > (
|
|
686
|
+
*this, nbits, dis_tables.get(), codes, ncodes, res, init_finalize_heap);
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
static float sqr (float x) {
|
|
692
|
+
return x * x;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
void ProductQuantizer::compute_sdc_table ()
|
|
696
|
+
{
|
|
697
|
+
sdc_table.resize (M * ksub * ksub);
|
|
698
|
+
|
|
699
|
+
for (int m = 0; m < M; m++) {
|
|
700
|
+
|
|
701
|
+
const float *cents = centroids.data() + m * ksub * dsub;
|
|
702
|
+
float * dis_tab = sdc_table.data() + m * ksub * ksub;
|
|
703
|
+
|
|
704
|
+
// TODO optimize with BLAS
|
|
705
|
+
for (int i = 0; i < ksub; i++) {
|
|
706
|
+
const float *centi = cents + i * dsub;
|
|
707
|
+
for (int j = 0; j < ksub; j++) {
|
|
708
|
+
float accu = 0;
|
|
709
|
+
const float *centj = cents + j * dsub;
|
|
710
|
+
for (int k = 0; k < dsub; k++)
|
|
711
|
+
accu += sqr (centi[k] - centj[k]);
|
|
712
|
+
dis_tab [i + j * ksub] = accu;
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
}
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
void ProductQuantizer::search_sdc (const uint8_t * qcodes,
|
|
719
|
+
size_t nq,
|
|
720
|
+
const uint8_t * bcodes,
|
|
721
|
+
const size_t nb,
|
|
722
|
+
float_maxheap_array_t * res,
|
|
723
|
+
bool init_finalize_heap) const
|
|
724
|
+
{
|
|
725
|
+
FAISS_THROW_IF_NOT (sdc_table.size() == M * ksub * ksub);
|
|
726
|
+
FAISS_THROW_IF_NOT (nbits == 8);
|
|
727
|
+
size_t k = res->k;
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
#pragma omp parallel for
|
|
731
|
+
for (size_t i = 0; i < nq; i++) {
|
|
732
|
+
|
|
733
|
+
/* Compute distances and keep smallest values */
|
|
734
|
+
idx_t * heap_ids = res->ids + i * k;
|
|
735
|
+
float * heap_dis = res->val + i * k;
|
|
736
|
+
const uint8_t * qcode = qcodes + i * code_size;
|
|
737
|
+
|
|
738
|
+
if (init_finalize_heap)
|
|
739
|
+
maxheap_heapify (k, heap_dis, heap_ids);
|
|
740
|
+
|
|
741
|
+
const uint8_t * bcode = bcodes;
|
|
742
|
+
for (size_t j = 0; j < nb; j++) {
|
|
743
|
+
float dis = 0;
|
|
744
|
+
const float * tab = sdc_table.data();
|
|
745
|
+
for (int m = 0; m < M; m++) {
|
|
746
|
+
dis += tab[bcode[m] + qcode[m] * ksub];
|
|
747
|
+
tab += ksub * ksub;
|
|
748
|
+
}
|
|
749
|
+
if (dis < heap_dis[0]) {
|
|
750
|
+
maxheap_pop (k, heap_dis, heap_ids);
|
|
751
|
+
maxheap_push (k, heap_dis, heap_ids, dis, j);
|
|
752
|
+
}
|
|
753
|
+
bcode += code_size;
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
if (init_finalize_heap)
|
|
757
|
+
maxheap_reorder (k, heap_dis, heap_ids);
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
ProductQuantizer::PQEncoderGeneric::PQEncoderGeneric(uint8_t *code, int nbits,
|
|
764
|
+
uint8_t offset)
|
|
765
|
+
: code(code), offset(offset), nbits(nbits), reg(0) {
|
|
766
|
+
assert(nbits <= 64);
|
|
767
|
+
if (offset > 0) {
|
|
768
|
+
reg = (*code & ((1 << offset) - 1));
|
|
769
|
+
}
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
void ProductQuantizer::PQEncoderGeneric::encode(uint64_t x) {
|
|
773
|
+
reg |= (uint8_t)(x << offset);
|
|
774
|
+
x >>= (8 - offset);
|
|
775
|
+
if (offset + nbits >= 8) {
|
|
776
|
+
*code++ = reg;
|
|
777
|
+
|
|
778
|
+
for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
|
|
779
|
+
*code++ = (uint8_t)x;
|
|
780
|
+
x >>= 8;
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
offset += nbits;
|
|
784
|
+
offset &= 7;
|
|
785
|
+
reg = (uint8_t)x;
|
|
786
|
+
} else {
|
|
787
|
+
offset += nbits;
|
|
788
|
+
}
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
ProductQuantizer::PQEncoderGeneric::~PQEncoderGeneric() {
|
|
792
|
+
if (offset > 0) {
|
|
793
|
+
*code = reg;
|
|
794
|
+
}
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
ProductQuantizer::PQEncoder8::PQEncoder8(uint8_t *code, int nbits)
|
|
799
|
+
: code(code) {
|
|
800
|
+
assert(8 == nbits);
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
void ProductQuantizer::PQEncoder8::encode(uint64_t x) {
|
|
804
|
+
*code++ = (uint8_t)x;
|
|
805
|
+
}
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
ProductQuantizer::PQEncoder16::PQEncoder16(uint8_t *code, int nbits)
|
|
809
|
+
: code((uint16_t *)code) {
|
|
810
|
+
assert(16 == nbits);
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
void ProductQuantizer::PQEncoder16::encode(uint64_t x) {
|
|
814
|
+
*code++ = (uint16_t)x;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
ProductQuantizer::PQDecoderGeneric::PQDecoderGeneric(const uint8_t *code,
|
|
819
|
+
int nbits)
|
|
820
|
+
: code(code),
|
|
821
|
+
offset(0),
|
|
822
|
+
nbits(nbits),
|
|
823
|
+
mask((1ull << nbits) - 1),
|
|
824
|
+
reg(0) {
|
|
825
|
+
assert(nbits <= 64);
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
uint64_t ProductQuantizer::PQDecoderGeneric::decode() {
|
|
829
|
+
if (offset == 0) {
|
|
830
|
+
reg = *code;
|
|
831
|
+
}
|
|
832
|
+
uint64_t c = (reg >> offset);
|
|
833
|
+
|
|
834
|
+
if (offset + nbits >= 8) {
|
|
835
|
+
uint64_t e = 8 - offset;
|
|
836
|
+
++code;
|
|
837
|
+
for (int i = 0; i < (nbits - (8 - offset)) / 8; ++i) {
|
|
838
|
+
c |= ((uint64_t)(*code++) << e);
|
|
839
|
+
e += 8;
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
offset += nbits;
|
|
843
|
+
offset &= 7;
|
|
844
|
+
if (offset > 0) {
|
|
845
|
+
reg = *code;
|
|
846
|
+
c |= ((uint64_t)reg << e);
|
|
847
|
+
}
|
|
848
|
+
} else {
|
|
849
|
+
offset += nbits;
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
return c & mask;
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
ProductQuantizer::PQDecoder8::PQDecoder8(const uint8_t *code, int nbits)
|
|
857
|
+
: code(code) {
|
|
858
|
+
assert(8 == nbits);
|
|
859
|
+
}
|
|
860
|
+
|
|
861
|
+
uint64_t ProductQuantizer::PQDecoder8::decode() {
|
|
862
|
+
return (uint64_t)(*code++);
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
ProductQuantizer::PQDecoder16::PQDecoder16(const uint8_t *code, int nbits)
|
|
867
|
+
: code((uint16_t *)code) {
|
|
868
|
+
assert(16 == nbits);
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
uint64_t ProductQuantizer::PQDecoder16::decode() {
|
|
872
|
+
return (uint64_t)(*code++);
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
} // namespace faiss
|