faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <cstddef>
|
|
11
|
+
#include <cstdint>
|
|
12
|
+
#include <type_traits>
|
|
13
|
+
|
|
14
|
+
#include <faiss/impl/ProductQuantizer.h>
|
|
15
|
+
#include <faiss/impl/platform_macros.h>
|
|
16
|
+
#include <faiss/utils/simd_levels.h>
|
|
17
|
+
|
|
18
|
+
namespace faiss {
|
|
19
|
+
namespace pq_code_distance {
|
|
20
|
+
|
|
21
|
+
/*********************************************************************
|
|
22
|
+
* PQCodeDistance — SIMD-dispatched PQ code distance
|
|
23
|
+
*
|
|
24
|
+
* Computes the distance from a PQ-encoded vector to a query vector,
|
|
25
|
+
* given a precomputed table of sub-distances (one per subquantizer
|
|
26
|
+
* per centroid). Originally extracted from IndexIVFPQ.cpp.
|
|
27
|
+
*
|
|
28
|
+
* DESIGN:
|
|
29
|
+
*
|
|
30
|
+
* PQCodeDistance<PQDecoderT, SL> computes PQ code distances at a given
|
|
31
|
+
* SIMD level. The dispatch site (IndexIVFPQ.cpp, IndexPQ.cpp) uses
|
|
32
|
+
* DISPATCH_SIMDLevel to select SL at runtime, which instantiates
|
|
33
|
+
* PQCodeDistance for ALL decoder types (PQDecoder8, PQDecoder16,
|
|
34
|
+
* PQDecoderGeneric) at the chosen level.
|
|
35
|
+
*
|
|
36
|
+
* Only PQDecoder8 has SIMD-optimized implementations (AVX2, AVX512,
|
|
37
|
+
* ARM_SVE). The other decoders always use scalar code — their decode()
|
|
38
|
+
* method is inherently sequential, so SIMD doesn't help.
|
|
39
|
+
*
|
|
40
|
+
* The primary template is always complete (no forward declarations
|
|
41
|
+
* needed). For PQDecoder8, it delegates to _impl dispatch bridge
|
|
42
|
+
* functions whose specializations are defined in per-SIMD .cpp files
|
|
43
|
+
* and resolved at link time. For other decoders, it uses scalar.
|
|
44
|
+
*
|
|
45
|
+
* ADDING A NEW SIMD LEVEL:
|
|
46
|
+
*
|
|
47
|
+
* 1. Add the level to SIMDLevel enum (simd_levels.h)
|
|
48
|
+
* 2. Add dispatch_config entry (simd_dispatch.bzl)
|
|
49
|
+
* 3. Define pq_code_distance_single_impl<NEW_LEVEL> and
|
|
50
|
+
* pq_code_distance_four_impl<NEW_LEVEL> specializations in a
|
|
51
|
+
* new .cpp file compiled with appropriate SIMD flags
|
|
52
|
+
* 4. Add the .cpp to the build (CMakeLists.txt, xplat.bzl)
|
|
53
|
+
*********************************************************************/
|
|
54
|
+
|
|
55
|
+
/// Scalar PQ code distance implementation.
|
|
56
|
+
/// Templated only on decoder type, independent of SIMD level.
|
|
57
|
+
/// Used directly by non-PQDecoder8 decoders (PQDecoder16,
|
|
58
|
+
/// PQDecoderGeneric) and as fallback for PQDecoder8 at NONE/NEON.
|
|
59
|
+
template <typename PQDecoderT>
|
|
60
|
+
struct PQCodeDistanceScalar {
|
|
61
|
+
using PQDecoder = PQDecoderT;
|
|
62
|
+
|
|
63
|
+
static float distance_single_code(
|
|
64
|
+
// number of subquantizers
|
|
65
|
+
size_t M,
|
|
66
|
+
size_t nbits,
|
|
67
|
+
// precomputed distances, layout (M, ksub)
|
|
68
|
+
const float* sim_table,
|
|
69
|
+
const uint8_t* code) {
|
|
70
|
+
PQDecoderT decoder(code, nbits);
|
|
71
|
+
const size_t ksub = 1 << nbits;
|
|
72
|
+
|
|
73
|
+
const float* tab = sim_table;
|
|
74
|
+
float result = 0;
|
|
75
|
+
|
|
76
|
+
for (size_t m = 0; m < M; m++) {
|
|
77
|
+
result += tab[decoder.decode()];
|
|
78
|
+
tab += ksub;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
return result;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
static void distance_four_codes(
|
|
85
|
+
size_t M,
|
|
86
|
+
size_t nbits,
|
|
87
|
+
const float* sim_table,
|
|
88
|
+
const uint8_t* __restrict code0,
|
|
89
|
+
const uint8_t* __restrict code1,
|
|
90
|
+
const uint8_t* __restrict code2,
|
|
91
|
+
const uint8_t* __restrict code3,
|
|
92
|
+
float& result0,
|
|
93
|
+
float& result1,
|
|
94
|
+
float& result2,
|
|
95
|
+
float& result3) {
|
|
96
|
+
PQDecoderT decoder0(code0, nbits);
|
|
97
|
+
PQDecoderT decoder1(code1, nbits);
|
|
98
|
+
PQDecoderT decoder2(code2, nbits);
|
|
99
|
+
PQDecoderT decoder3(code3, nbits);
|
|
100
|
+
const size_t ksub = 1 << nbits;
|
|
101
|
+
|
|
102
|
+
const float* tab = sim_table;
|
|
103
|
+
result0 = 0;
|
|
104
|
+
result1 = 0;
|
|
105
|
+
result2 = 0;
|
|
106
|
+
result3 = 0;
|
|
107
|
+
|
|
108
|
+
for (size_t m = 0; m < M; m++) {
|
|
109
|
+
result0 += tab[decoder0.decode()];
|
|
110
|
+
result1 += tab[decoder1.decode()];
|
|
111
|
+
result2 += tab[decoder2.decode()];
|
|
112
|
+
result3 += tab[decoder3.decode()];
|
|
113
|
+
tab += ksub;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
/*********************************************************************
|
|
119
|
+
* Dispatch bridge — function templates for PQDecoder8 SIMD dispatch.
|
|
120
|
+
*
|
|
121
|
+
* Primary declarations only; specializations are defined in per-SIMD
|
|
122
|
+
* .cpp files (AVX2, AVX512, ARM_SVE) and pq_code_distance-generic.cpp
|
|
123
|
+
* (NONE, ARM_NEON). Same pattern as fvec_L2sqr et al. in distances.h.
|
|
124
|
+
*********************************************************************/
|
|
125
|
+
|
|
126
|
+
template <SIMDLevel SL>
|
|
127
|
+
float pq_code_distance_single_impl(
|
|
128
|
+
size_t M,
|
|
129
|
+
size_t nbits,
|
|
130
|
+
const float* sim_table,
|
|
131
|
+
const uint8_t* code);
|
|
132
|
+
|
|
133
|
+
template <SIMDLevel SL>
|
|
134
|
+
void pq_code_distance_four_impl(
|
|
135
|
+
size_t M,
|
|
136
|
+
size_t nbits,
|
|
137
|
+
const float* sim_table,
|
|
138
|
+
const uint8_t* __restrict code0,
|
|
139
|
+
const uint8_t* __restrict code1,
|
|
140
|
+
const uint8_t* __restrict code2,
|
|
141
|
+
const uint8_t* __restrict code3,
|
|
142
|
+
float& result0,
|
|
143
|
+
float& result1,
|
|
144
|
+
float& result2,
|
|
145
|
+
float& result3);
|
|
146
|
+
|
|
147
|
+
/// Primary template — always complete.
|
|
148
|
+
/// For PQDecoder8, delegates to _impl dispatch bridges (resolved at
|
|
149
|
+
/// link time to per-SIMD implementations). For other decoders, uses
|
|
150
|
+
/// scalar — their sequential decode() methods don't benefit from SIMD.
|
|
151
|
+
template <typename PQDecoderT, SIMDLevel SL>
|
|
152
|
+
struct PQCodeDistance {
|
|
153
|
+
using PQDecoder = PQDecoderT;
|
|
154
|
+
|
|
155
|
+
static float distance_single_code(
|
|
156
|
+
size_t M,
|
|
157
|
+
size_t nbits,
|
|
158
|
+
const float* sim_table,
|
|
159
|
+
const uint8_t* code) {
|
|
160
|
+
if constexpr (std::is_same_v<PQDecoderT, PQDecoder8>) {
|
|
161
|
+
return pq_code_distance_single_impl<SL>(M, nbits, sim_table, code);
|
|
162
|
+
} else {
|
|
163
|
+
return PQCodeDistanceScalar<PQDecoderT>::distance_single_code(
|
|
164
|
+
M, nbits, sim_table, code);
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
static void distance_four_codes(
|
|
169
|
+
size_t M,
|
|
170
|
+
size_t nbits,
|
|
171
|
+
const float* sim_table,
|
|
172
|
+
const uint8_t* __restrict code0,
|
|
173
|
+
const uint8_t* __restrict code1,
|
|
174
|
+
const uint8_t* __restrict code2,
|
|
175
|
+
const uint8_t* __restrict code3,
|
|
176
|
+
float& result0,
|
|
177
|
+
float& result1,
|
|
178
|
+
float& result2,
|
|
179
|
+
float& result3) {
|
|
180
|
+
if constexpr (std::is_same_v<PQDecoderT, PQDecoder8>) {
|
|
181
|
+
pq_code_distance_four_impl<SL>(
|
|
182
|
+
M,
|
|
183
|
+
nbits,
|
|
184
|
+
sim_table,
|
|
185
|
+
code0,
|
|
186
|
+
code1,
|
|
187
|
+
code2,
|
|
188
|
+
code3,
|
|
189
|
+
result0,
|
|
190
|
+
result1,
|
|
191
|
+
result2,
|
|
192
|
+
result3);
|
|
193
|
+
} else {
|
|
194
|
+
PQCodeDistanceScalar<PQDecoderT>::distance_four_codes(
|
|
195
|
+
M,
|
|
196
|
+
nbits,
|
|
197
|
+
sim_table,
|
|
198
|
+
code0,
|
|
199
|
+
code1,
|
|
200
|
+
code2,
|
|
201
|
+
code3,
|
|
202
|
+
result0,
|
|
203
|
+
result1,
|
|
204
|
+
result2,
|
|
205
|
+
result3);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
};
|
|
209
|
+
|
|
210
|
+
/*********************************************************************
|
|
211
|
+
* Non-templated PQ code distance dispatch (PQDecoder8 only).
|
|
212
|
+
*
|
|
213
|
+
* These follow the same pattern as distances.h: the caller does not
|
|
214
|
+
* name a SIMDLevel. Internally they dispatch via DISPATCH_SIMDLevel
|
|
215
|
+
* to the best available SIMD implementation (DD: runtime detection,
|
|
216
|
+
* static: compile-time selection). Definitions are in
|
|
217
|
+
* pq_code_distance-generic.cpp.
|
|
218
|
+
*********************************************************************/
|
|
219
|
+
|
|
220
|
+
/// Compute PQ distance for a single code, dispatching to the best
|
|
221
|
+
/// available SIMD level.
|
|
222
|
+
FAISS_API float pq_code_distance_single(
|
|
223
|
+
size_t M,
|
|
224
|
+
size_t nbits,
|
|
225
|
+
const float* sim_table,
|
|
226
|
+
const uint8_t* code);
|
|
227
|
+
|
|
228
|
+
/// Compute PQ distances for four codes simultaneously, dispatching
|
|
229
|
+
/// to the best available SIMD level.
|
|
230
|
+
FAISS_API void pq_code_distance_four(
|
|
231
|
+
size_t M,
|
|
232
|
+
size_t nbits,
|
|
233
|
+
const float* sim_table,
|
|
234
|
+
const uint8_t* __restrict code0,
|
|
235
|
+
const uint8_t* __restrict code1,
|
|
236
|
+
const uint8_t* __restrict code2,
|
|
237
|
+
const uint8_t* __restrict code3,
|
|
238
|
+
float& result0,
|
|
239
|
+
float& result1,
|
|
240
|
+
float& result2,
|
|
241
|
+
float& result3);
|
|
242
|
+
|
|
243
|
+
} // namespace pq_code_distance
|
|
244
|
+
|
|
245
|
+
// Re-export public API into namespace faiss for convenience
|
|
246
|
+
using pq_code_distance::pq_code_distance_four;
|
|
247
|
+
using pq_code_distance::pq_code_distance_single;
|
|
248
|
+
using pq_code_distance::PQCodeDistance;
|
|
249
|
+
using pq_code_distance::PQCodeDistanceScalar;
|
|
250
|
+
|
|
251
|
+
} // namespace faiss
|
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <cstddef>
|
|
11
11
|
#include <cstdint>
|
|
12
|
+
#include <cstring>
|
|
12
13
|
|
|
13
14
|
// Only include x86 SIMD intrinsics on x86/x86_64 architectures
|
|
14
15
|
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \
|
|
@@ -423,3 +424,262 @@ inline uint64_t popcount(const uint8_t* data, size_t size) {
|
|
|
423
424
|
}
|
|
424
425
|
|
|
425
426
|
} // namespace faiss::rabitq
|
|
427
|
+
|
|
428
|
+
/*********************************************************
|
|
429
|
+
* Multi-bit RaBitQ inner product kernels.
|
|
430
|
+
*
|
|
431
|
+
* Compute: sum_i rotated_q[i] * ((sign_bit_i << ex_bits) + ex_code_val_i + cb)
|
|
432
|
+
*
|
|
433
|
+
* Strategy:
|
|
434
|
+
* ex_bits == 1: Specialized kernel — both sign_bits and ex_code are
|
|
435
|
+
* 1-bit-per-dim packed, enabling direct bit→mask→float
|
|
436
|
+
* conversion with zero per-element extraction.
|
|
437
|
+
* ex_bits >= 2: Bit-plane decomposition (BMI2 required) — PEXT extracts
|
|
438
|
+
* each bit plane in one instruction, then the same
|
|
439
|
+
* bit→mask→float kernel computes each plane's dot product.
|
|
440
|
+
* Fallback: Scalar extraction via 64-bit window read + shift + mask.
|
|
441
|
+
*********************************************************/
|
|
442
|
+
namespace faiss::rabitq::multibit {
|
|
443
|
+
|
|
444
|
+
/// Scalar inner product for multi-bit RaBitQ.
|
|
445
|
+
/// Extracts each code value in O(1) via 64-bit window read + shift + mask.
|
|
446
|
+
/// Also serves as the tail handler for SIMD kernels via the @p start parameter.
|
|
447
|
+
inline float ip_scalar(
|
|
448
|
+
const uint8_t* __restrict sign_bits,
|
|
449
|
+
const uint8_t* __restrict ex_code,
|
|
450
|
+
const float* __restrict rotated_q,
|
|
451
|
+
size_t start,
|
|
452
|
+
size_t d,
|
|
453
|
+
size_t ex_bits,
|
|
454
|
+
float cb) {
|
|
455
|
+
float result = 0.0f;
|
|
456
|
+
const int sign_shift = static_cast<int>(ex_bits);
|
|
457
|
+
const uint64_t code_mask = (1ULL << ex_bits) - 1;
|
|
458
|
+
for (size_t i = start; i < d; i++) {
|
|
459
|
+
int sb = (sign_bits[i / 8] >> (i % 8)) & 1;
|
|
460
|
+
size_t bit_pos = i * ex_bits;
|
|
461
|
+
size_t byte_idx = bit_pos / 8;
|
|
462
|
+
size_t bit_offset = bit_pos % 8;
|
|
463
|
+
uint64_t raw = 0;
|
|
464
|
+
memcpy(&raw, ex_code + byte_idx, sizeof(uint64_t));
|
|
465
|
+
int ex_val = static_cast<int>((raw >> bit_offset) & code_mask);
|
|
466
|
+
result += rotated_q[i] *
|
|
467
|
+
(static_cast<float>((sb << sign_shift) + ex_val) + cb);
|
|
468
|
+
}
|
|
469
|
+
return result;
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
#if defined(__x86_64__) || defined(_M_X64)
|
|
473
|
+
|
|
474
|
+
#if defined(__AVX2__)
|
|
475
|
+
/// Horizontal sum of 8 floats in a __m256 register.
|
|
476
|
+
inline float hsum_avx2(__m256 v) {
|
|
477
|
+
__m128 hi = _mm256_extractf128_ps(v, 1);
|
|
478
|
+
__m128 lo = _mm256_castps256_ps128(v);
|
|
479
|
+
lo = _mm_add_ps(lo, hi);
|
|
480
|
+
__m128 shuf = _mm_movehdup_ps(lo);
|
|
481
|
+
lo = _mm_add_ps(lo, shuf);
|
|
482
|
+
shuf = _mm_movehl_ps(shuf, lo);
|
|
483
|
+
return _mm_cvtss_f32(_mm_add_ss(lo, shuf));
|
|
484
|
+
}
|
|
485
|
+
#endif // __AVX2__
|
|
486
|
+
|
|
487
|
+
/*********************************************************
|
|
488
|
+
* Specialized 1-bit kernels (ex_bits == 1).
|
|
489
|
+
*
|
|
490
|
+
* For 1 extra bit, both sign_bits and ex_code are 1-bit-per-dim packed,
|
|
491
|
+
* so we convert bits to floats directly — no extraction loops needed.
|
|
492
|
+
*********************************************************/
|
|
493
|
+
|
|
494
|
+
#if defined(__AVX512F__)
|
|
495
|
+
/// AVX-512: 16 dims/iter, ex_bits == 1.
|
|
496
|
+
inline float ip_1exbit_avx512(
|
|
497
|
+
const uint8_t* __restrict sign_bits,
|
|
498
|
+
const uint8_t* __restrict ex_code,
|
|
499
|
+
const float* __restrict rotated_q,
|
|
500
|
+
size_t d,
|
|
501
|
+
float cb) {
|
|
502
|
+
__m512 acc = _mm512_setzero_ps();
|
|
503
|
+
const __m512 v_cb = _mm512_set1_ps(cb);
|
|
504
|
+
const __m512 v_two = _mm512_set1_ps(2.0f);
|
|
505
|
+
const __m512 v_one = _mm512_set1_ps(1.0f);
|
|
506
|
+
|
|
507
|
+
size_t i = 0;
|
|
508
|
+
for (; i + 16 <= d; i += 16) {
|
|
509
|
+
uint16_t sb16;
|
|
510
|
+
memcpy(&sb16, sign_bits + i / 8, sizeof(uint16_t));
|
|
511
|
+
uint16_t eb16;
|
|
512
|
+
memcpy(&eb16, ex_code + i / 8, sizeof(uint16_t));
|
|
513
|
+
|
|
514
|
+
__m512 sb_f = _mm512_maskz_mov_ps(_cvtu32_mask16(sb16), v_one);
|
|
515
|
+
__m512 eb_f = _mm512_maskz_mov_ps(_cvtu32_mask16(eb16), v_one);
|
|
516
|
+
|
|
517
|
+
__m512 recon = _mm512_add_ps(_mm512_fmadd_ps(sb_f, v_two, eb_f), v_cb);
|
|
518
|
+
__m512 rq = _mm512_loadu_ps(rotated_q + i);
|
|
519
|
+
acc = _mm512_fmadd_ps(rq, recon, acc);
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
float result = _mm512_reduce_add_ps(acc);
|
|
523
|
+
result += ip_scalar(sign_bits, ex_code, rotated_q, i, d, 1, cb);
|
|
524
|
+
return result;
|
|
525
|
+
}
|
|
526
|
+
#endif // __AVX512F__
|
|
527
|
+
|
|
528
|
+
#if defined(__AVX2__)
|
|
529
|
+
/// AVX2: 8 dims/iter, ex_bits == 1.
|
|
530
|
+
inline float ip_1exbit_avx2(
|
|
531
|
+
const uint8_t* __restrict sign_bits,
|
|
532
|
+
const uint8_t* __restrict ex_code,
|
|
533
|
+
const float* __restrict rotated_q,
|
|
534
|
+
size_t d,
|
|
535
|
+
float cb) {
|
|
536
|
+
__m256 acc = _mm256_setzero_ps();
|
|
537
|
+
const __m256 v_cb = _mm256_set1_ps(cb);
|
|
538
|
+
const __m256 v_two = _mm256_set1_ps(2.0f);
|
|
539
|
+
const __m256 v_one = _mm256_set1_ps(1.0f);
|
|
540
|
+
const __m256i bit_pos = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
|
|
541
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
542
|
+
|
|
543
|
+
size_t i = 0;
|
|
544
|
+
for (; i + 8 <= d; i += 8) {
|
|
545
|
+
uint8_t sb = sign_bits[i / 8];
|
|
546
|
+
uint8_t eb = ex_code[i / 8];
|
|
547
|
+
|
|
548
|
+
__m256i sb_cmp = _mm256_cmpgt_epi32(
|
|
549
|
+
_mm256_and_si256(_mm256_set1_epi32(sb), bit_pos), zero);
|
|
550
|
+
__m256 sb_f = _mm256_and_ps(_mm256_castsi256_ps(sb_cmp), v_one);
|
|
551
|
+
|
|
552
|
+
__m256i eb_cmp = _mm256_cmpgt_epi32(
|
|
553
|
+
_mm256_and_si256(_mm256_set1_epi32(eb), bit_pos), zero);
|
|
554
|
+
__m256 eb_f = _mm256_and_ps(_mm256_castsi256_ps(eb_cmp), v_one);
|
|
555
|
+
|
|
556
|
+
__m256 recon = _mm256_add_ps(_mm256_fmadd_ps(sb_f, v_two, eb_f), v_cb);
|
|
557
|
+
__m256 rq = _mm256_loadu_ps(rotated_q + i);
|
|
558
|
+
acc = _mm256_fmadd_ps(rq, recon, acc);
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
float result = hsum_avx2(acc);
|
|
562
|
+
result += ip_scalar(sign_bits, ex_code, rotated_q, i, d, 1, cb);
|
|
563
|
+
return result;
|
|
564
|
+
}
|
|
565
|
+
#endif // __AVX2__
|
|
566
|
+
|
|
567
|
+
/*********************************************************
|
|
568
|
+
* Bit-plane decomposition kernels (ex_bits >= 2, BMI2 required).
|
|
569
|
+
*
|
|
570
|
+
* Decomposes the inner product as:
|
|
571
|
+
* ex_ip = (1 << ex_bits) * sign_dot
|
|
572
|
+
* + Σ_{b=0}^{ex_bits-1} (1 << b) * plane_dot_b
|
|
573
|
+
* + cb * total_q
|
|
574
|
+
*
|
|
575
|
+
* Each plane_dot_b is a float × bit-vector dot product, computed using
|
|
576
|
+
* the same bit→mask→float conversion as the 1-bit kernel. PEXT
|
|
577
|
+
* extracts each bit plane from the packed ex_code in one instruction
|
|
578
|
+
* per 8 dimensions.
|
|
579
|
+
*********************************************************/
|
|
580
|
+
|
|
581
|
+
#if defined(__AVX2__) && defined(__BMI2__)
|
|
582
|
+
/// AVX2 + BMI2 bit-plane decomposition: 8 dims/iter, ex_bits in [2, 7].
|
|
583
|
+
/// Caller must ensure ex_bits <= 7 (pext_masks[7] / v_weights[8]).
|
|
584
|
+
inline float ip_bitplane_avx2(
|
|
585
|
+
const uint8_t* __restrict sign_bits,
|
|
586
|
+
const uint8_t* __restrict ex_code,
|
|
587
|
+
const float* __restrict rotated_q,
|
|
588
|
+
size_t d,
|
|
589
|
+
size_t ex_bits,
|
|
590
|
+
float cb) {
|
|
591
|
+
__m256 acc = _mm256_setzero_ps();
|
|
592
|
+
const __m256 v_one = _mm256_set1_ps(1.0f);
|
|
593
|
+
const __m256i bit_pos = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
|
|
594
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
595
|
+
const __m256 v_cb = _mm256_set1_ps(cb);
|
|
596
|
+
|
|
597
|
+
// Precompute PEXT masks and plane weights
|
|
598
|
+
uint64_t pext_masks[7];
|
|
599
|
+
__m256 v_weights[8];
|
|
600
|
+
for (size_t b = 0; b < ex_bits; b++) {
|
|
601
|
+
uint64_t m = 0;
|
|
602
|
+
for (int j = 0; j < 8; j++) {
|
|
603
|
+
m |= (1ULL << (b + j * ex_bits));
|
|
604
|
+
}
|
|
605
|
+
pext_masks[b] = m;
|
|
606
|
+
v_weights[b] = _mm256_set1_ps(static_cast<float>(1u << b));
|
|
607
|
+
}
|
|
608
|
+
v_weights[ex_bits] = _mm256_set1_ps(static_cast<float>(1u << ex_bits));
|
|
609
|
+
|
|
610
|
+
size_t i = 0;
|
|
611
|
+
for (; i + 8 <= d; i += 8) {
|
|
612
|
+
// Sign bit → float via bit mask comparison
|
|
613
|
+
__m256i sb_cmp = _mm256_cmpgt_epi32(
|
|
614
|
+
_mm256_and_si256(_mm256_set1_epi32(sign_bits[i / 8]), bit_pos),
|
|
615
|
+
zero);
|
|
616
|
+
__m256 recon = _mm256_mul_ps(
|
|
617
|
+
_mm256_and_ps(_mm256_castsi256_ps(sb_cmp), v_one),
|
|
618
|
+
v_weights[ex_bits]);
|
|
619
|
+
|
|
620
|
+
// Load packed ex_code for 8 dims (8 × ex_bits bits = ex_bits bytes)
|
|
621
|
+
uint64_t ex64 = 0;
|
|
622
|
+
memcpy(&ex64, ex_code + (i / 8) * ex_bits, sizeof(uint64_t));
|
|
623
|
+
|
|
624
|
+
// Extract each bit plane via PEXT → bit mask → float
|
|
625
|
+
for (size_t b = 0; b < ex_bits; b++) {
|
|
626
|
+
auto plane = static_cast<uint8_t>(_pext_u64(ex64, pext_masks[b]));
|
|
627
|
+
__m256i p_cmp = _mm256_cmpgt_epi32(
|
|
628
|
+
_mm256_and_si256(_mm256_set1_epi32(plane), bit_pos), zero);
|
|
629
|
+
__m256 p_f = _mm256_and_ps(_mm256_castsi256_ps(p_cmp), v_one);
|
|
630
|
+
recon = _mm256_fmadd_ps(p_f, v_weights[b], recon);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
__m256 rq = _mm256_loadu_ps(rotated_q + i);
|
|
634
|
+
acc = _mm256_fmadd_ps(rq, _mm256_add_ps(recon, v_cb), acc);
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
float result = hsum_avx2(acc);
|
|
638
|
+
result += ip_scalar(sign_bits, ex_code, rotated_q, i, d, ex_bits, cb);
|
|
639
|
+
return result;
|
|
640
|
+
}
|
|
641
|
+
#endif // __AVX2__ && __BMI2__
|
|
642
|
+
|
|
643
|
+
#endif // x86_64
|
|
644
|
+
|
|
645
|
+
/**
|
|
646
|
+
* Dispatch to the best available kernel for the given ex_bits.
|
|
647
|
+
*
|
|
648
|
+
* Routing (compile-time):
|
|
649
|
+
* ex_bits == 1: specialized 1-bit kernel (AVX-512 > AVX2 > scalar)
|
|
650
|
+
* ex_bits >= 2: bit-plane decomposition (AVX2+BMI2 > scalar)
|
|
651
|
+
*
|
|
652
|
+
* @param sign_bits packed sign bits (1 bit/dim, standard byte packing)
|
|
653
|
+
* @param ex_code packed extra-bit codes (ex_bits bits/dim)
|
|
654
|
+
* @param rotated_q rotated query vector (float[d])
|
|
655
|
+
* @param d dimensionality
|
|
656
|
+
* @param ex_bits number of extra bits per dimension (nb_bits - 1)
|
|
657
|
+
* @param cb constant bias: -(2^ex_bits - 0.5)
|
|
658
|
+
* @return inner product value
|
|
659
|
+
*/
|
|
660
|
+
inline float compute_inner_product(
|
|
661
|
+
const uint8_t* __restrict sign_bits,
|
|
662
|
+
const uint8_t* __restrict ex_code,
|
|
663
|
+
const float* __restrict rotated_q,
|
|
664
|
+
size_t d,
|
|
665
|
+
size_t ex_bits,
|
|
666
|
+
float cb) {
|
|
667
|
+
if (ex_bits == 1) {
|
|
668
|
+
#if defined(__AVX512F__)
|
|
669
|
+
return ip_1exbit_avx512(sign_bits, ex_code, rotated_q, d, cb);
|
|
670
|
+
#elif defined(__AVX2__)
|
|
671
|
+
return ip_1exbit_avx2(sign_bits, ex_code, rotated_q, d, cb);
|
|
672
|
+
#else
|
|
673
|
+
return ip_scalar(sign_bits, ex_code, rotated_q, 0, d, 1, cb);
|
|
674
|
+
#endif
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
#if defined(__AVX2__) && defined(__BMI2__)
|
|
678
|
+
if (ex_bits <= 7) {
|
|
679
|
+
return ip_bitplane_avx2(sign_bits, ex_code, rotated_q, d, ex_bits, cb);
|
|
680
|
+
}
|
|
681
|
+
#endif
|
|
682
|
+
return ip_scalar(sign_bits, ex_code, rotated_q, 0, d, ex_bits, cb);
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
} // namespace faiss::rabitq::multibit
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/utils/distances.h>
|
|
9
|
+
|
|
10
|
+
#ifdef __aarch64__
|
|
11
|
+
|
|
12
|
+
#include <arm_neon.h>
|
|
13
|
+
#include <limits>
|
|
14
|
+
|
|
15
|
+
#define AUTOVEC_LEVEL SIMDLevel::ARM_NEON
|
|
16
|
+
#include <faiss/utils/simd_impl/distances_autovec-inl.h>
|
|
17
|
+
|
|
18
|
+
namespace faiss {
|
|
19
|
+
|
|
20
|
+
template <>
|
|
21
|
+
void fvec_madd<SIMDLevel::ARM_NEON>(
|
|
22
|
+
size_t n,
|
|
23
|
+
const float* a,
|
|
24
|
+
float bf,
|
|
25
|
+
const float* b,
|
|
26
|
+
float* c) {
|
|
27
|
+
const size_t n_simd = n - (n & 3);
|
|
28
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
29
|
+
size_t i;
|
|
30
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
31
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
32
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
33
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
34
|
+
vst1q_f32(c + i, ci);
|
|
35
|
+
}
|
|
36
|
+
for (; i < n; ++i)
|
|
37
|
+
c[i] = a[i] + bf * b[i];
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
template <>
|
|
41
|
+
void fvec_L2sqr_ny_transposed<SIMDLevel::ARM_NEON>(
|
|
42
|
+
float* dis,
|
|
43
|
+
const float* x,
|
|
44
|
+
const float* y,
|
|
45
|
+
const float* y_sqlen,
|
|
46
|
+
size_t d,
|
|
47
|
+
size_t d_offset,
|
|
48
|
+
size_t ny) {
|
|
49
|
+
// Use autovectorized implementation
|
|
50
|
+
fvec_L2sqr_ny_transposed<SIMDLevel::NONE>(
|
|
51
|
+
dis, x, y, y_sqlen, d, d_offset, ny);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
template <>
|
|
55
|
+
void fvec_inner_products_ny<SIMDLevel::ARM_NEON>(
|
|
56
|
+
float* dis,
|
|
57
|
+
const float* x,
|
|
58
|
+
const float* y,
|
|
59
|
+
size_t d,
|
|
60
|
+
size_t ny) {
|
|
61
|
+
fvec_inner_products_ny<SIMDLevel::NONE>(dis, x, y, d, ny);
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
template <>
|
|
65
|
+
void fvec_L2sqr_ny<SIMDLevel::ARM_NEON>(
|
|
66
|
+
float* dis,
|
|
67
|
+
const float* x,
|
|
68
|
+
const float* y,
|
|
69
|
+
size_t d,
|
|
70
|
+
size_t ny) {
|
|
71
|
+
fvec_L2sqr_ny<SIMDLevel::NONE>(dis, x, y, d, ny);
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
template <>
|
|
75
|
+
size_t fvec_L2sqr_ny_nearest<SIMDLevel::ARM_NEON>(
|
|
76
|
+
float* distances_tmp_buffer,
|
|
77
|
+
const float* x,
|
|
78
|
+
const float* y,
|
|
79
|
+
size_t d,
|
|
80
|
+
size_t ny) {
|
|
81
|
+
return fvec_L2sqr_ny_nearest<SIMDLevel::NONE>(
|
|
82
|
+
distances_tmp_buffer, x, y, d, ny);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
template <>
|
|
86
|
+
size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::ARM_NEON>(
|
|
87
|
+
float* distances_tmp_buffer,
|
|
88
|
+
const float* x,
|
|
89
|
+
const float* y,
|
|
90
|
+
const float* y_sqlen,
|
|
91
|
+
size_t d,
|
|
92
|
+
size_t d_offset,
|
|
93
|
+
size_t ny) {
|
|
94
|
+
return fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::NONE>(
|
|
95
|
+
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
template <>
|
|
99
|
+
int fvec_madd_and_argmin<SIMDLevel::ARM_NEON>(
|
|
100
|
+
size_t n,
|
|
101
|
+
const float* a,
|
|
102
|
+
float bf,
|
|
103
|
+
const float* b,
|
|
104
|
+
float* c) {
|
|
105
|
+
float32x4_t vminv = vdupq_n_f32(1e20);
|
|
106
|
+
uint32x4_t iminv = vdupq_n_u32(static_cast<uint32_t>(-1));
|
|
107
|
+
size_t i;
|
|
108
|
+
{
|
|
109
|
+
const size_t n_simd = n - (n & 3);
|
|
110
|
+
const uint32_t iota[] = {0, 1, 2, 3};
|
|
111
|
+
uint32x4_t iv = vld1q_u32(iota);
|
|
112
|
+
const uint32x4_t incv = vdupq_n_u32(4);
|
|
113
|
+
const float32x4_t bfv = vdupq_n_f32(bf);
|
|
114
|
+
for (i = 0; i < n_simd; i += 4) {
|
|
115
|
+
const float32x4_t ai = vld1q_f32(a + i);
|
|
116
|
+
const float32x4_t bi = vld1q_f32(b + i);
|
|
117
|
+
const float32x4_t ci = vfmaq_f32(ai, bfv, bi);
|
|
118
|
+
vst1q_f32(c + i, ci);
|
|
119
|
+
const uint32x4_t less_than = vcltq_f32(ci, vminv);
|
|
120
|
+
vminv = vminq_f32(ci, vminv);
|
|
121
|
+
iminv = vorrq_u32(
|
|
122
|
+
vandq_u32(less_than, iv),
|
|
123
|
+
vandq_u32(vmvnq_u32(less_than), iminv));
|
|
124
|
+
iv = vaddq_u32(iv, incv);
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
float vmin = vminvq_f32(vminv);
|
|
128
|
+
uint32_t imin;
|
|
129
|
+
{
|
|
130
|
+
const float32x4_t vminy = vdupq_n_f32(vmin);
|
|
131
|
+
const uint32x4_t equals = vceqq_f32(vminv, vminy);
|
|
132
|
+
imin = vminvq_u32(vorrq_u32(
|
|
133
|
+
vandq_u32(equals, iminv),
|
|
134
|
+
vandq_u32(
|
|
135
|
+
vmvnq_u32(equals),
|
|
136
|
+
vdupq_n_u32(std::numeric_limits<uint32_t>::max()))));
|
|
137
|
+
}
|
|
138
|
+
for (; i < n; ++i) {
|
|
139
|
+
c[i] = a[i] + bf * b[i];
|
|
140
|
+
if (c[i] < vmin) {
|
|
141
|
+
vmin = c[i];
|
|
142
|
+
imin = static_cast<uint32_t>(i);
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
return static_cast<int>(imin);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
} // namespace faiss
|
|
149
|
+
|
|
150
|
+
#endif // __aarch64__
|