faiss 0.6.1 → 0.6.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
- data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
- data/vendor/faiss/faiss/factory_tools.cpp +4 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
- data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
- data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
- data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
- data/vendor/faiss/faiss/impl/HNSW.h +51 -13
- data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
- data/vendor/faiss/faiss/impl/Panorama.h +11 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
- data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
- data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
- data/vendor/faiss/faiss/impl/io_macros.h +25 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
- data/vendor/faiss/faiss/index_factory.cpp +5 -1
- data/vendor/faiss/faiss/index_io.h +16 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
- data/vendor/faiss/faiss/utils/bf16.h +34 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
- metadata +12 -2
|
@@ -25,7 +25,9 @@ namespace faiss {
|
|
|
25
25
|
// Forward declarations to avoid circular dependency.
|
|
26
26
|
struct IndexHNSW;
|
|
27
27
|
struct IndexHNSWFlatPanorama;
|
|
28
|
-
|
|
28
|
+
template <class HC_>
|
|
29
|
+
struct MinimaxHeapT;
|
|
30
|
+
using MinimaxHeap = MinimaxHeapT<CMax<float, int32_t>>;
|
|
29
31
|
class LockVector;
|
|
30
32
|
|
|
31
33
|
/** Implementation of the Hierarchical Navigable Small World
|
|
@@ -60,30 +62,52 @@ struct HNSW {
|
|
|
60
62
|
/// internal storage of vectors (32 bits: this is expensive)
|
|
61
63
|
using storage_idx_t = int32_t;
|
|
62
64
|
|
|
63
|
-
//
|
|
64
|
-
|
|
65
|
+
// The two comparator flavors HNSW supports. CMax (smaller-is-better)
|
|
66
|
+
// is the default; CMin (larger-is-better) is used when `is_similarity`
|
|
67
|
+
// is set on the owning index.
|
|
68
|
+
using C_distance = CMax<float, int64_t>;
|
|
69
|
+
using C_similarity = CMin<float, int64_t>;
|
|
70
|
+
|
|
71
|
+
// Back-compat alias: keeps `HNSW::C` resolving to the distance
|
|
72
|
+
// (CMax) comparator everywhere the type is referenced directly.
|
|
73
|
+
using C = C_distance;
|
|
65
74
|
|
|
66
75
|
typedef std::pair<float, storage_idx_t> Node;
|
|
67
76
|
|
|
68
77
|
/// to sort pairs of (id, distance) from nearest to farthest or the reverse
|
|
69
|
-
|
|
78
|
+
template <class CT>
|
|
79
|
+
struct NodeDistCloserT {
|
|
70
80
|
float d;
|
|
71
81
|
int id;
|
|
72
|
-
|
|
73
|
-
bool operator<(const
|
|
74
|
-
|
|
82
|
+
NodeDistCloserT(float d_in, int id_in) : d(d_in), id(id_in) {}
|
|
83
|
+
bool operator<(const NodeDistCloserT& obj1) const {
|
|
84
|
+
// priority_queue keeps the "worst" element at the top so that
|
|
85
|
+
// when the queue is full we can pop it. For CMax (distance) the
|
|
86
|
+
// worst element is the largest d; for CMin (similarity) it is
|
|
87
|
+
// the smallest d. Equivalent to: obj1.d "better than" d.
|
|
88
|
+
return CT::cmp(obj1.d, d);
|
|
75
89
|
}
|
|
76
90
|
};
|
|
77
91
|
|
|
78
|
-
|
|
92
|
+
template <class CT>
|
|
93
|
+
struct NodeDistFartherT {
|
|
79
94
|
float d;
|
|
80
95
|
int id;
|
|
81
|
-
|
|
82
|
-
bool operator<(const
|
|
83
|
-
|
|
96
|
+
NodeDistFartherT(float d_in, int id_in) : d(d_in), id(id_in) {}
|
|
97
|
+
bool operator<(const NodeDistFartherT& obj1) const {
|
|
98
|
+
// priority_queue here keeps the "best" element at the top so we
|
|
99
|
+
// can process the nearest candidate first. For CMax (distance)
|
|
100
|
+
// the best is the smallest d; for CMin (similarity) the best is
|
|
101
|
+
// the largest d. Equivalent to: d "better than" obj1.d.
|
|
102
|
+
return CT::cmp(d, obj1.d);
|
|
84
103
|
}
|
|
85
104
|
};
|
|
86
105
|
|
|
106
|
+
// Back-compat aliases: default to the distance (CMax) comparator so
|
|
107
|
+
// existing call sites that mention `HNSW::NodeDist*` keep working.
|
|
108
|
+
using NodeDistCloser = NodeDistCloserT<C_distance>;
|
|
109
|
+
using NodeDistFarther = NodeDistFartherT<C_distance>;
|
|
110
|
+
|
|
87
111
|
/// assignment probability to each layer (sum=1)
|
|
88
112
|
std::vector<double> assign_probas;
|
|
89
113
|
|
|
@@ -131,6 +155,12 @@ struct HNSW {
|
|
|
131
155
|
/// use Panorama progressive pruning in search
|
|
132
156
|
bool is_panorama = false;
|
|
133
157
|
|
|
158
|
+
/// distance comparison semantics: when true, distances are treated as
|
|
159
|
+
/// similarity scores (larger is better). Default false matches the
|
|
160
|
+
/// historical L2/Hamming behavior (smaller is better).
|
|
161
|
+
/// Not serialized: must be re-set by the owning Index after loading.
|
|
162
|
+
bool is_similarity = false;
|
|
163
|
+
|
|
134
164
|
// See impl/VisitedTable.h.
|
|
135
165
|
std::optional<bool> use_visited_hashset;
|
|
136
166
|
|
|
@@ -216,10 +246,11 @@ struct HNSW {
|
|
|
216
246
|
|
|
217
247
|
int prepare_level_tab(size_t n, bool preset_levels = false);
|
|
218
248
|
|
|
249
|
+
template <class C = C_distance>
|
|
219
250
|
static void shrink_neighbor_list(
|
|
220
251
|
DistanceComputer& qdis,
|
|
221
|
-
std::priority_queue<
|
|
222
|
-
std::vector<
|
|
252
|
+
std::priority_queue<NodeDistFartherT<C>>& input,
|
|
253
|
+
std::vector<NodeDistFartherT<C>>& output,
|
|
223
254
|
size_t max_size,
|
|
224
255
|
bool keep_max_size_level0 = false);
|
|
225
256
|
|
|
@@ -250,6 +281,11 @@ struct HNSWStats {
|
|
|
250
281
|
// global var that collects them all
|
|
251
282
|
FAISS_API extern HNSWStats hnsw_stats;
|
|
252
283
|
|
|
284
|
+
/// Internal HNSW algorithm helpers. These are not part of the public API; they
|
|
285
|
+
/// are exposed here only so that unit tests (and a few cross-TU callers such as
|
|
286
|
+
/// the Panorama search variant) can reach them.
|
|
287
|
+
namespace hnsw_detail {
|
|
288
|
+
|
|
253
289
|
int search_from_candidates(
|
|
254
290
|
const HNSW& hnsw,
|
|
255
291
|
DistanceComputer& qdis,
|
|
@@ -302,4 +338,6 @@ void search_neighbors_to_add(
|
|
|
302
338
|
VisitedTable& vt,
|
|
303
339
|
bool reference_version = false);
|
|
304
340
|
|
|
341
|
+
} // namespace hnsw_detail
|
|
342
|
+
|
|
305
343
|
} // namespace faiss
|
|
@@ -234,10 +234,11 @@ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
|
|
|
234
234
|
std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
|
|
235
235
|
|
|
236
236
|
dis->set_query(center.get());
|
|
237
|
-
VisitedTable vt
|
|
237
|
+
std::unique_ptr<VisitedTable> vt =
|
|
238
|
+
VisitedTable::create(ntotal, use_visited_hashset);
|
|
238
239
|
|
|
239
240
|
// Do not collect the visited nodes
|
|
240
|
-
search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
|
|
241
|
+
search_on_graph<false>(knn_graph, *dis, *vt, ep, L, retset, tmpset);
|
|
241
242
|
|
|
242
243
|
// set enterpoint
|
|
243
244
|
enterpoint = retset[0].id;
|
|
@@ -344,7 +345,8 @@ void NSG::link(
|
|
|
344
345
|
std::vector<Node> pool;
|
|
345
346
|
std::vector<Neighbor> tmp;
|
|
346
347
|
|
|
347
|
-
VisitedTable vt
|
|
348
|
+
std::unique_ptr<VisitedTable> vt =
|
|
349
|
+
VisitedTable::create(ntotal, use_visited_hashset);
|
|
348
350
|
std::unique_ptr<DistanceComputer> dis(
|
|
349
351
|
storage_distance_computer(storage));
|
|
350
352
|
|
|
@@ -355,13 +357,13 @@ void NSG::link(
|
|
|
355
357
|
|
|
356
358
|
// Collect the visited nodes into pool
|
|
357
359
|
search_on_graph<true>(
|
|
358
|
-
knn_graph, *dis, vt, enterpoint, L, tmp, pool);
|
|
360
|
+
knn_graph, *dis, *vt, enterpoint, L, tmp, pool);
|
|
359
361
|
|
|
360
|
-
sync_prune(i, pool, *dis, vt, knn_graph, graph);
|
|
362
|
+
sync_prune(i, pool, *dis, *vt, knn_graph, graph);
|
|
361
363
|
|
|
362
364
|
pool.clear();
|
|
363
365
|
tmp.clear();
|
|
364
|
-
vt
|
|
366
|
+
vt->advance();
|
|
365
367
|
}
|
|
366
368
|
} // omp parallel
|
|
367
369
|
|
|
@@ -531,19 +533,21 @@ void NSG::add_reverse_links(
|
|
|
531
533
|
|
|
532
534
|
int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
|
|
533
535
|
int root = enterpoint;
|
|
534
|
-
VisitedTable vt
|
|
535
|
-
|
|
536
|
+
std::unique_ptr<VisitedTable> vt =
|
|
537
|
+
VisitedTable::create(ntotal, use_visited_hashset);
|
|
538
|
+
std::unique_ptr<VisitedTable> vt2 =
|
|
539
|
+
VisitedTable::create(ntotal, use_visited_hashset);
|
|
536
540
|
|
|
537
541
|
int num_attached = 0;
|
|
538
542
|
int cnt = 0;
|
|
539
543
|
while (true) {
|
|
540
|
-
cnt = dfs(vt, root, cnt);
|
|
544
|
+
cnt = dfs(*vt, root, cnt);
|
|
541
545
|
if (cnt >= ntotal) {
|
|
542
546
|
break;
|
|
543
547
|
}
|
|
544
548
|
|
|
545
|
-
root = attach_unlinked(storage, vt, vt2, degrees);
|
|
546
|
-
vt2
|
|
549
|
+
root = attach_unlinked(storage, *vt, *vt2, degrees);
|
|
550
|
+
vt2->advance();
|
|
547
551
|
num_attached += 1;
|
|
548
552
|
}
|
|
549
553
|
|
|
@@ -36,7 +36,12 @@ namespace faiss {
|
|
|
36
36
|
/// from active_indices (subsequent levels after pruning).
|
|
37
37
|
/// @tparam LevelWidth Compile-time level width in floats (0 = use runtime
|
|
38
38
|
/// level_width_dims). Enables full loop unrolling.
|
|
39
|
+
// Skip pragmas under nvcc: its EDG frontend warns on `#pragma GCC optimize`
|
|
40
|
+
// (#1675-D) for every `.cu` that transitively includes this header. These
|
|
41
|
+
// templates are CPU-only, so the hint is irrelevant during nvcc parse.
|
|
42
|
+
#if !defined(__NVCC__)
|
|
39
43
|
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
44
|
+
#endif
|
|
40
45
|
template <bool AllActive = false, size_t LevelWidth = 0>
|
|
41
46
|
static inline void compute_level_dot_kernel(
|
|
42
47
|
const float* FAISS_RESTRICT query_level,
|
|
@@ -83,7 +88,9 @@ static inline void compute_level_dot_kernel(
|
|
|
83
88
|
dot_products[i] = dp;
|
|
84
89
|
}
|
|
85
90
|
}
|
|
91
|
+
#if !defined(__NVCC__)
|
|
86
92
|
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
93
|
+
#endif
|
|
87
94
|
|
|
88
95
|
/// Update exact distances with the current level's dot products, then apply
|
|
89
96
|
/// Panorama pruning: for each active vector, compute a lower bound on
|
|
@@ -92,7 +99,9 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
|
92
99
|
///
|
|
93
100
|
/// Uses `if constexpr` on C::is_max rather than C::cmp() to ensure the
|
|
94
101
|
/// comparison autovectorizes (C::cmp generates scalar function calls).
|
|
102
|
+
#if !defined(__NVCC__)
|
|
95
103
|
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
|
|
104
|
+
#endif
|
|
96
105
|
template <bool AllActive, typename C, MetricType M>
|
|
97
106
|
static inline void prune_kernel(
|
|
98
107
|
float* FAISS_RESTRICT exact_distances,
|
|
@@ -128,7 +137,9 @@ static inline void prune_kernel(
|
|
|
128
137
|
}
|
|
129
138
|
}
|
|
130
139
|
}
|
|
140
|
+
#if !defined(__NVCC__)
|
|
131
141
|
FAISS_PRAGMA_IMPRECISE_FUNCTION_END
|
|
142
|
+
#endif
|
|
132
143
|
|
|
133
144
|
/// Compact active_indices in-place, removing entries where active_byteset[i]
|
|
134
145
|
/// is zero. Returns the new count of active elements. Uses a branchless BMI2 +
|
|
@@ -13,12 +13,15 @@
|
|
|
13
13
|
#include <cstdio>
|
|
14
14
|
#include <cstring>
|
|
15
15
|
#include <memory>
|
|
16
|
+
#include <type_traits>
|
|
16
17
|
|
|
17
18
|
#include <algorithm>
|
|
18
19
|
|
|
19
20
|
#include <faiss/IndexFlat.h>
|
|
20
21
|
#include <faiss/VectorTransform.h>
|
|
21
22
|
#include <faiss/impl/FaissAssert.h>
|
|
23
|
+
// NOLINTNEXTLINE(facebook-hte-InlineHeader)
|
|
24
|
+
#include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
|
|
22
25
|
#include <faiss/impl/simd_dispatch.h>
|
|
23
26
|
#include <faiss/utils/distances.h>
|
|
24
27
|
|
|
@@ -719,8 +722,28 @@ void pq_knn_search_with_tables(
|
|
|
719
722
|
|
|
720
723
|
switch (nbits) {
|
|
721
724
|
case 8:
|
|
722
|
-
|
|
723
|
-
|
|
725
|
+
if (ksub == 256) {
|
|
726
|
+
constexpr bool max_heap =
|
|
727
|
+
std::is_same_v<C, CMax<float, int64_t>>;
|
|
728
|
+
pq_code_distance::pq_scan_8bit(
|
|
729
|
+
M,
|
|
730
|
+
dis_table,
|
|
731
|
+
codes,
|
|
732
|
+
ncodes,
|
|
733
|
+
k,
|
|
734
|
+
heap_dis,
|
|
735
|
+
heap_ids,
|
|
736
|
+
max_heap);
|
|
737
|
+
} else {
|
|
738
|
+
pq_estimators_from_tables<uint8_t, C>(
|
|
739
|
+
pq,
|
|
740
|
+
codes,
|
|
741
|
+
ncodes,
|
|
742
|
+
dis_table,
|
|
743
|
+
k,
|
|
744
|
+
heap_dis,
|
|
745
|
+
heap_ids);
|
|
746
|
+
}
|
|
724
747
|
break;
|
|
725
748
|
|
|
726
749
|
case 16:
|
|
@@ -321,7 +321,7 @@ float compute_full_multibit_distance(
|
|
|
321
321
|
size_t d,
|
|
322
322
|
size_t ex_bits,
|
|
323
323
|
MetricType metric_type) {
|
|
324
|
-
return with_selected_simd_levels<
|
|
324
|
+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
|
|
325
325
|
[&]<SIMDLevel SL>() {
|
|
326
326
|
return compute_full_multibit_distance<SL>(
|
|
327
327
|
sign_bits,
|
|
@@ -551,7 +551,13 @@ FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
|
|
|
551
551
|
// Dispatch on SIMDLevel once here so the distance computer methods
|
|
552
552
|
// call the SIMD-specialized rabitq functions directly (no per-call
|
|
553
553
|
// with_simd_level overhead).
|
|
554
|
-
|
|
554
|
+
//
|
|
555
|
+
// Use A0_SPR (which includes AVX512_SPR) so that on Sapphire Rapids
|
|
556
|
+
// and later x86 microarchitectures the VPOPCNTDQ-based RaBitQ
|
|
557
|
+
// specialization in rabitq_avx512_spr.cpp is selected. On AVX-512
|
|
558
|
+
// CPUs without VPOPCNTDQ, dispatch falls through to the AVX512
|
|
559
|
+
// specialization in rabitq_avx512.cpp.
|
|
560
|
+
return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
|
|
555
561
|
[&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
|
|
556
562
|
if (qb == 0) {
|
|
557
563
|
auto dc =
|
|
@@ -25,6 +25,218 @@
|
|
|
25
25
|
|
|
26
26
|
namespace faiss {
|
|
27
27
|
|
|
28
|
+
namespace {
|
|
29
|
+
|
|
30
|
+
// Gaussian Lloyd-Max optimal quantizer centroids and boundaries for N(0,1).
|
|
31
|
+
// clang-format off
|
|
32
|
+
const float kLloydMaxCentroids1[] = {
|
|
33
|
+
-0.797884560802865f, 0.797884560802865f
|
|
34
|
+
};
|
|
35
|
+
const float kLloydMaxBoundaries1[] = {
|
|
36
|
+
0.000000000000000f
|
|
37
|
+
};
|
|
38
|
+
const float kLloydMaxCentroids2[] = {
|
|
39
|
+
-1.510417608499078f, -0.452780034636484f,
|
|
40
|
+
0.452780034636483f, 1.510417608499078f
|
|
41
|
+
};
|
|
42
|
+
const float kLloydMaxBoundaries2[] = {
|
|
43
|
+
-0.981598821567781f, 0.000000000000000f, 0.981598821567781f
|
|
44
|
+
};
|
|
45
|
+
const float kLloydMaxCentroids3[] = {
|
|
46
|
+
-2.151945704536914f, -1.343909278504930f,
|
|
47
|
+
-0.756005281205826f, -0.245094178944203f,
|
|
48
|
+
0.245094178944203f, 0.756005281205825f,
|
|
49
|
+
1.343909278504930f, 2.151945704536914f
|
|
50
|
+
};
|
|
51
|
+
const float kLloydMaxBoundaries3[] = {
|
|
52
|
+
-1.747927491520922f, -1.049957279855378f,
|
|
53
|
+
-0.500549730075014f, 0.000000000000000f,
|
|
54
|
+
0.500549730075014f, 1.049957279855378f,
|
|
55
|
+
1.747927491520922f
|
|
56
|
+
};
|
|
57
|
+
const float kLloydMaxCentroids4[] = {
|
|
58
|
+
-2.732589570994957f, -2.069017226531159f,
|
|
59
|
+
-1.618046386021649f, -1.256231197346957f,
|
|
60
|
+
-0.942340456486774f, -0.656759118532318f,
|
|
61
|
+
-0.388048299490198f, -0.128395029851116f,
|
|
62
|
+
0.128395029851116f, 0.388048299490198f,
|
|
63
|
+
0.656759118532318f, 0.942340456486773f,
|
|
64
|
+
1.256231197346959f, 1.618046386021649f,
|
|
65
|
+
2.069017226531160f, 2.732589570994943f
|
|
66
|
+
};
|
|
67
|
+
const float kLloydMaxBoundaries4[] = {
|
|
68
|
+
-2.400803398763058f, -1.843531806276404f,
|
|
69
|
+
-1.437138791684303f, -1.099285826916865f,
|
|
70
|
+
-0.799549787509546f, -0.522403709011258f,
|
|
71
|
+
-0.258221664670657f, 0.000000000000000f,
|
|
72
|
+
0.258221664670657f, 0.522403709011258f,
|
|
73
|
+
0.799549787509546f, 1.099285826916866f,
|
|
74
|
+
1.437138791684304f, 1.843531806276404f,
|
|
75
|
+
2.400803398763051f
|
|
76
|
+
};
|
|
77
|
+
const float kLloydMaxCentroids8[] = {
|
|
78
|
+
-4.2734901319f, -3.8270895246f, -3.5457169520f, -3.3354593381f,
|
|
79
|
+
-3.1655721017f, -3.0219515320f, -2.8969009924f, -2.7857394515f,
|
|
80
|
+
-2.6853990170f, -2.5937556343f, -2.5092755166f, -2.4308135619f,
|
|
81
|
+
-2.3574913691f, -2.2886197969f, -2.2236478246f, -2.1621276457f,
|
|
82
|
+
-2.1036901632f, -2.0480273642f, -1.9948793740f, -1.9440247677f,
|
|
83
|
+
-1.8952732015f, -1.8484597247f, -1.8034403315f, -1.7600884415f,
|
|
84
|
+
-1.7182920846f, -1.6779516274f, -1.6389779215f, -1.6012907825f,
|
|
85
|
+
-1.5648177311f, -1.5294929453f, -1.4952563823f, -1.4620530375f,
|
|
86
|
+
-1.4298323186f, -1.3985475108f, -1.3681553217f, -1.3386154890f,
|
|
87
|
+
-1.3098904444f, -1.2819450217f, -1.2547462051f, -1.2282629097f,
|
|
88
|
+
-1.2024657910f, -1.1773270781f, -1.1528204287f, -1.1289208010f,
|
|
89
|
+
-1.1056043421f, -1.0828482901f, -1.0606308873f, -1.0389313043f,
|
|
90
|
+
-1.0177295729f, -0.9970065268f, -0.9767437492f, -0.9569235264f,
|
|
91
|
+
-0.9375288069f, -0.9185431646f, -0.8999507663f, -0.8817363426f,
|
|
92
|
+
-0.8638851621f, -0.8463830081f, -0.8292161569f, -0.8123713596f,
|
|
93
|
+
-0.7958358242f, -0.7795971999f, -0.7636435625f, -0.7479634007f,
|
|
94
|
+
-0.7325456038f, -0.7173794494f, -0.7024545929f, -0.6877610560f,
|
|
95
|
+
-0.6732892172f, -0.6590298016f, -0.6449738716f, -0.6311128174f,
|
|
96
|
+
-0.6174383481f, -0.6039424829f, -0.5906175419f, -0.5774561379f,
|
|
97
|
+
-0.5644511676f, -0.5515958029f, -0.5388834832f, -0.5263079060f,
|
|
98
|
+
-0.5138630194f, -0.5015430136f, -0.4893423125f, -0.4772555660f,
|
|
99
|
+
-0.4652776416f, -0.4534036165f, -0.4416287701f, -0.4299485757f,
|
|
100
|
+
-0.4183586932f, -0.4068549615f, -0.3954333909f, -0.3840901561f,
|
|
101
|
+
-0.3728215889f, -0.3616241712f, -0.3504945283f, -0.3394294221f,
|
|
102
|
+
-0.3284257446f, -0.3174805116f, -0.3065908567f, -0.2957540250f,
|
|
103
|
+
-0.2849673675f, -0.2742283355f, -0.2635344752f, -0.2528834222f,
|
|
104
|
+
-0.2422728967f, -0.2317006985f, -0.2211647022f, -0.2106628526f,
|
|
105
|
+
-0.2001931607f, -0.1897536989f, -0.1793425974f, -0.1689580400f,
|
|
106
|
+
-0.1585982605f, -0.1482615390f, -0.1379461985f, -0.1276506012f,
|
|
107
|
+
-0.1173731457f, -0.1071122637f, -0.0968664166f, -0.0866340933f,
|
|
108
|
+
-0.0764138065f, -0.0662040909f, -0.0560034994f, -0.0458106014f,
|
|
109
|
+
-0.0356239797f, -0.0254422284f, -0.0152639496f, -0.0050877521f,
|
|
110
|
+
0.0050877521f, 0.0152639496f, 0.0254422284f, 0.0356239797f,
|
|
111
|
+
0.0458106014f, 0.0560034994f, 0.0662040909f, 0.0764138065f,
|
|
112
|
+
0.0866340933f, 0.0968664166f, 0.1071122637f, 0.1173731457f,
|
|
113
|
+
0.1276506012f, 0.1379461985f, 0.1482615390f, 0.1585982605f,
|
|
114
|
+
0.1689580400f, 0.1793425974f, 0.1897536989f, 0.2001931607f,
|
|
115
|
+
0.2106628526f, 0.2211647022f, 0.2317006985f, 0.2422728967f,
|
|
116
|
+
0.2528834222f, 0.2635344752f, 0.2742283355f, 0.2849673675f,
|
|
117
|
+
0.2957540250f, 0.3065908567f, 0.3174805116f, 0.3284257446f,
|
|
118
|
+
0.3394294221f, 0.3504945283f, 0.3616241712f, 0.3728215889f,
|
|
119
|
+
0.3840901561f, 0.3954333909f, 0.4068549615f, 0.4183586932f,
|
|
120
|
+
0.4299485757f, 0.4416287701f, 0.4534036165f, 0.4652776416f,
|
|
121
|
+
0.4772555660f, 0.4893423125f, 0.5015430136f, 0.5138630194f,
|
|
122
|
+
0.5263079060f, 0.5388834832f, 0.5515958029f, 0.5644511676f,
|
|
123
|
+
0.5774561379f, 0.5906175419f, 0.6039424829f, 0.6174383481f,
|
|
124
|
+
0.6311128174f, 0.6449738716f, 0.6590298016f, 0.6732892172f,
|
|
125
|
+
0.6877610560f, 0.7024545929f, 0.7173794494f, 0.7325456038f,
|
|
126
|
+
0.7479634007f, 0.7636435625f, 0.7795971999f, 0.7958358242f,
|
|
127
|
+
0.8123713596f, 0.8292161569f, 0.8463830081f, 0.8638851621f,
|
|
128
|
+
0.8817363426f, 0.8999507663f, 0.9185431646f, 0.9375288069f,
|
|
129
|
+
0.9569235264f, 0.9767437492f, 0.9970065268f, 1.0177295729f,
|
|
130
|
+
1.0389313043f, 1.0606308873f, 1.0828482901f, 1.1056043421f,
|
|
131
|
+
1.1289208010f, 1.1528204287f, 1.1773270781f, 1.2024657910f,
|
|
132
|
+
1.2282629097f, 1.2547462051f, 1.2819450217f, 1.3098904444f,
|
|
133
|
+
1.3386154890f, 1.3681553217f, 1.3985475108f, 1.4298323186f,
|
|
134
|
+
1.4620530375f, 1.4952563823f, 1.5294929453f, 1.5648177311f,
|
|
135
|
+
1.6012907825f, 1.6389779215f, 1.6779516274f, 1.7182920846f,
|
|
136
|
+
1.7600884415f, 1.8034403315f, 1.8484597247f, 1.8952732015f,
|
|
137
|
+
1.9440247677f, 1.9948793740f, 2.0480273642f, 2.1036901632f,
|
|
138
|
+
2.1621276457f, 2.2236478246f, 2.2886197969f, 2.3574913691f,
|
|
139
|
+
2.4308135619f, 2.5092755166f, 2.5937556343f, 2.6853990170f,
|
|
140
|
+
2.7857394515f, 2.8969009924f, 3.0219515320f, 3.1655721017f,
|
|
141
|
+
3.3354593381f, 3.5457169520f, 3.8270895246f, 4.2734901319f
|
|
142
|
+
};
|
|
143
|
+
const float kLloydMaxBoundaries8[] = {
|
|
144
|
+
-4.0502898282f, -3.6864032383f, -3.4405881450f, -3.2505157199f,
|
|
145
|
+
-3.0937618168f, -2.9594262622f, -2.8413202220f, -2.7355692343f,
|
|
146
|
+
-2.6395773257f, -2.5515155755f, -2.4700445392f, -2.3941524655f,
|
|
147
|
+
-2.3230555830f, -2.2561338107f, -2.1928877352f, -2.1329089044f,
|
|
148
|
+
-2.0758587637f, -2.0214533691f, -1.9694520708f, -1.9196489846f,
|
|
149
|
+
-1.8718664631f, -1.8259500281f, -1.7817643865f, -1.7391902630f,
|
|
150
|
+
-1.6981218560f, -1.6584647744f, -1.6201343520f, -1.5830542568f,
|
|
151
|
+
-1.5471553382f, -1.5123746638f, -1.4786547099f, -1.4459426781f,
|
|
152
|
+
-1.4141899147f, -1.3833514163f, -1.3533854053f, -1.3242529667f,
|
|
153
|
+
-1.2959177331f, -1.2683456134f, -1.2415045574f, -1.2153643503f,
|
|
154
|
+
-1.1898964346f, -1.1650737534f, -1.1408706148f, -1.1172625715f,
|
|
155
|
+
-1.0942263161f, -1.0717395887f, -1.0497810958f, -1.0283304386f,
|
|
156
|
+
-1.0073680499f, -0.9868751380f, -0.9668336378f, -0.9472261667f,
|
|
157
|
+
-0.9280359858f, -0.9092469654f, -0.8908435544f, -0.8728107524f,
|
|
158
|
+
-0.8551340851f, -0.8377995825f, -0.8207937582f, -0.8041035919f,
|
|
159
|
+
-0.7877165121f, -0.7716203812f, -0.7558034816f, -0.7402545023f,
|
|
160
|
+
-0.7249625266f, -0.7099170212f, -0.6951078244f, -0.6805251366f,
|
|
161
|
+
-0.6661595094f, -0.6520018366f, -0.6380433445f, -0.6242755828f,
|
|
162
|
+
-0.6106904155f, -0.5972800124f, -0.5840368399f, -0.5709536527f,
|
|
163
|
+
-0.5580234853f, -0.5452396431f, -0.5325956946f, -0.5200854627f,
|
|
164
|
+
-0.5077030165f, -0.4954426631f, -0.4832989393f, -0.4712666038f,
|
|
165
|
+
-0.4593406291f, -0.4475161933f, -0.4357886729f, -0.4241536345f,
|
|
166
|
+
-0.4126068274f, -0.4011441762f, -0.3897617735f, -0.3784558725f,
|
|
167
|
+
-0.3672228800f, -0.3560593498f, -0.3449619752f, -0.3339275834f,
|
|
168
|
+
-0.3229531281f, -0.3120356842f, -0.3011724408f, -0.2903606962f,
|
|
169
|
+
-0.2795978515f, -0.2688814053f, -0.2582089487f, -0.2475781595f,
|
|
170
|
+
-0.2369867976f, -0.2264327004f, -0.2159137774f, -0.2054280067f,
|
|
171
|
+
-0.1949734298f, -0.1845481481f, -0.1741503187f, -0.1637781502f,
|
|
172
|
+
-0.1534298998f, -0.1431038688f, -0.1327983999f, -0.1225118735f,
|
|
173
|
+
-0.1122427047f, -0.1019893401f, -0.0917502549f, -0.0815239499f,
|
|
174
|
+
-0.0713089487f, -0.0611037951f, -0.0509070504f, -0.0407172906f,
|
|
175
|
+
-0.0305331041f, -0.0203530890f, -0.0101758509f, 0.0000000000f,
|
|
176
|
+
0.0101758509f, 0.0203530890f, 0.0305331041f, 0.0407172906f,
|
|
177
|
+
0.0509070504f, 0.0611037951f, 0.0713089487f, 0.0815239499f,
|
|
178
|
+
0.0917502549f, 0.1019893401f, 0.1122427047f, 0.1225118735f,
|
|
179
|
+
0.1327983999f, 0.1431038688f, 0.1534298998f, 0.1637781502f,
|
|
180
|
+
0.1741503187f, 0.1845481481f, 0.1949734298f, 0.2054280067f,
|
|
181
|
+
0.2159137774f, 0.2264327004f, 0.2369867976f, 0.2475781595f,
|
|
182
|
+
0.2582089487f, 0.2688814053f, 0.2795978515f, 0.2903606962f,
|
|
183
|
+
0.3011724408f, 0.3120356842f, 0.3229531281f, 0.3339275834f,
|
|
184
|
+
0.3449619752f, 0.3560593498f, 0.3672228800f, 0.3784558725f,
|
|
185
|
+
0.3897617735f, 0.4011441762f, 0.4126068274f, 0.4241536345f,
|
|
186
|
+
0.4357886729f, 0.4475161933f, 0.4593406291f, 0.4712666038f,
|
|
187
|
+
0.4832989393f, 0.4954426631f, 0.5077030165f, 0.5200854627f,
|
|
188
|
+
0.5325956946f, 0.5452396431f, 0.5580234853f, 0.5709536527f,
|
|
189
|
+
0.5840368399f, 0.5972800124f, 0.6106904155f, 0.6242755828f,
|
|
190
|
+
0.6380433445f, 0.6520018366f, 0.6661595094f, 0.6805251366f,
|
|
191
|
+
0.6951078244f, 0.7099170212f, 0.7249625266f, 0.7402545023f,
|
|
192
|
+
0.7558034816f, 0.7716203812f, 0.7877165121f, 0.8041035919f,
|
|
193
|
+
0.8207937582f, 0.8377995825f, 0.8551340851f, 0.8728107524f,
|
|
194
|
+
0.8908435544f, 0.9092469654f, 0.9280359858f, 0.9472261667f,
|
|
195
|
+
0.9668336378f, 0.9868751380f, 1.0073680499f, 1.0283304386f,
|
|
196
|
+
1.0497810958f, 1.0717395887f, 1.0942263161f, 1.1172625715f,
|
|
197
|
+
1.1408706148f, 1.1650737534f, 1.1898964346f, 1.2153643503f,
|
|
198
|
+
1.2415045574f, 1.2683456134f, 1.2959177331f, 1.3242529667f,
|
|
199
|
+
1.3533854053f, 1.3833514163f, 1.4141899147f, 1.4459426781f,
|
|
200
|
+
1.4786547099f, 1.5123746638f, 1.5471553382f, 1.5830542568f,
|
|
201
|
+
1.6201343520f, 1.6584647744f, 1.6981218560f, 1.7391902630f,
|
|
202
|
+
1.7817643865f, 1.8259500281f, 1.8718664631f, 1.9196489846f,
|
|
203
|
+
1.9694520708f, 2.0214533691f, 2.0758587637f, 2.1329089044f,
|
|
204
|
+
2.1928877352f, 2.2561338107f, 2.3230555830f, 2.3941524655f,
|
|
205
|
+
2.4700445392f, 2.5515155755f, 2.6395773257f, 2.7355692343f,
|
|
206
|
+
2.8413202220f, 2.9594262622f, 3.0937618168f, 3.2505157199f,
|
|
207
|
+
3.4405881450f, 3.6864032383f, 4.0502898282f
|
|
208
|
+
};
|
|
209
|
+
// clang-format on
|
|
210
|
+
|
|
211
|
+
struct LloydMaxTable {
|
|
212
|
+
const float* centroids;
|
|
213
|
+
const float* boundaries;
|
|
214
|
+
};
|
|
215
|
+
|
|
216
|
+
const LloydMaxTable kLloydMaxTables[] = {
|
|
217
|
+
{nullptr, nullptr}, // 0
|
|
218
|
+
{kLloydMaxCentroids1, kLloydMaxBoundaries1}, // 1
|
|
219
|
+
{kLloydMaxCentroids2, kLloydMaxBoundaries2}, // 2
|
|
220
|
+
{kLloydMaxCentroids3, kLloydMaxBoundaries3}, // 3
|
|
221
|
+
{kLloydMaxCentroids4, kLloydMaxBoundaries4}, // 4
|
|
222
|
+
{nullptr, nullptr}, // 5 (unused)
|
|
223
|
+
{nullptr, nullptr}, // 6 (unused)
|
|
224
|
+
{nullptr, nullptr}, // 7 (unused)
|
|
225
|
+
{kLloydMaxCentroids8, kLloydMaxBoundaries8}, // 8
|
|
226
|
+
};
|
|
227
|
+
|
|
228
|
+
void populate_lloyd_max_trained(size_t mse_bits, std::vector<float>& trained) {
|
|
229
|
+
FAISS_THROW_IF_NOT(mse_bits >= 1 && mse_bits <= 8);
|
|
230
|
+
FAISS_THROW_IF_NOT(kLloydMaxTables[mse_bits].centroids != nullptr);
|
|
231
|
+
size_t k = size_t(1) << mse_bits;
|
|
232
|
+
const auto& t = kLloydMaxTables[mse_bits];
|
|
233
|
+
trained.resize(k + (k - 1));
|
|
234
|
+
std::copy(t.centroids, t.centroids + k, trained.begin());
|
|
235
|
+
std::copy(t.boundaries, t.boundaries + k - 1, trained.begin() + k);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
} // namespace
|
|
239
|
+
|
|
28
240
|
/*******************************************************************
|
|
29
241
|
* ScalarQuantizer implementation
|
|
30
242
|
********************************************************************/
|
|
@@ -80,6 +292,24 @@ void ScalarQuantizer::set_derived_sizes() {
|
|
|
80
292
|
code_size = 0;
|
|
81
293
|
bits = 0;
|
|
82
294
|
break;
|
|
295
|
+
case QT_2bit_tq:
|
|
296
|
+
case QT_3bit_tq:
|
|
297
|
+
case QT_4bit_tq:
|
|
298
|
+
case QT_5bit_tq: {
|
|
299
|
+
size_t nb_bits = (qtype == QT_2bit_tq) ? 2
|
|
300
|
+
: (qtype == QT_3bit_tq) ? 3
|
|
301
|
+
: (qtype == QT_4bit_tq) ? 4
|
|
302
|
+
: (qtype == QT_5bit_tq) ? 5
|
|
303
|
+
: 0;
|
|
304
|
+
FAISS_THROW_IF_NOT_MSG(nb_bits > 0, "unexpected TurboQ qtype");
|
|
305
|
+
size_t mse_bits = nb_bits - 1;
|
|
306
|
+
size_t mse_bytes = mse_bits * ((d + 7) / 8);
|
|
307
|
+
size_t qjl_bytes = (d + 7) / 8;
|
|
308
|
+
code_size = mse_bytes + qjl_bytes +
|
|
309
|
+
sizeof(scalar_quantizer::SQTurboQFactors);
|
|
310
|
+
bits = nb_bits;
|
|
311
|
+
break;
|
|
312
|
+
}
|
|
83
313
|
default:
|
|
84
314
|
break;
|
|
85
315
|
}
|
|
@@ -134,27 +364,60 @@ void ScalarQuantizer::train(size_t n, const float* x) {
|
|
|
134
364
|
// no training necessary
|
|
135
365
|
break;
|
|
136
366
|
case QT_1bit_tqmse:
|
|
137
|
-
|
|
367
|
+
populate_lloyd_max_trained(1, trained);
|
|
138
368
|
break;
|
|
139
369
|
case QT_2bit_tqmse:
|
|
140
|
-
|
|
370
|
+
populate_lloyd_max_trained(2, trained);
|
|
141
371
|
break;
|
|
142
372
|
case QT_3bit_tqmse:
|
|
143
|
-
|
|
373
|
+
populate_lloyd_max_trained(3, trained);
|
|
144
374
|
break;
|
|
145
375
|
case QT_4bit_tqmse:
|
|
146
|
-
|
|
376
|
+
populate_lloyd_max_trained(4, trained);
|
|
147
377
|
break;
|
|
148
378
|
case QT_8bit_tqmse:
|
|
149
|
-
|
|
379
|
+
populate_lloyd_max_trained(8, trained);
|
|
380
|
+
break;
|
|
381
|
+
case QT_2bit_tq:
|
|
382
|
+
case QT_3bit_tq:
|
|
383
|
+
case QT_4bit_tq:
|
|
384
|
+
case QT_5bit_tq: {
|
|
385
|
+
size_t mse_bits = bits - 1;
|
|
386
|
+
populate_lloyd_max_trained(mse_bits, trained);
|
|
387
|
+
// Pack seed and qjl_type at end of trained for dispatch
|
|
388
|
+
float seed_f[2];
|
|
389
|
+
TurboQuantRefine::pack_seed(turboq_refine.seed, seed_f);
|
|
390
|
+
trained.push_back(seed_f[0]);
|
|
391
|
+
trained.push_back(seed_f[1]);
|
|
392
|
+
trained.push_back(static_cast<float>(turboq_refine.qjl_type));
|
|
393
|
+
turboq_refine.init_projection(d);
|
|
150
394
|
break;
|
|
395
|
+
}
|
|
151
396
|
default:
|
|
152
397
|
break;
|
|
153
398
|
}
|
|
154
399
|
}
|
|
155
400
|
|
|
401
|
+
void ScalarQuantizer::TurboQuantRefine::init_projection(size_t d) {
|
|
402
|
+
if (use_fwht()) {
|
|
403
|
+
padded_d = 1;
|
|
404
|
+
while (padded_d < d) {
|
|
405
|
+
padded_d <<= 1;
|
|
406
|
+
}
|
|
407
|
+
fwht_signs.resize(padded_d);
|
|
408
|
+
RandomGenerator rng(seed);
|
|
409
|
+
for (size_t i = 0; i < padded_d; i++) {
|
|
410
|
+
fwht_signs[i] = (rng.rand_int(2) == 0) ? 1.0f : -1.0f;
|
|
411
|
+
}
|
|
412
|
+
} else {
|
|
413
|
+
rr_matrix.resize(d * d);
|
|
414
|
+
float_randn(rr_matrix.data(), d * d, seed);
|
|
415
|
+
matrix_qr(static_cast<int>(d), static_cast<int>(d), rr_matrix.data());
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
|
|
156
419
|
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
|
|
157
|
-
return
|
|
420
|
+
return with_simd_level_spr([&]<SIMDLevel SL>() -> SQuantizer* {
|
|
158
421
|
if constexpr (SL != SIMDLevel::NONE) {
|
|
159
422
|
auto* q = scalar_quantizer::sq_select_quantizer<SL>(
|
|
160
423
|
qtype, d, trained);
|
|
@@ -197,7 +460,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
|
|
|
197
460
|
ScalarQuantizer::SQDistanceComputer* ScalarQuantizer::get_distance_computer(
|
|
198
461
|
MetricType metric) const {
|
|
199
462
|
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
200
|
-
return
|
|
463
|
+
return with_simd_level_spr([&]<SIMDLevel SL>() -> SQDistanceComputer* {
|
|
201
464
|
if constexpr (SL != SIMDLevel::NONE) {
|
|
202
465
|
auto* dc = scalar_quantizer::sq_select_distance_computer<SL>(
|
|
203
466
|
metric, qtype, d, trained);
|
|
@@ -216,7 +479,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
|
|
|
216
479
|
bool store_pairs,
|
|
217
480
|
const IDSelector* sel,
|
|
218
481
|
bool by_residual) const {
|
|
219
|
-
return
|
|
482
|
+
return with_simd_level_spr([&]<SIMDLevel SL>() -> InvertedListScanner* {
|
|
220
483
|
if constexpr (SL != SIMDLevel::NONE) {
|
|
221
484
|
auto* s = scalar_quantizer::sq_select_InvertedListScanner<SL>(
|
|
222
485
|
qtype,
|