faiss 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +20 -2
@@ -10,62 +10,56 @@
|
|
10
10
|
#ifndef FAISS_POLYSEMOUS_TRAINING_INCLUDED
|
11
11
|
#define FAISS_POLYSEMOUS_TRAINING_INCLUDED
|
12
12
|
|
13
|
-
|
14
13
|
#include <faiss/impl/ProductQuantizer.h>
|
15
14
|
|
16
|
-
|
17
15
|
namespace faiss {
|
18
16
|
|
19
|
-
|
20
17
|
/// parameters used for the simulated annealing method
|
21
18
|
struct SimulatedAnnealingParameters {
|
22
|
-
|
23
19
|
// optimization parameters
|
24
|
-
double init_temperature;
|
25
|
-
double temperature_decay;
|
26
|
-
|
27
|
-
int
|
28
|
-
int
|
20
|
+
double init_temperature; // init probability of accepting a bad swap
|
21
|
+
double temperature_decay; // at each iteration the temp is multiplied by
|
22
|
+
// this
|
23
|
+
int n_iter; // nb of iterations
|
24
|
+
int n_redo; // nb of runs of the simulation
|
25
|
+
int seed; // random seed
|
29
26
|
int verbose;
|
30
27
|
bool only_bit_flips; // restrict permutation changes to bit flips
|
31
|
-
bool init_random;
|
28
|
+
bool init_random; // initialize with a random permutation (not identity)
|
32
29
|
|
33
30
|
// set reasonable defaults
|
34
|
-
SimulatedAnnealingParameters
|
35
|
-
|
31
|
+
SimulatedAnnealingParameters();
|
36
32
|
};
|
37
33
|
|
38
|
-
|
39
34
|
/// abstract class for the loss function
|
40
35
|
struct PermutationObjective {
|
41
|
-
|
42
36
|
int n;
|
43
37
|
|
44
|
-
virtual double compute_cost
|
38
|
+
virtual double compute_cost(const int* perm) const = 0;
|
45
39
|
|
46
40
|
// what would the cost update be if iw and jw were swapped?
|
47
41
|
// default implementation just computes both and computes the difference
|
48
|
-
virtual double cost_update
|
42
|
+
virtual double cost_update(const int* perm, int iw, int jw) const;
|
49
43
|
|
50
|
-
virtual ~PermutationObjective
|
44
|
+
virtual ~PermutationObjective() {}
|
51
45
|
};
|
52
46
|
|
53
|
-
|
54
47
|
struct ReproduceDistancesObjective : PermutationObjective {
|
55
|
-
|
56
48
|
double dis_weight_factor;
|
57
49
|
|
58
|
-
static double sqr
|
50
|
+
static double sqr(double x) {
|
51
|
+
return x * x;
|
52
|
+
}
|
59
53
|
|
60
54
|
// weighting of distances: it is more important to reproduce small
|
61
55
|
// distances well
|
62
|
-
double dis_weight
|
56
|
+
double dis_weight(double x) const;
|
63
57
|
|
64
58
|
std::vector<double> source_dis; ///< "real" corrected distances (size n^2)
|
65
|
-
const double
|
59
|
+
const double* target_dis; ///< wanted distances (size n^2)
|
66
60
|
std::vector<double> weights; ///< weights for each distance (size n^2)
|
67
61
|
|
68
|
-
double get_source_dis
|
62
|
+
double get_source_dis(int i, int j) const;
|
69
63
|
|
70
64
|
// cost = quadratic difference between actual distance and Hamming distance
|
71
65
|
double compute_cost(const int* perm) const override;
|
@@ -74,16 +68,19 @@ struct ReproduceDistancesObjective : PermutationObjective {
|
|
74
68
|
// computed in O(n) instead of O(n^2) for the full re-computation
|
75
69
|
double cost_update(const int* perm, int iw, int jw) const override;
|
76
70
|
|
77
|
-
ReproduceDistancesObjective
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
71
|
+
ReproduceDistancesObjective(
|
72
|
+
int n,
|
73
|
+
const double* source_dis_in,
|
74
|
+
const double* target_dis_in,
|
75
|
+
double dis_weight_factor);
|
82
76
|
|
83
|
-
static void compute_mean_stdev
|
84
|
-
|
77
|
+
static void compute_mean_stdev(
|
78
|
+
const double* tab,
|
79
|
+
size_t n2,
|
80
|
+
double* mean_out,
|
81
|
+
double* stddev_out);
|
85
82
|
|
86
|
-
void set_affine_target_dis
|
83
|
+
void set_affine_target_dis(const double* source_dis_in);
|
87
84
|
|
88
85
|
~ReproduceDistancesObjective() override {}
|
89
86
|
};
|
@@ -91,39 +88,36 @@ struct ReproduceDistancesObjective : PermutationObjective {
|
|
91
88
|
struct RandomGenerator;
|
92
89
|
|
93
90
|
/// Simulated annealing optimization algorithm for permutations.
|
94
|
-
|
95
|
-
|
96
|
-
PermutationObjective *obj;
|
91
|
+
struct SimulatedAnnealingOptimizer : SimulatedAnnealingParameters {
|
92
|
+
PermutationObjective* obj;
|
97
93
|
int n; ///< size of the permutation
|
98
|
-
FILE
|
94
|
+
FILE* logfile; /// logs values of the cost function
|
99
95
|
|
100
|
-
SimulatedAnnealingOptimizer
|
101
|
-
|
102
|
-
|
96
|
+
SimulatedAnnealingOptimizer(
|
97
|
+
PermutationObjective* obj,
|
98
|
+
const SimulatedAnnealingParameters& p);
|
99
|
+
RandomGenerator* rnd;
|
103
100
|
|
104
101
|
/// remember initial cost of optimization
|
105
102
|
double init_cost;
|
106
103
|
|
107
104
|
// main entry point. Perform the optimization loop, starting from
|
108
105
|
// and modifying permutation in-place
|
109
|
-
double optimize
|
106
|
+
double optimize(int* perm);
|
110
107
|
|
111
108
|
// run the optimization and return the best result in best_perm
|
112
|
-
double run_optimization
|
109
|
+
double run_optimization(int* best_perm);
|
113
110
|
|
114
|
-
virtual ~SimulatedAnnealingOptimizer
|
111
|
+
virtual ~SimulatedAnnealingOptimizer();
|
115
112
|
};
|
116
113
|
|
117
|
-
|
118
|
-
|
119
|
-
|
120
114
|
/// optimizes the order of indices in a ProductQuantizer
|
121
|
-
struct PolysemousTraining: SimulatedAnnealingParameters {
|
122
|
-
|
115
|
+
struct PolysemousTraining : SimulatedAnnealingParameters {
|
123
116
|
enum Optimization_type_t {
|
124
117
|
OT_None,
|
125
|
-
OT_ReproduceDistances_affine,
|
126
|
-
OT_Ranking_weighted_diff
|
118
|
+
OT_ReproduceDistances_affine, ///< default
|
119
|
+
OT_Ranking_weighted_diff ///< same as _2, but use rank of y+ - rank of
|
120
|
+
///< y-
|
127
121
|
};
|
128
122
|
Optimization_type_t optimization_type;
|
129
123
|
|
@@ -133,26 +127,29 @@ struct PolysemousTraining: SimulatedAnnealingParameters {
|
|
133
127
|
int ntrain_permutation;
|
134
128
|
double dis_weight_factor; ///< decay of exp that weights distance loss
|
135
129
|
|
130
|
+
/// refuse to train if it would require more than that amount of RAM
|
131
|
+
size_t max_memory;
|
132
|
+
|
136
133
|
// filename pattern for the logging of iterations
|
137
134
|
std::string log_pattern;
|
138
135
|
|
139
136
|
// sets default values
|
140
|
-
PolysemousTraining
|
137
|
+
PolysemousTraining();
|
141
138
|
|
142
139
|
/// reorder the centroids so that the Hamming distance becomes a
|
143
140
|
/// good approximation of the SDC distance (called by train)
|
144
|
-
void optimize_pq_for_hamming
|
145
|
-
|
141
|
+
void optimize_pq_for_hamming(ProductQuantizer& pq, size_t n, const float* x)
|
142
|
+
const;
|
146
143
|
|
147
144
|
/// called by optimize_pq_for_hamming
|
148
|
-
void optimize_ranking
|
145
|
+
void optimize_ranking(ProductQuantizer& pq, size_t n, const float* x) const;
|
149
146
|
/// called by optimize_pq_for_hamming
|
150
|
-
void optimize_reproduce_distances
|
147
|
+
void optimize_reproduce_distances(ProductQuantizer& pq) const;
|
151
148
|
|
149
|
+
/// make sure we don't blow up the memory
|
150
|
+
size_t memory_usage_per_thread(const ProductQuantizer& pq) const;
|
152
151
|
};
|
153
152
|
|
154
|
-
|
155
153
|
} // namespace faiss
|
156
154
|
|
157
|
-
|
158
155
|
#endif
|
@@ -7,20 +7,18 @@
|
|
7
7
|
|
8
8
|
namespace faiss {
|
9
9
|
|
10
|
-
inline
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
{
|
10
|
+
inline PQEncoderGeneric::PQEncoderGeneric(
|
11
|
+
uint8_t* code,
|
12
|
+
int nbits,
|
13
|
+
uint8_t offset)
|
14
|
+
: code(code), offset(offset), nbits(nbits), reg(0) {
|
15
15
|
assert(nbits <= 64);
|
16
16
|
if (offset > 0) {
|
17
17
|
reg = (*code & ((1 << offset) - 1));
|
18
18
|
}
|
19
19
|
}
|
20
20
|
|
21
|
-
inline
|
22
|
-
void PQEncoderGeneric::encode(uint64_t x)
|
23
|
-
{
|
21
|
+
inline void PQEncoderGeneric::encode(uint64_t x) {
|
24
22
|
reg |= (uint8_t)(x << offset);
|
25
23
|
x >>= (8 - offset);
|
26
24
|
if (offset + nbits >= 8) {
|
@@ -39,51 +37,39 @@ void PQEncoderGeneric::encode(uint64_t x)
|
|
39
37
|
}
|
40
38
|
}
|
41
39
|
|
42
|
-
inline
|
43
|
-
PQEncoderGeneric::~PQEncoderGeneric()
|
44
|
-
{
|
40
|
+
inline PQEncoderGeneric::~PQEncoderGeneric() {
|
45
41
|
if (offset > 0) {
|
46
42
|
*code = reg;
|
47
43
|
}
|
48
44
|
}
|
49
45
|
|
50
|
-
|
51
|
-
inline
|
52
|
-
PQEncoder8::PQEncoder8(uint8_t *code, int nbits)
|
53
|
-
: code(code) {
|
46
|
+
inline PQEncoder8::PQEncoder8(uint8_t* code, int nbits) : code(code) {
|
54
47
|
assert(8 == nbits);
|
55
48
|
}
|
56
49
|
|
57
|
-
inline
|
58
|
-
void PQEncoder8::encode(uint64_t x) {
|
50
|
+
inline void PQEncoder8::encode(uint64_t x) {
|
59
51
|
*code++ = (uint8_t)x;
|
60
52
|
}
|
61
53
|
|
62
|
-
inline
|
63
|
-
|
64
|
-
: code((uint16_t *)code) {
|
54
|
+
inline PQEncoder16::PQEncoder16(uint8_t* code, int nbits)
|
55
|
+
: code((uint16_t*)code) {
|
65
56
|
assert(16 == nbits);
|
66
57
|
}
|
67
58
|
|
68
|
-
inline
|
69
|
-
void PQEncoder16::encode(uint64_t x) {
|
59
|
+
inline void PQEncoder16::encode(uint64_t x) {
|
70
60
|
*code++ = (uint16_t)x;
|
71
61
|
}
|
72
62
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
nbits(nbits),
|
80
|
-
mask((1ull << nbits) - 1),
|
81
|
-
reg(0) {
|
63
|
+
inline PQDecoderGeneric::PQDecoderGeneric(const uint8_t* code, int nbits)
|
64
|
+
: code(code),
|
65
|
+
offset(0),
|
66
|
+
nbits(nbits),
|
67
|
+
mask((1ull << nbits) - 1),
|
68
|
+
reg(0) {
|
82
69
|
assert(nbits <= 64);
|
83
70
|
}
|
84
71
|
|
85
|
-
inline
|
86
|
-
uint64_t PQDecoderGeneric::decode() {
|
72
|
+
inline uint64_t PQDecoderGeneric::decode() {
|
87
73
|
if (offset == 0) {
|
88
74
|
reg = *code;
|
89
75
|
}
|
@@ -110,27 +96,20 @@ uint64_t PQDecoderGeneric::decode() {
|
|
110
96
|
return c & mask;
|
111
97
|
}
|
112
98
|
|
113
|
-
|
114
|
-
|
115
|
-
PQDecoder8::PQDecoder8(const uint8_t *code, int nbits)
|
116
|
-
: code(code) {
|
117
|
-
assert(8 == nbits);
|
99
|
+
inline PQDecoder8::PQDecoder8(const uint8_t* code, int nbits_in) : code(code) {
|
100
|
+
assert(8 == nbits_in);
|
118
101
|
}
|
119
102
|
|
120
|
-
inline
|
121
|
-
uint64_t PQDecoder8::decode() {
|
103
|
+
inline uint64_t PQDecoder8::decode() {
|
122
104
|
return (uint64_t)(*code++);
|
123
105
|
}
|
124
106
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
: code((uint16_t *)code) {
|
129
|
-
assert(16 == nbits);
|
107
|
+
inline PQDecoder16::PQDecoder16(const uint8_t* code, int nbits_in)
|
108
|
+
: code((uint16_t*)code) {
|
109
|
+
assert(16 == nbits_in);
|
130
110
|
}
|
131
111
|
|
132
|
-
inline
|
133
|
-
uint64_t PQDecoder16::decode() {
|
112
|
+
inline uint64_t PQDecoder16::decode() {
|
134
113
|
return (uint64_t)(*code++);
|
135
114
|
}
|
136
115
|
|
@@ -9,113 +9,118 @@
|
|
9
9
|
|
10
10
|
#include <faiss/impl/ProductQuantizer.h>
|
11
11
|
|
12
|
-
|
13
12
|
#include <cstddef>
|
14
|
-
#include <cstring>
|
15
13
|
#include <cstdio>
|
14
|
+
#include <cstring>
|
16
15
|
#include <memory>
|
17
16
|
|
18
17
|
#include <algorithm>
|
19
18
|
|
20
|
-
#include <faiss/impl/FaissAssert.h>
|
21
|
-
#include <faiss/VectorTransform.h>
|
22
19
|
#include <faiss/IndexFlat.h>
|
20
|
+
#include <faiss/VectorTransform.h>
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
23
22
|
#include <faiss/utils/distances.h>
|
24
23
|
|
25
|
-
|
26
24
|
extern "C" {
|
27
25
|
|
28
26
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
29
27
|
|
30
|
-
int sgemm_
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
28
|
+
int sgemm_(
|
29
|
+
const char* transa,
|
30
|
+
const char* transb,
|
31
|
+
FINTEGER* m,
|
32
|
+
FINTEGER* n,
|
33
|
+
FINTEGER* k,
|
34
|
+
const float* alpha,
|
35
|
+
const float* a,
|
36
|
+
FINTEGER* lda,
|
37
|
+
const float* b,
|
38
|
+
FINTEGER* ldb,
|
39
|
+
float* beta,
|
40
|
+
float* c,
|
41
|
+
FINTEGER* ldc);
|
35
42
|
}
|
36
43
|
|
37
|
-
|
38
44
|
namespace faiss {
|
39
45
|
|
40
|
-
|
41
46
|
/* compute an estimator using look-up tables for typical values of M */
|
42
47
|
template <typename CT, class C>
|
43
|
-
void pq_estimators_from_tables_Mmul4
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
48
|
+
void pq_estimators_from_tables_Mmul4(
|
49
|
+
int M,
|
50
|
+
const CT* codes,
|
51
|
+
size_t ncodes,
|
52
|
+
const float* __restrict dis_table,
|
53
|
+
size_t ksub,
|
54
|
+
size_t k,
|
55
|
+
float* heap_dis,
|
56
|
+
int64_t* heap_ids) {
|
52
57
|
for (size_t j = 0; j < ncodes; j++) {
|
53
58
|
float dis = 0;
|
54
|
-
const float
|
59
|
+
const float* dt = dis_table;
|
55
60
|
|
56
|
-
for (size_t m = 0; m < M; m+=4) {
|
61
|
+
for (size_t m = 0; m < M; m += 4) {
|
57
62
|
float dism = 0;
|
58
|
-
dism
|
59
|
-
|
60
|
-
dism += dt[*codes++];
|
61
|
-
|
63
|
+
dism = dt[*codes++];
|
64
|
+
dt += ksub;
|
65
|
+
dism += dt[*codes++];
|
66
|
+
dt += ksub;
|
67
|
+
dism += dt[*codes++];
|
68
|
+
dt += ksub;
|
69
|
+
dism += dt[*codes++];
|
70
|
+
dt += ksub;
|
62
71
|
dis += dism;
|
63
72
|
}
|
64
73
|
|
65
|
-
if (C::cmp
|
66
|
-
heap_replace_top<C>
|
74
|
+
if (C::cmp(heap_dis[0], dis)) {
|
75
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
67
76
|
}
|
68
77
|
}
|
69
78
|
}
|
70
79
|
|
71
|
-
|
72
80
|
template <typename CT, class C>
|
73
|
-
void pq_estimators_from_tables_M4
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
{
|
81
|
-
|
81
|
+
void pq_estimators_from_tables_M4(
|
82
|
+
const CT* codes,
|
83
|
+
size_t ncodes,
|
84
|
+
const float* __restrict dis_table,
|
85
|
+
size_t ksub,
|
86
|
+
size_t k,
|
87
|
+
float* heap_dis,
|
88
|
+
int64_t* heap_ids) {
|
82
89
|
for (size_t j = 0; j < ncodes; j++) {
|
83
90
|
float dis = 0;
|
84
|
-
const float
|
85
|
-
dis
|
86
|
-
|
87
|
-
dis += dt[*codes++];
|
91
|
+
const float* dt = dis_table;
|
92
|
+
dis = dt[*codes++];
|
93
|
+
dt += ksub;
|
94
|
+
dis += dt[*codes++];
|
95
|
+
dt += ksub;
|
96
|
+
dis += dt[*codes++];
|
97
|
+
dt += ksub;
|
88
98
|
dis += dt[*codes++];
|
89
99
|
|
90
|
-
if (C::cmp
|
91
|
-
heap_replace_top<C>
|
100
|
+
if (C::cmp(heap_dis[0], dis)) {
|
101
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
92
102
|
}
|
93
103
|
}
|
94
104
|
}
|
95
105
|
|
96
|
-
|
97
106
|
template <typename CT, class C>
|
98
|
-
static inline void pq_estimators_from_tables
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
{
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
pq_estimators_from_tables_M4<CT, C> (codes, ncodes,
|
110
|
-
dis_table, pq.ksub, k,
|
111
|
-
heap_dis, heap_ids);
|
107
|
+
static inline void pq_estimators_from_tables(
|
108
|
+
const ProductQuantizer& pq,
|
109
|
+
const CT* codes,
|
110
|
+
size_t ncodes,
|
111
|
+
const float* dis_table,
|
112
|
+
size_t k,
|
113
|
+
float* heap_dis,
|
114
|
+
int64_t* heap_ids) {
|
115
|
+
if (pq.M == 4) {
|
116
|
+
pq_estimators_from_tables_M4<CT, C>(
|
117
|
+
codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
|
112
118
|
return;
|
113
119
|
}
|
114
120
|
|
115
121
|
if (pq.M % 4 == 0) {
|
116
|
-
pq_estimators_from_tables_Mmul4<CT, C>
|
117
|
-
|
118
|
-
heap_dis, heap_ids);
|
122
|
+
pq_estimators_from_tables_Mmul4<CT, C>(
|
123
|
+
pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
|
119
124
|
return;
|
120
125
|
}
|
121
126
|
|
@@ -124,132 +129,124 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
|
|
124
129
|
const size_t ksub = pq.ksub;
|
125
130
|
for (size_t j = 0; j < ncodes; j++) {
|
126
131
|
float dis = 0;
|
127
|
-
const float
|
132
|
+
const float* __restrict dt = dis_table;
|
128
133
|
for (int m = 0; m < M; m++) {
|
129
134
|
dis += dt[*codes++];
|
130
135
|
dt += ksub;
|
131
136
|
}
|
132
|
-
if (C::cmp
|
133
|
-
heap_replace_top<C>
|
137
|
+
if (C::cmp(heap_dis[0], dis)) {
|
138
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
134
139
|
}
|
135
140
|
}
|
136
141
|
}
|
137
142
|
|
138
143
|
template <class C>
|
139
|
-
static inline void pq_estimators_from_tables_generic(
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
{
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
dt += ksub;
|
160
|
-
}
|
144
|
+
static inline void pq_estimators_from_tables_generic(
|
145
|
+
const ProductQuantizer& pq,
|
146
|
+
size_t nbits,
|
147
|
+
const uint8_t* codes,
|
148
|
+
size_t ncodes,
|
149
|
+
const float* dis_table,
|
150
|
+
size_t k,
|
151
|
+
float* heap_dis,
|
152
|
+
int64_t* heap_ids) {
|
153
|
+
const size_t M = pq.M;
|
154
|
+
const size_t ksub = pq.ksub;
|
155
|
+
for (size_t j = 0; j < ncodes; ++j) {
|
156
|
+
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
|
157
|
+
float dis = 0;
|
158
|
+
const float* __restrict dt = dis_table;
|
159
|
+
for (size_t m = 0; m < M; m++) {
|
160
|
+
uint64_t c = decoder.decode();
|
161
|
+
dis += dt[c];
|
162
|
+
dt += ksub;
|
163
|
+
}
|
161
164
|
|
162
|
-
|
163
|
-
|
165
|
+
if (C::cmp(heap_dis[0], dis)) {
|
166
|
+
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
167
|
+
}
|
164
168
|
}
|
165
|
-
}
|
166
169
|
}
|
167
170
|
|
168
171
|
/*********************************************
|
169
172
|
* PQ implementation
|
170
173
|
*********************************************/
|
171
174
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
d(d), M(M), nbits(nbits), assign_index(nullptr)
|
176
|
-
{
|
177
|
-
set_derived_values ();
|
175
|
+
ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
|
176
|
+
: d(d), M(M), nbits(nbits), assign_index(nullptr) {
|
177
|
+
set_derived_values();
|
178
178
|
}
|
179
179
|
|
180
|
-
ProductQuantizer::ProductQuantizer ()
|
181
|
-
: ProductQuantizer(0, 1, 0) {}
|
180
|
+
ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
|
182
181
|
|
183
|
-
void ProductQuantizer::set_derived_values
|
182
|
+
void ProductQuantizer::set_derived_values() {
|
184
183
|
// quite a few derived values
|
185
|
-
FAISS_THROW_IF_NOT_MSG
|
184
|
+
FAISS_THROW_IF_NOT_MSG(
|
185
|
+
d % M == 0,
|
186
|
+
"The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
|
186
187
|
dsub = d / M;
|
187
188
|
code_size = (nbits * M + 7) / 8;
|
188
189
|
ksub = 1 << nbits;
|
189
|
-
centroids.resize
|
190
|
+
centroids.resize(d * ksub);
|
190
191
|
verbose = false;
|
191
192
|
train_type = Train_default;
|
192
193
|
}
|
193
194
|
|
194
|
-
void ProductQuantizer::set_params
|
195
|
-
|
196
|
-
|
197
|
-
|
195
|
+
void ProductQuantizer::set_params(const float* centroids_, int m) {
|
196
|
+
memcpy(get_centroids(m, 0),
|
197
|
+
centroids_,
|
198
|
+
ksub * dsub * sizeof(centroids_[0]));
|
198
199
|
}
|
199
200
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
std::vector<float> mean
|
201
|
+
static void init_hypercube(
|
202
|
+
int d,
|
203
|
+
int nbits,
|
204
|
+
int n,
|
205
|
+
const float* x,
|
206
|
+
float* centroids) {
|
207
|
+
std::vector<float> mean(d);
|
207
208
|
for (int i = 0; i < n; i++)
|
208
209
|
for (int j = 0; j < d; j++)
|
209
|
-
mean
|
210
|
+
mean[j] += x[i * d + j];
|
210
211
|
|
211
212
|
float maxm = 0;
|
212
213
|
for (int j = 0; j < d; j++) {
|
213
|
-
mean
|
214
|
-
if (fabs(mean[j]) > maxm)
|
214
|
+
mean[j] /= n;
|
215
|
+
if (fabs(mean[j]) > maxm)
|
216
|
+
maxm = fabs(mean[j]);
|
215
217
|
}
|
216
218
|
|
217
219
|
for (int i = 0; i < (1 << nbits); i++) {
|
218
|
-
float
|
220
|
+
float* cent = centroids + i * d;
|
219
221
|
for (int j = 0; j < nbits; j++)
|
220
|
-
cent[j] = mean
|
222
|
+
cent[j] = mean[j] + (((i >> j) & 1) ? 1 : -1) * maxm;
|
221
223
|
for (int j = nbits; j < d; j++)
|
222
|
-
cent[j] = mean
|
224
|
+
cent[j] = mean[j];
|
223
225
|
}
|
224
|
-
|
225
|
-
|
226
226
|
}
|
227
227
|
|
228
|
-
static void init_hypercube_pca
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
228
|
+
static void init_hypercube_pca(
|
229
|
+
int d,
|
230
|
+
int nbits,
|
231
|
+
int n,
|
232
|
+
const float* x,
|
233
|
+
float* centroids) {
|
234
|
+
PCAMatrix pca(d, nbits);
|
235
|
+
pca.train(n, x);
|
235
236
|
|
236
237
|
for (int i = 0; i < (1 << nbits); i++) {
|
237
|
-
float
|
238
|
+
float* cent = centroids + i * d;
|
238
239
|
for (int j = 0; j < d; j++) {
|
239
240
|
cent[j] = pca.mean[j];
|
240
241
|
float f = 1.0;
|
241
242
|
for (int k = 0; k < nbits; k++)
|
242
|
-
cent[j] += f *
|
243
|
-
|
244
|
-
(((i >> k) & 1) ? 1 : -1) *
|
245
|
-
pca.PCAMat [j + k * d];
|
243
|
+
cent[j] += f * sqrt(pca.eigenvalues[k]) *
|
244
|
+
(((i >> k) & 1) ? 1 : -1) * pca.PCAMat[j + k * d];
|
246
245
|
}
|
247
246
|
}
|
248
|
-
|
249
247
|
}
|
250
248
|
|
251
|
-
void ProductQuantizer::train
|
252
|
-
{
|
249
|
+
void ProductQuantizer::train(int n, const float* x) {
|
253
250
|
if (train_type != Train_shared) {
|
254
251
|
train_type_t final_train_type;
|
255
252
|
final_train_type = train_type;
|
@@ -257,234 +254,229 @@ void ProductQuantizer::train (int n, const float * x)
|
|
257
254
|
train_type == Train_hypercube_pca) {
|
258
255
|
if (dsub < nbits) {
|
259
256
|
final_train_type = Train_default;
|
260
|
-
printf
|
261
|
-
|
257
|
+
printf("cannot train hypercube: nbits=%zd > log2(d=%zd)\n",
|
258
|
+
nbits,
|
259
|
+
dsub);
|
262
260
|
}
|
263
261
|
}
|
264
262
|
|
265
|
-
float
|
266
|
-
ScopeDeleter<float> del
|
263
|
+
float* xslice = new float[n * dsub];
|
264
|
+
ScopeDeleter<float> del(xslice);
|
267
265
|
for (int m = 0; m < M; m++) {
|
268
266
|
for (int j = 0; j < n; j++)
|
269
|
-
memcpy
|
270
|
-
|
271
|
-
|
267
|
+
memcpy(xslice + j * dsub,
|
268
|
+
x + j * d + m * dsub,
|
269
|
+
dsub * sizeof(float));
|
272
270
|
|
273
|
-
Clustering clus
|
271
|
+
Clustering clus(dsub, ksub, cp);
|
274
272
|
|
275
273
|
// we have some initialization for the centroids
|
276
274
|
if (final_train_type != Train_default) {
|
277
|
-
clus.centroids.resize
|
275
|
+
clus.centroids.resize(dsub * ksub);
|
278
276
|
}
|
279
277
|
|
280
278
|
switch (final_train_type) {
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
279
|
+
case Train_hypercube:
|
280
|
+
init_hypercube(
|
281
|
+
dsub, nbits, n, xslice, clus.centroids.data());
|
282
|
+
break;
|
283
|
+
case Train_hypercube_pca:
|
284
|
+
init_hypercube_pca(
|
285
|
+
dsub, nbits, n, xslice, clus.centroids.data());
|
286
|
+
break;
|
287
|
+
case Train_hot_start:
|
288
|
+
memcpy(clus.centroids.data(),
|
289
|
+
get_centroids(m, 0),
|
290
|
+
dsub * ksub * sizeof(float));
|
291
|
+
break;
|
292
|
+
default:;
|
295
293
|
}
|
296
294
|
|
297
|
-
if(verbose) {
|
295
|
+
if (verbose) {
|
298
296
|
clus.verbose = true;
|
299
|
-
printf
|
297
|
+
printf("Training PQ slice %d/%zd\n", m, M);
|
300
298
|
}
|
301
|
-
IndexFlatL2 index
|
302
|
-
clus.train
|
303
|
-
set_params
|
299
|
+
IndexFlatL2 index(dsub);
|
300
|
+
clus.train(n, xslice, assign_index ? *assign_index : index);
|
301
|
+
set_params(clus.centroids.data(), m);
|
304
302
|
}
|
305
303
|
|
306
|
-
|
307
304
|
} else {
|
305
|
+
Clustering clus(dsub, ksub, cp);
|
308
306
|
|
309
|
-
|
310
|
-
|
311
|
-
if(verbose) {
|
307
|
+
if (verbose) {
|
312
308
|
clus.verbose = true;
|
313
|
-
printf
|
309
|
+
printf("Training all PQ slices at once\n");
|
314
310
|
}
|
315
311
|
|
316
|
-
IndexFlatL2 index
|
312
|
+
IndexFlatL2 index(dsub);
|
317
313
|
|
318
|
-
clus.train
|
314
|
+
clus.train(n * M, x, assign_index ? *assign_index : index);
|
319
315
|
for (int m = 0; m < M; m++) {
|
320
|
-
set_params
|
316
|
+
set_params(clus.centroids.data(), m);
|
321
317
|
}
|
322
|
-
|
323
318
|
}
|
324
319
|
}
|
325
320
|
|
326
|
-
template<class PQEncoder>
|
327
|
-
void compute_code(const ProductQuantizer& pq, const float
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
321
|
+
template <class PQEncoder>
|
322
|
+
void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
|
323
|
+
std::vector<float> distances(pq.ksub);
|
324
|
+
PQEncoder encoder(code, pq.nbits);
|
325
|
+
for (size_t m = 0; m < pq.M; m++) {
|
326
|
+
float mindis = 1e20;
|
327
|
+
uint64_t idxm = 0;
|
328
|
+
const float* xsub = x + m * pq.dsub;
|
329
|
+
|
330
|
+
fvec_L2sqr_ny(
|
331
|
+
distances.data(),
|
332
|
+
xsub,
|
333
|
+
pq.get_centroids(m, 0),
|
334
|
+
pq.dsub,
|
335
|
+
pq.ksub);
|
336
|
+
|
337
|
+
/* Find best centroid */
|
338
|
+
for (size_t i = 0; i < pq.ksub; i++) {
|
339
|
+
float dis = distances[i];
|
340
|
+
if (dis < mindis) {
|
341
|
+
mindis = dis;
|
342
|
+
idxm = i;
|
343
|
+
}
|
344
|
+
}
|
345
345
|
|
346
|
-
|
347
|
-
|
346
|
+
encoder.encode(idxm);
|
347
|
+
}
|
348
348
|
}
|
349
349
|
|
350
|
-
void ProductQuantizer::compute_code(const float
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
350
|
+
void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
|
351
|
+
switch (nbits) {
|
352
|
+
case 8:
|
353
|
+
faiss::compute_code<PQEncoder8>(*this, x, code);
|
354
|
+
break;
|
355
355
|
|
356
|
-
|
357
|
-
|
358
|
-
|
356
|
+
case 16:
|
357
|
+
faiss::compute_code<PQEncoder16>(*this, x, code);
|
358
|
+
break;
|
359
359
|
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
360
|
+
default:
|
361
|
+
faiss::compute_code<PQEncoderGeneric>(*this, x, code);
|
362
|
+
break;
|
363
|
+
}
|
364
364
|
}
|
365
365
|
|
366
|
-
template<class PQDecoder>
|
367
|
-
void decode(const ProductQuantizer& pq, const uint8_t
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
366
|
+
template <class PQDecoder>
|
367
|
+
void decode(const ProductQuantizer& pq, const uint8_t* code, float* x) {
|
368
|
+
PQDecoder decoder(code, pq.nbits);
|
369
|
+
for (size_t m = 0; m < pq.M; m++) {
|
370
|
+
uint64_t c = decoder.decode();
|
371
|
+
memcpy(x + m * pq.dsub,
|
372
|
+
pq.get_centroids(m, c),
|
373
|
+
sizeof(float) * pq.dsub);
|
374
|
+
}
|
374
375
|
}
|
375
376
|
|
376
|
-
void ProductQuantizer::decode
|
377
|
-
{
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
break;
|
386
|
-
|
387
|
-
default:
|
388
|
-
faiss::decode<PQDecoderGeneric>(*this, code, x);
|
389
|
-
break;
|
390
|
-
}
|
391
|
-
}
|
377
|
+
void ProductQuantizer::decode(const uint8_t* code, float* x) const {
|
378
|
+
switch (nbits) {
|
379
|
+
case 8:
|
380
|
+
faiss::decode<PQDecoder8>(*this, code, x);
|
381
|
+
break;
|
382
|
+
|
383
|
+
case 16:
|
384
|
+
faiss::decode<PQDecoder16>(*this, code, x);
|
385
|
+
break;
|
392
386
|
|
387
|
+
default:
|
388
|
+
faiss::decode<PQDecoderGeneric>(*this, code, x);
|
389
|
+
break;
|
390
|
+
}
|
391
|
+
}
|
393
392
|
|
394
|
-
void ProductQuantizer::decode
|
395
|
-
{
|
393
|
+
void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
396
394
|
for (size_t i = 0; i < n; i++) {
|
397
|
-
this->decode
|
395
|
+
this->decode(code + code_size * i, x + d * i);
|
398
396
|
}
|
399
397
|
}
|
400
398
|
|
399
|
+
void ProductQuantizer::compute_code_from_distance_table(
|
400
|
+
const float* tab,
|
401
|
+
uint8_t* code) const {
|
402
|
+
PQEncoderGeneric encoder(code, nbits);
|
403
|
+
for (size_t m = 0; m < M; m++) {
|
404
|
+
float mindis = 1e20;
|
405
|
+
uint64_t idxm = 0;
|
406
|
+
|
407
|
+
/* Find best centroid */
|
408
|
+
for (size_t j = 0; j < ksub; j++) {
|
409
|
+
float dis = *tab++;
|
410
|
+
if (dis < mindis) {
|
411
|
+
mindis = dis;
|
412
|
+
idxm = j;
|
413
|
+
}
|
414
|
+
}
|
401
415
|
|
402
|
-
|
403
|
-
uint8_t *code) const
|
404
|
-
{
|
405
|
-
PQEncoderGeneric encoder(code, nbits);
|
406
|
-
for (size_t m = 0; m < M; m++) {
|
407
|
-
float mindis = 1e20;
|
408
|
-
uint64_t idxm = 0;
|
409
|
-
|
410
|
-
/* Find best centroid */
|
411
|
-
for (size_t j = 0; j < ksub; j++) {
|
412
|
-
float dis = *tab++;
|
413
|
-
if (dis < mindis) {
|
414
|
-
mindis = dis;
|
415
|
-
idxm = j;
|
416
|
-
}
|
416
|
+
encoder.encode(idxm);
|
417
417
|
}
|
418
|
-
|
419
|
-
encoder.encode(idxm);
|
420
|
-
}
|
421
418
|
}
|
422
419
|
|
423
|
-
void ProductQuantizer::compute_codes_with_assign_index
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
FAISS_THROW_IF_NOT (assign_index && assign_index->d == dsub);
|
420
|
+
void ProductQuantizer::compute_codes_with_assign_index(
|
421
|
+
const float* x,
|
422
|
+
uint8_t* codes,
|
423
|
+
size_t n) {
|
424
|
+
FAISS_THROW_IF_NOT(assign_index && assign_index->d == dsub);
|
429
425
|
|
430
426
|
for (size_t m = 0; m < M; m++) {
|
431
|
-
assign_index->reset
|
432
|
-
assign_index->add
|
427
|
+
assign_index->reset();
|
428
|
+
assign_index->add(ksub, get_centroids(m, 0));
|
433
429
|
size_t bs = 65536;
|
434
|
-
float
|
435
|
-
ScopeDeleter<float> del
|
436
|
-
idx_t
|
437
|
-
ScopeDeleter<idx_t> del2
|
430
|
+
float* xslice = new float[bs * dsub];
|
431
|
+
ScopeDeleter<float> del(xslice);
|
432
|
+
idx_t* assign = new idx_t[bs];
|
433
|
+
ScopeDeleter<idx_t> del2(assign);
|
438
434
|
|
439
435
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
440
436
|
size_t i1 = std::min(i0 + bs, n);
|
441
437
|
|
442
438
|
for (size_t i = i0; i < i1; i++) {
|
443
|
-
memcpy
|
444
|
-
|
445
|
-
|
439
|
+
memcpy(xslice + (i - i0) * dsub,
|
440
|
+
x + i * d + m * dsub,
|
441
|
+
dsub * sizeof(float));
|
446
442
|
}
|
447
443
|
|
448
|
-
assign_index->assign
|
444
|
+
assign_index->assign(i1 - i0, xslice, assign);
|
449
445
|
|
450
446
|
if (nbits == 8) {
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
447
|
+
uint8_t* c = codes + code_size * i0 + m;
|
448
|
+
for (size_t i = i0; i < i1; i++) {
|
449
|
+
*c = assign[i - i0];
|
450
|
+
c += M;
|
451
|
+
}
|
456
452
|
} else if (nbits == 16) {
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
453
|
+
uint16_t* c = (uint16_t*)(codes + code_size * i0 + m * 2);
|
454
|
+
for (size_t i = i0; i < i1; i++) {
|
455
|
+
*c = assign[i - i0];
|
456
|
+
c += M;
|
457
|
+
}
|
462
458
|
} else {
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
459
|
+
for (size_t i = i0; i < i1; ++i) {
|
460
|
+
uint8_t* c = codes + code_size * i + ((m * nbits) / 8);
|
461
|
+
uint8_t offset = (m * nbits) % 8;
|
462
|
+
uint64_t ass = assign[i - i0];
|
463
|
+
|
464
|
+
PQEncoderGeneric encoder(c, nbits, offset);
|
465
|
+
encoder.encode(ass);
|
466
|
+
}
|
471
467
|
}
|
472
|
-
|
473
468
|
}
|
474
469
|
}
|
475
|
-
|
476
470
|
}
|
477
471
|
|
478
|
-
void ProductQuantizer::compute_codes
|
479
|
-
|
480
|
-
|
481
|
-
{
|
482
|
-
// process by blocks to avoid using too much RAM
|
472
|
+
void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
473
|
+
const {
|
474
|
+
// process by blocks to avoid using too much RAM
|
483
475
|
size_t bs = 256 * 1024;
|
484
476
|
if (n > bs) {
|
485
477
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
486
478
|
size_t i1 = std::min(i0 + bs, n);
|
487
|
-
compute_codes
|
479
|
+
compute_codes(x + d * i0, codes + code_size * i0, i1 - i0);
|
488
480
|
}
|
489
481
|
return;
|
490
482
|
}
|
@@ -493,282 +485,300 @@ void ProductQuantizer::compute_codes (const float * x,
|
|
493
485
|
|
494
486
|
#pragma omp parallel for
|
495
487
|
for (int64_t i = 0; i < n; i++)
|
496
|
-
compute_code
|
488
|
+
compute_code(x + i * d, codes + i * code_size);
|
497
489
|
|
498
490
|
} else { // worthwile to use BLAS
|
499
|
-
float
|
500
|
-
ScopeDeleter<float> del
|
501
|
-
compute_distance_tables
|
491
|
+
float* dis_tables = new float[n * ksub * M];
|
492
|
+
ScopeDeleter<float> del(dis_tables);
|
493
|
+
compute_distance_tables(n, x, dis_tables);
|
502
494
|
|
503
495
|
#pragma omp parallel for
|
504
496
|
for (int64_t i = 0; i < n; i++) {
|
505
|
-
uint8_t
|
506
|
-
const float
|
507
|
-
compute_code_from_distance_table
|
497
|
+
uint8_t* code = codes + i * code_size;
|
498
|
+
const float* tab = dis_tables + i * ksub * M;
|
499
|
+
compute_code_from_distance_table(tab, code);
|
508
500
|
}
|
509
501
|
}
|
510
502
|
}
|
511
503
|
|
512
|
-
|
513
|
-
|
514
|
-
float * dis_table) const
|
515
|
-
{
|
504
|
+
void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
|
505
|
+
const {
|
516
506
|
size_t m;
|
517
507
|
|
518
508
|
for (m = 0; m < M; m++) {
|
519
|
-
fvec_L2sqr_ny
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
509
|
+
fvec_L2sqr_ny(
|
510
|
+
dis_table + m * ksub,
|
511
|
+
x + m * dsub,
|
512
|
+
get_centroids(m, 0),
|
513
|
+
dsub,
|
514
|
+
ksub);
|
524
515
|
}
|
525
516
|
}
|
526
517
|
|
527
|
-
void ProductQuantizer::compute_inner_prod_table
|
528
|
-
|
529
|
-
{
|
518
|
+
void ProductQuantizer::compute_inner_prod_table(
|
519
|
+
const float* x,
|
520
|
+
float* dis_table) const {
|
530
521
|
size_t m;
|
531
522
|
|
532
523
|
for (m = 0; m < M; m++) {
|
533
|
-
fvec_inner_products_ny
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
524
|
+
fvec_inner_products_ny(
|
525
|
+
dis_table + m * ksub,
|
526
|
+
x + m * dsub,
|
527
|
+
get_centroids(m, 0),
|
528
|
+
dsub,
|
529
|
+
ksub);
|
538
530
|
}
|
539
531
|
}
|
540
532
|
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
{
|
547
|
-
|
548
|
-
#ifdef __AVX2__
|
533
|
+
void ProductQuantizer::compute_distance_tables(
|
534
|
+
size_t nx,
|
535
|
+
const float* x,
|
536
|
+
float* dis_tables) const {
|
537
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
549
538
|
if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
|
550
539
|
compute_PQ_dis_tables_dsub2(
|
551
|
-
|
552
|
-
nx, x, false, dis_tables
|
553
|
-
);
|
540
|
+
d, ksub, centroids.data(), nx, x, false, dis_tables);
|
554
541
|
} else
|
555
542
|
#endif
|
556
|
-
|
543
|
+
if (dsub < 16) {
|
557
544
|
|
558
545
|
#pragma omp parallel for
|
559
546
|
for (int64_t i = 0; i < nx; i++) {
|
560
|
-
compute_distance_table
|
547
|
+
compute_distance_table(x + i * d, dis_tables + i * ksub * M);
|
561
548
|
}
|
562
549
|
|
563
550
|
} else { // use BLAS
|
564
551
|
|
565
552
|
for (int m = 0; m < M; m++) {
|
566
|
-
pairwise_L2sqr
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
553
|
+
pairwise_L2sqr(
|
554
|
+
dsub,
|
555
|
+
nx,
|
556
|
+
x + dsub * m,
|
557
|
+
ksub,
|
558
|
+
centroids.data() + m * dsub * ksub,
|
559
|
+
dis_tables + ksub * m,
|
560
|
+
d,
|
561
|
+
dsub,
|
562
|
+
ksub * M);
|
571
563
|
}
|
572
564
|
}
|
573
565
|
}
|
574
566
|
|
575
|
-
void ProductQuantizer::compute_inner_prod_tables
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
#ifdef __AVX2__
|
567
|
+
void ProductQuantizer::compute_inner_prod_tables(
|
568
|
+
size_t nx,
|
569
|
+
const float* x,
|
570
|
+
float* dis_tables) const {
|
571
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
581
572
|
if (dsub == 2 && nbits < 8) {
|
582
573
|
compute_PQ_dis_tables_dsub2(
|
583
|
-
|
584
|
-
nx, x, true, dis_tables
|
585
|
-
);
|
574
|
+
d, ksub, centroids.data(), nx, x, true, dis_tables);
|
586
575
|
} else
|
587
576
|
#endif
|
588
|
-
|
577
|
+
if (dsub < 16) {
|
589
578
|
|
590
579
|
#pragma omp parallel for
|
591
580
|
for (int64_t i = 0; i < nx; i++) {
|
592
|
-
compute_inner_prod_table
|
581
|
+
compute_inner_prod_table(x + i * d, dis_tables + i * ksub * M);
|
593
582
|
}
|
594
583
|
|
595
584
|
} else { // use BLAS
|
596
585
|
|
597
586
|
// compute distance tables
|
598
587
|
for (int m = 0; m < M; m++) {
|
599
|
-
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub,
|
600
|
-
|
588
|
+
FINTEGER ldc = ksub * M, nxi = nx, ksubi = ksub, dsubi = dsub,
|
589
|
+
di = d;
|
601
590
|
float one = 1.0, zero = 0;
|
602
591
|
|
603
|
-
sgemm_
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
592
|
+
sgemm_("Transposed",
|
593
|
+
"Not transposed",
|
594
|
+
&ksubi,
|
595
|
+
&nxi,
|
596
|
+
&dsubi,
|
597
|
+
&one,
|
598
|
+
¢roids[m * dsub * ksub],
|
599
|
+
&dsubi,
|
600
|
+
x + dsub * m,
|
601
|
+
&di,
|
602
|
+
&zero,
|
603
|
+
dis_tables + ksub * m,
|
604
|
+
&ldc);
|
608
605
|
}
|
609
|
-
|
610
606
|
}
|
611
607
|
}
|
612
608
|
|
613
609
|
template <class C>
|
614
|
-
static void pq_knn_search_with_tables
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
{
|
610
|
+
static void pq_knn_search_with_tables(
|
611
|
+
const ProductQuantizer& pq,
|
612
|
+
size_t nbits,
|
613
|
+
const float* dis_tables,
|
614
|
+
const uint8_t* codes,
|
615
|
+
const size_t ncodes,
|
616
|
+
HeapArray<C>* res,
|
617
|
+
bool init_finalize_heap) {
|
623
618
|
size_t k = res->k, nx = res->nh;
|
624
619
|
size_t ksub = pq.ksub, M = pq.M;
|
625
620
|
|
626
|
-
|
627
621
|
#pragma omp parallel for
|
628
622
|
for (int64_t i = 0; i < nx; i++) {
|
629
623
|
/* query preparation for asymmetric search: compute look-up tables */
|
630
624
|
const float* dis_table = dis_tables + i * ksub * M;
|
631
625
|
|
632
626
|
/* Compute distances and keep smallest values */
|
633
|
-
int64_t
|
634
|
-
float
|
627
|
+
int64_t* __restrict heap_ids = res->ids + i * k;
|
628
|
+
float* __restrict heap_dis = res->val + i * k;
|
635
629
|
|
636
630
|
if (init_finalize_heap) {
|
637
|
-
heap_heapify<C>
|
631
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
638
632
|
}
|
639
633
|
|
640
634
|
switch (nbits) {
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
635
|
+
case 8:
|
636
|
+
pq_estimators_from_tables<uint8_t, C>(
|
637
|
+
pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
|
638
|
+
break;
|
639
|
+
|
640
|
+
case 16:
|
641
|
+
pq_estimators_from_tables<uint16_t, C>(
|
642
|
+
pq,
|
643
|
+
(uint16_t*)codes,
|
644
|
+
ncodes,
|
645
|
+
dis_table,
|
646
|
+
k,
|
647
|
+
heap_dis,
|
648
|
+
heap_ids);
|
649
|
+
break;
|
650
|
+
|
651
|
+
default:
|
652
|
+
pq_estimators_from_tables_generic<C>(
|
653
|
+
pq,
|
654
|
+
nbits,
|
655
|
+
codes,
|
656
|
+
ncodes,
|
657
|
+
dis_table,
|
658
|
+
k,
|
659
|
+
heap_dis,
|
660
|
+
heap_ids);
|
661
|
+
break;
|
662
662
|
}
|
663
663
|
|
664
664
|
if (init_finalize_heap) {
|
665
|
-
heap_reorder<C>
|
665
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
666
666
|
}
|
667
667
|
}
|
668
668
|
}
|
669
669
|
|
670
|
-
void ProductQuantizer::search
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
{
|
677
|
-
FAISS_THROW_IF_NOT
|
678
|
-
std::unique_ptr<float[]> dis_tables(new float
|
679
|
-
compute_distance_tables
|
680
|
-
|
681
|
-
pq_knn_search_with_tables<CMax<float, int64_t>>
|
682
|
-
|
670
|
+
void ProductQuantizer::search(
|
671
|
+
const float* __restrict x,
|
672
|
+
size_t nx,
|
673
|
+
const uint8_t* codes,
|
674
|
+
const size_t ncodes,
|
675
|
+
float_maxheap_array_t* res,
|
676
|
+
bool init_finalize_heap) const {
|
677
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
678
|
+
std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
|
679
|
+
compute_distance_tables(nx, x, dis_tables.get());
|
680
|
+
|
681
|
+
pq_knn_search_with_tables<CMax<float, int64_t>>(
|
682
|
+
*this,
|
683
|
+
nbits,
|
684
|
+
dis_tables.get(),
|
685
|
+
codes,
|
686
|
+
ncodes,
|
687
|
+
res,
|
688
|
+
init_finalize_heap);
|
683
689
|
}
|
684
690
|
|
685
|
-
void ProductQuantizer::search_ip
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
{
|
692
|
-
FAISS_THROW_IF_NOT
|
693
|
-
std::unique_ptr<float[]> dis_tables(new float
|
694
|
-
compute_inner_prod_tables
|
695
|
-
|
696
|
-
pq_knn_search_with_tables<CMin<float, int64_t
|
697
|
-
|
691
|
+
void ProductQuantizer::search_ip(
|
692
|
+
const float* __restrict x,
|
693
|
+
size_t nx,
|
694
|
+
const uint8_t* codes,
|
695
|
+
const size_t ncodes,
|
696
|
+
float_minheap_array_t* res,
|
697
|
+
bool init_finalize_heap) const {
|
698
|
+
FAISS_THROW_IF_NOT(nx == res->nh);
|
699
|
+
std::unique_ptr<float[]> dis_tables(new float[nx * ksub * M]);
|
700
|
+
compute_inner_prod_tables(nx, x, dis_tables.get());
|
701
|
+
|
702
|
+
pq_knn_search_with_tables<CMin<float, int64_t>>(
|
703
|
+
*this,
|
704
|
+
nbits,
|
705
|
+
dis_tables.get(),
|
706
|
+
codes,
|
707
|
+
ncodes,
|
708
|
+
res,
|
709
|
+
init_finalize_heap);
|
698
710
|
}
|
699
711
|
|
700
|
-
|
701
|
-
|
702
|
-
static float sqr (float x) {
|
712
|
+
static float sqr(float x) {
|
703
713
|
return x * x;
|
704
714
|
}
|
705
715
|
|
706
|
-
void ProductQuantizer::compute_sdc_table
|
707
|
-
|
708
|
-
sdc_table.resize (M * ksub * ksub);
|
709
|
-
|
710
|
-
for (int m = 0; m < M; m++) {
|
716
|
+
void ProductQuantizer::compute_sdc_table() {
|
717
|
+
sdc_table.resize(M * ksub * ksub);
|
711
718
|
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
719
|
+
if (dsub < 4) {
|
720
|
+
#pragma omp parallel for
|
721
|
+
for (int mk = 0; mk < M * ksub; mk++) {
|
722
|
+
// allow omp to schedule in a more fine-grained way
|
723
|
+
// `collapse` is not supported in OpenMP 2.x
|
724
|
+
int m = mk / ksub;
|
725
|
+
int k = mk % ksub;
|
726
|
+
const float* cents = centroids.data() + m * ksub * dsub;
|
727
|
+
const float* centi = cents + k * dsub;
|
728
|
+
float* dis_tab = sdc_table.data() + m * ksub * ksub;
|
729
|
+
fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
|
730
|
+
}
|
731
|
+
} else {
|
732
|
+
// NOTE: it would disable the omp loop in pairwise_L2sqr
|
733
|
+
// but still accelerate especially when M >= 4
|
734
|
+
#pragma omp parallel for
|
735
|
+
for (int m = 0; m < M; m++) {
|
736
|
+
const float* cents = centroids.data() + m * ksub * dsub;
|
737
|
+
float* dis_tab = sdc_table.data() + m * ksub * ksub;
|
738
|
+
pairwise_L2sqr(
|
739
|
+
dsub, ksub, cents, ksub, cents, dis_tab, dsub, dsub, ksub);
|
725
740
|
}
|
726
741
|
}
|
727
742
|
}
|
728
743
|
|
729
|
-
void ProductQuantizer::search_sdc
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
{
|
736
|
-
FAISS_THROW_IF_NOT
|
737
|
-
FAISS_THROW_IF_NOT
|
744
|
+
void ProductQuantizer::search_sdc(
|
745
|
+
const uint8_t* qcodes,
|
746
|
+
size_t nq,
|
747
|
+
const uint8_t* bcodes,
|
748
|
+
const size_t nb,
|
749
|
+
float_maxheap_array_t* res,
|
750
|
+
bool init_finalize_heap) const {
|
751
|
+
FAISS_THROW_IF_NOT(sdc_table.size() == M * ksub * ksub);
|
752
|
+
FAISS_THROW_IF_NOT(nbits == 8);
|
738
753
|
size_t k = res->k;
|
739
754
|
|
740
|
-
|
741
755
|
#pragma omp parallel for
|
742
756
|
for (int64_t i = 0; i < nq; i++) {
|
743
|
-
|
744
757
|
/* Compute distances and keep smallest values */
|
745
|
-
idx_t
|
746
|
-
float
|
747
|
-
const uint8_t
|
758
|
+
idx_t* heap_ids = res->ids + i * k;
|
759
|
+
float* heap_dis = res->val + i * k;
|
760
|
+
const uint8_t* qcode = qcodes + i * code_size;
|
748
761
|
|
749
762
|
if (init_finalize_heap)
|
750
|
-
maxheap_heapify
|
763
|
+
maxheap_heapify(k, heap_dis, heap_ids);
|
751
764
|
|
752
|
-
const uint8_t
|
765
|
+
const uint8_t* bcode = bcodes;
|
753
766
|
for (size_t j = 0; j < nb; j++) {
|
754
767
|
float dis = 0;
|
755
|
-
const float
|
768
|
+
const float* tab = sdc_table.data();
|
756
769
|
for (int m = 0; m < M; m++) {
|
757
770
|
dis += tab[bcode[m] + qcode[m] * ksub];
|
758
771
|
tab += ksub * ksub;
|
759
772
|
}
|
760
773
|
if (dis < heap_dis[0]) {
|
761
|
-
maxheap_replace_top
|
774
|
+
maxheap_replace_top(k, heap_dis, heap_ids, dis, j);
|
762
775
|
}
|
763
776
|
bcode += code_size;
|
764
777
|
}
|
765
778
|
|
766
779
|
if (init_finalize_heap)
|
767
|
-
maxheap_reorder
|
780
|
+
maxheap_reorder(k, heap_dis, heap_ids);
|
768
781
|
}
|
769
|
-
|
770
782
|
}
|
771
783
|
|
772
|
-
|
773
|
-
|
774
|
-
} // namespace faiss
|
784
|
+
} // namespace faiss
|