faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/Panorama.h>
|
|
9
|
+
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <cmath>
|
|
12
|
+
#include <cstring>
|
|
13
|
+
#include <vector>
|
|
14
|
+
|
|
15
|
+
namespace faiss {
|
|
16
|
+
|
|
17
|
+
/**************************************************************
|
|
18
|
+
* Panorama structure implementation
|
|
19
|
+
**************************************************************/
|
|
20
|
+
|
|
21
|
+
Panorama::Panorama(size_t code_size, size_t n_levels, size_t batch_size)
|
|
22
|
+
: code_size(code_size), n_levels(n_levels), batch_size(batch_size) {
|
|
23
|
+
set_derived_values();
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
void Panorama::set_derived_values() {
|
|
27
|
+
this->d = code_size / sizeof(float);
|
|
28
|
+
this->level_width_floats = ((d + n_levels - 1) / n_levels);
|
|
29
|
+
this->level_width = this->level_width_floats * sizeof(float);
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* @brief Copy codes to level-oriented layout
|
|
34
|
+
* @param codes The base pointer to codes
|
|
35
|
+
* @param offset Where to start writing new data (in number of vectors)
|
|
36
|
+
* @param n_entry The number of new vectors to write
|
|
37
|
+
* @param code The new vector data
|
|
38
|
+
*/
|
|
39
|
+
void Panorama::copy_codes_to_level_layout(
|
|
40
|
+
uint8_t* codes,
|
|
41
|
+
size_t offset,
|
|
42
|
+
size_t n_entry,
|
|
43
|
+
const uint8_t* code) {
|
|
44
|
+
for (size_t entry_idx = 0; entry_idx < n_entry; entry_idx++) {
|
|
45
|
+
size_t current_pos = offset + entry_idx;
|
|
46
|
+
|
|
47
|
+
// Determine which batch we're in and position within that batch.
|
|
48
|
+
size_t batch_no = current_pos / batch_size;
|
|
49
|
+
size_t pos_in_batch = current_pos % batch_size;
|
|
50
|
+
|
|
51
|
+
// Copy entry into level-oriented layout for this batch.
|
|
52
|
+
size_t batch_offset = batch_no * batch_size * code_size;
|
|
53
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
54
|
+
size_t level_offset = level * level_width * batch_size;
|
|
55
|
+
size_t start_byte = level * level_width;
|
|
56
|
+
size_t actual_level_width =
|
|
57
|
+
std::min(level_width, code_size - level * level_width);
|
|
58
|
+
|
|
59
|
+
const uint8_t* src = code + entry_idx * code_size + start_byte;
|
|
60
|
+
uint8_t* dest = codes + batch_offset + level_offset +
|
|
61
|
+
pos_in_batch * actual_level_width;
|
|
62
|
+
|
|
63
|
+
memcpy(dest, src, actual_level_width);
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
void Panorama::compute_cumulative_sums(
|
|
69
|
+
float* cumsum_base,
|
|
70
|
+
size_t offset,
|
|
71
|
+
size_t n_entry,
|
|
72
|
+
const float* vectors) {
|
|
73
|
+
std::vector<float> suffix_sums(d + 1);
|
|
74
|
+
|
|
75
|
+
for (size_t entry_idx = 0; entry_idx < n_entry; entry_idx++) {
|
|
76
|
+
size_t current_pos = offset + entry_idx;
|
|
77
|
+
size_t batch_no = current_pos / batch_size;
|
|
78
|
+
size_t pos_in_batch = current_pos % batch_size;
|
|
79
|
+
|
|
80
|
+
const float* vector = vectors + entry_idx * d;
|
|
81
|
+
|
|
82
|
+
// Compute suffix sums of squared values.
|
|
83
|
+
suffix_sums[d] = 0.0f;
|
|
84
|
+
for (int j = d - 1; j >= 0; j--) {
|
|
85
|
+
float squared_val = vector[j] * vector[j];
|
|
86
|
+
suffix_sums[j] = suffix_sums[j + 1] + squared_val;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// Store cumulative sums in batch-oriented layout.
|
|
90
|
+
size_t cumsum_batch_offset = batch_no * batch_size * (n_levels + 1);
|
|
91
|
+
|
|
92
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
93
|
+
size_t start_idx = level * level_width_floats;
|
|
94
|
+
size_t cumsum_offset =
|
|
95
|
+
cumsum_batch_offset + level * batch_size + pos_in_batch;
|
|
96
|
+
if (start_idx < d) {
|
|
97
|
+
cumsum_base[cumsum_offset] = std::sqrt(suffix_sums[start_idx]);
|
|
98
|
+
} else {
|
|
99
|
+
cumsum_base[cumsum_offset] = 0.0f;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Last level sum is always 0.
|
|
104
|
+
size_t cumsum_offset =
|
|
105
|
+
cumsum_batch_offset + n_levels * batch_size + pos_in_batch;
|
|
106
|
+
cumsum_base[cumsum_offset] = 0.0f;
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
void Panorama::compute_query_cum_sums(const float* query, float* query_cum_sums)
|
|
111
|
+
const {
|
|
112
|
+
std::vector<float> suffix_sums(d + 1);
|
|
113
|
+
suffix_sums[d] = 0.0f;
|
|
114
|
+
|
|
115
|
+
for (int j = d - 1; j >= 0; j--) {
|
|
116
|
+
float squared_val = query[j] * query[j];
|
|
117
|
+
suffix_sums[j] = suffix_sums[j + 1] + squared_val;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
121
|
+
size_t start_idx = level * level_width_floats;
|
|
122
|
+
if (start_idx < d) {
|
|
123
|
+
query_cum_sums[level] = std::sqrt(suffix_sums[start_idx]);
|
|
124
|
+
} else {
|
|
125
|
+
query_cum_sums[level] = 0.0f;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
query_cum_sums[n_levels] = 0.0f;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
void Panorama::reconstruct(idx_t key, float* recons, const uint8_t* codes_base)
|
|
133
|
+
const {
|
|
134
|
+
uint8_t* recons_buffer = reinterpret_cast<uint8_t*>(recons);
|
|
135
|
+
|
|
136
|
+
size_t batch_no = key / batch_size;
|
|
137
|
+
size_t pos_in_batch = key % batch_size;
|
|
138
|
+
size_t batch_offset = batch_no * batch_size * code_size;
|
|
139
|
+
|
|
140
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
141
|
+
size_t level_offset = level * level_width * batch_size;
|
|
142
|
+
const uint8_t* src = codes_base + batch_offset + level_offset +
|
|
143
|
+
pos_in_batch * level_width;
|
|
144
|
+
uint8_t* dest = recons_buffer + level * level_width;
|
|
145
|
+
size_t copy_size =
|
|
146
|
+
std::min(level_width, code_size - level * level_width);
|
|
147
|
+
memcpy(dest, src, copy_size);
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
void Panorama::copy_entry(
|
|
152
|
+
uint8_t* dest_codes,
|
|
153
|
+
uint8_t* src_codes,
|
|
154
|
+
float* dest_cum_sums,
|
|
155
|
+
float* src_cum_sums,
|
|
156
|
+
size_t dest_idx,
|
|
157
|
+
size_t src_idx) const {
|
|
158
|
+
// Calculate positions
|
|
159
|
+
size_t src_batch_no = src_idx / batch_size;
|
|
160
|
+
size_t src_pos_in_batch = src_idx % batch_size;
|
|
161
|
+
size_t dest_batch_no = dest_idx / batch_size;
|
|
162
|
+
size_t dest_pos_in_batch = dest_idx % batch_size;
|
|
163
|
+
|
|
164
|
+
// Calculate offsets
|
|
165
|
+
size_t src_batch_offset = src_batch_no * batch_size * code_size;
|
|
166
|
+
size_t dest_batch_offset = dest_batch_no * batch_size * code_size;
|
|
167
|
+
size_t src_cumsum_batch_offset = src_batch_no * batch_size * (n_levels + 1);
|
|
168
|
+
size_t dest_cumsum_batch_offset =
|
|
169
|
+
dest_batch_no * batch_size * (n_levels + 1);
|
|
170
|
+
|
|
171
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
172
|
+
// Copy code
|
|
173
|
+
size_t level_offset = level * level_width * batch_size;
|
|
174
|
+
size_t actual_level_width =
|
|
175
|
+
std::min(level_width, code_size - level * level_width);
|
|
176
|
+
|
|
177
|
+
const uint8_t* src = src_codes + src_batch_offset + level_offset +
|
|
178
|
+
src_pos_in_batch * actual_level_width;
|
|
179
|
+
uint8_t* dest = dest_codes + dest_batch_offset + level_offset +
|
|
180
|
+
dest_pos_in_batch * actual_level_width;
|
|
181
|
+
memcpy(dest, src, actual_level_width);
|
|
182
|
+
|
|
183
|
+
// Copy cum_sums
|
|
184
|
+
size_t cumsum_level_offset = level * batch_size;
|
|
185
|
+
|
|
186
|
+
const size_t src_offset = src_cumsum_batch_offset +
|
|
187
|
+
cumsum_level_offset + src_pos_in_batch;
|
|
188
|
+
size_t dest_offset = dest_cumsum_batch_offset + cumsum_level_offset +
|
|
189
|
+
dest_pos_in_batch;
|
|
190
|
+
dest_cum_sums[dest_offset] = src_cum_sums[src_offset];
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
} // namespace faiss
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_PANORAMA_H
|
|
11
|
+
#define FAISS_PANORAMA_H
|
|
12
|
+
|
|
13
|
+
#include <faiss/impl/IDSelector.h>
|
|
14
|
+
#include <faiss/impl/PanoramaStats.h>
|
|
15
|
+
#include <faiss/utils/distances.h>
|
|
16
|
+
|
|
17
|
+
#include <algorithm>
|
|
18
|
+
#include <cstddef>
|
|
19
|
+
#include <cstdint>
|
|
20
|
+
#include <vector>
|
|
21
|
+
|
|
22
|
+
namespace faiss {
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* Implements the core logic of Panorama-based refinement.
|
|
26
|
+
* arXiv: https://arxiv.org/abs/2510.00566
|
|
27
|
+
*
|
|
28
|
+
* Panorama partitions the dimensions of all vectors into L contiguous levels.
|
|
29
|
+
* During the refinement stage of ANNS, it computes distances between the query
|
|
30
|
+
* and its candidates level-by-level. After processing each level, it prunes the
|
|
31
|
+
* candidates whose lower bound exceeds the k-th best distance.
|
|
32
|
+
*
|
|
33
|
+
* In order to enable speedups, the dimensions (or codes) of each vector are
|
|
34
|
+
* stored in a batched, level-major manner. Within each batch of b vectors, the
|
|
35
|
+
* dimensions corresponding to level 1 will be stored first (for all elements in
|
|
36
|
+
* that batch), followed by level 2, and so on. This allows for efficient memory
|
|
37
|
+
* access patterns.
|
|
38
|
+
*
|
|
39
|
+
* Coupled with the appropriate orthogonal PreTransform (e.g. PCA, Cayley,
|
|
40
|
+
* etc.), Panorama can prune the vast majority of dimensions, greatly
|
|
41
|
+
* accelerating the refinement stage.
|
|
42
|
+
*/
|
|
43
|
+
struct Panorama {
|
|
44
|
+
size_t d = 0;
|
|
45
|
+
size_t code_size = 0;
|
|
46
|
+
size_t n_levels = 0;
|
|
47
|
+
size_t level_width = 0;
|
|
48
|
+
size_t level_width_floats = 0;
|
|
49
|
+
size_t batch_size = 0;
|
|
50
|
+
|
|
51
|
+
explicit Panorama(size_t code_size, size_t n_levels, size_t batch_size);
|
|
52
|
+
|
|
53
|
+
void set_derived_values();
|
|
54
|
+
|
|
55
|
+
/// Helper method to copy codes into level-oriented batch layout at a given
|
|
56
|
+
/// offset in the list.
|
|
57
|
+
void copy_codes_to_level_layout(
|
|
58
|
+
uint8_t* codes,
|
|
59
|
+
size_t offset,
|
|
60
|
+
size_t n_entry,
|
|
61
|
+
const uint8_t* code);
|
|
62
|
+
|
|
63
|
+
/// Helper method to compute the cumulative sums of the codes.
|
|
64
|
+
/// The cumsums also follow the level-oriented batch layout to minimize the
|
|
65
|
+
/// number of random memory accesses.
|
|
66
|
+
void compute_cumulative_sums(
|
|
67
|
+
float* cumsum_base,
|
|
68
|
+
size_t offset,
|
|
69
|
+
size_t n_entry,
|
|
70
|
+
const float* vectors);
|
|
71
|
+
|
|
72
|
+
/// Compute the cumulative sums of the query vector.
|
|
73
|
+
void compute_query_cum_sums(const float* query, float* query_cum_sums)
|
|
74
|
+
const;
|
|
75
|
+
|
|
76
|
+
/// Copy single entry (code and cum_sum) from one location to another.
|
|
77
|
+
void copy_entry(
|
|
78
|
+
uint8_t* dest_codes,
|
|
79
|
+
uint8_t* src_codes,
|
|
80
|
+
float* dest_cum_sums,
|
|
81
|
+
float* src_cum_sums,
|
|
82
|
+
size_t dest_idx,
|
|
83
|
+
size_t src_idx) const;
|
|
84
|
+
|
|
85
|
+
/// Panorama's core progressive filtering algorithm:
|
|
86
|
+
/// Process vectors in batches for cache efficiency. For each batch:
|
|
87
|
+
/// 1. Apply ID selection filter and initialize distances
|
|
88
|
+
/// (||y||^2 + ||x||^2).
|
|
89
|
+
/// 2. Maintain an "active set" of candidate indices that haven't been
|
|
90
|
+
/// pruned yet.
|
|
91
|
+
/// 3. For each level, refine distances incrementally and compact the active
|
|
92
|
+
/// set:
|
|
93
|
+
/// - Compute dot product for current level: exact_dist -= 2*<x,y>.
|
|
94
|
+
/// - Use Cauchy-Schwarz bound on remaining levels to get lower bound
|
|
95
|
+
/// - Prune candidates whose lower bound exceeds k-th best distance.
|
|
96
|
+
/// - Compact active_indices to remove pruned candidates (branchless)
|
|
97
|
+
/// 4. After all levels, survivors are exact distances; update heap.
|
|
98
|
+
/// This achieves early termination while maintaining SIMD-friendly
|
|
99
|
+
/// sequential access patterns in the level-oriented storage layout.
|
|
100
|
+
template <typename C>
|
|
101
|
+
size_t progressive_filter_batch(
|
|
102
|
+
const uint8_t* codes_base,
|
|
103
|
+
const float* cum_sums,
|
|
104
|
+
const float* query,
|
|
105
|
+
const float* query_cum_sums,
|
|
106
|
+
size_t batch_no,
|
|
107
|
+
size_t list_size,
|
|
108
|
+
const IDSelector* sel,
|
|
109
|
+
const idx_t* ids,
|
|
110
|
+
bool use_sel,
|
|
111
|
+
std::vector<uint32_t>& active_indices,
|
|
112
|
+
std::vector<float>& exact_distances,
|
|
113
|
+
float threshold,
|
|
114
|
+
PanoramaStats& local_stats) const;
|
|
115
|
+
|
|
116
|
+
void reconstruct(idx_t key, float* recons, const uint8_t* codes_base) const;
|
|
117
|
+
};
|
|
118
|
+
|
|
119
|
+
template <typename C>
|
|
120
|
+
size_t Panorama::progressive_filter_batch(
|
|
121
|
+
const uint8_t* codes_base,
|
|
122
|
+
const float* cum_sums,
|
|
123
|
+
const float* query,
|
|
124
|
+
const float* query_cum_sums,
|
|
125
|
+
size_t batch_no,
|
|
126
|
+
size_t list_size,
|
|
127
|
+
const IDSelector* sel,
|
|
128
|
+
const idx_t* ids,
|
|
129
|
+
bool use_sel,
|
|
130
|
+
std::vector<uint32_t>& active_indices,
|
|
131
|
+
std::vector<float>& exact_distances,
|
|
132
|
+
float threshold,
|
|
133
|
+
PanoramaStats& local_stats) const {
|
|
134
|
+
size_t batch_start = batch_no * batch_size;
|
|
135
|
+
size_t curr_batch_size = std::min(list_size - batch_start, batch_size);
|
|
136
|
+
|
|
137
|
+
size_t cumsum_batch_offset = batch_no * batch_size * (n_levels + 1);
|
|
138
|
+
const float* batch_cum_sums = cum_sums + cumsum_batch_offset;
|
|
139
|
+
const float* level_cum_sums = batch_cum_sums + batch_size;
|
|
140
|
+
float q_norm = query_cum_sums[0] * query_cum_sums[0];
|
|
141
|
+
|
|
142
|
+
size_t batch_offset = batch_no * batch_size * code_size;
|
|
143
|
+
const uint8_t* storage_base = codes_base + batch_offset;
|
|
144
|
+
|
|
145
|
+
// Initialize active set with ID-filtered vectors.
|
|
146
|
+
size_t num_active = 0;
|
|
147
|
+
for (size_t i = 0; i < curr_batch_size; i++) {
|
|
148
|
+
size_t global_idx = batch_start + i;
|
|
149
|
+
idx_t id = (ids == nullptr) ? global_idx : ids[global_idx];
|
|
150
|
+
bool include = !use_sel || sel->is_member(id);
|
|
151
|
+
|
|
152
|
+
active_indices[num_active] = i;
|
|
153
|
+
float cum_sum = batch_cum_sums[i];
|
|
154
|
+
exact_distances[i] = cum_sum * cum_sum + q_norm;
|
|
155
|
+
|
|
156
|
+
num_active += include;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if (num_active == 0) {
|
|
160
|
+
return 0;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
size_t total_active = num_active;
|
|
164
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
165
|
+
local_stats.total_dims_scanned += num_active;
|
|
166
|
+
local_stats.total_dims += total_active;
|
|
167
|
+
|
|
168
|
+
float query_cum_norm = query_cum_sums[level + 1];
|
|
169
|
+
|
|
170
|
+
size_t level_offset = level * level_width * batch_size;
|
|
171
|
+
const float* level_storage =
|
|
172
|
+
(const float*)(storage_base + level_offset);
|
|
173
|
+
|
|
174
|
+
size_t next_active = 0;
|
|
175
|
+
for (size_t i = 0; i < num_active; i++) {
|
|
176
|
+
uint32_t idx = active_indices[i];
|
|
177
|
+
size_t actual_level_width = std::min(
|
|
178
|
+
level_width_floats, d - level * level_width_floats);
|
|
179
|
+
|
|
180
|
+
const float* yj = level_storage + idx * actual_level_width;
|
|
181
|
+
const float* query_level = query + level * level_width_floats;
|
|
182
|
+
|
|
183
|
+
float dot_product =
|
|
184
|
+
fvec_inner_product(query_level, yj, actual_level_width);
|
|
185
|
+
|
|
186
|
+
exact_distances[idx] -= 2.0f * dot_product;
|
|
187
|
+
|
|
188
|
+
float cum_sum = level_cum_sums[idx];
|
|
189
|
+
float cauchy_schwarz_bound = 2.0f * cum_sum * query_cum_norm;
|
|
190
|
+
float lower_bound = exact_distances[idx] - cauchy_schwarz_bound;
|
|
191
|
+
|
|
192
|
+
active_indices[next_active] = idx;
|
|
193
|
+
next_active += C::cmp(threshold, lower_bound) ? 1 : 0;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
num_active = next_active;
|
|
197
|
+
level_cum_sums += batch_size;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
return num_active;
|
|
201
|
+
}
|
|
202
|
+
} // namespace faiss
|
|
203
|
+
|
|
204
|
+
#endif
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
9
|
+
|
|
10
|
+
namespace faiss {
|
|
11
|
+
|
|
12
|
+
// NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
|
|
13
|
+
RaBitQStats rabitq_stats;
|
|
14
|
+
|
|
15
|
+
void RaBitQStats::reset() {
|
|
16
|
+
n_1bit_evaluations = 0;
|
|
17
|
+
n_multibit_evaluations = 0;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
double RaBitQStats::skip_percentage() const {
|
|
21
|
+
const size_t copy_n_1bit_evaluations = n_1bit_evaluations;
|
|
22
|
+
const size_t copy_n_multibit_evaluations = n_multibit_evaluations;
|
|
23
|
+
return copy_n_1bit_evaluations > 0
|
|
24
|
+
? 100.0 * (copy_n_1bit_evaluations - copy_n_multibit_evaluations) /
|
|
25
|
+
copy_n_1bit_evaluations
|
|
26
|
+
: 0.0;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
} // namespace faiss
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/platform_macros.h>
|
|
11
|
+
#include <cstddef>
|
|
12
|
+
|
|
13
|
+
namespace faiss {
|
|
14
|
+
|
|
15
|
+
/// Statistics for RaBitQ multi-bit two-stage search.
|
|
16
|
+
///
|
|
17
|
+
/// These stats are ONLY collected for multi-bit mode (nb_bits > 1).
|
|
18
|
+
/// In 1-bit mode, there is no two-stage filtering - all candidates are
|
|
19
|
+
/// evaluated with a single distance computation, so there is nothing
|
|
20
|
+
/// meaningful to track. For 1-bit mode, both counters remain 0.
|
|
21
|
+
///
|
|
22
|
+
/// Multi-bit mode uses a two-stage search:
|
|
23
|
+
/// Stage 1: Compute 1-bit lower bound distance for all candidates
|
|
24
|
+
/// Stage 2: Compute full multi-bit distance only for promising candidates
|
|
25
|
+
///
|
|
26
|
+
/// The skip_percentage() metric measures filtering effectiveness:
|
|
27
|
+
/// how many candidates were filtered out by the 1-bit lower bound
|
|
28
|
+
/// without needing the more expensive multi-bit distance computation.
|
|
29
|
+
///
|
|
30
|
+
/// WARNING: Statistics are not robust to internal threading nor to
|
|
31
|
+
/// concurrent RaBitQ searches. Use these values in a single-threaded
|
|
32
|
+
/// context to accurately gauge RaBitQ's filtering effectiveness.
|
|
33
|
+
/// Call reset() before search, then read stats after search completes.
|
|
34
|
+
struct RaBitQStats {
|
|
35
|
+
/// Number of candidates evaluated using 1-bit (lower bound) distance.
|
|
36
|
+
/// This is the first stage of two-stage search in multi-bit mode.
|
|
37
|
+
/// Always 0 in 1-bit mode (stats not tracked).
|
|
38
|
+
size_t n_1bit_evaluations = 0;
|
|
39
|
+
|
|
40
|
+
/// Number of candidates that passed 1-bit filtering and required
|
|
41
|
+
/// full multi-bit distance computation (second stage).
|
|
42
|
+
/// Always 0 in 1-bit mode (stats not tracked).
|
|
43
|
+
size_t n_multibit_evaluations = 0;
|
|
44
|
+
|
|
45
|
+
void reset();
|
|
46
|
+
|
|
47
|
+
/// Compute percentage of candidates skipped (filtered out by 1-bit stage).
|
|
48
|
+
/// Returns 0 if no candidates were evaluated (including 1-bit mode).
|
|
49
|
+
double skip_percentage() const;
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
/// Global stats for RaBitQ indexes
|
|
53
|
+
// NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
|
|
54
|
+
FAISS_API extern RaBitQStats rabitq_stats;
|
|
55
|
+
|
|
56
|
+
} // namespace faiss
|
|
@@ -16,6 +16,18 @@
|
|
|
16
16
|
namespace faiss {
|
|
17
17
|
namespace rabitq_utils {
|
|
18
18
|
|
|
19
|
+
// Verify no unexpected padding in structures used for per-vector storage.
|
|
20
|
+
// These checks ensure compute_per_vector_storage_size() remains accurate.
|
|
21
|
+
static_assert(
|
|
22
|
+
sizeof(SignBitFactors) == 8,
|
|
23
|
+
"SignBitFactors has unexpected padding");
|
|
24
|
+
static_assert(
|
|
25
|
+
sizeof(SignBitFactorsWithError) == 12,
|
|
26
|
+
"SignBitFactorsWithError has unexpected padding");
|
|
27
|
+
static_assert(
|
|
28
|
+
sizeof(ExtraBitsFactors) == 8,
|
|
29
|
+
"ExtraBitsFactors has unexpected padding");
|
|
30
|
+
|
|
19
31
|
// Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
|
|
20
32
|
// L2 reconstruction error.
|
|
21
33
|
const float Z_MAX_BY_QB[8] = {
|
|
@@ -54,13 +66,16 @@ void compute_vector_intermediate_values(
|
|
|
54
66
|
}
|
|
55
67
|
}
|
|
56
68
|
|
|
57
|
-
|
|
69
|
+
SignBitFactorsWithError compute_factors_from_intermediates(
|
|
58
70
|
float norm_L2sqr,
|
|
59
71
|
float or_L2sqr,
|
|
60
72
|
float dp_oO,
|
|
61
73
|
size_t d,
|
|
62
|
-
MetricType metric_type
|
|
74
|
+
MetricType metric_type,
|
|
75
|
+
bool compute_error) {
|
|
63
76
|
constexpr float epsilon = std::numeric_limits<float>::epsilon();
|
|
77
|
+
constexpr float kConstEpsilon =
|
|
78
|
+
1.9f; // Error bound constant from RaBitQ paper
|
|
64
79
|
const float inv_d_sqrt =
|
|
65
80
|
(d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
|
|
66
81
|
|
|
@@ -72,25 +87,57 @@ FactorsData compute_factors_from_intermediates(
|
|
|
72
87
|
const float inv_dp_oO =
|
|
73
88
|
(std::abs(normalized_dp) < epsilon) ? 1.0f : (1.0f / normalized_dp);
|
|
74
89
|
|
|
75
|
-
|
|
90
|
+
SignBitFactorsWithError factors;
|
|
76
91
|
factors.or_minus_c_l2sqr = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
77
92
|
? (norm_L2sqr - or_L2sqr)
|
|
78
93
|
: norm_L2sqr;
|
|
79
94
|
factors.dp_multiplier = inv_dp_oO * sqrt_norm_L2;
|
|
80
95
|
|
|
96
|
+
// Compute error bound only if needed (skip for 1-bit mode)
|
|
97
|
+
if (compute_error) {
|
|
98
|
+
const float xu_cb_norm_sqr = static_cast<float>(d) * 0.25f;
|
|
99
|
+
const float ip_resi_xucb = 0.5f * dp_oO;
|
|
100
|
+
|
|
101
|
+
float tmp_error = 0.0f;
|
|
102
|
+
if (std::abs(ip_resi_xucb) > epsilon) {
|
|
103
|
+
const float ratio_sq = (norm_L2sqr * xu_cb_norm_sqr) /
|
|
104
|
+
(ip_resi_xucb * ip_resi_xucb);
|
|
105
|
+
if (ratio_sq > 1.0f) {
|
|
106
|
+
if (d == 1) {
|
|
107
|
+
tmp_error = sqrt_norm_L2 * kConstEpsilon *
|
|
108
|
+
std::sqrt(ratio_sq - 1.0f);
|
|
109
|
+
} else {
|
|
110
|
+
tmp_error = sqrt_norm_L2 * kConstEpsilon *
|
|
111
|
+
std::sqrt((ratio_sq - 1.0f) /
|
|
112
|
+
static_cast<float>(d - 1));
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Apply metric-specific multiplier
|
|
118
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
119
|
+
factors.f_error = 2.0f * tmp_error;
|
|
120
|
+
} else if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
121
|
+
factors.f_error = 1.0f * tmp_error;
|
|
122
|
+
} else {
|
|
123
|
+
factors.f_error = 0.0f;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
81
127
|
return factors;
|
|
82
128
|
}
|
|
83
129
|
|
|
84
|
-
|
|
130
|
+
SignBitFactorsWithError compute_vector_factors(
|
|
85
131
|
const float* x,
|
|
86
132
|
size_t d,
|
|
87
133
|
const float* centroid,
|
|
88
|
-
MetricType metric_type
|
|
134
|
+
MetricType metric_type,
|
|
135
|
+
bool compute_error) {
|
|
89
136
|
float norm_L2sqr, or_L2sqr, dp_oO;
|
|
90
137
|
compute_vector_intermediate_values(
|
|
91
138
|
x, d, centroid, norm_L2sqr, or_L2sqr, dp_oO);
|
|
92
139
|
return compute_factors_from_intermediates(
|
|
93
|
-
norm_L2sqr, or_L2sqr, dp_oO, d, metric_type);
|
|
140
|
+
norm_L2sqr, or_L2sqr, dp_oO, d, metric_type, compute_error);
|
|
94
141
|
}
|
|
95
142
|
|
|
96
143
|
QueryFactorsData compute_query_factors(
|
|
@@ -113,6 +160,7 @@ QueryFactorsData compute_query_factors(
|
|
|
113
160
|
} else {
|
|
114
161
|
query_factors.qr_to_c_L2sqr = fvec_norm_L2sqr(query, d);
|
|
115
162
|
}
|
|
163
|
+
query_factors.g_error = std::sqrt(query_factors.qr_to_c_L2sqr);
|
|
116
164
|
|
|
117
165
|
// Rotate the query (subtract centroid)
|
|
118
166
|
rotated_q.resize(d);
|