faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#ifdef COMPILE_SIMD_ARM_NEON
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/scalar_quantizer/codecs.h>
|
|
11
|
+
#include <faiss/impl/scalar_quantizer/distance_computers.h>
|
|
12
|
+
#include <faiss/impl/scalar_quantizer/quantizers.h>
|
|
13
|
+
#include <faiss/impl/scalar_quantizer/scanners.h>
|
|
14
|
+
#include <faiss/impl/scalar_quantizer/similarities.h>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
|
|
18
|
+
namespace scalar_quantizer {
|
|
19
|
+
|
|
20
|
+
/**********************************************************
|
|
21
|
+
* Codecs
|
|
22
|
+
**********************************************************/
|
|
23
|
+
|
|
24
|
+
template <>
|
|
25
|
+
struct Codec8bit<SIMDLevel::ARM_NEON> : Codec8bit<SIMDLevel::NONE> {
|
|
26
|
+
static FAISS_ALWAYS_INLINE simd8float32
|
|
27
|
+
decode_8_components(const uint8_t* code, size_t i) {
|
|
28
|
+
float32_t result[8] = {};
|
|
29
|
+
for (size_t j = 0; j < 8; j++) {
|
|
30
|
+
result[j] =
|
|
31
|
+
Codec8bit<SIMDLevel::NONE>::decode_component(code, i + j);
|
|
32
|
+
}
|
|
33
|
+
float32x4_t res1 = vld1q_f32(result);
|
|
34
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
|
35
|
+
return simd8float32(float32x4x2_t{res1, res2});
|
|
36
|
+
}
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
template <>
|
|
40
|
+
struct Codec4bit<SIMDLevel::ARM_NEON> : Codec4bit<SIMDLevel::NONE> {
|
|
41
|
+
static FAISS_ALWAYS_INLINE simd8float32
|
|
42
|
+
decode_8_components(const uint8_t* code, size_t i) {
|
|
43
|
+
float32_t result[8] = {};
|
|
44
|
+
for (size_t j = 0; j < 8; j++) {
|
|
45
|
+
result[j] =
|
|
46
|
+
Codec4bit<SIMDLevel::NONE>::decode_component(code, i + j);
|
|
47
|
+
}
|
|
48
|
+
float32x4_t res1 = vld1q_f32(result);
|
|
49
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
|
50
|
+
return simd8float32(float32x4x2_t{res1, res2});
|
|
51
|
+
}
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
template <>
|
|
55
|
+
struct Codec6bit<SIMDLevel::ARM_NEON> : Codec6bit<SIMDLevel::NONE> {
|
|
56
|
+
static FAISS_ALWAYS_INLINE simd8float32
|
|
57
|
+
decode_8_components(const uint8_t* code, size_t i) {
|
|
58
|
+
float32_t result[8] = {};
|
|
59
|
+
for (size_t j = 0; j < 8; j++) {
|
|
60
|
+
result[j] =
|
|
61
|
+
Codec6bit<SIMDLevel::NONE>::decode_component(code, i + j);
|
|
62
|
+
}
|
|
63
|
+
float32x4_t res1 = vld1q_f32(result);
|
|
64
|
+
float32x4_t res2 = vld1q_f32(result + 4);
|
|
65
|
+
return simd8float32(float32x4x2_t{res1, res2});
|
|
66
|
+
}
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
/**********************************************************
|
|
70
|
+
* Quantizers (uniform and non-uniform)
|
|
71
|
+
**********************************************************/
|
|
72
|
+
|
|
73
|
+
template <class Codec>
|
|
74
|
+
struct QuantizerTemplate<
|
|
75
|
+
Codec,
|
|
76
|
+
scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
|
|
77
|
+
SIMDLevel::ARM_NEON>
|
|
78
|
+
: QuantizerTemplate<
|
|
79
|
+
Codec,
|
|
80
|
+
scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
|
|
81
|
+
SIMDLevel::NONE> {
|
|
82
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
83
|
+
: QuantizerTemplate<
|
|
84
|
+
Codec,
|
|
85
|
+
scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
|
|
86
|
+
SIMDLevel::NONE>(d, trained) {
|
|
87
|
+
assert(d % 8 == 0);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
91
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
92
|
+
simd8float32 xi = Codec::decode_8_components(code, i);
|
|
93
|
+
return simd8float32(
|
|
94
|
+
float32x4x2_t{
|
|
95
|
+
vfmaq_n_f32(
|
|
96
|
+
vdupq_n_f32(this->vmin),
|
|
97
|
+
xi.data.val[0],
|
|
98
|
+
this->vdiff),
|
|
99
|
+
vfmaq_n_f32(
|
|
100
|
+
vdupq_n_f32(this->vmin),
|
|
101
|
+
xi.data.val[1],
|
|
102
|
+
this->vdiff)});
|
|
103
|
+
}
|
|
104
|
+
};
|
|
105
|
+
|
|
106
|
+
template <class Codec>
|
|
107
|
+
struct QuantizerTemplate<
|
|
108
|
+
Codec,
|
|
109
|
+
scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
|
|
110
|
+
SIMDLevel::ARM_NEON>
|
|
111
|
+
: QuantizerTemplate<
|
|
112
|
+
Codec,
|
|
113
|
+
scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
|
|
114
|
+
SIMDLevel::NONE> {
|
|
115
|
+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
|
|
116
|
+
: QuantizerTemplate<
|
|
117
|
+
Codec,
|
|
118
|
+
scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
|
|
119
|
+
SIMDLevel::NONE>(d, trained) {
|
|
120
|
+
assert(d % 8 == 0);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
124
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
125
|
+
simd8float32 xi = Codec::decode_8_components(code, i);
|
|
126
|
+
return simd8float32(
|
|
127
|
+
float32x4x2_t{
|
|
128
|
+
vfmaq_f32(
|
|
129
|
+
vld1q_f32(this->vmin + i),
|
|
130
|
+
xi.data.val[0],
|
|
131
|
+
vld1q_f32(this->vdiff + i)),
|
|
132
|
+
vfmaq_f32(
|
|
133
|
+
vld1q_f32(this->vmin + i + 4),
|
|
134
|
+
xi.data.val[1],
|
|
135
|
+
vld1q_f32(this->vdiff + i + 4))});
|
|
136
|
+
}
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
/**********************************************************
|
|
140
|
+
* FP16 Quantizer
|
|
141
|
+
**********************************************************/
|
|
142
|
+
|
|
143
|
+
template <>
|
|
144
|
+
struct QuantizerFP16<SIMDLevel::ARM_NEON> : QuantizerFP16<SIMDLevel::NONE> {
|
|
145
|
+
QuantizerFP16(size_t d, const std::vector<float>& trained)
|
|
146
|
+
: QuantizerFP16<SIMDLevel::NONE>(d, trained) {
|
|
147
|
+
assert(d % 8 == 0);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
151
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
152
|
+
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
|
|
153
|
+
return simd8float32(
|
|
154
|
+
float32x4x2_t{
|
|
155
|
+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
|
|
156
|
+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))});
|
|
157
|
+
}
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
/**********************************************************
|
|
161
|
+
* BF16 Quantizer
|
|
162
|
+
**********************************************************/
|
|
163
|
+
|
|
164
|
+
template <>
|
|
165
|
+
struct QuantizerBF16<SIMDLevel::ARM_NEON> : QuantizerBF16<SIMDLevel::NONE> {
|
|
166
|
+
QuantizerBF16(size_t d, const std::vector<float>& trained)
|
|
167
|
+
: QuantizerBF16<SIMDLevel::NONE>(d, trained) {
|
|
168
|
+
assert(d % 8 == 0);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
172
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
173
|
+
uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
|
|
174
|
+
return simd8float32(
|
|
175
|
+
float32x4x2_t{
|
|
176
|
+
vreinterpretq_f32_u32(
|
|
177
|
+
vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
|
|
178
|
+
vreinterpretq_f32_u32(
|
|
179
|
+
vshlq_n_u32(vmovl_u16(codei.val[1]), 16))});
|
|
180
|
+
}
|
|
181
|
+
};
|
|
182
|
+
|
|
183
|
+
/**********************************************************
|
|
184
|
+
* 8bit Direct Quantizer
|
|
185
|
+
**********************************************************/
|
|
186
|
+
|
|
187
|
+
template <>
|
|
188
|
+
struct Quantizer8bitDirect<SIMDLevel::ARM_NEON>
|
|
189
|
+
: Quantizer8bitDirect<SIMDLevel::NONE> {
|
|
190
|
+
Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
|
|
191
|
+
: Quantizer8bitDirect<SIMDLevel::NONE>(d, trained) {
|
|
192
|
+
assert(d % 8 == 0);
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
196
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
197
|
+
uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
|
|
198
|
+
uint16x8_t y8 = vmovl_u8(x8);
|
|
199
|
+
uint16x4_t y8_0 = vget_low_u16(y8);
|
|
200
|
+
uint16x4_t y8_1 = vget_high_u16(y8);
|
|
201
|
+
return simd8float32(
|
|
202
|
+
float32x4x2_t{
|
|
203
|
+
vcvtq_f32_u32(vmovl_u16(y8_0)),
|
|
204
|
+
vcvtq_f32_u32(vmovl_u16(y8_1))});
|
|
205
|
+
}
|
|
206
|
+
};
|
|
207
|
+
|
|
208
|
+
/**********************************************************
|
|
209
|
+
* 8bit Direct Signed Quantizer
|
|
210
|
+
**********************************************************/
|
|
211
|
+
|
|
212
|
+
template <>
|
|
213
|
+
struct Quantizer8bitDirectSigned<SIMDLevel::ARM_NEON>
|
|
214
|
+
: Quantizer8bitDirectSigned<SIMDLevel::NONE> {
|
|
215
|
+
Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
|
|
216
|
+
: Quantizer8bitDirectSigned<SIMDLevel::NONE>(d, trained) {
|
|
217
|
+
assert(d % 8 == 0);
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
FAISS_ALWAYS_INLINE simd8float32
|
|
221
|
+
reconstruct_8_components(const uint8_t* code, int i) const {
|
|
222
|
+
uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
|
|
223
|
+
uint16x8_t y8 = vmovl_u8(x8);
|
|
224
|
+
int16x8_t z8 = vreinterpretq_s16_u16(
|
|
225
|
+
vsubq_u16(y8, vdupq_n_u16(128))); // subtract 128 from all lanes
|
|
226
|
+
int16x4_t z8_0 = vget_low_s16(z8);
|
|
227
|
+
int16x4_t z8_1 = vget_high_s16(z8);
|
|
228
|
+
return simd8float32(
|
|
229
|
+
float32x4x2_t{
|
|
230
|
+
vcvtq_f32_s32(vmovl_s16(z8_0)),
|
|
231
|
+
vcvtq_f32_s32(vmovl_s16(z8_1))});
|
|
232
|
+
}
|
|
233
|
+
};
|
|
234
|
+
|
|
235
|
+
/**********************************************************
|
|
236
|
+
* Similarities (L2 and IP)
|
|
237
|
+
**********************************************************/
|
|
238
|
+
|
|
239
|
+
template <>
|
|
240
|
+
struct SimilarityL2<SIMDLevel::ARM_NEON> {
|
|
241
|
+
static constexpr int simdwidth = 8;
|
|
242
|
+
static constexpr SIMDLevel simd_level = SIMDLevel::ARM_NEON;
|
|
243
|
+
static constexpr MetricType metric_type = METRIC_L2;
|
|
244
|
+
|
|
245
|
+
const float *y, *yi;
|
|
246
|
+
|
|
247
|
+
explicit SimilarityL2(const float* y) : y(y), yi(nullptr) {}
|
|
248
|
+
|
|
249
|
+
simd8float32 accu8;
|
|
250
|
+
|
|
251
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
|
252
|
+
accu8.clear();
|
|
253
|
+
yi = y;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) {
|
|
257
|
+
simd8float32 yiv(yi);
|
|
258
|
+
yi += 8;
|
|
259
|
+
simd8float32 tmp = yiv - x;
|
|
260
|
+
accu8 = accu8 + tmp * tmp;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(
|
|
264
|
+
simd8float32 x,
|
|
265
|
+
simd8float32 y_2) {
|
|
266
|
+
simd8float32 tmp = y_2 - x;
|
|
267
|
+
accu8 = accu8 + tmp * tmp;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
|
271
|
+
return horizontal_add(accu8);
|
|
272
|
+
}
|
|
273
|
+
};
|
|
274
|
+
|
|
275
|
+
template <>
|
|
276
|
+
struct SimilarityIP<SIMDLevel::ARM_NEON> {
|
|
277
|
+
static constexpr int simdwidth = 8;
|
|
278
|
+
static constexpr SIMDLevel simd_level = SIMDLevel::ARM_NEON;
|
|
279
|
+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
|
280
|
+
|
|
281
|
+
const float *y, *yi;
|
|
282
|
+
|
|
283
|
+
explicit SimilarityIP(const float* y) : y(y), yi(nullptr) {}
|
|
284
|
+
|
|
285
|
+
simd8float32 accu8;
|
|
286
|
+
|
|
287
|
+
FAISS_ALWAYS_INLINE void begin_8() {
|
|
288
|
+
accu8.clear();
|
|
289
|
+
yi = y;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
FAISS_ALWAYS_INLINE void add_8_components(simd8float32 x) {
|
|
293
|
+
simd8float32 yiv(yi);
|
|
294
|
+
yi += 8;
|
|
295
|
+
accu8 = accu8 + yiv * x;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
FAISS_ALWAYS_INLINE void add_8_components_2(
|
|
299
|
+
simd8float32 x1,
|
|
300
|
+
simd8float32 x2) {
|
|
301
|
+
accu8 = accu8 + x1 * x2;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
FAISS_ALWAYS_INLINE float result_8() {
|
|
305
|
+
return horizontal_add(accu8);
|
|
306
|
+
}
|
|
307
|
+
};
|
|
308
|
+
|
|
309
|
+
/**********************************************************
|
|
310
|
+
* Distance Computers
|
|
311
|
+
**********************************************************/
|
|
312
|
+
|
|
313
|
+
template <class Quantizer, class Similarity>
|
|
314
|
+
struct DCTemplate<Quantizer, Similarity, SIMDLevel::ARM_NEON>
|
|
315
|
+
: SQDistanceComputer {
|
|
316
|
+
using Sim = Similarity;
|
|
317
|
+
|
|
318
|
+
Quantizer quant;
|
|
319
|
+
|
|
320
|
+
DCTemplate(size_t d, const std::vector<float>& trained)
|
|
321
|
+
: quant(d, trained) {}
|
|
322
|
+
|
|
323
|
+
float compute_distance(const float* x, const uint8_t* code) const {
|
|
324
|
+
Similarity sim(x);
|
|
325
|
+
sim.begin_8();
|
|
326
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
|
327
|
+
simd8float32 xi = quant.reconstruct_8_components(code, i);
|
|
328
|
+
sim.add_8_components(xi);
|
|
329
|
+
}
|
|
330
|
+
return sim.result_8();
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
334
|
+
const {
|
|
335
|
+
Similarity sim(nullptr);
|
|
336
|
+
sim.begin_8();
|
|
337
|
+
for (size_t i = 0; i < quant.d; i += 8) {
|
|
338
|
+
simd8float32 x1 = quant.reconstruct_8_components(code1, i);
|
|
339
|
+
simd8float32 x2 = quant.reconstruct_8_components(code2, i);
|
|
340
|
+
sim.add_8_components_2(x1, x2);
|
|
341
|
+
}
|
|
342
|
+
return sim.result_8();
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
void set_query(const float* x) final {
|
|
346
|
+
q = x;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
350
|
+
return compute_code_distance(
|
|
351
|
+
codes + i * code_size, codes + j * code_size);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
float query_to_code(const uint8_t* code) const final {
|
|
355
|
+
return compute_distance(q, code);
|
|
356
|
+
}
|
|
357
|
+
};
|
|
358
|
+
|
|
359
|
+
template <class Similarity>
|
|
360
|
+
struct DistanceComputerByte<Similarity, SIMDLevel::ARM_NEON>
|
|
361
|
+
: SQDistanceComputer {
|
|
362
|
+
using Sim = Similarity;
|
|
363
|
+
|
|
364
|
+
int d;
|
|
365
|
+
std::vector<uint8_t> tmp;
|
|
366
|
+
|
|
367
|
+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
|
|
368
|
+
|
|
369
|
+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
|
|
370
|
+
const {
|
|
371
|
+
int accu = 0;
|
|
372
|
+
for (int i = 0; i < d; i++) {
|
|
373
|
+
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
374
|
+
accu += int(code1[i]) * code2[i];
|
|
375
|
+
} else {
|
|
376
|
+
int diff = int(code1[i]) - code2[i];
|
|
377
|
+
accu += diff * diff;
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
return accu;
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
void set_query(const float* x) final {
|
|
384
|
+
for (int i = 0; i < d; i++) {
|
|
385
|
+
tmp[i] = int(x[i]);
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
int compute_distance(const float* x, const uint8_t* code) {
|
|
390
|
+
set_query(x);
|
|
391
|
+
return compute_code_distance(tmp.data(), code);
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
float symmetric_dis(idx_t i, idx_t j) override {
|
|
395
|
+
return compute_code_distance(
|
|
396
|
+
codes + i * code_size, codes + j * code_size);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
float query_to_code(const uint8_t* code) const final {
|
|
400
|
+
return compute_code_distance(tmp.data(), code);
|
|
401
|
+
}
|
|
402
|
+
};
|
|
403
|
+
|
|
404
|
+
} // namespace scalar_quantizer
|
|
405
|
+
} // namespace faiss
|
|
406
|
+
|
|
407
|
+
#define THE_LEVEL_TO_DISPATCH SIMDLevel::ARM_NEON
|
|
408
|
+
#include <faiss/impl/scalar_quantizer/sq-dispatch.h>
|
|
409
|
+
|
|
410
|
+
#ifdef COMPILE_SIMD_ARM_SVE
|
|
411
|
+
|
|
412
|
+
// ARM_SVE: SVE is a superset of NEON. Forward to the NEON implementation
|
|
413
|
+
// until a dedicated SVE specialization is written.
|
|
414
|
+
|
|
415
|
+
namespace faiss {
|
|
416
|
+
namespace scalar_quantizer {
|
|
417
|
+
|
|
418
|
+
// NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
|
|
419
|
+
template <>
|
|
420
|
+
ScalarQuantizer::SQuantizer* sq_select_quantizer<SIMDLevel::ARM_SVE>(
|
|
421
|
+
QuantizerType qtype,
|
|
422
|
+
size_t d,
|
|
423
|
+
const std::vector<float>& trained) {
|
|
424
|
+
return sq_select_quantizer<SIMDLevel::ARM_NEON>(qtype, d, trained);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
|
|
428
|
+
template <>
|
|
429
|
+
SQDistanceComputer* sq_select_distance_computer<SIMDLevel::ARM_SVE>(
|
|
430
|
+
MetricType metric,
|
|
431
|
+
ScalarQuantizer::QuantizerType qtype,
|
|
432
|
+
size_t d,
|
|
433
|
+
const std::vector<float>& trained) {
|
|
434
|
+
return sq_select_distance_computer<SIMDLevel::ARM_NEON>(
|
|
435
|
+
metric, qtype, d, trained);
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
// NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
|
|
439
|
+
template <>
|
|
440
|
+
InvertedListScanner* sq_select_InvertedListScanner<SIMDLevel::ARM_SVE>(
|
|
441
|
+
QuantizerType qtype,
|
|
442
|
+
MetricType mt,
|
|
443
|
+
size_t d,
|
|
444
|
+
size_t code_size,
|
|
445
|
+
const std::vector<float>& trained,
|
|
446
|
+
const Index* quantizer,
|
|
447
|
+
bool store_pairs,
|
|
448
|
+
const IDSelector* sel,
|
|
449
|
+
bool by_residual) {
|
|
450
|
+
return sq_select_InvertedListScanner<SIMDLevel::ARM_NEON>(
|
|
451
|
+
qtype,
|
|
452
|
+
mt,
|
|
453
|
+
d,
|
|
454
|
+
code_size,
|
|
455
|
+
trained,
|
|
456
|
+
quantizer,
|
|
457
|
+
store_pairs,
|
|
458
|
+
sel,
|
|
459
|
+
by_residual);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
} // namespace scalar_quantizer
|
|
463
|
+
} // namespace faiss
|
|
464
|
+
|
|
465
|
+
#endif // COMPILE_SIMD_ARM_SVE
|
|
466
|
+
|
|
467
|
+
#endif // COMPILE_SIMD_ARM_NEON
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/scalar_quantizer/training.h>
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <algorithm>
|
|
12
|
+
#include <cmath>
|
|
13
|
+
|
|
14
|
+
namespace faiss {
|
|
15
|
+
|
|
16
|
+
namespace scalar_quantizer {
|
|
17
|
+
/*******************************************************************
|
|
18
|
+
* Quantizer range training
|
|
19
|
+
*/
|
|
20
|
+
|
|
21
|
+
static float sqr(float x) {
|
|
22
|
+
return x * x;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
void train_Uniform(
|
|
26
|
+
RangeStat rs,
|
|
27
|
+
float rs_arg,
|
|
28
|
+
idx_t n,
|
|
29
|
+
int k,
|
|
30
|
+
const float* x,
|
|
31
|
+
std::vector<float>& trained) {
|
|
32
|
+
FAISS_THROW_IF_NOT(n > 0);
|
|
33
|
+
trained.resize(2);
|
|
34
|
+
float& vmin = trained[0];
|
|
35
|
+
float& vmax = trained[1];
|
|
36
|
+
|
|
37
|
+
if (rs == ScalarQuantizer::RS_minmax) {
|
|
38
|
+
vmin = HUGE_VAL;
|
|
39
|
+
vmax = -HUGE_VAL;
|
|
40
|
+
for (size_t i = 0; i < n; i++) {
|
|
41
|
+
if (x[i] < vmin) {
|
|
42
|
+
vmin = x[i];
|
|
43
|
+
}
|
|
44
|
+
if (x[i] > vmax) {
|
|
45
|
+
vmax = x[i];
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
float vexp = (vmax - vmin) * rs_arg;
|
|
49
|
+
vmin -= vexp;
|
|
50
|
+
vmax += vexp;
|
|
51
|
+
} else if (rs == ScalarQuantizer::RS_meanstd) {
|
|
52
|
+
double sum = 0, sum2 = 0;
|
|
53
|
+
for (size_t i = 0; i < n; i++) {
|
|
54
|
+
sum += x[i];
|
|
55
|
+
sum2 += x[i] * x[i];
|
|
56
|
+
}
|
|
57
|
+
float mean = sum / n;
|
|
58
|
+
float var = sum2 / n - mean * mean;
|
|
59
|
+
float std = var <= 0 ? 1.0 : std::sqrt(var);
|
|
60
|
+
|
|
61
|
+
vmin = mean - std * rs_arg;
|
|
62
|
+
vmax = mean + std * rs_arg;
|
|
63
|
+
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
|
64
|
+
std::vector<float> x_copy(n);
|
|
65
|
+
memcpy(x_copy.data(), x, n * sizeof(*x));
|
|
66
|
+
idx_t o = static_cast<idx_t>(rs_arg * n);
|
|
67
|
+
if (o < 0) {
|
|
68
|
+
o = 0;
|
|
69
|
+
}
|
|
70
|
+
if (o > n - o) {
|
|
71
|
+
o = n / 2;
|
|
72
|
+
}
|
|
73
|
+
std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
|
|
74
|
+
vmin = x_copy[o];
|
|
75
|
+
std::nth_element(
|
|
76
|
+
x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
|
|
77
|
+
vmax = x_copy[n - 1 - o];
|
|
78
|
+
|
|
79
|
+
} else if (rs == ScalarQuantizer::RS_optim) {
|
|
80
|
+
float a, b;
|
|
81
|
+
float sx = 0;
|
|
82
|
+
{
|
|
83
|
+
vmin = HUGE_VAL, vmax = -HUGE_VAL;
|
|
84
|
+
for (size_t i = 0; i < n; i++) {
|
|
85
|
+
if (x[i] < vmin) {
|
|
86
|
+
vmin = x[i];
|
|
87
|
+
}
|
|
88
|
+
if (x[i] > vmax) {
|
|
89
|
+
vmax = x[i];
|
|
90
|
+
}
|
|
91
|
+
sx += x[i];
|
|
92
|
+
}
|
|
93
|
+
b = vmin;
|
|
94
|
+
a = (vmax - vmin) / (k - 1);
|
|
95
|
+
}
|
|
96
|
+
int verbose = false;
|
|
97
|
+
int niter = 2000;
|
|
98
|
+
float last_err = -1;
|
|
99
|
+
int iter_last_err = 0;
|
|
100
|
+
for (int it = 0; it < niter; it++) {
|
|
101
|
+
float sn = 0, sn2 = 0, sxn = 0, err1 = 0;
|
|
102
|
+
|
|
103
|
+
for (idx_t i = 0; i < n; i++) {
|
|
104
|
+
float xi = x[i];
|
|
105
|
+
float ni = floor((xi - b) / a + 0.5);
|
|
106
|
+
if (ni < 0) {
|
|
107
|
+
ni = 0;
|
|
108
|
+
}
|
|
109
|
+
if (ni >= k) {
|
|
110
|
+
ni = k - 1;
|
|
111
|
+
}
|
|
112
|
+
err1 += sqr(xi - (ni * a + b));
|
|
113
|
+
sn += ni;
|
|
114
|
+
sn2 += ni * ni;
|
|
115
|
+
sxn += ni * xi;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if (err1 == last_err) {
|
|
119
|
+
iter_last_err++;
|
|
120
|
+
if (iter_last_err == 16) {
|
|
121
|
+
break;
|
|
122
|
+
}
|
|
123
|
+
} else {
|
|
124
|
+
last_err = err1;
|
|
125
|
+
iter_last_err = 0;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
float det = sqr(sn) - sn2 * n;
|
|
129
|
+
|
|
130
|
+
b = (sn * sxn - sn2 * sx) / det;
|
|
131
|
+
a = (sn * sx - n * sxn) / det;
|
|
132
|
+
if (verbose) {
|
|
133
|
+
printf("it %d, err1=%g \r", it, err1);
|
|
134
|
+
fflush(stdout);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
if (verbose) {
|
|
138
|
+
printf("\n");
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
vmin = b;
|
|
142
|
+
vmax = b + a * (k - 1);
|
|
143
|
+
|
|
144
|
+
} else {
|
|
145
|
+
FAISS_THROW_MSG("Invalid qtype");
|
|
146
|
+
}
|
|
147
|
+
vmax -= vmin;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
void train_NonUniform(
|
|
151
|
+
RangeStat rs,
|
|
152
|
+
float rs_arg,
|
|
153
|
+
idx_t n,
|
|
154
|
+
int d,
|
|
155
|
+
int k,
|
|
156
|
+
const float* x,
|
|
157
|
+
std::vector<float>& trained) {
|
|
158
|
+
trained.resize(static_cast<size_t>(2) * d);
|
|
159
|
+
float* vmin = trained.data();
|
|
160
|
+
float* vmax = trained.data() + d;
|
|
161
|
+
if (rs == ScalarQuantizer::RS_minmax) {
|
|
162
|
+
memcpy(vmin, x, sizeof(*x) * d);
|
|
163
|
+
memcpy(vmax, x, sizeof(*x) * d);
|
|
164
|
+
for (size_t i = 1; i < n; i++) {
|
|
165
|
+
const float* xi = x + i * d;
|
|
166
|
+
for (size_t j = 0; j < d; j++) {
|
|
167
|
+
if (xi[j] < vmin[j]) {
|
|
168
|
+
vmin[j] = xi[j];
|
|
169
|
+
}
|
|
170
|
+
if (xi[j] > vmax[j]) {
|
|
171
|
+
vmax[j] = xi[j];
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
float* vdiff = vmax;
|
|
176
|
+
for (size_t j = 0; j < d; j++) {
|
|
177
|
+
float vexp = (vmax[j] - vmin[j]) * rs_arg;
|
|
178
|
+
vmin[j] -= vexp;
|
|
179
|
+
vmax[j] += vexp;
|
|
180
|
+
vdiff[j] = vmax[j] - vmin[j];
|
|
181
|
+
}
|
|
182
|
+
} else {
|
|
183
|
+
// transpose
|
|
184
|
+
std::vector<float> xt(n * d);
|
|
185
|
+
for (size_t i = 1; i < n; i++) {
|
|
186
|
+
const float* xi = x + i * d;
|
|
187
|
+
for (size_t j = 0; j < d; j++) {
|
|
188
|
+
xt[j * n + i] = xi[j];
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
std::vector<float> trained_d(2);
|
|
192
|
+
#pragma omp parallel for
|
|
193
|
+
for (int j = 0; j < d; j++) {
|
|
194
|
+
train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d);
|
|
195
|
+
vmin[j] = trained_d[0];
|
|
196
|
+
vmax[j] = trained_d[1];
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
} // namespace scalar_quantizer
|
|
202
|
+
|
|
203
|
+
} // namespace faiss
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
/*******************************************************************
|
|
11
|
+
* Quantizer range training for the scalar quantizer. This is independent of the
|
|
12
|
+
* searching code and needs not to be very optimized (scalar quantizer training
|
|
13
|
+
* is very efficient).
|
|
14
|
+
*/
|
|
15
|
+
|
|
16
|
+
#include <faiss/impl/ScalarQuantizer.h>
|
|
17
|
+
|
|
18
|
+
namespace faiss {
|
|
19
|
+
|
|
20
|
+
namespace scalar_quantizer {
|
|
21
|
+
|
|
22
|
+
using RangeStat = ScalarQuantizer::RangeStat;
|
|
23
|
+
|
|
24
|
+
void train_Uniform(
|
|
25
|
+
RangeStat rs,
|
|
26
|
+
float rs_arg,
|
|
27
|
+
idx_t n,
|
|
28
|
+
int k,
|
|
29
|
+
const float* x,
|
|
30
|
+
std::vector<float>& trained);
|
|
31
|
+
|
|
32
|
+
void train_NonUniform(
|
|
33
|
+
RangeStat rs,
|
|
34
|
+
float rs_arg,
|
|
35
|
+
idx_t n,
|
|
36
|
+
int d,
|
|
37
|
+
int k,
|
|
38
|
+
const float* x,
|
|
39
|
+
std::vector<float>& trained);
|
|
40
|
+
} // namespace scalar_quantizer
|
|
41
|
+
|
|
42
|
+
} // namespace faiss
|