faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_PRODUCT_QUANTIZER_H
|
|
11
|
+
#define FAISS_PRODUCT_QUANTIZER_H
|
|
12
|
+
|
|
13
|
+
#include <stdint.h>
|
|
14
|
+
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
#include <faiss/Clustering.h>
|
|
18
|
+
#include <faiss/utils/Heap.h>
|
|
19
|
+
|
|
20
|
+
namespace faiss {
|
|
21
|
+
|
|
22
|
+
/** Product Quantizer. Implemented only for METRIC_L2 */
|
|
23
|
+
struct ProductQuantizer {
|
|
24
|
+
|
|
25
|
+
using idx_t = Index::idx_t;
|
|
26
|
+
|
|
27
|
+
size_t d; ///< size of the input vectors
|
|
28
|
+
size_t M; ///< number of subquantizers
|
|
29
|
+
size_t nbits; ///< number of bits per quantization index
|
|
30
|
+
|
|
31
|
+
// values derived from the above
|
|
32
|
+
size_t dsub; ///< dimensionality of each subvector
|
|
33
|
+
size_t code_size; ///< bytes per indexed vector
|
|
34
|
+
size_t ksub; ///< number of centroids for each subquantizer
|
|
35
|
+
bool verbose; ///< verbose during training?
|
|
36
|
+
|
|
37
|
+
/// initialization
|
|
38
|
+
enum train_type_t {
|
|
39
|
+
Train_default,
|
|
40
|
+
Train_hot_start, ///< the centroids are already initialized
|
|
41
|
+
Train_shared, ///< share dictionary accross PQ segments
|
|
42
|
+
Train_hypercube, ///< intialize centroids with nbits-D hypercube
|
|
43
|
+
Train_hypercube_pca, ///< intialize centroids with nbits-D hypercube
|
|
44
|
+
};
|
|
45
|
+
train_type_t train_type;
|
|
46
|
+
|
|
47
|
+
ClusteringParameters cp; ///< parameters used during clustering
|
|
48
|
+
|
|
49
|
+
/// if non-NULL, use this index for assignment (should be of size
|
|
50
|
+
/// d / M)
|
|
51
|
+
Index *assign_index;
|
|
52
|
+
|
|
53
|
+
/// Centroid table, size M * ksub * dsub
|
|
54
|
+
std::vector<float> centroids;
|
|
55
|
+
|
|
56
|
+
/// return the centroids associated with subvector m
|
|
57
|
+
float * get_centroids (size_t m, size_t i) {
|
|
58
|
+
return ¢roids [(m * ksub + i) * dsub];
|
|
59
|
+
}
|
|
60
|
+
const float * get_centroids (size_t m, size_t i) const {
|
|
61
|
+
return ¢roids [(m * ksub + i) * dsub];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// Train the product quantizer on a set of points. A clustering
|
|
65
|
+
// can be set on input to define non-default clustering parameters
|
|
66
|
+
void train (int n, const float *x);
|
|
67
|
+
|
|
68
|
+
ProductQuantizer(size_t d, /* dimensionality of the input vectors */
|
|
69
|
+
size_t M, /* number of subquantizers */
|
|
70
|
+
size_t nbits); /* number of bit per subvector index */
|
|
71
|
+
|
|
72
|
+
ProductQuantizer ();
|
|
73
|
+
|
|
74
|
+
/// compute derived values when d, M and nbits have been set
|
|
75
|
+
void set_derived_values ();
|
|
76
|
+
|
|
77
|
+
/// Define the centroids for subquantizer m
|
|
78
|
+
void set_params (const float * centroids, int m);
|
|
79
|
+
|
|
80
|
+
/// Quantize one vector with the product quantizer
|
|
81
|
+
void compute_code (const float * x, uint8_t * code) const ;
|
|
82
|
+
|
|
83
|
+
/// same as compute_code for several vectors
|
|
84
|
+
void compute_codes (const float * x,
|
|
85
|
+
uint8_t * codes,
|
|
86
|
+
size_t n) const ;
|
|
87
|
+
|
|
88
|
+
/// speed up code assignment using assign_index
|
|
89
|
+
/// (non-const because the index is changed)
|
|
90
|
+
void compute_codes_with_assign_index (
|
|
91
|
+
const float * x,
|
|
92
|
+
uint8_t * codes,
|
|
93
|
+
size_t n);
|
|
94
|
+
|
|
95
|
+
/// decode a vector from a given code (or n vectors if third argument)
|
|
96
|
+
void decode (const uint8_t *code, float *x) const;
|
|
97
|
+
void decode (const uint8_t *code, float *x, size_t n) const;
|
|
98
|
+
|
|
99
|
+
/// If we happen to have the distance tables precomputed, this is
|
|
100
|
+
/// more efficient to compute the codes.
|
|
101
|
+
void compute_code_from_distance_table (const float *tab,
|
|
102
|
+
uint8_t *code) const;
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
/** Compute distance table for one vector.
|
|
106
|
+
*
|
|
107
|
+
* The distance table for x = [x_0 x_1 .. x_(M-1)] is a M * ksub
|
|
108
|
+
* matrix that contains
|
|
109
|
+
*
|
|
110
|
+
* dis_table (m, j) = || x_m - c_(m, j)||^2
|
|
111
|
+
* for m = 0..M-1 and j = 0 .. ksub - 1
|
|
112
|
+
*
|
|
113
|
+
* where c_(m, j) is the centroid no j of sub-quantizer m.
|
|
114
|
+
*
|
|
115
|
+
* @param x input vector size d
|
|
116
|
+
* @param dis_table output table, size M * ksub
|
|
117
|
+
*/
|
|
118
|
+
void compute_distance_table (const float * x,
|
|
119
|
+
float * dis_table) const;
|
|
120
|
+
|
|
121
|
+
void compute_inner_prod_table (const float * x,
|
|
122
|
+
float * dis_table) const;
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
/** compute distance table for several vectors
|
|
126
|
+
* @param nx nb of input vectors
|
|
127
|
+
* @param x input vector size nx * d
|
|
128
|
+
* @param dis_table output table, size nx * M * ksub
|
|
129
|
+
*/
|
|
130
|
+
void compute_distance_tables (size_t nx,
|
|
131
|
+
const float * x,
|
|
132
|
+
float * dis_tables) const;
|
|
133
|
+
|
|
134
|
+
void compute_inner_prod_tables (size_t nx,
|
|
135
|
+
const float * x,
|
|
136
|
+
float * dis_tables) const;
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
/** perform a search (L2 distance)
|
|
140
|
+
* @param x query vectors, size nx * d
|
|
141
|
+
* @param nx nb of queries
|
|
142
|
+
* @param codes database codes, size ncodes * code_size
|
|
143
|
+
* @param ncodes nb of nb vectors
|
|
144
|
+
* @param res heap array to store results (nh == nx)
|
|
145
|
+
* @param init_finalize_heap initialize heap (input) and sort (output)?
|
|
146
|
+
*/
|
|
147
|
+
void search (const float * x,
|
|
148
|
+
size_t nx,
|
|
149
|
+
const uint8_t * codes,
|
|
150
|
+
const size_t ncodes,
|
|
151
|
+
float_maxheap_array_t *res,
|
|
152
|
+
bool init_finalize_heap = true) const;
|
|
153
|
+
|
|
154
|
+
/** same as search, but with inner product similarity */
|
|
155
|
+
void search_ip (const float * x,
|
|
156
|
+
size_t nx,
|
|
157
|
+
const uint8_t * codes,
|
|
158
|
+
const size_t ncodes,
|
|
159
|
+
float_minheap_array_t *res,
|
|
160
|
+
bool init_finalize_heap = true) const;
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
/// Symmetric Distance Table
|
|
164
|
+
std::vector<float> sdc_table;
|
|
165
|
+
|
|
166
|
+
// intitialize the SDC table from the centroids
|
|
167
|
+
void compute_sdc_table ();
|
|
168
|
+
|
|
169
|
+
void search_sdc (const uint8_t * qcodes,
|
|
170
|
+
size_t nq,
|
|
171
|
+
const uint8_t * bcodes,
|
|
172
|
+
const size_t ncodes,
|
|
173
|
+
float_maxheap_array_t * res,
|
|
174
|
+
bool init_finalize_heap = true) const;
|
|
175
|
+
|
|
176
|
+
struct PQEncoderGeneric {
|
|
177
|
+
uint8_t *code; ///< code for this vector
|
|
178
|
+
uint8_t offset;
|
|
179
|
+
const int nbits; ///< number of bits per subquantizer index
|
|
180
|
+
|
|
181
|
+
uint8_t reg;
|
|
182
|
+
|
|
183
|
+
PQEncoderGeneric(uint8_t *code, int nbits, uint8_t offset = 0);
|
|
184
|
+
|
|
185
|
+
void encode(uint64_t x);
|
|
186
|
+
|
|
187
|
+
~PQEncoderGeneric();
|
|
188
|
+
};
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
struct PQEncoder8 {
|
|
192
|
+
uint8_t *code;
|
|
193
|
+
|
|
194
|
+
PQEncoder8(uint8_t *code, int nbits);
|
|
195
|
+
|
|
196
|
+
void encode(uint64_t x);
|
|
197
|
+
};
|
|
198
|
+
|
|
199
|
+
struct PQEncoder16 {
|
|
200
|
+
uint16_t *code;
|
|
201
|
+
|
|
202
|
+
PQEncoder16(uint8_t *code, int nbits);
|
|
203
|
+
|
|
204
|
+
void encode(uint64_t x);
|
|
205
|
+
};
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
struct PQDecoderGeneric {
|
|
209
|
+
const uint8_t *code;
|
|
210
|
+
uint8_t offset;
|
|
211
|
+
const int nbits;
|
|
212
|
+
const uint64_t mask;
|
|
213
|
+
uint8_t reg;
|
|
214
|
+
|
|
215
|
+
PQDecoderGeneric(const uint8_t *code, int nbits);
|
|
216
|
+
|
|
217
|
+
uint64_t decode();
|
|
218
|
+
};
|
|
219
|
+
|
|
220
|
+
struct PQDecoder8 {
|
|
221
|
+
const uint8_t *code;
|
|
222
|
+
|
|
223
|
+
PQDecoder8(const uint8_t *code, int nbits);
|
|
224
|
+
|
|
225
|
+
uint64_t decode();
|
|
226
|
+
};
|
|
227
|
+
|
|
228
|
+
struct PQDecoder16 {
|
|
229
|
+
const uint16_t *code;
|
|
230
|
+
|
|
231
|
+
PQDecoder16(const uint8_t *code, int nbits);
|
|
232
|
+
|
|
233
|
+
uint64_t decode();
|
|
234
|
+
};
|
|
235
|
+
|
|
236
|
+
};
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
} // namespace faiss
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
#endif
|
|
@@ -0,0 +1,1628 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/ScalarQuantizer.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdio>
|
|
13
|
+
#include <algorithm>
|
|
14
|
+
|
|
15
|
+
#include <omp.h>
|
|
16
|
+
|
|
17
|
+
#ifdef __SSE__
|
|
18
|
+
#include <immintrin.h>
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#include <faiss/utils/utils.h>
|
|
22
|
+
#include <faiss/impl/FaissAssert.h>
|
|
23
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
24
|
+
|
|
25
|
+
namespace faiss {
|
|
26
|
+
|
|
27
|
+
/*******************************************************************
|
|
28
|
+
* ScalarQuantizer implementation
|
|
29
|
+
*
|
|
30
|
+
* The main source of complexity is to support combinations of 4
|
|
31
|
+
* variants without incurring runtime tests or virtual function calls:
|
|
32
|
+
*
|
|
33
|
+
* - 4 / 8 bits per code component
|
|
34
|
+
* - uniform / non-uniform
|
|
35
|
+
* - IP / L2 distance search
|
|
36
|
+
* - scalar / AVX distance computation
|
|
37
|
+
*
|
|
38
|
+
* The appropriate Quantizer object is returned via select_quantizer
|
|
39
|
+
* that hides the template mess.
|
|
40
|
+
********************************************************************/
|
|
41
|
+
|
|
42
|
+
#ifdef __AVX__
|
|
43
|
+
#define USE_AVX
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
#ifdef __F16C__
|
|
47
|
+
#define USE_F16C
|
|
48
|
+
#endif
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
namespace {
|
|
52
|
+
|
|
53
|
+
typedef Index::idx_t idx_t;
|
|
54
|
+
typedef ScalarQuantizer::QuantizerType QuantizerType;
|
|
55
|
+
typedef ScalarQuantizer::RangeStat RangeStat;
|
|
56
|
+
using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
/*******************************************************************
|
|
60
|
+
* Codec: converts between values in [0, 1] and an index in a code
|
|
61
|
+
* array. The "i" parameter is the vector component index (not byte
|
|
62
|
+
* index).
|
|
63
|
+
*/
|
|
64
|
+
|
|
65
|
+
struct Codec8bit {
|
|
66
|
+
|
|
67
|
+
static void encode_component (float x, uint8_t *code, int i) {
|
|
68
|
+
code[i] = (int)(255 * x);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
static float decode_component (const uint8_t *code, int i) {
|
|
72
|
+
return (code[i] + 0.5f) / 255.0f;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
#ifdef USE_AVX
|
|
76
|
+
static __m256 decode_8_components (const uint8_t *code, int i) {
|
|
77
|
+
uint64_t c8 = *(uint64_t*)(code + i);
|
|
78
|
+
__m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8));
|
|
79
|
+
__m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32));
|
|
80
|
+
// __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
|
|
81
|
+
__m256i i8 = _mm256_castsi128_si256 (c4lo);
|
|
82
|
+
i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
|
|
83
|
+
__m256 f8 = _mm256_cvtepi32_ps (i8);
|
|
84
|
+
__m256 half = _mm256_set1_ps (0.5f);
|
|
85
|
+
f8 += half;
|
|
86
|
+
__m256 one_255 = _mm256_set1_ps (1.f / 255.f);
|
|
87
|
+
return f8 * one_255;
|
|
88
|
+
}
|
|
89
|
+
#endif
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
struct Codec4bit {
|
|
94
|
+
|
|
95
|
+
static void encode_component (float x, uint8_t *code, int i) {
|
|
96
|
+
code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
static float decode_component (const uint8_t *code, int i) {
|
|
100
|
+
return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
#ifdef USE_AVX
|
|
105
|
+
static __m256 decode_8_components (const uint8_t *code, int i) {
|
|
106
|
+
uint32_t c4 = *(uint32_t*)(code + (i >> 1));
|
|
107
|
+
uint32_t mask = 0x0f0f0f0f;
|
|
108
|
+
uint32_t c4ev = c4 & mask;
|
|
109
|
+
uint32_t c4od = (c4 >> 4) & mask;
|
|
110
|
+
|
|
111
|
+
// the 8 lower bytes of c8 contain the values
|
|
112
|
+
__m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev),
|
|
113
|
+
_mm_set1_epi32(c4od));
|
|
114
|
+
__m128i c4lo = _mm_cvtepu8_epi32 (c8);
|
|
115
|
+
__m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4));
|
|
116
|
+
__m256i i8 = _mm256_castsi128_si256 (c4lo);
|
|
117
|
+
i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
|
|
118
|
+
__m256 f8 = _mm256_cvtepi32_ps (i8);
|
|
119
|
+
__m256 half = _mm256_set1_ps (0.5f);
|
|
120
|
+
f8 += half;
|
|
121
|
+
__m256 one_255 = _mm256_set1_ps (1.f / 15.f);
|
|
122
|
+
return f8 * one_255;
|
|
123
|
+
}
|
|
124
|
+
#endif
|
|
125
|
+
};
|
|
126
|
+
|
|
127
|
+
struct Codec6bit {
|
|
128
|
+
|
|
129
|
+
static void encode_component (float x, uint8_t *code, int i) {
|
|
130
|
+
int bits = (int)(x * 63.0);
|
|
131
|
+
code += (i >> 2) * 3;
|
|
132
|
+
switch(i & 3) {
|
|
133
|
+
case 0:
|
|
134
|
+
code[0] |= bits;
|
|
135
|
+
break;
|
|
136
|
+
case 1:
|
|
137
|
+
code[0] |= bits << 6;
|
|
138
|
+
code[1] |= bits >> 2;
|
|
139
|
+
break;
|
|
140
|
+
case 2:
|
|
141
|
+
code[1] |= bits << 4;
|
|
142
|
+
code[2] |= bits >> 4;
|
|
143
|
+
break;
|
|
144
|
+
case 3:
|
|
145
|
+
code[2] |= bits << 2;
|
|
146
|
+
break;
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
static float decode_component (const uint8_t *code, int i) {
|
|
151
|
+
uint8_t bits;
|
|
152
|
+
code += (i >> 2) * 3;
|
|
153
|
+
switch(i & 3) {
|
|
154
|
+
case 0:
|
|
155
|
+
bits = code[0] & 0x3f;
|
|
156
|
+
break;
|
|
157
|
+
case 1:
|
|
158
|
+
bits = code[0] >> 6;
|
|
159
|
+
bits |= (code[1] & 0xf) << 2;
|
|
160
|
+
break;
|
|
161
|
+
case 2:
|
|
162
|
+
bits = code[1] >> 4;
|
|
163
|
+
bits |= (code[2] & 3) << 4;
|
|
164
|
+
break;
|
|
165
|
+
case 3:
|
|
166
|
+
bits = code[2] >> 2;
|
|
167
|
+
break;
|
|
168
|
+
}
|
|
169
|
+
return (bits + 0.5f) / 63.0f;
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
#ifdef USE_AVX
|
|
173
|
+
static __m256 decode_8_components (const uint8_t *code, int i) {
|
|
174
|
+
return _mm256_set_ps
|
|
175
|
+
(decode_component(code, i + 7),
|
|
176
|
+
decode_component(code, i + 6),
|
|
177
|
+
decode_component(code, i + 5),
|
|
178
|
+
decode_component(code, i + 4),
|
|
179
|
+
decode_component(code, i + 3),
|
|
180
|
+
decode_component(code, i + 2),
|
|
181
|
+
decode_component(code, i + 1),
|
|
182
|
+
decode_component(code, i + 0));
|
|
183
|
+
}
|
|
184
|
+
#endif
|
|
185
|
+
};
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
#ifdef USE_F16C
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
uint16_t encode_fp16 (float x) {
|
|
193
|
+
__m128 xf = _mm_set1_ps (x);
|
|
194
|
+
__m128i xi = _mm_cvtps_ph (
|
|
195
|
+
xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
|
|
196
|
+
return _mm_cvtsi128_si32 (xi) & 0xffff;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
float decode_fp16 (uint16_t x) {
|
|
201
|
+
__m128i xi = _mm_set1_epi16 (x);
|
|
202
|
+
__m128 xf = _mm_cvtph_ps (xi);
|
|
203
|
+
return _mm_cvtss_f32 (xf);
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
#else
|
|
207
|
+
|
|
208
|
+
// non-intrinsic FP16 <-> FP32 code adapted from
|
|
209
|
+
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
|
|
210
|
+
|
|
211
|
+
float floatbits (uint32_t x) {
|
|
212
|
+
void *xptr = &x;
|
|
213
|
+
return *(float*)xptr;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
uint32_t intbits (float f) {
|
|
217
|
+
void *fptr = &f;
|
|
218
|
+
return *(uint32_t*)fptr;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
uint16_t encode_fp16 (float f) {
|
|
223
|
+
|
|
224
|
+
// via Fabian "ryg" Giesen.
|
|
225
|
+
// https://gist.github.com/2156668
|
|
226
|
+
uint32_t sign_mask = 0x80000000u;
|
|
227
|
+
int32_t o;
|
|
228
|
+
|
|
229
|
+
uint32_t fint = intbits(f);
|
|
230
|
+
uint32_t sign = fint & sign_mask;
|
|
231
|
+
fint ^= sign;
|
|
232
|
+
|
|
233
|
+
// NOTE all the integer compares in this function can be safely
|
|
234
|
+
// compiled into signed compares since all operands are below
|
|
235
|
+
// 0x80000000. Important if you want fast straight SSE2 code (since
|
|
236
|
+
// there's no unsigned PCMPGTD).
|
|
237
|
+
|
|
238
|
+
// Inf or NaN (all exponent bits set)
|
|
239
|
+
// NaN->qNaN and Inf->Inf
|
|
240
|
+
// unconditional assignment here, will override with right value for
|
|
241
|
+
// the regular case below.
|
|
242
|
+
uint32_t f32infty = 255u << 23;
|
|
243
|
+
o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
|
|
244
|
+
|
|
245
|
+
// (De)normalized number or zero
|
|
246
|
+
// update fint unconditionally to save the blending; we don't need it
|
|
247
|
+
// anymore for the Inf/NaN case anyway.
|
|
248
|
+
|
|
249
|
+
const uint32_t round_mask = ~0xfffu;
|
|
250
|
+
const uint32_t magic = 15u << 23;
|
|
251
|
+
|
|
252
|
+
// Shift exponent down, denormalize if necessary.
|
|
253
|
+
// NOTE This represents half-float denormals using single
|
|
254
|
+
// precision denormals. The main reason to do this is that
|
|
255
|
+
// there's no shift with per-lane variable shifts in SSE*, which
|
|
256
|
+
// we'd otherwise need. It has some funky side effects though:
|
|
257
|
+
// - This conversion will actually respect the FTZ (Flush To Zero)
|
|
258
|
+
// flag in MXCSR - if it's set, no half-float denormals will be
|
|
259
|
+
// generated. I'm honestly not sure whether this is good or
|
|
260
|
+
// bad. It's definitely interesting.
|
|
261
|
+
// - If the underlying HW doesn't support denormals (not an issue
|
|
262
|
+
// with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
|
|
263
|
+
// you will always get flush-to-zero behavior. This is bad,
|
|
264
|
+
// unless you're on a CPU where you don't care.
|
|
265
|
+
// - Denormals tend to be slow. FP32 denormals are rare in
|
|
266
|
+
// practice outside of things like recursive filters in DSP -
|
|
267
|
+
// not a typical half-float application. Whether FP16 denormals
|
|
268
|
+
// are rare in practice, I don't know. Whatever slow path your
|
|
269
|
+
// HW may or may not have for denormals, this may well hit it.
|
|
270
|
+
float fscale = floatbits(fint & round_mask) * floatbits(magic);
|
|
271
|
+
fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
|
|
272
|
+
int32_t fint2 = intbits(fscale) - round_mask;
|
|
273
|
+
|
|
274
|
+
if (fint < f32infty)
|
|
275
|
+
o = fint2 >> 13; // Take the bits!
|
|
276
|
+
|
|
277
|
+
return (o | (sign >> 16));
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
float decode_fp16 (uint16_t h) {
|
|
281
|
+
|
|
282
|
+
// https://gist.github.com/2144712
|
|
283
|
+
// Fabian "ryg" Giesen.
|
|
284
|
+
|
|
285
|
+
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
|
|
286
|
+
|
|
287
|
+
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
|
|
288
|
+
int32_t exp = shifted_exp & o; // just the exponent
|
|
289
|
+
o += (int32_t)(127 - 15) << 23; // exponent adjust
|
|
290
|
+
|
|
291
|
+
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
|
|
292
|
+
int32_t zerodenorm_val = intbits(
|
|
293
|
+
floatbits(o + (1u<<23)) - floatbits(113u << 23));
|
|
294
|
+
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
|
|
295
|
+
|
|
296
|
+
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
|
|
297
|
+
return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
#endif
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
/*******************************************************************
|
|
305
|
+
* Quantizer: normalizes scalar vector components, then passes them
|
|
306
|
+
* through a codec
|
|
307
|
+
*******************************************************************/
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
template<class Codec, bool uniform, int SIMD>
|
|
314
|
+
struct QuantizerTemplate {};
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
template<class Codec>
|
|
318
|
+
struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
|
|
319
|
+
const size_t d;
|
|
320
|
+
const float vmin, vdiff;
|
|
321
|
+
|
|
322
|
+
QuantizerTemplate(size_t d, const std::vector<float> &trained):
|
|
323
|
+
d(d), vmin(trained[0]), vdiff(trained[1])
|
|
324
|
+
{
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
void encode_vector(const float* x, uint8_t* code) const final {
|
|
328
|
+
for (size_t i = 0; i < d; i++) {
|
|
329
|
+
float xi = (x[i] - vmin) / vdiff;
|
|
330
|
+
if (xi < 0) {
|
|
331
|
+
xi = 0;
|
|
332
|
+
}
|
|
333
|
+
if (xi > 1.0) {
|
|
334
|
+
xi = 1.0;
|
|
335
|
+
}
|
|
336
|
+
Codec::encode_component(xi, code, i);
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
void decode_vector(const uint8_t* code, float* x) const final {
|
|
341
|
+
for (size_t i = 0; i < d; i++) {
|
|
342
|
+
float xi = Codec::decode_component(code, i);
|
|
343
|
+
x[i] = vmin + xi * vdiff;
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
float reconstruct_component (const uint8_t * code, int i) const
|
|
348
|
+
{
|
|
349
|
+
float xi = Codec::decode_component (code, i);
|
|
350
|
+
return vmin + xi * vdiff;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
};
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
#ifdef USE_AVX
|
|
358
|
+
|
|
359
|
+
template<class Codec>
|
|
360
|
+
struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
|
|
361
|
+
|
|
362
|
+
QuantizerTemplate (size_t d, const std::vector<float> &trained):
|
|
363
|
+
QuantizerTemplate<Codec, true, 1> (d, trained) {}
|
|
364
|
+
|
|
365
|
+
__m256 reconstruct_8_components (const uint8_t * code, int i) const
|
|
366
|
+
{
|
|
367
|
+
__m256 xi = Codec::decode_8_components (code, i);
|
|
368
|
+
return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
};
|
|
372
|
+
|
|
373
|
+
#endif
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
template<class Codec>
|
|
378
|
+
struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
|
|
379
|
+
const size_t d;
|
|
380
|
+
const float *vmin, *vdiff;
|
|
381
|
+
|
|
382
|
+
QuantizerTemplate (size_t d, const std::vector<float> &trained):
|
|
383
|
+
d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
|
|
384
|
+
|
|
385
|
+
void encode_vector(const float* x, uint8_t* code) const final {
|
|
386
|
+
for (size_t i = 0; i < d; i++) {
|
|
387
|
+
float xi = (x[i] - vmin[i]) / vdiff[i];
|
|
388
|
+
if (xi < 0)
|
|
389
|
+
xi = 0;
|
|
390
|
+
if (xi > 1.0)
|
|
391
|
+
xi = 1.0;
|
|
392
|
+
Codec::encode_component(xi, code, i);
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
void decode_vector(const uint8_t* code, float* x) const final {
|
|
397
|
+
for (size_t i = 0; i < d; i++) {
|
|
398
|
+
float xi = Codec::decode_component(code, i);
|
|
399
|
+
x[i] = vmin[i] + xi * vdiff[i];
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
float reconstruct_component (const uint8_t * code, int i) const
|
|
404
|
+
{
|
|
405
|
+
float xi = Codec::decode_component (code, i);
|
|
406
|
+
return vmin[i] + xi * vdiff[i];
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
};
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
#ifdef USE_AVX
|
|
413
|
+
|
|
414
|
+
template<class Codec>
|
|
415
|
+
struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
|
|
416
|
+
|
|
417
|
+
QuantizerTemplate (size_t d, const std::vector<float> &trained):
|
|
418
|
+
QuantizerTemplate<Codec, false, 1> (d, trained) {}
|
|
419
|
+
|
|
420
|
+
__m256 reconstruct_8_components (const uint8_t * code, int i) const
|
|
421
|
+
{
|
|
422
|
+
__m256 xi = Codec::decode_8_components (code, i);
|
|
423
|
+
return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
};
|
|
428
|
+
|
|
429
|
+
#endif
|
|
430
|
+
|
|
431
|
+
/*******************************************************************
|
|
432
|
+
* FP16 quantizer
|
|
433
|
+
*******************************************************************/
|
|
434
|
+
|
|
435
|
+
template<int SIMDWIDTH>
|
|
436
|
+
struct QuantizerFP16 {};
|
|
437
|
+
|
|
438
|
+
template<>
|
|
439
|
+
struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
|
|
440
|
+
const size_t d;
|
|
441
|
+
|
|
442
|
+
QuantizerFP16(size_t d, const std::vector<float> & /* unused */):
|
|
443
|
+
d(d) {}
|
|
444
|
+
|
|
445
|
+
void encode_vector(const float* x, uint8_t* code) const final {
|
|
446
|
+
for (size_t i = 0; i < d; i++) {
|
|
447
|
+
((uint16_t*)code)[i] = encode_fp16(x[i]);
|
|
448
|
+
}
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
void decode_vector(const uint8_t* code, float* x) const final {
|
|
452
|
+
for (size_t i = 0; i < d; i++) {
|
|
453
|
+
x[i] = decode_fp16(((uint16_t*)code)[i]);
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
float reconstruct_component (const uint8_t * code, int i) const
|
|
458
|
+
{
|
|
459
|
+
return decode_fp16(((uint16_t*)code)[i]);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
};
|
|
463
|
+
|
|
464
|
+
#ifdef USE_F16C
|
|
465
|
+
|
|
466
|
+
template<>
|
|
467
|
+
struct QuantizerFP16<8>: QuantizerFP16<1> {
|
|
468
|
+
|
|
469
|
+
QuantizerFP16 (size_t d, const std::vector<float> &trained):
|
|
470
|
+
QuantizerFP16<1> (d, trained) {}
|
|
471
|
+
|
|
472
|
+
__m256 reconstruct_8_components (const uint8_t * code, int i) const
|
|
473
|
+
{
|
|
474
|
+
__m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
|
|
475
|
+
return _mm256_cvtph_ps (codei);
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
};
|
|
479
|
+
|
|
480
|
+
#endif
|
|
481
|
+
|
|
482
|
+
/*******************************************************************
|
|
483
|
+
* 8bit_direct quantizer
|
|
484
|
+
*******************************************************************/
|
|
485
|
+
|
|
486
|
+
template<int SIMDWIDTH>
|
|
487
|
+
struct Quantizer8bitDirect {};
|
|
488
|
+
|
|
489
|
+
template<>
|
|
490
|
+
struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
|
|
491
|
+
const size_t d;
|
|
492
|
+
|
|
493
|
+
Quantizer8bitDirect(size_t d, const std::vector<float> & /* unused */):
|
|
494
|
+
d(d) {}
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
void encode_vector(const float* x, uint8_t* code) const final {
|
|
498
|
+
for (size_t i = 0; i < d; i++) {
|
|
499
|
+
code[i] = (uint8_t)x[i];
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
void decode_vector(const uint8_t* code, float* x) const final {
|
|
504
|
+
for (size_t i = 0; i < d; i++) {
|
|
505
|
+
x[i] = code[i];
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
float reconstruct_component (const uint8_t * code, int i) const
|
|
510
|
+
{
|
|
511
|
+
return code[i];
|
|
512
|
+
}
|
|
513
|
+
|
|
514
|
+
};
|
|
515
|
+
|
|
516
|
+
#ifdef USE_AVX
|
|
517
|
+
|
|
518
|
+
template<>
|
|
519
|
+
struct Quantizer8bitDirect<8>: Quantizer8bitDirect<1> {
|
|
520
|
+
|
|
521
|
+
Quantizer8bitDirect (size_t d, const std::vector<float> &trained):
|
|
522
|
+
Quantizer8bitDirect<1> (d, trained) {}
|
|
523
|
+
|
|
524
|
+
__m256 reconstruct_8_components (const uint8_t * code, int i) const
|
|
525
|
+
{
|
|
526
|
+
__m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
|
|
527
|
+
__m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
|
|
528
|
+
return _mm256_cvtepi32_ps (y8); // 8 * float32
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
};
|
|
532
|
+
|
|
533
|
+
#endif
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
template<int SIMDWIDTH>
|
|
537
|
+
ScalarQuantizer::Quantizer *select_quantizer_1 (
|
|
538
|
+
QuantizerType qtype,
|
|
539
|
+
size_t d, const std::vector<float> & trained)
|
|
540
|
+
{
|
|
541
|
+
switch(qtype) {
|
|
542
|
+
case ScalarQuantizer::QT_8bit:
|
|
543
|
+
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(d, trained);
|
|
544
|
+
case ScalarQuantizer::QT_6bit:
|
|
545
|
+
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(d, trained);
|
|
546
|
+
case ScalarQuantizer::QT_4bit:
|
|
547
|
+
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(d, trained);
|
|
548
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
549
|
+
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(d, trained);
|
|
550
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
551
|
+
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(d, trained);
|
|
552
|
+
case ScalarQuantizer::QT_fp16:
|
|
553
|
+
return new QuantizerFP16<SIMDWIDTH> (d, trained);
|
|
554
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
555
|
+
return new Quantizer8bitDirect<SIMDWIDTH> (d, trained);
|
|
556
|
+
}
|
|
557
|
+
FAISS_THROW_MSG ("unknown qtype");
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
/*******************************************************************
|
|
564
|
+
* Quantizer range training
|
|
565
|
+
*/
|
|
566
|
+
|
|
567
|
+
static float sqr (float x) {
|
|
568
|
+
return x * x;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
void train_Uniform(RangeStat rs, float rs_arg,
|
|
573
|
+
idx_t n, int k, const float *x,
|
|
574
|
+
std::vector<float> & trained)
|
|
575
|
+
{
|
|
576
|
+
trained.resize (2);
|
|
577
|
+
float & vmin = trained[0];
|
|
578
|
+
float & vmax = trained[1];
|
|
579
|
+
|
|
580
|
+
if (rs == ScalarQuantizer::RS_minmax) {
|
|
581
|
+
vmin = HUGE_VAL; vmax = -HUGE_VAL;
|
|
582
|
+
for (size_t i = 0; i < n; i++) {
|
|
583
|
+
if (x[i] < vmin) vmin = x[i];
|
|
584
|
+
if (x[i] > vmax) vmax = x[i];
|
|
585
|
+
}
|
|
586
|
+
float vexp = (vmax - vmin) * rs_arg;
|
|
587
|
+
vmin -= vexp;
|
|
588
|
+
vmax += vexp;
|
|
589
|
+
} else if (rs == ScalarQuantizer::RS_meanstd) {
|
|
590
|
+
double sum = 0, sum2 = 0;
|
|
591
|
+
for (size_t i = 0; i < n; i++) {
|
|
592
|
+
sum += x[i];
|
|
593
|
+
sum2 += x[i] * x[i];
|
|
594
|
+
}
|
|
595
|
+
float mean = sum / n;
|
|
596
|
+
float var = sum2 / n - mean * mean;
|
|
597
|
+
float std = var <= 0 ? 1.0 : sqrt(var);
|
|
598
|
+
|
|
599
|
+
vmin = mean - std * rs_arg ;
|
|
600
|
+
vmax = mean + std * rs_arg ;
|
|
601
|
+
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
|
602
|
+
std::vector<float> x_copy(n);
|
|
603
|
+
memcpy(x_copy.data(), x, n * sizeof(*x));
|
|
604
|
+
// TODO just do a qucikselect
|
|
605
|
+
std::sort(x_copy.begin(), x_copy.end());
|
|
606
|
+
int o = int(rs_arg * n);
|
|
607
|
+
if (o < 0) o = 0;
|
|
608
|
+
if (o > n - o) o = n / 2;
|
|
609
|
+
vmin = x_copy[o];
|
|
610
|
+
vmax = x_copy[n - 1 - o];
|
|
611
|
+
|
|
612
|
+
} else if (rs == ScalarQuantizer::RS_optim) {
|
|
613
|
+
float a, b;
|
|
614
|
+
float sx = 0;
|
|
615
|
+
{
|
|
616
|
+
vmin = HUGE_VAL, vmax = -HUGE_VAL;
|
|
617
|
+
for (size_t i = 0; i < n; i++) {
|
|
618
|
+
if (x[i] < vmin) vmin = x[i];
|
|
619
|
+
if (x[i] > vmax) vmax = x[i];
|
|
620
|
+
sx += x[i];
|
|
621
|
+
}
|
|
622
|
+
b = vmin;
|
|
623
|
+
a = (vmax - vmin) / (k - 1);
|
|
624
|
+
}
|
|
625
|
+
int verbose = false;
|
|
626
|
+
int niter = 2000;
|
|
627
|
+
float last_err = -1;
|
|
628
|
+
int iter_last_err = 0;
|
|
629
|
+
for (int it = 0; it < niter; it++) {
|
|
630
|
+
float sn = 0, sn2 = 0, sxn = 0, err1 = 0;
|
|
631
|
+
|
|
632
|
+
for (idx_t i = 0; i < n; i++) {
|
|
633
|
+
float xi = x[i];
|
|
634
|
+
float ni = floor ((xi - b) / a + 0.5);
|
|
635
|
+
if (ni < 0) ni = 0;
|
|
636
|
+
if (ni >= k) ni = k - 1;
|
|
637
|
+
err1 += sqr (xi - (ni * a + b));
|
|
638
|
+
sn += ni;
|
|
639
|
+
sn2 += ni * ni;
|
|
640
|
+
sxn += ni * xi;
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
if (err1 == last_err) {
|
|
644
|
+
iter_last_err ++;
|
|
645
|
+
if (iter_last_err == 16) break;
|
|
646
|
+
} else {
|
|
647
|
+
last_err = err1;
|
|
648
|
+
iter_last_err = 0;
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
float det = sqr (sn) - sn2 * n;
|
|
652
|
+
|
|
653
|
+
b = (sn * sxn - sn2 * sx) / det;
|
|
654
|
+
a = (sn * sx - n * sxn) / det;
|
|
655
|
+
if (verbose) {
|
|
656
|
+
printf ("it %d, err1=%g \r", it, err1);
|
|
657
|
+
fflush(stdout);
|
|
658
|
+
}
|
|
659
|
+
}
|
|
660
|
+
if (verbose) printf("\n");
|
|
661
|
+
|
|
662
|
+
vmin = b;
|
|
663
|
+
vmax = b + a * (k - 1);
|
|
664
|
+
|
|
665
|
+
} else {
|
|
666
|
+
FAISS_THROW_MSG ("Invalid qtype");
|
|
667
|
+
}
|
|
668
|
+
vmax -= vmin;
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
void train_NonUniform(RangeStat rs, float rs_arg,
|
|
672
|
+
idx_t n, int d, int k, const float *x,
|
|
673
|
+
std::vector<float> & trained)
|
|
674
|
+
{
|
|
675
|
+
|
|
676
|
+
trained.resize (2 * d);
|
|
677
|
+
float * vmin = trained.data();
|
|
678
|
+
float * vmax = trained.data() + d;
|
|
679
|
+
if (rs == ScalarQuantizer::RS_minmax) {
|
|
680
|
+
memcpy (vmin, x, sizeof(*x) * d);
|
|
681
|
+
memcpy (vmax, x, sizeof(*x) * d);
|
|
682
|
+
for (size_t i = 1; i < n; i++) {
|
|
683
|
+
const float *xi = x + i * d;
|
|
684
|
+
for (size_t j = 0; j < d; j++) {
|
|
685
|
+
if (xi[j] < vmin[j]) vmin[j] = xi[j];
|
|
686
|
+
if (xi[j] > vmax[j]) vmax[j] = xi[j];
|
|
687
|
+
}
|
|
688
|
+
}
|
|
689
|
+
float *vdiff = vmax;
|
|
690
|
+
for (size_t j = 0; j < d; j++) {
|
|
691
|
+
float vexp = (vmax[j] - vmin[j]) * rs_arg;
|
|
692
|
+
vmin[j] -= vexp;
|
|
693
|
+
vmax[j] += vexp;
|
|
694
|
+
vdiff [j] = vmax[j] - vmin[j];
|
|
695
|
+
}
|
|
696
|
+
} else {
|
|
697
|
+
// transpose
|
|
698
|
+
std::vector<float> xt(n * d);
|
|
699
|
+
for (size_t i = 1; i < n; i++) {
|
|
700
|
+
const float *xi = x + i * d;
|
|
701
|
+
for (size_t j = 0; j < d; j++) {
|
|
702
|
+
xt[j * n + i] = xi[j];
|
|
703
|
+
}
|
|
704
|
+
}
|
|
705
|
+
std::vector<float> trained_d(2);
|
|
706
|
+
#pragma omp parallel for
|
|
707
|
+
for (size_t j = 0; j < d; j++) {
|
|
708
|
+
train_Uniform(rs, rs_arg,
|
|
709
|
+
n, k, xt.data() + j * n,
|
|
710
|
+
trained_d);
|
|
711
|
+
vmin[j] = trained_d[0];
|
|
712
|
+
vmax[j] = trained_d[1];
|
|
713
|
+
}
|
|
714
|
+
}
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
/*******************************************************************
|
|
720
|
+
* Similarity: gets vector components and computes a similarity wrt. a
|
|
721
|
+
* query vector stored in the object. The data fields just encapsulate
|
|
722
|
+
* an accumulator.
|
|
723
|
+
*/
|
|
724
|
+
|
|
725
|
+
template<int SIMDWIDTH>
|
|
726
|
+
struct SimilarityL2 {};
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
template<>
|
|
730
|
+
struct SimilarityL2<1> {
|
|
731
|
+
static constexpr int simdwidth = 1;
|
|
732
|
+
static constexpr MetricType metric_type = METRIC_L2;
|
|
733
|
+
|
|
734
|
+
const float *y, *yi;
|
|
735
|
+
|
|
736
|
+
explicit SimilarityL2 (const float * y): y(y) {}
|
|
737
|
+
|
|
738
|
+
/******* scalar accumulator *******/
|
|
739
|
+
|
|
740
|
+
float accu;
|
|
741
|
+
|
|
742
|
+
void begin () {
|
|
743
|
+
accu = 0;
|
|
744
|
+
yi = y;
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
void add_component (float x) {
|
|
748
|
+
float tmp = *yi++ - x;
|
|
749
|
+
accu += tmp * tmp;
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
void add_component_2 (float x1, float x2) {
|
|
753
|
+
float tmp = x1 - x2;
|
|
754
|
+
accu += tmp * tmp;
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
float result () {
|
|
758
|
+
return accu;
|
|
759
|
+
}
|
|
760
|
+
};
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
#ifdef USE_AVX
|
|
764
|
+
template<>
|
|
765
|
+
struct SimilarityL2<8> {
|
|
766
|
+
static constexpr int simdwidth = 8;
|
|
767
|
+
static constexpr MetricType metric_type = METRIC_L2;
|
|
768
|
+
|
|
769
|
+
const float *y, *yi;
|
|
770
|
+
|
|
771
|
+
explicit SimilarityL2 (const float * y): y(y) {}
|
|
772
|
+
__m256 accu8;
|
|
773
|
+
|
|
774
|
+
void begin_8 () {
|
|
775
|
+
accu8 = _mm256_setzero_ps();
|
|
776
|
+
yi = y;
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
void add_8_components (__m256 x) {
|
|
780
|
+
__m256 yiv = _mm256_loadu_ps (yi);
|
|
781
|
+
yi += 8;
|
|
782
|
+
__m256 tmp = yiv - x;
|
|
783
|
+
accu8 += tmp * tmp;
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
void add_8_components_2 (__m256 x, __m256 y) {
|
|
787
|
+
__m256 tmp = y - x;
|
|
788
|
+
accu8 += tmp * tmp;
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
float result_8 () {
|
|
792
|
+
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
|
793
|
+
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
|
794
|
+
// now add the 0th and 4th component
|
|
795
|
+
return
|
|
796
|
+
_mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
|
|
797
|
+
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
};
|
|
801
|
+
|
|
802
|
+
#endif
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
template<int SIMDWIDTH>
|
|
806
|
+
struct SimilarityIP {};
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
template<>
|
|
810
|
+
struct SimilarityIP<1> {
|
|
811
|
+
static constexpr int simdwidth = 1;
|
|
812
|
+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
|
813
|
+
const float *y, *yi;
|
|
814
|
+
|
|
815
|
+
float accu;
|
|
816
|
+
|
|
817
|
+
explicit SimilarityIP (const float * y):
|
|
818
|
+
y (y) {}
|
|
819
|
+
|
|
820
|
+
void begin () {
|
|
821
|
+
accu = 0;
|
|
822
|
+
yi = y;
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
void add_component (float x) {
|
|
826
|
+
accu += *yi++ * x;
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
void add_component_2 (float x1, float x2) {
|
|
830
|
+
accu += x1 * x2;
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
float result () {
|
|
834
|
+
return accu;
|
|
835
|
+
}
|
|
836
|
+
};
|
|
837
|
+
|
|
838
|
+
#ifdef USE_AVX
|
|
839
|
+
|
|
840
|
+
template<>
|
|
841
|
+
struct SimilarityIP<8> {
|
|
842
|
+
static constexpr int simdwidth = 8;
|
|
843
|
+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
|
844
|
+
|
|
845
|
+
const float *y, *yi;
|
|
846
|
+
|
|
847
|
+
float accu;
|
|
848
|
+
|
|
849
|
+
explicit SimilarityIP (const float * y):
|
|
850
|
+
y (y) {}
|
|
851
|
+
|
|
852
|
+
__m256 accu8;
|
|
853
|
+
|
|
854
|
+
void begin_8 () {
|
|
855
|
+
accu8 = _mm256_setzero_ps();
|
|
856
|
+
yi = y;
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
void add_8_components (__m256 x) {
|
|
860
|
+
__m256 yiv = _mm256_loadu_ps (yi);
|
|
861
|
+
yi += 8;
|
|
862
|
+
accu8 += yiv * x;
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
void add_8_components_2 (__m256 x1, __m256 x2) {
|
|
866
|
+
accu8 += x1 * x2;
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
float result_8 () {
|
|
870
|
+
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
|
871
|
+
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
|
872
|
+
// now add the 0th and 4th component
|
|
873
|
+
return
|
|
874
|
+
_mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
|
|
875
|
+
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
|
876
|
+
}
|
|
877
|
+
};
|
|
878
|
+
#endif
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
/*******************************************************************
|
|
882
|
+
* DistanceComputer: combines a similarity and a quantizer to do
|
|
883
|
+
* code-to-vector or code-to-code comparisons
|
|
884
|
+
*******************************************************************/
|
|
885
|
+
|
|
886
|
+
template<class Quantizer, class Similarity, int SIMDWIDTH>
|
|
887
|
+
struct DCTemplate : SQDistanceComputer {};
|
|
888
|
+
|
|
889
|
+
template<class Quantizer, class Similarity>
|
|
890
|
+
struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
|
|
891
|
+
{
|
|
892
|
+
using Sim = Similarity;
|
|
893
|
+
|
|
894
|
+
Quantizer quant;
|
|
895
|
+
|
|
896
|
+
DCTemplate(size_t d, const std::vector<float> &trained):
|
|
897
|
+
quant(d, trained)
|
|
898
|
+
{}
|
|
899
|
+
|
|
900
|
+
float compute_distance(const float* x, const uint8_t* code) const {
|
|
901
|
+
|
|
902
|
+
Similarity sim(x);
|
|
903
|
+
sim.begin();
|
|
904
|
+
for (size_t i = 0; i < quant.d; i++) {
|
|
905
|
+
float xi = quant.reconstruct_component(code, i);
|
|
906
|
+
sim.add_component(xi);
|
|
907
|
+
}
|
|
908
|
+
return sim.result();
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
912
|
+
const {
|
|
913
|
+
Similarity sim(nullptr);
|
|
914
|
+
sim.begin();
|
|
915
|
+
for (size_t i = 0; i < quant.d; i++) {
|
|
916
|
+
float x1 = quant.reconstruct_component(code1, i);
|
|
917
|
+
float x2 = quant.reconstruct_component(code2, i);
|
|
918
|
+
sim.add_component_2(x1, x2);
|
|
919
|
+
}
|
|
920
|
+
return sim.result();
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
void set_query (const float *x) final {
|
|
924
|
+
q = x;
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
/// compute distance of vector i to current query
|
|
928
|
+
float operator () (idx_t i) final {
|
|
929
|
+
return compute_distance (q, codes + i * code_size);
|
|
930
|
+
}
|
|
931
|
+
|
|
932
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
|
933
|
+
return compute_code_distance (codes + i * code_size,
|
|
934
|
+
codes + j * code_size);
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
float query_to_code (const uint8_t * code) const {
|
|
938
|
+
return compute_distance (q, code);
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
};
|
|
942
|
+
|
|
943
|
+
#ifdef USE_F16C
|
|
944
|
+
|
|
945
|
+
template<class Quantizer, class Similarity>
|
|
946
|
+
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
|
|
947
|
+
{
|
|
948
|
+
using Sim = Similarity;
|
|
949
|
+
|
|
950
|
+
Quantizer quant;
|
|
951
|
+
|
|
952
|
+
DCTemplate(size_t d, const std::vector<float> &trained):
|
|
953
|
+
quant(d, trained)
|
|
954
|
+
{}
|
|
955
|
+
|
|
956
|
+
float compute_distance(const float* x, const uint8_t* code) const {
|
|
957
|
+
|
|
958
|
+
Similarity sim(x);
|
|
959
|
+
sim.begin_8();
|
|
960
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
|
961
|
+
__m256 xi = quant.reconstruct_8_components(code, i);
|
|
962
|
+
sim.add_8_components(xi);
|
|
963
|
+
}
|
|
964
|
+
return sim.result_8();
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
968
|
+
const {
|
|
969
|
+
Similarity sim(nullptr);
|
|
970
|
+
sim.begin_8();
|
|
971
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
|
972
|
+
__m256 x1 = quant.reconstruct_8_components(code1, i);
|
|
973
|
+
__m256 x2 = quant.reconstruct_8_components(code2, i);
|
|
974
|
+
sim.add_8_components_2(x1, x2);
|
|
975
|
+
}
|
|
976
|
+
return sim.result_8();
|
|
977
|
+
}
|
|
978
|
+
|
|
979
|
+
void set_query (const float *x) final {
|
|
980
|
+
q = x;
|
|
981
|
+
}
|
|
982
|
+
|
|
983
|
+
/// compute distance of vector i to current query
|
|
984
|
+
float operator () (idx_t i) final {
|
|
985
|
+
return compute_distance (q, codes + i * code_size);
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
|
989
|
+
return compute_code_distance (codes + i * code_size,
|
|
990
|
+
codes + j * code_size);
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
float query_to_code (const uint8_t * code) const {
|
|
994
|
+
return compute_distance (q, code);
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
};
|
|
998
|
+
|
|
999
|
+
#endif
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
/*******************************************************************
|
|
1004
|
+
* DistanceComputerByte: computes distances in the integer domain
|
|
1005
|
+
*******************************************************************/
|
|
1006
|
+
|
|
1007
|
+
template<class Similarity, int SIMDWIDTH>
|
|
1008
|
+
struct DistanceComputerByte : SQDistanceComputer {};
|
|
1009
|
+
|
|
1010
|
+
template<class Similarity>
|
|
1011
|
+
struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
1012
|
+
using Sim = Similarity;
|
|
1013
|
+
|
|
1014
|
+
int d;
|
|
1015
|
+
std::vector<uint8_t> tmp;
|
|
1016
|
+
|
|
1017
|
+
DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
1021
|
+
const {
|
|
1022
|
+
int accu = 0;
|
|
1023
|
+
for (int i = 0; i < d; i++) {
|
|
1024
|
+
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1025
|
+
accu += int(code1[i]) * code2[i];
|
|
1026
|
+
} else {
|
|
1027
|
+
int diff = int(code1[i]) - code2[i];
|
|
1028
|
+
accu += diff * diff;
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
return accu;
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
void set_query (const float *x) final {
|
|
1035
|
+
for (int i = 0; i < d; i++) {
|
|
1036
|
+
tmp[i] = int(x[i]);
|
|
1037
|
+
}
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
|
1041
|
+
set_query(x);
|
|
1042
|
+
return compute_code_distance(tmp.data(), code);
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
/// compute distance of vector i to current query
|
|
1046
|
+
float operator () (idx_t i) final {
|
|
1047
|
+
return compute_distance (q, codes + i * code_size);
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
|
1051
|
+
return compute_code_distance (codes + i * code_size,
|
|
1052
|
+
codes + j * code_size);
|
|
1053
|
+
}
|
|
1054
|
+
|
|
1055
|
+
float query_to_code (const uint8_t * code) const {
|
|
1056
|
+
return compute_code_distance (tmp.data(), code);
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
};
|
|
1060
|
+
|
|
1061
|
+
#ifdef USE_AVX
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
template<class Similarity>
|
|
1065
|
+
struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
1066
|
+
using Sim = Similarity;
|
|
1067
|
+
|
|
1068
|
+
int d;
|
|
1069
|
+
std::vector<uint8_t> tmp;
|
|
1070
|
+
|
|
1071
|
+
DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
1075
|
+
const {
|
|
1076
|
+
// __m256i accu = _mm256_setzero_ps ();
|
|
1077
|
+
__m256i accu = _mm256_setzero_si256 ();
|
|
1078
|
+
for (int i = 0; i < d; i += 16) {
|
|
1079
|
+
// load 16 bytes, convert to 16 uint16_t
|
|
1080
|
+
__m256i c1 = _mm256_cvtepu8_epi16
|
|
1081
|
+
(_mm_loadu_si128((__m128i*)(code1 + i)));
|
|
1082
|
+
__m256i c2 = _mm256_cvtepu8_epi16
|
|
1083
|
+
(_mm_loadu_si128((__m128i*)(code2 + i)));
|
|
1084
|
+
__m256i prod32;
|
|
1085
|
+
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1086
|
+
prod32 = _mm256_madd_epi16(c1, c2);
|
|
1087
|
+
} else {
|
|
1088
|
+
__m256i diff = _mm256_sub_epi16(c1, c2);
|
|
1089
|
+
prod32 = _mm256_madd_epi16(diff, diff);
|
|
1090
|
+
}
|
|
1091
|
+
accu = _mm256_add_epi32 (accu, prod32);
|
|
1092
|
+
|
|
1093
|
+
}
|
|
1094
|
+
__m128i sum = _mm256_extractf128_si256(accu, 0);
|
|
1095
|
+
sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1));
|
|
1096
|
+
sum = _mm_hadd_epi32 (sum, sum);
|
|
1097
|
+
sum = _mm_hadd_epi32 (sum, sum);
|
|
1098
|
+
return _mm_cvtsi128_si32 (sum);
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
void set_query (const float *x) final {
|
|
1102
|
+
/*
|
|
1103
|
+
for (int i = 0; i < d; i += 8) {
|
|
1104
|
+
__m256 xi = _mm256_loadu_ps (x + i);
|
|
1105
|
+
__m256i ci = _mm256_cvtps_epi32(xi);
|
|
1106
|
+
*/
|
|
1107
|
+
for (int i = 0; i < d; i++) {
|
|
1108
|
+
tmp[i] = int(x[i]);
|
|
1109
|
+
}
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
|
1113
|
+
set_query(x);
|
|
1114
|
+
return compute_code_distance(tmp.data(), code);
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
/// compute distance of vector i to current query
|
|
1118
|
+
float operator () (idx_t i) final {
|
|
1119
|
+
return compute_distance (q, codes + i * code_size);
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
float symmetric_dis (idx_t i, idx_t j) override {
|
|
1123
|
+
return compute_code_distance (codes + i * code_size,
|
|
1124
|
+
codes + j * code_size);
|
|
1125
|
+
}
|
|
1126
|
+
|
|
1127
|
+
float query_to_code (const uint8_t * code) const {
|
|
1128
|
+
return compute_code_distance (tmp.data(), code);
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
|
|
1132
|
+
};
|
|
1133
|
+
|
|
1134
|
+
#endif
|
|
1135
|
+
|
|
1136
|
+
/*******************************************************************
|
|
1137
|
+
* select_distance_computer: runtime selection of template
|
|
1138
|
+
* specialization
|
|
1139
|
+
*******************************************************************/
|
|
1140
|
+
|
|
1141
|
+
|
|
1142
|
+
template<class Sim>
|
|
1143
|
+
SQDistanceComputer *select_distance_computer (
|
|
1144
|
+
QuantizerType qtype,
|
|
1145
|
+
size_t d, const std::vector<float> & trained)
|
|
1146
|
+
{
|
|
1147
|
+
constexpr int SIMDWIDTH = Sim::simdwidth;
|
|
1148
|
+
switch(qtype) {
|
|
1149
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
1150
|
+
return new DCTemplate<QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
|
|
1151
|
+
Sim, SIMDWIDTH>(d, trained);
|
|
1152
|
+
|
|
1153
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
1154
|
+
return new DCTemplate<QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
|
|
1155
|
+
Sim, SIMDWIDTH>(d, trained);
|
|
1156
|
+
|
|
1157
|
+
case ScalarQuantizer::QT_8bit:
|
|
1158
|
+
return new DCTemplate<QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
|
|
1159
|
+
Sim, SIMDWIDTH>(d, trained);
|
|
1160
|
+
|
|
1161
|
+
case ScalarQuantizer::QT_6bit:
|
|
1162
|
+
return new DCTemplate<QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
|
|
1163
|
+
Sim, SIMDWIDTH>(d, trained);
|
|
1164
|
+
|
|
1165
|
+
case ScalarQuantizer::QT_4bit:
|
|
1166
|
+
return new DCTemplate<QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
|
|
1167
|
+
Sim, SIMDWIDTH>(d, trained);
|
|
1168
|
+
|
|
1169
|
+
case ScalarQuantizer::QT_fp16:
|
|
1170
|
+
return new DCTemplate
|
|
1171
|
+
<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
|
|
1172
|
+
|
|
1173
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
1174
|
+
if (d % 16 == 0) {
|
|
1175
|
+
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
|
|
1176
|
+
} else {
|
|
1177
|
+
return new DCTemplate
|
|
1178
|
+
<Quantizer8bitDirect<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
|
|
1179
|
+
}
|
|
1180
|
+
}
|
|
1181
|
+
FAISS_THROW_MSG ("unknown qtype");
|
|
1182
|
+
return nullptr;
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
} // anonymous namespace
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
|
|
1191
|
+
/*******************************************************************
|
|
1192
|
+
* ScalarQuantizer implementation
|
|
1193
|
+
********************************************************************/
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
ScalarQuantizer::ScalarQuantizer
|
|
1198
|
+
(size_t d, QuantizerType qtype):
|
|
1199
|
+
qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d (d)
|
|
1200
|
+
{
|
|
1201
|
+
switch (qtype) {
|
|
1202
|
+
case QT_8bit:
|
|
1203
|
+
case QT_8bit_uniform:
|
|
1204
|
+
case QT_8bit_direct:
|
|
1205
|
+
code_size = d;
|
|
1206
|
+
break;
|
|
1207
|
+
case QT_4bit:
|
|
1208
|
+
case QT_4bit_uniform:
|
|
1209
|
+
code_size = (d + 1) / 2;
|
|
1210
|
+
break;
|
|
1211
|
+
case QT_6bit:
|
|
1212
|
+
code_size = (d * 6 + 7) / 8;
|
|
1213
|
+
break;
|
|
1214
|
+
case QT_fp16:
|
|
1215
|
+
code_size = d * 2;
|
|
1216
|
+
break;
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
ScalarQuantizer::ScalarQuantizer ():
|
|
1222
|
+
qtype(QT_8bit),
|
|
1223
|
+
rangestat(RS_minmax), rangestat_arg(0), d (0), code_size(0)
|
|
1224
|
+
{}
|
|
1225
|
+
|
|
1226
|
+
void ScalarQuantizer::train (size_t n, const float *x)
|
|
1227
|
+
{
|
|
1228
|
+
int bit_per_dim =
|
|
1229
|
+
qtype == QT_4bit_uniform ? 4 :
|
|
1230
|
+
qtype == QT_4bit ? 4 :
|
|
1231
|
+
qtype == QT_6bit ? 6 :
|
|
1232
|
+
qtype == QT_8bit_uniform ? 8 :
|
|
1233
|
+
qtype == QT_8bit ? 8 : -1;
|
|
1234
|
+
|
|
1235
|
+
switch (qtype) {
|
|
1236
|
+
case QT_4bit_uniform: case QT_8bit_uniform:
|
|
1237
|
+
train_Uniform (rangestat, rangestat_arg,
|
|
1238
|
+
n * d, 1 << bit_per_dim, x, trained);
|
|
1239
|
+
break;
|
|
1240
|
+
case QT_4bit: case QT_8bit: case QT_6bit:
|
|
1241
|
+
train_NonUniform (rangestat, rangestat_arg,
|
|
1242
|
+
n, d, 1 << bit_per_dim, x, trained);
|
|
1243
|
+
break;
|
|
1244
|
+
case QT_fp16:
|
|
1245
|
+
case QT_8bit_direct:
|
|
1246
|
+
// no training necessary
|
|
1247
|
+
break;
|
|
1248
|
+
}
|
|
1249
|
+
}
|
|
1250
|
+
|
|
1251
|
+
void ScalarQuantizer::train_residual(size_t n,
|
|
1252
|
+
const float *x,
|
|
1253
|
+
Index *quantizer,
|
|
1254
|
+
bool by_residual,
|
|
1255
|
+
bool verbose)
|
|
1256
|
+
{
|
|
1257
|
+
const float * x_in = x;
|
|
1258
|
+
|
|
1259
|
+
// 100k points more than enough
|
|
1260
|
+
x = fvecs_maybe_subsample (
|
|
1261
|
+
d, (size_t*)&n, 100000,
|
|
1262
|
+
x, verbose, 1234);
|
|
1263
|
+
|
|
1264
|
+
ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
|
|
1265
|
+
|
|
1266
|
+
if (by_residual) {
|
|
1267
|
+
std::vector<Index::idx_t> idx(n);
|
|
1268
|
+
quantizer->assign (n, x, idx.data());
|
|
1269
|
+
|
|
1270
|
+
std::vector<float> residuals(n * d);
|
|
1271
|
+
quantizer->compute_residual_n (n, x, residuals.data(), idx.data());
|
|
1272
|
+
|
|
1273
|
+
train (n, residuals.data());
|
|
1274
|
+
} else {
|
|
1275
|
+
train (n, x);
|
|
1276
|
+
}
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
|
|
1281
|
+
{
|
|
1282
|
+
#ifdef USE_F16C
|
|
1283
|
+
if (d % 8 == 0) {
|
|
1284
|
+
return select_quantizer_1<8> (qtype, d, trained);
|
|
1285
|
+
} else
|
|
1286
|
+
#endif
|
|
1287
|
+
{
|
|
1288
|
+
return select_quantizer_1<1> (qtype, d, trained);
|
|
1289
|
+
}
|
|
1290
|
+
}
|
|
1291
|
+
|
|
1292
|
+
|
|
1293
|
+
void ScalarQuantizer::compute_codes (const float * x,
|
|
1294
|
+
uint8_t * codes,
|
|
1295
|
+
size_t n) const
|
|
1296
|
+
{
|
|
1297
|
+
std::unique_ptr<Quantizer> squant(select_quantizer ());
|
|
1298
|
+
|
|
1299
|
+
memset (codes, 0, code_size * n);
|
|
1300
|
+
#pragma omp parallel for
|
|
1301
|
+
for (size_t i = 0; i < n; i++)
|
|
1302
|
+
squant->encode_vector (x + i * d, codes + i * code_size);
|
|
1303
|
+
}
|
|
1304
|
+
|
|
1305
|
+
void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const
|
|
1306
|
+
{
|
|
1307
|
+
std::unique_ptr<Quantizer> squant(select_quantizer ());
|
|
1308
|
+
|
|
1309
|
+
#pragma omp parallel for
|
|
1310
|
+
for (size_t i = 0; i < n; i++)
|
|
1311
|
+
squant->decode_vector (codes + i * code_size, x + i * d);
|
|
1312
|
+
}
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
SQDistanceComputer *
|
|
1316
|
+
ScalarQuantizer::get_distance_computer (MetricType metric) const
|
|
1317
|
+
{
|
|
1318
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
1319
|
+
#ifdef USE_F16C
|
|
1320
|
+
if (d % 8 == 0) {
|
|
1321
|
+
if (metric == METRIC_L2) {
|
|
1322
|
+
return select_distance_computer<SimilarityL2<8> >
|
|
1323
|
+
(qtype, d, trained);
|
|
1324
|
+
} else {
|
|
1325
|
+
return select_distance_computer<SimilarityIP<8> >
|
|
1326
|
+
(qtype, d, trained);
|
|
1327
|
+
}
|
|
1328
|
+
} else
|
|
1329
|
+
#endif
|
|
1330
|
+
{
|
|
1331
|
+
if (metric == METRIC_L2) {
|
|
1332
|
+
return select_distance_computer<SimilarityL2<1> >
|
|
1333
|
+
(qtype, d, trained);
|
|
1334
|
+
} else {
|
|
1335
|
+
return select_distance_computer<SimilarityIP<1> >
|
|
1336
|
+
(qtype, d, trained);
|
|
1337
|
+
}
|
|
1338
|
+
}
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
|
|
1342
|
+
/*******************************************************************
|
|
1343
|
+
* IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
|
|
1344
|
+
*
|
|
1345
|
+
* It is an InvertedListScanner, but is designed to work with
|
|
1346
|
+
* IndexScalarQuantizer as well.
|
|
1347
|
+
********************************************************************/
|
|
1348
|
+
|
|
1349
|
+
namespace {
|
|
1350
|
+
|
|
1351
|
+
|
|
1352
|
+
template<class DCClass>
|
|
1353
|
+
struct IVFSQScannerIP: InvertedListScanner {
|
|
1354
|
+
DCClass dc;
|
|
1355
|
+
bool store_pairs, by_residual;
|
|
1356
|
+
|
|
1357
|
+
size_t code_size;
|
|
1358
|
+
|
|
1359
|
+
idx_t list_no; /// current list (set to 0 for Flat index
|
|
1360
|
+
float accu0; /// added to all distances
|
|
1361
|
+
|
|
1362
|
+
IVFSQScannerIP(int d, const std::vector<float> & trained,
|
|
1363
|
+
size_t code_size, bool store_pairs,
|
|
1364
|
+
bool by_residual):
|
|
1365
|
+
dc(d, trained), store_pairs(store_pairs),
|
|
1366
|
+
by_residual(by_residual),
|
|
1367
|
+
code_size(code_size), list_no(0), accu0(0)
|
|
1368
|
+
{}
|
|
1369
|
+
|
|
1370
|
+
|
|
1371
|
+
void set_query (const float *query) override {
|
|
1372
|
+
dc.set_query (query);
|
|
1373
|
+
}
|
|
1374
|
+
|
|
1375
|
+
void set_list (idx_t list_no, float coarse_dis) override {
|
|
1376
|
+
this->list_no = list_no;
|
|
1377
|
+
accu0 = by_residual ? coarse_dis : 0;
|
|
1378
|
+
}
|
|
1379
|
+
|
|
1380
|
+
float distance_to_code (const uint8_t *code) const final {
|
|
1381
|
+
return accu0 + dc.query_to_code (code);
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
size_t scan_codes (size_t list_size,
|
|
1385
|
+
const uint8_t *codes,
|
|
1386
|
+
const idx_t *ids,
|
|
1387
|
+
float *simi, idx_t *idxi,
|
|
1388
|
+
size_t k) const override
|
|
1389
|
+
{
|
|
1390
|
+
size_t nup = 0;
|
|
1391
|
+
|
|
1392
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1393
|
+
|
|
1394
|
+
float accu = accu0 + dc.query_to_code (codes);
|
|
1395
|
+
|
|
1396
|
+
if (accu > simi [0]) {
|
|
1397
|
+
minheap_pop (k, simi, idxi);
|
|
1398
|
+
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1399
|
+
minheap_push (k, simi, idxi, accu, id);
|
|
1400
|
+
nup++;
|
|
1401
|
+
}
|
|
1402
|
+
codes += code_size;
|
|
1403
|
+
}
|
|
1404
|
+
return nup;
|
|
1405
|
+
}
|
|
1406
|
+
|
|
1407
|
+
void scan_codes_range (size_t list_size,
|
|
1408
|
+
const uint8_t *codes,
|
|
1409
|
+
const idx_t *ids,
|
|
1410
|
+
float radius,
|
|
1411
|
+
RangeQueryResult & res) const override
|
|
1412
|
+
{
|
|
1413
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1414
|
+
float accu = accu0 + dc.query_to_code (codes);
|
|
1415
|
+
if (accu > radius) {
|
|
1416
|
+
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1417
|
+
res.add (accu, id);
|
|
1418
|
+
}
|
|
1419
|
+
codes += code_size;
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
};
|
|
1425
|
+
|
|
1426
|
+
|
|
1427
|
+
template<class DCClass>
|
|
1428
|
+
struct IVFSQScannerL2: InvertedListScanner {
|
|
1429
|
+
|
|
1430
|
+
DCClass dc;
|
|
1431
|
+
|
|
1432
|
+
bool store_pairs, by_residual;
|
|
1433
|
+
size_t code_size;
|
|
1434
|
+
const Index *quantizer;
|
|
1435
|
+
idx_t list_no; /// current inverted list
|
|
1436
|
+
const float *x; /// current query
|
|
1437
|
+
|
|
1438
|
+
std::vector<float> tmp;
|
|
1439
|
+
|
|
1440
|
+
IVFSQScannerL2(int d, const std::vector<float> & trained,
|
|
1441
|
+
size_t code_size, const Index *quantizer,
|
|
1442
|
+
bool store_pairs, bool by_residual):
|
|
1443
|
+
dc(d, trained), store_pairs(store_pairs), by_residual(by_residual),
|
|
1444
|
+
code_size(code_size), quantizer(quantizer),
|
|
1445
|
+
list_no (0), x (nullptr), tmp (d)
|
|
1446
|
+
{
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
void set_query (const float *query) override {
|
|
1451
|
+
x = query;
|
|
1452
|
+
if (!quantizer) {
|
|
1453
|
+
dc.set_query (query);
|
|
1454
|
+
}
|
|
1455
|
+
}
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
void set_list (idx_t list_no, float /*coarse_dis*/) override {
|
|
1459
|
+
if (by_residual) {
|
|
1460
|
+
this->list_no = list_no;
|
|
1461
|
+
// shift of x_in wrt centroid
|
|
1462
|
+
quantizer->compute_residual (x, tmp.data(), list_no);
|
|
1463
|
+
dc.set_query (tmp.data ());
|
|
1464
|
+
} else {
|
|
1465
|
+
dc.set_query (x);
|
|
1466
|
+
}
|
|
1467
|
+
}
|
|
1468
|
+
|
|
1469
|
+
float distance_to_code (const uint8_t *code) const final {
|
|
1470
|
+
return dc.query_to_code (code);
|
|
1471
|
+
}
|
|
1472
|
+
|
|
1473
|
+
size_t scan_codes (size_t list_size,
|
|
1474
|
+
const uint8_t *codes,
|
|
1475
|
+
const idx_t *ids,
|
|
1476
|
+
float *simi, idx_t *idxi,
|
|
1477
|
+
size_t k) const override
|
|
1478
|
+
{
|
|
1479
|
+
size_t nup = 0;
|
|
1480
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1481
|
+
|
|
1482
|
+
float dis = dc.query_to_code (codes);
|
|
1483
|
+
|
|
1484
|
+
if (dis < simi [0]) {
|
|
1485
|
+
maxheap_pop (k, simi, idxi);
|
|
1486
|
+
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1487
|
+
maxheap_push (k, simi, idxi, dis, id);
|
|
1488
|
+
nup++;
|
|
1489
|
+
}
|
|
1490
|
+
codes += code_size;
|
|
1491
|
+
}
|
|
1492
|
+
return nup;
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
void scan_codes_range (size_t list_size,
|
|
1496
|
+
const uint8_t *codes,
|
|
1497
|
+
const idx_t *ids,
|
|
1498
|
+
float radius,
|
|
1499
|
+
RangeQueryResult & res) const override
|
|
1500
|
+
{
|
|
1501
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
1502
|
+
float dis = dc.query_to_code (codes);
|
|
1503
|
+
if (dis < radius) {
|
|
1504
|
+
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1505
|
+
res.add (dis, id);
|
|
1506
|
+
}
|
|
1507
|
+
codes += code_size;
|
|
1508
|
+
}
|
|
1509
|
+
}
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
};
|
|
1513
|
+
|
|
1514
|
+
template<class DCClass>
|
|
1515
|
+
InvertedListScanner* sel2_InvertedListScanner
|
|
1516
|
+
(const ScalarQuantizer *sq,
|
|
1517
|
+
const Index *quantizer, bool store_pairs, bool r)
|
|
1518
|
+
{
|
|
1519
|
+
if (DCClass::Sim::metric_type == METRIC_L2) {
|
|
1520
|
+
return new IVFSQScannerL2<DCClass>(sq->d, sq->trained, sq->code_size,
|
|
1521
|
+
quantizer, store_pairs, r);
|
|
1522
|
+
} else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1523
|
+
return new IVFSQScannerIP<DCClass>(sq->d, sq->trained, sq->code_size,
|
|
1524
|
+
store_pairs, r);
|
|
1525
|
+
} else {
|
|
1526
|
+
FAISS_THROW_MSG("unsupported metric type");
|
|
1527
|
+
}
|
|
1528
|
+
}
|
|
1529
|
+
|
|
1530
|
+
template<class Similarity, class Codec, bool uniform>
|
|
1531
|
+
InvertedListScanner* sel12_InvertedListScanner
|
|
1532
|
+
(const ScalarQuantizer *sq,
|
|
1533
|
+
const Index *quantizer, bool store_pairs, bool r)
|
|
1534
|
+
{
|
|
1535
|
+
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1536
|
+
using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
|
|
1537
|
+
using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
|
|
1538
|
+
return sel2_InvertedListScanner<DCClass> (sq, quantizer, store_pairs, r);
|
|
1539
|
+
}
|
|
1540
|
+
|
|
1541
|
+
|
|
1542
|
+
|
|
1543
|
+
template<class Similarity>
|
|
1544
|
+
InvertedListScanner* sel1_InvertedListScanner
|
|
1545
|
+
(const ScalarQuantizer *sq, const Index *quantizer,
|
|
1546
|
+
bool store_pairs, bool r)
|
|
1547
|
+
{
|
|
1548
|
+
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1549
|
+
switch(sq->qtype) {
|
|
1550
|
+
case ScalarQuantizer::QT_8bit_uniform:
|
|
1551
|
+
return sel12_InvertedListScanner
|
|
1552
|
+
<Similarity, Codec8bit, true>(sq, quantizer, store_pairs, r);
|
|
1553
|
+
case ScalarQuantizer::QT_4bit_uniform:
|
|
1554
|
+
return sel12_InvertedListScanner
|
|
1555
|
+
<Similarity, Codec4bit, true>(sq, quantizer, store_pairs, r);
|
|
1556
|
+
case ScalarQuantizer::QT_8bit:
|
|
1557
|
+
return sel12_InvertedListScanner
|
|
1558
|
+
<Similarity, Codec8bit, false>(sq, quantizer, store_pairs, r);
|
|
1559
|
+
case ScalarQuantizer::QT_4bit:
|
|
1560
|
+
return sel12_InvertedListScanner
|
|
1561
|
+
<Similarity, Codec4bit, false>(sq, quantizer, store_pairs, r);
|
|
1562
|
+
case ScalarQuantizer::QT_6bit:
|
|
1563
|
+
return sel12_InvertedListScanner
|
|
1564
|
+
<Similarity, Codec6bit, false>(sq, quantizer, store_pairs, r);
|
|
1565
|
+
case ScalarQuantizer::QT_fp16:
|
|
1566
|
+
return sel2_InvertedListScanner
|
|
1567
|
+
<DCTemplate<QuantizerFP16<SIMDWIDTH>, Similarity, SIMDWIDTH> >
|
|
1568
|
+
(sq, quantizer, store_pairs, r);
|
|
1569
|
+
case ScalarQuantizer::QT_8bit_direct:
|
|
1570
|
+
if (sq->d % 16 == 0) {
|
|
1571
|
+
return sel2_InvertedListScanner
|
|
1572
|
+
<DistanceComputerByte<Similarity, SIMDWIDTH> >
|
|
1573
|
+
(sq, quantizer, store_pairs, r);
|
|
1574
|
+
} else {
|
|
1575
|
+
return sel2_InvertedListScanner
|
|
1576
|
+
<DCTemplate<Quantizer8bitDirect<SIMDWIDTH>,
|
|
1577
|
+
Similarity, SIMDWIDTH> >
|
|
1578
|
+
(sq, quantizer, store_pairs, r);
|
|
1579
|
+
}
|
|
1580
|
+
|
|
1581
|
+
}
|
|
1582
|
+
|
|
1583
|
+
FAISS_THROW_MSG ("unknown qtype");
|
|
1584
|
+
return nullptr;
|
|
1585
|
+
}
|
|
1586
|
+
|
|
1587
|
+
template<int SIMDWIDTH>
|
|
1588
|
+
InvertedListScanner* sel0_InvertedListScanner
|
|
1589
|
+
(MetricType mt, const ScalarQuantizer *sq,
|
|
1590
|
+
const Index *quantizer, bool store_pairs, bool by_residual)
|
|
1591
|
+
{
|
|
1592
|
+
if (mt == METRIC_L2) {
|
|
1593
|
+
return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH> >
|
|
1594
|
+
(sq, quantizer, store_pairs, by_residual);
|
|
1595
|
+
} else if (mt == METRIC_INNER_PRODUCT) {
|
|
1596
|
+
return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH> >
|
|
1597
|
+
(sq, quantizer, store_pairs, by_residual);
|
|
1598
|
+
} else {
|
|
1599
|
+
FAISS_THROW_MSG("unsupported metric type");
|
|
1600
|
+
}
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
|
|
1604
|
+
|
|
1605
|
+
} // anonymous namespace
|
|
1606
|
+
|
|
1607
|
+
|
|
1608
|
+
InvertedListScanner* ScalarQuantizer::select_InvertedListScanner
|
|
1609
|
+
(MetricType mt, const Index *quantizer,
|
|
1610
|
+
bool store_pairs, bool by_residual) const
|
|
1611
|
+
{
|
|
1612
|
+
#ifdef USE_F16C
|
|
1613
|
+
if (d % 8 == 0) {
|
|
1614
|
+
return sel0_InvertedListScanner<8>
|
|
1615
|
+
(mt, this, quantizer, store_pairs, by_residual);
|
|
1616
|
+
} else
|
|
1617
|
+
#endif
|
|
1618
|
+
{
|
|
1619
|
+
return sel0_InvertedListScanner<1>
|
|
1620
|
+
(mt, this, quantizer, store_pairs, by_residual);
|
|
1621
|
+
}
|
|
1622
|
+
}
|
|
1623
|
+
|
|
1624
|
+
|
|
1625
|
+
|
|
1626
|
+
|
|
1627
|
+
|
|
1628
|
+
} // namespace faiss
|