faiss 0.6.1 → 0.6.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
- data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
- data/vendor/faiss/faiss/factory_tools.cpp +4 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
- data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
- data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
- data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
- data/vendor/faiss/faiss/impl/HNSW.h +51 -13
- data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
- data/vendor/faiss/faiss/impl/Panorama.h +11 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
- data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
- data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
- data/vendor/faiss/faiss/impl/io_macros.h +25 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
- data/vendor/faiss/faiss/index_factory.cpp +5 -1
- data/vendor/faiss/faiss/index_io.h +16 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
- data/vendor/faiss/faiss/utils/bf16.h +34 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
- metadata +12 -2
|
@@ -7,6 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <cstring>
|
|
11
|
+
|
|
10
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
11
13
|
#include <faiss/impl/DistanceComputer.h>
|
|
12
14
|
#include <faiss/impl/Quantizer.h>
|
|
@@ -39,6 +41,10 @@ struct ScalarQuantizer : Quantizer {
|
|
|
39
41
|
QT_3bit_tqmse, ///< TurboQuant MSE-optimized, 3 bits per component
|
|
40
42
|
QT_4bit_tqmse, ///< TurboQuant MSE-optimized, 4 bits per component
|
|
41
43
|
QT_8bit_tqmse, ///< TurboQuant MSE-optimized, 8 bits per component
|
|
44
|
+
QT_2bit_tq, ///< Full TurboQuant (1-bit MSE + 1-bit QJL + factors)
|
|
45
|
+
QT_3bit_tq, ///< Full TurboQuant (2-bit MSE + 1-bit QJL + factors)
|
|
46
|
+
QT_4bit_tq, ///< Full TurboQuant (3-bit MSE + 1-bit QJL + factors)
|
|
47
|
+
QT_5bit_tq, ///< Full TurboQuant (4-bit MSE + 1-bit QJL + factors)
|
|
42
48
|
QT_count
|
|
43
49
|
};
|
|
44
50
|
|
|
@@ -131,6 +137,50 @@ struct ScalarQuantizer : Quantizer {
|
|
|
131
137
|
}
|
|
132
138
|
};
|
|
133
139
|
|
|
140
|
+
/// TurboQuant full (QT_*_tq) refinement state, isolated from the
|
|
141
|
+
/// main ScalarQuantizer to avoid polluting it with TQ-specific data.
|
|
142
|
+
struct TurboQuantRefine {
|
|
143
|
+
static bool is_turboq_full(QuantizerType qt) {
|
|
144
|
+
return qt >= QT_2bit_tq && qt <= QT_5bit_tq;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
static void pack_seed(uint64_t seed, float out[2]) {
|
|
148
|
+
static_assert(sizeof(uint64_t) == 2 * sizeof(float));
|
|
149
|
+
std::memcpy(out, &seed, sizeof(uint64_t));
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
static uint64_t unpack_seed(float lo, float hi) {
|
|
153
|
+
float tmp[2] = {lo, hi};
|
|
154
|
+
uint64_t s;
|
|
155
|
+
static_assert(sizeof(uint64_t) == 2 * sizeof(float));
|
|
156
|
+
std::memcpy(&s, tmp, sizeof(uint64_t));
|
|
157
|
+
return s;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
uint8_t qjl_type = 0;
|
|
161
|
+
uint64_t seed = 42;
|
|
162
|
+
size_t padded_d = 0;
|
|
163
|
+
std::vector<float> fwht_signs;
|
|
164
|
+
std::vector<float> rr_matrix;
|
|
165
|
+
size_t nb_bits_lo = 0;
|
|
166
|
+
size_t n_hi_dims = 0;
|
|
167
|
+
|
|
168
|
+
void init_projection(size_t d);
|
|
169
|
+
bool use_fwht() const {
|
|
170
|
+
return qjl_type == 0;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
struct DistanceComputer : SQDistanceComputer {
|
|
174
|
+
virtual void configure(uint8_t qb, bool int_qjl) = 0;
|
|
175
|
+
virtual void set_prescreen_threshold(
|
|
176
|
+
const float* t,
|
|
177
|
+
bool minimize) = 0;
|
|
178
|
+
virtual void clear_prescreen_threshold() = 0;
|
|
179
|
+
};
|
|
180
|
+
};
|
|
181
|
+
|
|
182
|
+
TurboQuantRefine turboq_refine;
|
|
183
|
+
|
|
134
184
|
SQDistanceComputer* get_distance_computer(
|
|
135
185
|
MetricType metric = METRIC_L2) const;
|
|
136
186
|
|
|
@@ -18,19 +18,19 @@ namespace faiss {
|
|
|
18
18
|
// A size of ~1M seems to be the threshold where the hash set wins.
|
|
19
19
|
size_t visited_table_hashset_threshold = 500000;
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
21
|
+
std::unique_ptr<VisitedTable> VisitedTable::create(
|
|
22
|
+
size_t size,
|
|
23
|
+
std::optional<bool> use_hashset) {
|
|
24
|
+
bool use_set =
|
|
25
|
+
use_hashset.value_or(size >= visited_table_hashset_threshold);
|
|
26
|
+
if (use_set) {
|
|
27
|
+
return std::make_unique<VisitedTableSet>();
|
|
27
28
|
}
|
|
29
|
+
return std::make_unique<VisitedTableVector>(size);
|
|
28
30
|
}
|
|
29
31
|
|
|
30
|
-
void
|
|
31
|
-
if (visno
|
|
32
|
-
visited_set.clear();
|
|
33
|
-
} else if (visno < 254) {
|
|
32
|
+
void VisitedTableVector::advance() {
|
|
33
|
+
if (visno < 254) {
|
|
34
34
|
// 254 rather than 255 because sometimes we use visno and visno+1
|
|
35
35
|
++visno;
|
|
36
36
|
} else {
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
#include <stdint.h>
|
|
12
12
|
|
|
13
|
+
#include <memory>
|
|
13
14
|
#include <optional>
|
|
14
15
|
#include <unordered_set>
|
|
15
16
|
#include <vector>
|
|
@@ -21,54 +22,88 @@ namespace faiss {
|
|
|
21
22
|
|
|
22
23
|
FAISS_API extern size_t visited_table_hashset_threshold;
|
|
23
24
|
|
|
24
|
-
///
|
|
25
|
+
/// Abstract base class for a fast, reusable Visited Set for graph search
|
|
26
|
+
/// algorithms.
|
|
25
27
|
struct VisitedTable {
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
28
|
+
virtual ~VisitedTable() = default;
|
|
29
|
+
|
|
30
|
+
/// set flag #no to true, return whether this changed it.
|
|
31
|
+
virtual bool set(size_t no) = 0;
|
|
32
|
+
|
|
33
|
+
/// get flag #no
|
|
34
|
+
virtual bool get(size_t no) const = 0;
|
|
35
|
+
|
|
36
|
+
/// prefetch flag #no
|
|
37
|
+
virtual void prefetch(size_t no) const = 0;
|
|
38
|
+
|
|
39
|
+
/// pre-allocate bucket space to avoid rehashing during repeated set() calls
|
|
40
|
+
virtual void reserve(size_t /*n*/) {}
|
|
29
41
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
42
|
+
/// reset all flags to false
|
|
43
|
+
virtual void advance() = 0;
|
|
44
|
+
|
|
45
|
+
/// Factory method to create appropriate implementation.
|
|
46
|
+
/// If use_hashset is nullopt, the use of a hashset will be determined by
|
|
47
|
+
/// size >= visited_table_hashset_threshold.
|
|
48
|
+
static std::unique_ptr<VisitedTable> create(
|
|
33
49
|
size_t size,
|
|
34
50
|
std::optional<bool> use_hashset = std::nullopt);
|
|
51
|
+
};
|
|
35
52
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
}
|
|
53
|
+
/// Set-based implementation using unordered_set.
|
|
54
|
+
/// O(1) to construct and O(visits) to advance.
|
|
55
|
+
struct VisitedTableSet FAISS_FINAL : VisitedTable {
|
|
56
|
+
std::unordered_set<size_t> visited_set;
|
|
57
|
+
|
|
58
|
+
VisitedTableSet() = default;
|
|
59
|
+
|
|
60
|
+
bool set(size_t no) final {
|
|
61
|
+
return visited_set.insert(no).second;
|
|
46
62
|
}
|
|
47
63
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if (visno == 0) {
|
|
51
|
-
visited_set.reserve(n);
|
|
52
|
-
}
|
|
64
|
+
bool get(size_t no) const final {
|
|
65
|
+
return visited_set.count(no) != 0;
|
|
53
66
|
}
|
|
54
67
|
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if (visno == 0) {
|
|
58
|
-
return visited_set.count(no) != 0;
|
|
59
|
-
} else {
|
|
60
|
-
return visited[no] == visno;
|
|
61
|
-
}
|
|
68
|
+
void prefetch(size_t /*no*/) const final {
|
|
69
|
+
// No-op for set-based implementation
|
|
62
70
|
}
|
|
63
71
|
|
|
64
|
-
void
|
|
65
|
-
|
|
66
|
-
|
|
72
|
+
void reserve(size_t n) final {
|
|
73
|
+
visited_set.reserve(n);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
void advance() final {
|
|
77
|
+
visited_set.clear();
|
|
78
|
+
}
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
/// Vector-based implementation using a versioned byte array.
|
|
82
|
+
/// Faster for get()/set(), but O(size) to initialize.
|
|
83
|
+
/// advance() is O(1) except every 250 calls, which are O(size).
|
|
84
|
+
struct VisitedTableVector FAISS_FINAL : VisitedTable {
|
|
85
|
+
std::vector<uint8_t> visited;
|
|
86
|
+
uint8_t visno{1}; // Version number, 1..254
|
|
87
|
+
|
|
88
|
+
explicit VisitedTableVector(size_t size) : visited(size, 0) {}
|
|
89
|
+
|
|
90
|
+
bool set(size_t no) final {
|
|
91
|
+
if (visited[no] == visno) {
|
|
92
|
+
return false;
|
|
67
93
|
}
|
|
94
|
+
visited[no] = visno;
|
|
95
|
+
return true;
|
|
68
96
|
}
|
|
69
97
|
|
|
70
|
-
|
|
71
|
-
|
|
98
|
+
bool get(size_t no) const final {
|
|
99
|
+
return visited[no] == visno;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
void prefetch(size_t no) const final {
|
|
103
|
+
prefetch_L2(&visited[no]);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
void advance() final;
|
|
72
107
|
};
|
|
73
108
|
|
|
74
109
|
} // namespace faiss
|
|
@@ -48,7 +48,9 @@ using namespace simd_result_handlers;
|
|
|
48
48
|
* so callers don't need to know the handler type.
|
|
49
49
|
***************************************************************/
|
|
50
50
|
|
|
51
|
-
|
|
51
|
+
// SIMDLevel SL = THE_LEVEL_TO_DISPATCH added to make the mangled
|
|
52
|
+
// symbol name unique per translation unit.
|
|
53
|
+
template <class Handler, SIMDLevel SL = THE_LEVEL_TO_DISPATCH>
|
|
52
54
|
struct ScannerMixIn : FastScanCodeScanner {
|
|
53
55
|
Handler handler_;
|
|
54
56
|
|
|
@@ -5,39 +5,32 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
#include <cmath>
|
|
9
|
-
|
|
10
8
|
#include <faiss/impl/hnsw/MinimaxHeap.h>
|
|
11
9
|
|
|
12
|
-
#include <cassert>
|
|
13
|
-
|
|
14
10
|
#include <faiss/impl/simd_dispatch.h>
|
|
15
11
|
|
|
16
12
|
namespace faiss {
|
|
17
13
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
--nvalid;
|
|
29
|
-
}
|
|
30
|
-
faiss::heap_pop<HC>(k--, dis.data(), ids.data());
|
|
31
|
-
}
|
|
32
|
-
faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
|
|
33
|
-
++nvalid;
|
|
14
|
+
// Runtime-dispatched pop_min (NONE + AVX2 + AVX512 only).
|
|
15
|
+
constexpr int MINIMAX_HEAP_SIMD_LEVELS = (1 << int(SIMDLevel::NONE)) |
|
|
16
|
+
(1 << int(SIMDLevel::AVX2)) | (1 << int(SIMDLevel::AVX512));
|
|
17
|
+
|
|
18
|
+
template <class HC_>
|
|
19
|
+
int MinimaxHeapT<HC_>::pop_min(float* vmin_out) {
|
|
20
|
+
return with_selected_simd_levels<MINIMAX_HEAP_SIMD_LEVELS>(
|
|
21
|
+
[&]<SIMDLevel SL>() {
|
|
22
|
+
return pop_min_tpl<HC_, SL>(this, vmin_out);
|
|
23
|
+
});
|
|
34
24
|
}
|
|
35
25
|
|
|
36
|
-
//
|
|
37
|
-
template
|
|
38
|
-
int
|
|
26
|
+
// Primary-template scalar implementation. Used directly when SL==NONE
|
|
27
|
+
template <class HC>
|
|
28
|
+
int pop_min_simd_none(MinimaxHeapT<HC>* heap, float* vmin_out) {
|
|
29
|
+
int k = heap->k;
|
|
30
|
+
int* ids = heap->ids.data();
|
|
31
|
+
float* dis = heap->dis.data();
|
|
39
32
|
assert(k > 0);
|
|
40
|
-
//
|
|
33
|
+
// Returns the "best" entry. This is an O(n) operation.
|
|
41
34
|
int i = k - 1;
|
|
42
35
|
while (i >= 0) {
|
|
43
36
|
if (ids[i] != -1) {
|
|
@@ -52,7 +45,8 @@ int MinimaxHeap::pop_min_tpl<SIMDLevel::NONE>(float* vmin_out) {
|
|
|
52
45
|
float vmin = dis[i];
|
|
53
46
|
i--;
|
|
54
47
|
while (i >= 0) {
|
|
55
|
-
|
|
48
|
+
// HC::cmp(vmin, dis[i]) → "dis[i] is better than vmin".
|
|
49
|
+
if (ids[i] != -1 && HC::cmp(vmin, dis[i])) {
|
|
56
50
|
vmin = dis[i];
|
|
57
51
|
imin = i;
|
|
58
52
|
}
|
|
@@ -63,29 +57,27 @@ int MinimaxHeap::pop_min_tpl<SIMDLevel::NONE>(float* vmin_out) {
|
|
|
63
57
|
}
|
|
64
58
|
int ret = ids[imin];
|
|
65
59
|
ids[imin] = -1;
|
|
66
|
-
--nvalid;
|
|
67
|
-
|
|
60
|
+
--heap->nvalid;
|
|
68
61
|
return ret;
|
|
69
62
|
}
|
|
70
63
|
|
|
71
|
-
//
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return
|
|
77
|
-
[&]<SIMDLevel SL>() { return pop_min_tpl<SL>(vmin_out); });
|
|
64
|
+
// declare for min and max heap at simd level NONE
|
|
65
|
+
template <>
|
|
66
|
+
int pop_min_tpl<CMin<float, int32_t>, SIMDLevel::NONE>(
|
|
67
|
+
MinimaxHeapT<CMin<float, int32_t>>* heap,
|
|
68
|
+
float* vmin_out) {
|
|
69
|
+
return pop_min_simd_none(heap, vmin_out);
|
|
78
70
|
}
|
|
79
71
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
}
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
return n_below;
|
|
72
|
+
template <>
|
|
73
|
+
int pop_min_tpl<CMax<float, int32_t>, SIMDLevel::NONE>(
|
|
74
|
+
MinimaxHeapT<CMax<float, int32_t>>* heap,
|
|
75
|
+
float* vmin_out) {
|
|
76
|
+
return pop_min_simd_none(heap, vmin_out);
|
|
89
77
|
}
|
|
90
78
|
|
|
79
|
+
// Explicit instantiations of pop_min for the two HC variants
|
|
80
|
+
template int MinimaxHeapT<CMax<float, int32_t>>::pop_min(float*);
|
|
81
|
+
template int MinimaxHeapT<CMin<float, int32_t>>::pop_min(float*);
|
|
82
|
+
|
|
91
83
|
} // namespace faiss
|
|
@@ -7,21 +7,30 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <cassert>
|
|
11
|
+
#include <cmath>
|
|
10
12
|
#include <cstdint>
|
|
11
13
|
#include <vector>
|
|
12
14
|
|
|
13
15
|
#include <faiss/utils/Heap.h>
|
|
16
|
+
#include <faiss/utils/ordered_key_value.h>
|
|
14
17
|
#include <faiss/utils/simd_levels.h>
|
|
15
18
|
|
|
16
19
|
namespace faiss {
|
|
17
20
|
|
|
18
21
|
/** Heap structure that allows fast access and updates.
|
|
19
22
|
*
|
|
20
|
-
*
|
|
21
|
-
*
|
|
22
|
-
*
|
|
23
|
+
* Templated on the comparator HC_ so that the same data structure can
|
|
24
|
+
* service both distance-style searches (HC_ = CMax<float, int32_t>, smaller
|
|
25
|
+
* is better) and similarity-style searches (HC_ = CMin<float, int32_t>,
|
|
26
|
+
* larger is better). For the distance variant the underlying heap is a
|
|
27
|
+
* max-heap and "pop_min" returns the closest element; for similarity the
|
|
28
|
+
* underlying heap is a min-heap and "pop_min" returns the most similar
|
|
29
|
+
* element.
|
|
23
30
|
*/
|
|
24
|
-
|
|
31
|
+
template <class HC_ = CMax<float, int32_t>>
|
|
32
|
+
struct MinimaxHeapT {
|
|
33
|
+
using HC = HC_;
|
|
25
34
|
using storage_idx_t = int32_t;
|
|
26
35
|
|
|
27
36
|
int n;
|
|
@@ -30,12 +39,34 @@ struct MinimaxHeap {
|
|
|
30
39
|
|
|
31
40
|
std::vector<storage_idx_t> ids;
|
|
32
41
|
std::vector<float> dis;
|
|
33
|
-
using HC = faiss::CMax<float, storage_idx_t>;
|
|
34
42
|
|
|
35
|
-
explicit
|
|
43
|
+
explicit MinimaxHeapT(int n_in)
|
|
36
44
|
: n(n_in), k(0), nvalid(0), ids(n_in), dis(n_in) {}
|
|
37
45
|
|
|
38
|
-
void push(storage_idx_t i, float v)
|
|
46
|
+
void push(storage_idx_t i, float v) {
|
|
47
|
+
// Treat NaN distances as the "worst" value so heap ordering is
|
|
48
|
+
// preserved (insertion is then guaranteed to fall through the
|
|
49
|
+
// not-better-than-top early-reject branch when the heap is full).
|
|
50
|
+
if (std::isnan(v)) {
|
|
51
|
+
v = HC::neutral();
|
|
52
|
+
}
|
|
53
|
+
if (k == n) {
|
|
54
|
+
// top of the heap is the "worst" entry under HC. If the new
|
|
55
|
+
// value is not strictly better than the worst, drop it.
|
|
56
|
+
// HC::cmp(top, v) means "v is better than top" for both CMax
|
|
57
|
+
// (cmp = a > b → top > v → v < top) and CMin (cmp = a < b →
|
|
58
|
+
// top < v → v > top).
|
|
59
|
+
if (!HC::cmp(dis[0], v)) {
|
|
60
|
+
return;
|
|
61
|
+
}
|
|
62
|
+
if (ids[0] != -1) {
|
|
63
|
+
--nvalid;
|
|
64
|
+
}
|
|
65
|
+
faiss::heap_pop<HC>(k--, dis.data(), ids.data());
|
|
66
|
+
}
|
|
67
|
+
faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
|
|
68
|
+
++nvalid;
|
|
69
|
+
}
|
|
39
70
|
|
|
40
71
|
float max() const {
|
|
41
72
|
return dis[0];
|
|
@@ -49,16 +80,34 @@ struct MinimaxHeap {
|
|
|
49
80
|
nvalid = k = 0;
|
|
50
81
|
}
|
|
51
82
|
|
|
52
|
-
///
|
|
53
|
-
/// Specializations exist for NONE, AVX2, and AVX512.
|
|
54
|
-
template <SIMDLevel SL>
|
|
55
|
-
int pop_min_tpl(float* vmin_out = nullptr);
|
|
56
|
-
|
|
57
|
-
/// Runtime-dispatched pop_min (calls pop_min_tpl with best available
|
|
58
|
-
/// SIMD level).
|
|
83
|
+
/// Runtime-dispatched best-element extraction (NONE + AVX2 + AVX512).
|
|
59
84
|
int pop_min(float* vmin_out = nullptr);
|
|
60
85
|
|
|
61
|
-
int count_below(float thresh)
|
|
86
|
+
int count_below(float thresh) {
|
|
87
|
+
int n_below = 0;
|
|
88
|
+
for (int i = 0; i < k; i++) {
|
|
89
|
+
// Count entries that are strictly "better than" thresh.
|
|
90
|
+
// HC::cmp(thresh, dis[i]) → for CMax: thresh > dis[i]
|
|
91
|
+
// (i.e., dis[i] < thresh, the historical L2 semantics);
|
|
92
|
+
// for CMin: thresh < dis[i] (similarity above threshold).
|
|
93
|
+
if (HC::cmp(thresh, dis[i])) {
|
|
94
|
+
n_below++;
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
return n_below;
|
|
98
|
+
}
|
|
62
99
|
};
|
|
63
100
|
|
|
101
|
+
// Default `MinimaxHeap` keeps the historical max-heap semantics (smaller
|
|
102
|
+
// distance is better). The CMin instantiation is used when the owning
|
|
103
|
+
// HNSW has `is_similarity = true`. The alias itself is declared once,
|
|
104
|
+
// alongside the forward declaration in HNSW.h, to avoid duplicate
|
|
105
|
+
// `using` declarations that SWIG treats as redundant.
|
|
106
|
+
|
|
107
|
+
// Forward declarations of the SIMD specializations. The actual bodies live
|
|
108
|
+
// in the SIMD-specific translation units (avx2.cpp, avx512.cpp) and are
|
|
109
|
+
// resolved at link time.
|
|
110
|
+
template <class HC_, SIMDLevel SL>
|
|
111
|
+
int pop_min_tpl(MinimaxHeapT<HC_>* heap, float* vmin_out);
|
|
112
|
+
|
|
64
113
|
} // namespace faiss
|
|
@@ -16,89 +16,135 @@
|
|
|
16
16
|
|
|
17
17
|
namespace faiss {
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
namespace {
|
|
20
|
+
|
|
21
|
+
/// Templated AVX2 implementation of "pop best" for both CMax (returns
|
|
22
|
+
/// the smallest distance) and CMin (returns the largest similarity).
|
|
23
|
+
/// The only differences between the two flavors are: (1) the initial
|
|
24
|
+
/// "worst possible" value, (2) the running-best update comparison
|
|
25
|
+
/// (`_CMP_LT_OS` vs `_CMP_GT_OS`), and (3) the tiebreaker direction.
|
|
26
|
+
template <class HC>
|
|
27
|
+
int pop_best_avx2(MinimaxHeapT<HC>& heap, float* vmin_out) {
|
|
28
|
+
using storage_idx_t = typename MinimaxHeapT<HC>::storage_idx_t;
|
|
22
29
|
static_assert(
|
|
23
30
|
std::is_same<storage_idx_t, int32_t>::value,
|
|
24
31
|
"This code expects storage_idx_t to be int32_t");
|
|
32
|
+
assert(heap.k > 0);
|
|
33
|
+
|
|
34
|
+
// For CMax (distance) the "best" candidate is the smallest value, so
|
|
35
|
+
// we initialize the running best to +inf. For CMin (similarity) the
|
|
36
|
+
// best is the largest value, so we initialize to -inf.
|
|
37
|
+
constexpr float worst_v = HC::is_max
|
|
38
|
+
? std::numeric_limits<float>::infinity()
|
|
39
|
+
: -std::numeric_limits<float>::infinity();
|
|
25
40
|
|
|
26
|
-
int32_t
|
|
27
|
-
float
|
|
41
|
+
int32_t best_idx = -1;
|
|
42
|
+
float best_dis = worst_v;
|
|
28
43
|
|
|
29
44
|
size_t iii = 0;
|
|
30
45
|
|
|
31
|
-
__m256i
|
|
32
|
-
__m256
|
|
33
|
-
_mm256_set1_ps(std::numeric_limits<float>::infinity());
|
|
46
|
+
__m256i best_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
|
|
47
|
+
__m256 best_distances = _mm256_set1_ps(worst_v);
|
|
34
48
|
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
|
|
35
49
|
__m256i offset = _mm256_set1_epi32(8);
|
|
36
50
|
|
|
37
|
-
//
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
// -1 index values are ignored.
|
|
41
|
-
const size_t k8 = (k / 8) * 8;
|
|
51
|
+
// Track the rightmost index whose distance equals the running best.
|
|
52
|
+
// -1 index values are filtered out via m1mask.
|
|
53
|
+
const size_t k8 = (heap.k / 8) * 8;
|
|
42
54
|
for (; iii < k8; iii += 8) {
|
|
43
55
|
__m256i indices =
|
|
44
|
-
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
|
|
45
|
-
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
|
|
56
|
+
_mm256_loadu_si256((const __m256i*)(heap.ids.data() + iii));
|
|
57
|
+
__m256 distances = _mm256_loadu_ps(heap.dis.data() + iii);
|
|
46
58
|
|
|
47
|
-
//
|
|
59
|
+
// Mask out -1 indices (invalid entries).
|
|
48
60
|
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
|
|
49
61
|
|
|
50
|
-
|
|
51
|
-
|
|
62
|
+
// dmask is "true where best is already (strictly) better than the
|
|
63
|
+
// candidate" — entries the candidate should NOT update. For CMax,
|
|
64
|
+
// best < candidate means we keep best (we want the smallest);
|
|
65
|
+
// for CMin we keep best when best > candidate (we want the largest).
|
|
66
|
+
__m256i dmask;
|
|
67
|
+
if constexpr (HC::is_max) {
|
|
68
|
+
dmask = _mm256_castps_si256(
|
|
69
|
+
_mm256_cmp_ps(best_distances, distances, _CMP_LT_OS));
|
|
70
|
+
} else {
|
|
71
|
+
dmask = _mm256_castps_si256(
|
|
72
|
+
_mm256_cmp_ps(best_distances, distances, _CMP_GT_OS));
|
|
73
|
+
}
|
|
52
74
|
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
|
|
53
75
|
|
|
54
|
-
const __m256i
|
|
76
|
+
const __m256i best_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
|
|
55
77
|
_mm256_castsi256_ps(current_indices),
|
|
56
|
-
_mm256_castsi256_ps(
|
|
78
|
+
_mm256_castsi256_ps(best_indices),
|
|
57
79
|
finalmask));
|
|
58
80
|
|
|
59
|
-
const __m256
|
|
60
|
-
_mm256_blendv_ps(distances,
|
|
81
|
+
const __m256 best_distances_new =
|
|
82
|
+
_mm256_blendv_ps(distances, best_distances, finalmask);
|
|
61
83
|
|
|
62
|
-
|
|
63
|
-
|
|
84
|
+
best_indices = best_indices_new;
|
|
85
|
+
best_distances = best_distances_new;
|
|
64
86
|
|
|
65
87
|
current_indices = _mm256_add_epi32(current_indices, offset);
|
|
66
88
|
}
|
|
67
89
|
|
|
68
|
-
// Vectorizing is doable
|
|
90
|
+
// Vectorizing the horizontal reduction is doable but not practical.
|
|
69
91
|
int32_t vidx8[8];
|
|
70
92
|
float vdis8[8];
|
|
71
|
-
_mm256_storeu_ps(vdis8,
|
|
72
|
-
_mm256_storeu_si256((__m256i*)vidx8,
|
|
93
|
+
_mm256_storeu_ps(vdis8, best_distances);
|
|
94
|
+
_mm256_storeu_si256((__m256i*)vidx8, best_indices);
|
|
73
95
|
|
|
74
96
|
for (size_t j = 0; j < 8; j++) {
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
97
|
+
const bool strictly_better =
|
|
98
|
+
HC::is_max ? (best_dis > vdis8[j]) : (best_dis < vdis8[j]);
|
|
99
|
+
if (strictly_better || (best_dis == vdis8[j] && best_idx < vidx8[j])) {
|
|
100
|
+
best_idx = vidx8[j];
|
|
101
|
+
best_dis = vdis8[j];
|
|
78
102
|
}
|
|
79
103
|
}
|
|
80
104
|
|
|
81
|
-
//
|
|
82
|
-
for (; iii < static_cast<size_t>(k); iii++) {
|
|
83
|
-
if (ids[iii]
|
|
84
|
-
|
|
85
|
-
|
|
105
|
+
// Tail (under 8 entries). Vectorizing is doable but not practical.
|
|
106
|
+
for (; iii < static_cast<size_t>(heap.k); iii++) {
|
|
107
|
+
if (heap.ids[iii] == -1) {
|
|
108
|
+
continue;
|
|
109
|
+
}
|
|
110
|
+
const bool weakly_better = HC::is_max ? (best_dis >= heap.dis[iii])
|
|
111
|
+
: (best_dis <= heap.dis[iii]);
|
|
112
|
+
if (weakly_better) {
|
|
113
|
+
best_dis = heap.dis[iii];
|
|
114
|
+
best_idx = iii;
|
|
86
115
|
}
|
|
87
116
|
}
|
|
88
117
|
|
|
89
|
-
if (
|
|
118
|
+
if (best_idx == -1) {
|
|
90
119
|
return -1;
|
|
91
120
|
}
|
|
92
121
|
|
|
93
122
|
if (vmin_out) {
|
|
94
|
-
*vmin_out =
|
|
123
|
+
*vmin_out = best_dis;
|
|
95
124
|
}
|
|
96
|
-
int ret = ids[
|
|
97
|
-
ids[
|
|
98
|
-
--nvalid;
|
|
125
|
+
int ret = heap.ids[best_idx];
|
|
126
|
+
heap.ids[best_idx] = -1;
|
|
127
|
+
--heap.nvalid;
|
|
99
128
|
return ret;
|
|
100
129
|
}
|
|
101
130
|
|
|
131
|
+
} // namespace
|
|
132
|
+
|
|
133
|
+
// Explicit specializations for AVX2
|
|
134
|
+
template <>
|
|
135
|
+
int pop_min_tpl<CMax<float, int32_t>, SIMDLevel::AVX2>(
|
|
136
|
+
MinimaxHeapT<CMax<float, int32_t>>* heap,
|
|
137
|
+
float* vmin_out) {
|
|
138
|
+
return pop_best_avx2<CMax<float, int32_t>>(*heap, vmin_out);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
template <>
|
|
142
|
+
int pop_min_tpl<CMin<float, int32_t>, SIMDLevel::AVX2>(
|
|
143
|
+
MinimaxHeapT<CMin<float, int32_t>>* heap,
|
|
144
|
+
float* vmin_out) {
|
|
145
|
+
return pop_best_avx2<CMin<float, int32_t>>(*heap, vmin_out);
|
|
146
|
+
}
|
|
147
|
+
|
|
102
148
|
} // namespace faiss
|
|
103
149
|
|
|
104
150
|
#endif // COMPILE_SIMD_AVX2
|