faiss 0.4.2 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/faiss/index.cpp +36 -10
- data/ext/faiss/index_binary.cpp +19 -6
- data/ext/faiss/kmeans.cpp +6 -6
- data/ext/faiss/numo.hpp +273 -123
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +1 -2
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.h +10 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +107 -7
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
- data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +3 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
- data/vendor/faiss/faiss/impl/HNSW.h +4 -4
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
- data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +43 -1
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +5 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +14 -3
|
@@ -116,7 +116,7 @@ struct IDSelectorBitmap : IDSelector {
|
|
|
116
116
|
/** reverts the membership test of another selector */
|
|
117
117
|
struct IDSelectorNot : IDSelector {
|
|
118
118
|
const IDSelector* sel;
|
|
119
|
-
IDSelectorNot(const IDSelector* sel) : sel(sel) {}
|
|
119
|
+
explicit IDSelectorNot(const IDSelector* sel) : sel(sel) {}
|
|
120
120
|
bool is_member(idx_t id) const final {
|
|
121
121
|
return !sel->is_member(id);
|
|
122
122
|
}
|
|
@@ -30,7 +30,7 @@
|
|
|
30
30
|
#endif
|
|
31
31
|
|
|
32
32
|
extern "C" {
|
|
33
|
-
// LU
|
|
33
|
+
// LU decomposition of a general matrix
|
|
34
34
|
void sgetrf_(
|
|
35
35
|
FINTEGER* m,
|
|
36
36
|
FINTEGER* n,
|
|
@@ -65,7 +65,7 @@ int sgemm_(
|
|
|
65
65
|
float* c,
|
|
66
66
|
FINTEGER* ldc);
|
|
67
67
|
|
|
68
|
-
// LU
|
|
68
|
+
// LU decomposition of a general matrix
|
|
69
69
|
void dgetrf_(
|
|
70
70
|
FINTEGER* m,
|
|
71
71
|
FINTEGER* n,
|
|
@@ -189,7 +189,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
|
|
|
189
189
|
std::vector<int32_t> codes(n * M); // [n, M]
|
|
190
190
|
random_int32(codes, 0, K - 1, gen);
|
|
191
191
|
|
|
192
|
-
// compute standard
|
|
192
|
+
// compute standard deviations of each dimension
|
|
193
193
|
std::vector<float> stddev(d, 0);
|
|
194
194
|
|
|
195
195
|
#pragma omp parallel for
|
|
@@ -487,7 +487,7 @@ void LocalSearchQuantizer::update_codebooks(
|
|
|
487
487
|
* L = (X - \sum cj)^2, j = 1, ..., M
|
|
488
488
|
* L = X^2 - 2X * \sum cj + (\sum cj)^2
|
|
489
489
|
*
|
|
490
|
-
* X^2 is
|
|
490
|
+
* X^2 is negligible since it is the same for all possible value
|
|
491
491
|
* k of the m-th subcode.
|
|
492
492
|
*
|
|
493
493
|
* 2X * \sum cj is the unary term
|
|
@@ -138,7 +138,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
|
|
|
138
138
|
/** Add some perturbation to codebooks
|
|
139
139
|
*
|
|
140
140
|
* @param T temperature of simulated annealing
|
|
141
|
-
* @param stddev standard
|
|
141
|
+
* @param stddev standard deviations of each dimension in training data
|
|
142
142
|
*/
|
|
143
143
|
void perturb_codebooks(
|
|
144
144
|
float T,
|
|
@@ -63,7 +63,7 @@ struct DummyScaler {
|
|
|
63
63
|
};
|
|
64
64
|
|
|
65
65
|
/// consumes 2x4 bits to encode a norm as a scalar additive quantizer
|
|
66
|
-
/// the norm is scaled because its range
|
|
66
|
+
/// the norm is scaled because its range is larger than other components
|
|
67
67
|
struct NormTableScaler {
|
|
68
68
|
static constexpr int nscale = 2;
|
|
69
69
|
int scale_int;
|
|
@@ -177,7 +177,7 @@ void NNDescent::join(DistanceComputer& qdis) {
|
|
|
177
177
|
}
|
|
178
178
|
}
|
|
179
179
|
|
|
180
|
-
/// Sample neighbors for each node to
|
|
180
|
+
/// Sample neighbors for each node to perform local join later
|
|
181
181
|
/// Store them in nn_new and nn_old
|
|
182
182
|
void NNDescent::update() {
|
|
183
183
|
// Step 1.
|
|
@@ -34,7 +34,7 @@ namespace faiss {
|
|
|
34
34
|
*
|
|
35
35
|
* Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
|
|
36
36
|
*
|
|
37
|
-
* This
|
|
37
|
+
* This implementation is heavily influenced by the efanna
|
|
38
38
|
* implementation by Cong Fu and the KGraph library by Wei Dong
|
|
39
39
|
* (https://github.com/ZJULearning/efanna_graph)
|
|
40
40
|
* (https://github.com/aaalgo/kgraph)
|
|
@@ -117,7 +117,7 @@ struct NNDescent {
|
|
|
117
117
|
/// Perform local join on each node
|
|
118
118
|
void join(DistanceComputer& qdis);
|
|
119
119
|
|
|
120
|
-
/// Sample new neighbors for each node to
|
|
120
|
+
/// Sample new neighbors for each node to perform local join later
|
|
121
121
|
void update();
|
|
122
122
|
|
|
123
123
|
/// Sample a small number of points to evaluate the quality of KNNG built
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/PanoramaStats.h>
|
|
11
|
+
|
|
12
|
+
namespace faiss {
|
|
13
|
+
|
|
14
|
+
void PanoramaStats::reset() {
|
|
15
|
+
total_dims_scanned = 0;
|
|
16
|
+
total_dims = 0;
|
|
17
|
+
ratio_dims_scanned = 1.0f;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
void PanoramaStats::add(const PanoramaStats& other) {
|
|
21
|
+
total_dims_scanned += other.total_dims_scanned;
|
|
22
|
+
total_dims += other.total_dims;
|
|
23
|
+
if (total_dims > 0) {
|
|
24
|
+
ratio_dims_scanned =
|
|
25
|
+
static_cast<float>(total_dims_scanned) / total_dims;
|
|
26
|
+
} else {
|
|
27
|
+
ratio_dims_scanned = 1.0f;
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
PanoramaStats indexPanorama_stats;
|
|
32
|
+
|
|
33
|
+
} // namespace faiss
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_PANORAMA_STATS_H
|
|
11
|
+
#define FAISS_PANORAMA_STATS_H
|
|
12
|
+
|
|
13
|
+
#include <faiss/impl/platform_macros.h>
|
|
14
|
+
|
|
15
|
+
namespace faiss {
|
|
16
|
+
|
|
17
|
+
/// Statistics are not robust to internal threading nor to
|
|
18
|
+
/// concurrent Panorama searches. Use these values in a
|
|
19
|
+
/// single-threaded context to accurately gauge Panorama's
|
|
20
|
+
/// pruning effectiveness.
|
|
21
|
+
struct PanoramaStats {
|
|
22
|
+
uint64_t total_dims_scanned = 0; // total dimensions scanned
|
|
23
|
+
uint64_t total_dims = 0; // total dimensions
|
|
24
|
+
float ratio_dims_scanned = 1.0f; // fraction of dimensions actually scanned
|
|
25
|
+
|
|
26
|
+
PanoramaStats() {
|
|
27
|
+
reset();
|
|
28
|
+
}
|
|
29
|
+
void reset();
|
|
30
|
+
void add(const PanoramaStats& other);
|
|
31
|
+
};
|
|
32
|
+
|
|
33
|
+
// Single global var for all Panorama indexes
|
|
34
|
+
FAISS_API extern PanoramaStats indexPanorama_stats;
|
|
35
|
+
|
|
36
|
+
} // namespace faiss
|
|
37
|
+
|
|
38
|
+
#endif
|
|
@@ -178,7 +178,7 @@ struct ReproduceWithHammingObjective : PermutationObjective {
|
|
|
178
178
|
return x * x;
|
|
179
179
|
}
|
|
180
180
|
|
|
181
|
-
//
|
|
181
|
+
// weighting of distances: it is more important to reproduce small
|
|
182
182
|
// distances well
|
|
183
183
|
double dis_weight(double x) const {
|
|
184
184
|
return exp(-dis_weight_factor * x);
|
|
@@ -295,7 +295,7 @@ struct ReproduceWithHammingObjective : PermutationObjective {
|
|
|
295
295
|
|
|
296
296
|
} // anonymous namespace
|
|
297
297
|
|
|
298
|
-
//
|
|
298
|
+
// weighting of distances: it is more important to reproduce small
|
|
299
299
|
// distances well
|
|
300
300
|
double ReproduceDistancesObjective::dis_weight(double x) const {
|
|
301
301
|
return exp(-dis_weight_factor * x);
|
|
@@ -636,7 +636,7 @@ struct Score3Computer : PermutationObjective {
|
|
|
636
636
|
return accu;
|
|
637
637
|
}
|
|
638
638
|
|
|
639
|
-
/// PermutationObjective
|
|
639
|
+
/// PermutationObjective implementation (just negates the scores
|
|
640
640
|
/// for minimization)
|
|
641
641
|
|
|
642
642
|
double compute_cost(const int* perm) const override {
|
|
@@ -689,7 +689,7 @@ struct RankingScore2 : Score3Computer<float, double> {
|
|
|
689
689
|
/// count nb of i, j in a x b st. i < j
|
|
690
690
|
/// a and b should be sorted on input
|
|
691
691
|
/// they are the ranks of j and k respectively.
|
|
692
|
-
/// specific version for diff-of-rank weighting, cannot
|
|
692
|
+
/// specific version for diff-of-rank weighting, cannot optimize
|
|
693
693
|
/// with a cumulative table
|
|
694
694
|
double accum_gt_weight_diff(
|
|
695
695
|
const std::vector<int>& a,
|
|
@@ -985,7 +985,7 @@ size_t PolysemousTraining::memory_usage_per_thread(
|
|
|
985
985
|
return n * n * n * sizeof(float);
|
|
986
986
|
}
|
|
987
987
|
|
|
988
|
-
FAISS_THROW_MSG("Invalid
|
|
988
|
+
FAISS_THROW_MSG("Invalid optimization type");
|
|
989
989
|
return 0;
|
|
990
990
|
}
|
|
991
991
|
|
|
@@ -154,7 +154,7 @@ void ProductAdditiveQuantizer::compute_unpacked_codes(
|
|
|
154
154
|
int32_t* unpacked_codes,
|
|
155
155
|
size_t n,
|
|
156
156
|
const float* centroids) const {
|
|
157
|
-
/// TODO:
|
|
157
|
+
/// TODO: actually we do not need to unpack and pack
|
|
158
158
|
size_t offset_d = 0, offset_m = 0;
|
|
159
159
|
std::vector<float> xsub;
|
|
160
160
|
std::vector<uint8_t> codes;
|
|
@@ -166,7 +166,7 @@ struct ProductQuantizer : Quantizer {
|
|
|
166
166
|
/// Symmetric Distance Table
|
|
167
167
|
std::vector<float> sdc_table;
|
|
168
168
|
|
|
169
|
-
//
|
|
169
|
+
// initialize the SDC table from the centroids
|
|
170
170
|
void compute_sdc_table();
|
|
171
171
|
|
|
172
172
|
void search_sdc(
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/utils/distances.h>
|
|
12
|
+
#include <algorithm>
|
|
13
|
+
#include <cmath>
|
|
14
|
+
#include <limits>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
namespace rabitq_utils {
|
|
18
|
+
|
|
19
|
+
// Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
|
|
20
|
+
// L2 reconstruction error.
|
|
21
|
+
const float Z_MAX_BY_QB[8] = {
|
|
22
|
+
0.79688, // qb = 1.
|
|
23
|
+
1.49375,
|
|
24
|
+
2.05078,
|
|
25
|
+
2.50938,
|
|
26
|
+
2.91250,
|
|
27
|
+
3.26406,
|
|
28
|
+
3.59844,
|
|
29
|
+
3.91016, // qb = 8.
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
void compute_vector_intermediate_values(
|
|
33
|
+
const float* x,
|
|
34
|
+
size_t d,
|
|
35
|
+
const float* centroid,
|
|
36
|
+
float& norm_L2sqr,
|
|
37
|
+
float& or_L2sqr,
|
|
38
|
+
float& dp_oO) {
|
|
39
|
+
norm_L2sqr = 0.0f;
|
|
40
|
+
or_L2sqr = 0.0f;
|
|
41
|
+
dp_oO = 0.0f;
|
|
42
|
+
|
|
43
|
+
for (size_t j = 0; j < d; j++) {
|
|
44
|
+
const float x_val = x[j];
|
|
45
|
+
const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
|
|
46
|
+
const float or_minus_c = x_val - centroid_val;
|
|
47
|
+
|
|
48
|
+
const float or_minus_c_sq = or_minus_c * or_minus_c;
|
|
49
|
+
norm_L2sqr += or_minus_c_sq;
|
|
50
|
+
or_L2sqr += x_val * x_val;
|
|
51
|
+
|
|
52
|
+
const bool xb = (or_minus_c > 0.0f);
|
|
53
|
+
dp_oO += xb ? or_minus_c : -or_minus_c;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
FactorsData compute_factors_from_intermediates(
|
|
58
|
+
float norm_L2sqr,
|
|
59
|
+
float or_L2sqr,
|
|
60
|
+
float dp_oO,
|
|
61
|
+
size_t d,
|
|
62
|
+
MetricType metric_type) {
|
|
63
|
+
constexpr float epsilon = std::numeric_limits<float>::epsilon();
|
|
64
|
+
const float inv_d_sqrt =
|
|
65
|
+
(d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
|
|
66
|
+
|
|
67
|
+
const float sqrt_norm_L2 = std::sqrt(norm_L2sqr);
|
|
68
|
+
const float inv_norm_L2 =
|
|
69
|
+
(norm_L2sqr < epsilon) ? 1.0f : (1.0f / sqrt_norm_L2);
|
|
70
|
+
|
|
71
|
+
const float normalized_dp = dp_oO * inv_norm_L2 * inv_d_sqrt;
|
|
72
|
+
const float inv_dp_oO =
|
|
73
|
+
(std::abs(normalized_dp) < epsilon) ? 1.0f : (1.0f / normalized_dp);
|
|
74
|
+
|
|
75
|
+
FactorsData factors;
|
|
76
|
+
factors.or_minus_c_l2sqr = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
77
|
+
? (norm_L2sqr - or_L2sqr)
|
|
78
|
+
: norm_L2sqr;
|
|
79
|
+
factors.dp_multiplier = inv_dp_oO * sqrt_norm_L2;
|
|
80
|
+
|
|
81
|
+
return factors;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
FactorsData compute_vector_factors(
|
|
85
|
+
const float* x,
|
|
86
|
+
size_t d,
|
|
87
|
+
const float* centroid,
|
|
88
|
+
MetricType metric_type) {
|
|
89
|
+
float norm_L2sqr, or_L2sqr, dp_oO;
|
|
90
|
+
compute_vector_intermediate_values(
|
|
91
|
+
x, d, centroid, norm_L2sqr, or_L2sqr, dp_oO);
|
|
92
|
+
return compute_factors_from_intermediates(
|
|
93
|
+
norm_L2sqr, or_L2sqr, dp_oO, d, metric_type);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
QueryFactorsData compute_query_factors(
|
|
97
|
+
const float* query,
|
|
98
|
+
size_t d,
|
|
99
|
+
const float* centroid,
|
|
100
|
+
uint8_t qb,
|
|
101
|
+
bool centered,
|
|
102
|
+
MetricType metric_type,
|
|
103
|
+
std::vector<float>& rotated_q,
|
|
104
|
+
std::vector<uint8_t>& rotated_qq) {
|
|
105
|
+
FAISS_THROW_IF_NOT(qb <= 8);
|
|
106
|
+
FAISS_THROW_IF_NOT(qb > 0);
|
|
107
|
+
|
|
108
|
+
QueryFactorsData query_factors;
|
|
109
|
+
|
|
110
|
+
// Compute distance from query to centroid
|
|
111
|
+
if (centroid != nullptr) {
|
|
112
|
+
query_factors.qr_to_c_L2sqr = fvec_L2sqr(query, centroid, d);
|
|
113
|
+
} else {
|
|
114
|
+
query_factors.qr_to_c_L2sqr = fvec_norm_L2sqr(query, d);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Rotate the query (subtract centroid)
|
|
118
|
+
rotated_q.resize(d);
|
|
119
|
+
for (size_t i = 0; i < d; i++) {
|
|
120
|
+
if (i < rotated_q.size()) {
|
|
121
|
+
rotated_q[i] =
|
|
122
|
+
query[i] - ((centroid == nullptr) ? 0.0f : centroid[i]);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
const float inv_d_sqrt =
|
|
127
|
+
(d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
|
|
128
|
+
|
|
129
|
+
// Compute quantization range
|
|
130
|
+
float v_min = std::numeric_limits<float>::max();
|
|
131
|
+
float v_max = std::numeric_limits<float>::lowest();
|
|
132
|
+
|
|
133
|
+
if (centered) {
|
|
134
|
+
float z_max = Z_MAX_BY_QB[qb - 1];
|
|
135
|
+
float v_radius = z_max * std::sqrt(query_factors.qr_to_c_L2sqr / d);
|
|
136
|
+
v_min = -v_radius;
|
|
137
|
+
v_max = v_radius;
|
|
138
|
+
} else {
|
|
139
|
+
// Only compute min/max if we have dimensions to process
|
|
140
|
+
if (d > 0 && !rotated_q.empty()) {
|
|
141
|
+
for (size_t i = 0; i < d; i++) {
|
|
142
|
+
const float v_q = rotated_q[i];
|
|
143
|
+
v_min = std::min(v_min, v_q);
|
|
144
|
+
v_max = std::max(v_max, v_q);
|
|
145
|
+
}
|
|
146
|
+
} else {
|
|
147
|
+
// For empty dimensions, use default range
|
|
148
|
+
v_min = 0.0f;
|
|
149
|
+
v_max = 1.0f;
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
// Quantize the query
|
|
154
|
+
const uint8_t max_code = (1 << qb) - 1;
|
|
155
|
+
const float delta = (v_max - v_min) / max_code;
|
|
156
|
+
const float inv_delta = 1.0f / delta;
|
|
157
|
+
|
|
158
|
+
rotated_qq.resize(d);
|
|
159
|
+
size_t sum_qq = 0;
|
|
160
|
+
int64_t sum2_signed_odd_int = 0;
|
|
161
|
+
|
|
162
|
+
// Process arrays - throw error if they are unexpectedly empty
|
|
163
|
+
if (d > 0 && !rotated_q.empty() && !rotated_qq.empty()) {
|
|
164
|
+
for (size_t i = 0; i < d; i++) {
|
|
165
|
+
const float v_q = rotated_q[i];
|
|
166
|
+
// Non-randomized scalar quantization
|
|
167
|
+
const uint8_t v_qq = std::clamp<float>(
|
|
168
|
+
std::round((v_q - v_min) * inv_delta), 0, max_code);
|
|
169
|
+
rotated_qq[i] = v_qq;
|
|
170
|
+
sum_qq += v_qq;
|
|
171
|
+
|
|
172
|
+
if (centered) {
|
|
173
|
+
int64_t signed_odd_int = int64_t(v_qq) * 2 - max_code;
|
|
174
|
+
sum2_signed_odd_int += signed_odd_int * signed_odd_int;
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
} else {
|
|
178
|
+
FAISS_THROW_MSG(
|
|
179
|
+
"Arrays unexpectedly empty when d=" + std::to_string(d) +
|
|
180
|
+
"or d is incorrectly set");
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Compute query factors
|
|
184
|
+
query_factors.c1 = 2 * delta * inv_d_sqrt;
|
|
185
|
+
query_factors.c2 = 2 * v_min * inv_d_sqrt;
|
|
186
|
+
query_factors.c34 = inv_d_sqrt * (delta * sum_qq + d * v_min);
|
|
187
|
+
|
|
188
|
+
if (centered) {
|
|
189
|
+
query_factors.int_dot_scale = std::sqrt(
|
|
190
|
+
query_factors.qr_to_c_L2sqr / (sum2_signed_odd_int * d));
|
|
191
|
+
} else {
|
|
192
|
+
query_factors.int_dot_scale = 1.0f;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
// Compute query norm for inner product metric
|
|
196
|
+
query_factors.qr_norm_L2sqr = 0.0f;
|
|
197
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
198
|
+
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
return query_factors;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
bool extract_bit_standard(const uint8_t* code, size_t bit_index) {
|
|
205
|
+
const size_t byte_idx = bit_index / 8;
|
|
206
|
+
const size_t bit_offset = bit_index % 8;
|
|
207
|
+
return (code[byte_idx] >> bit_offset) & 1;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
bool extract_bit_fastscan(const uint8_t* code, size_t bit_index) {
|
|
211
|
+
const size_t m = bit_index / 4; // Sub-quantizer index
|
|
212
|
+
const size_t dim_offset =
|
|
213
|
+
bit_index % 4; // Bit position within sub-quantizer
|
|
214
|
+
const size_t byte_idx = m / 2; // Byte index (2 sub-quantizers per byte)
|
|
215
|
+
const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
|
|
216
|
+
|
|
217
|
+
if (m % 2 == 0) {
|
|
218
|
+
// Lower 4 bits of byte
|
|
219
|
+
return (code[byte_idx] & bit_mask) != 0;
|
|
220
|
+
} else {
|
|
221
|
+
// Upper 4 bits of byte (shifted)
|
|
222
|
+
return (code[byte_idx] & (bit_mask << 4)) != 0;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
void set_bit_standard(uint8_t* code, size_t bit_index) {
|
|
227
|
+
const size_t byte_idx = bit_index / 8;
|
|
228
|
+
const size_t bit_offset = bit_index % 8;
|
|
229
|
+
code[byte_idx] |= (1 << bit_offset);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
void set_bit_fastscan(uint8_t* code, size_t bit_index) {
|
|
233
|
+
const size_t m = bit_index / 4;
|
|
234
|
+
const size_t dim_offset = bit_index % 4;
|
|
235
|
+
const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
|
|
236
|
+
const size_t byte_idx = m / 2;
|
|
237
|
+
|
|
238
|
+
if (m % 2 == 0) {
|
|
239
|
+
code[byte_idx] |= bit_mask;
|
|
240
|
+
} else {
|
|
241
|
+
code[byte_idx] |= (bit_mask << 4);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
} // namespace rabitq_utils
|
|
246
|
+
} // namespace faiss
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/MetricType.h>
|
|
11
|
+
#include <faiss/impl/platform_macros.h>
|
|
12
|
+
#include <cstddef>
|
|
13
|
+
#include <cstdint>
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
namespace rabitq_utils {
|
|
18
|
+
|
|
19
|
+
/** Factors computed per database vector for RaBitQ distance computation.
|
|
20
|
+
* These can be stored either embedded in codes (IndexRaBitQ) or separately
|
|
21
|
+
* (IndexRaBitQFastScan).
|
|
22
|
+
*/
|
|
23
|
+
struct FactorsData {
|
|
24
|
+
// ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
|
|
25
|
+
float or_minus_c_l2sqr = 0;
|
|
26
|
+
float dp_multiplier = 0;
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
/** Query-specific factors computed during search for RaBitQ distance
|
|
30
|
+
* computation. Used by both IndexRaBitQ and IndexRaBitQFastScan
|
|
31
|
+
* implementations.
|
|
32
|
+
*/
|
|
33
|
+
struct QueryFactorsData {
|
|
34
|
+
float c1 = 0;
|
|
35
|
+
float c2 = 0;
|
|
36
|
+
float c34 = 0;
|
|
37
|
+
|
|
38
|
+
float qr_to_c_L2sqr = 0;
|
|
39
|
+
float qr_norm_L2sqr = 0;
|
|
40
|
+
|
|
41
|
+
float int_dot_scale = 1;
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
/** Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
|
|
45
|
+
* L2 reconstruction error. Shared between all RaBitQ implementations.
|
|
46
|
+
*/
|
|
47
|
+
FAISS_API extern const float Z_MAX_BY_QB[8];
|
|
48
|
+
|
|
49
|
+
/** Compute factors for a single database vector using RaBitQ algorithm.
|
|
50
|
+
* This function consolidates the mathematical logic that was duplicated
|
|
51
|
+
* between IndexRaBitQ and IndexRaBitQFastScan.
|
|
52
|
+
*
|
|
53
|
+
* @param x input vector (d dimensions)
|
|
54
|
+
* @param d dimensionality
|
|
55
|
+
* @param centroid database centroid (nullptr if not used)
|
|
56
|
+
* @param metric_type distance metric (L2 or Inner Product)
|
|
57
|
+
* @return computed factors for distance computation
|
|
58
|
+
*/
|
|
59
|
+
FactorsData compute_vector_factors(
|
|
60
|
+
const float* x,
|
|
61
|
+
size_t d,
|
|
62
|
+
const float* centroid,
|
|
63
|
+
MetricType metric_type);
|
|
64
|
+
|
|
65
|
+
/** Compute intermediate values needed for vector factor computation.
|
|
66
|
+
* Separated out to allow different bit packing strategies while sharing
|
|
67
|
+
* the core mathematical computation.
|
|
68
|
+
*
|
|
69
|
+
* @param x input vector (d dimensions)
|
|
70
|
+
* @param d dimensionality
|
|
71
|
+
* @param centroid database centroid (nullptr if not used)
|
|
72
|
+
* @param norm_L2sqr output: ||or - c||^2
|
|
73
|
+
* @param or_L2sqr output: ||or||^2
|
|
74
|
+
* @param dp_oO output: sum of |or_i - c_i| (absolute deviations)
|
|
75
|
+
*/
|
|
76
|
+
void compute_vector_intermediate_values(
|
|
77
|
+
const float* x,
|
|
78
|
+
size_t d,
|
|
79
|
+
const float* centroid,
|
|
80
|
+
float& norm_L2sqr,
|
|
81
|
+
float& or_L2sqr,
|
|
82
|
+
float& dp_oO);
|
|
83
|
+
|
|
84
|
+
/** Compute final factors from intermediate values.
|
|
85
|
+
* @param norm_L2sqr ||or - c||^2
|
|
86
|
+
* @param or_L2sqr ||or||^2
|
|
87
|
+
* @param dp_oO sum of |or_i - c_i|
|
|
88
|
+
* @param d dimensionality
|
|
89
|
+
* @param metric_type distance metric
|
|
90
|
+
* @return computed factors
|
|
91
|
+
*/
|
|
92
|
+
FactorsData compute_factors_from_intermediates(
|
|
93
|
+
float norm_L2sqr,
|
|
94
|
+
float or_L2sqr,
|
|
95
|
+
float dp_oO,
|
|
96
|
+
size_t d,
|
|
97
|
+
MetricType metric_type);
|
|
98
|
+
|
|
99
|
+
/** Compute query factors for RaBitQ distance computation.
|
|
100
|
+
* This consolidates the query processing logic shared between implementations.
|
|
101
|
+
*
|
|
102
|
+
* @param query query vector (d dimensions)
|
|
103
|
+
* @param d dimensionality
|
|
104
|
+
* @param centroid database centroid (nullptr if not used)
|
|
105
|
+
* @param qb number of quantization bits (1-8)
|
|
106
|
+
* @param centered whether to use centered quantization
|
|
107
|
+
* @param metric_type distance metric
|
|
108
|
+
* @param rotated_q output: query - centroid
|
|
109
|
+
* @param rotated_qq output: quantized query values
|
|
110
|
+
* @return computed query factors
|
|
111
|
+
*/
|
|
112
|
+
QueryFactorsData compute_query_factors(
|
|
113
|
+
const float* query,
|
|
114
|
+
size_t d,
|
|
115
|
+
const float* centroid,
|
|
116
|
+
uint8_t qb,
|
|
117
|
+
bool centered,
|
|
118
|
+
MetricType metric_type,
|
|
119
|
+
std::vector<float>& rotated_q,
|
|
120
|
+
std::vector<uint8_t>& rotated_qq);
|
|
121
|
+
|
|
122
|
+
/** Extract bit value from RaBitQ code in standard format.
|
|
123
|
+
* Used by IndexRaBitQ which stores bits sequentially.
|
|
124
|
+
*
|
|
125
|
+
* @param code RaBitQ code data
|
|
126
|
+
* @param bit_index which bit to extract (0 to d-1)
|
|
127
|
+
* @return bit value (true/false)
|
|
128
|
+
*/
|
|
129
|
+
bool extract_bit_standard(const uint8_t* code, size_t bit_index);
|
|
130
|
+
|
|
131
|
+
/** Extract bit value from FastScan code format.
|
|
132
|
+
* Used by IndexRaBitQFastScan which packs bits into 4-bit sub-quantizers.
|
|
133
|
+
*
|
|
134
|
+
* @param code FastScan code data
|
|
135
|
+
* @param bit_index which bit to extract (0 to d-1)
|
|
136
|
+
* @return bit value (true/false)
|
|
137
|
+
*/
|
|
138
|
+
bool extract_bit_fastscan(const uint8_t* code, size_t bit_index);
|
|
139
|
+
|
|
140
|
+
/** Set bit value in standard RaBitQ code format.
|
|
141
|
+
* @param code RaBitQ code data to modify
|
|
142
|
+
* @param bit_index which bit to set (0 to d-1)
|
|
143
|
+
*/
|
|
144
|
+
void set_bit_standard(uint8_t* code, size_t bit_index);
|
|
145
|
+
|
|
146
|
+
/** Set bit value in FastScan code format.
|
|
147
|
+
* @param code FastScan code data to modify
|
|
148
|
+
* @param bit_index which bit to set (0 to d-1)
|
|
149
|
+
*/
|
|
150
|
+
void set_bit_fastscan(uint8_t* code, size_t bit_index);
|
|
151
|
+
|
|
152
|
+
} // namespace rabitq_utils
|
|
153
|
+
} // namespace faiss
|