faiss 0.4.1 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +39 -29
- data/vendor/faiss/faiss/Clustering.cpp +4 -2
- data/vendor/faiss/faiss/IVFlib.cpp +14 -7
- data/vendor/faiss/faiss/Index.h +72 -3
- data/vendor/faiss/faiss/Index2Layer.cpp +2 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +0 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +1 -0
- data/vendor/faiss/faiss/IndexBinary.h +46 -3
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +118 -4
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +41 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +0 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +18 -7
- data/vendor/faiss/faiss/IndexBinaryIVF.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +6 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +65 -24
- data/vendor/faiss/faiss/IndexHNSW.h +10 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +96 -18
- data/vendor/faiss/faiss/IndexIDMap.h +20 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +28 -10
- data/vendor/faiss/faiss/IndexIVF.h +16 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +18 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +33 -21
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +16 -6
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +24 -15
- data/vendor/faiss/faiss/IndexIVFFastScan.h +4 -2
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +59 -43
- data/vendor/faiss/faiss/IndexIVFFlat.h +10 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +16 -3
- data/vendor/faiss/faiss/IndexIVFPQ.h +8 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +14 -6
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +14 -4
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +28 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +8 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +9 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +8 -4
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -7
- data/vendor/faiss/faiss/IndexNSG.cpp +3 -3
- data/vendor/faiss/faiss/IndexPQ.cpp +0 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -0
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +0 -2
- data/vendor/faiss/faiss/IndexPreTransform.cpp +4 -2
- data/vendor/faiss/faiss/IndexRefine.cpp +11 -6
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +16 -4
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -3
- data/vendor/faiss/faiss/IndexShards.cpp +7 -6
- data/vendor/faiss/faiss/MatrixStats.cpp +16 -8
- data/vendor/faiss/faiss/MetaIndexes.cpp +12 -6
- data/vendor/faiss/faiss/MetricType.h +5 -3
- data/vendor/faiss/faiss/clone_index.cpp +2 -4
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +6 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +9 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +32 -10
- data/vendor/faiss/faiss/gpu/GpuIndex.h +88 -0
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +125 -0
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +39 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +3 -3
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +3 -2
- data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +41 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +6 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +34 -19
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +2 -3
- data/vendor/faiss/faiss/impl/NNDescent.cpp +17 -9
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +42 -21
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +6 -24
- data/vendor/faiss/faiss/impl/ResultHandler.h +56 -47
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +28 -15
- data/vendor/faiss/faiss/impl/index_read.cpp +36 -11
- data/vendor/faiss/faiss/impl/index_write.cpp +19 -6
- data/vendor/faiss/faiss/impl/io.cpp +9 -5
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +18 -11
- data/vendor/faiss/faiss/impl/mapped_io.cpp +4 -7
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +0 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +0 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +6 -6
- data/vendor/faiss/faiss/impl/zerocopy_io.cpp +1 -1
- data/vendor/faiss/faiss/impl/zerocopy_io.h +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +49 -33
- data/vendor/faiss/faiss/index_factory.h +8 -2
- data/vendor/faiss/faiss/index_io.h +0 -3
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +2 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +12 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +8 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +15 -8
- data/vendor/faiss/faiss/utils/Heap.h +23 -12
- data/vendor/faiss/faiss/utils/distances.cpp +42 -21
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -3
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +27 -4
- data/vendor/faiss/faiss/utils/extra_distances.cpp +8 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +20 -10
- data/vendor/faiss/faiss/utils/partitioning.cpp +8 -4
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +17 -9
- data/vendor/faiss/faiss/utils/rabitq_simd.h +539 -0
- data/vendor/faiss/faiss/utils/random.cpp +14 -7
- data/vendor/faiss/faiss/utils/utils.cpp +0 -3
- metadata +5 -2
@@ -30,8 +30,16 @@ IndexIVFPQR::IndexIVFPQR(
|
|
30
30
|
size_t M,
|
31
31
|
size_t nbits_per_idx,
|
32
32
|
size_t M_refine,
|
33
|
-
size_t nbits_per_idx_refine
|
34
|
-
|
33
|
+
size_t nbits_per_idx_refine,
|
34
|
+
bool own_invlists)
|
35
|
+
: IndexIVFPQ(
|
36
|
+
quantizer,
|
37
|
+
d,
|
38
|
+
nlist,
|
39
|
+
M,
|
40
|
+
nbits_per_idx,
|
41
|
+
METRIC_L2,
|
42
|
+
own_invlists),
|
35
43
|
refine_pq(d, M_refine, nbits_per_idx_refine),
|
36
44
|
k_factor(4) {
|
37
45
|
by_residual = true;
|
@@ -160,8 +168,9 @@ void IndexIVFPQR::search_preassigned(
|
|
160
168
|
for (int j = 0; j < k_coarse; j++) {
|
161
169
|
idx_t sl = shortlist[j];
|
162
170
|
|
163
|
-
if (sl == -1)
|
171
|
+
if (sl == -1) {
|
164
172
|
continue;
|
173
|
+
}
|
165
174
|
|
166
175
|
int list_no = lo_listno(sl);
|
167
176
|
int ofs = lo_offset(sl);
|
@@ -176,8 +185,9 @@ void IndexIVFPQR::search_preassigned(
|
|
176
185
|
const uint8_t* l2code = invlists->get_single_code(list_no, ofs);
|
177
186
|
|
178
187
|
pq.decode(l2code, residual_2);
|
179
|
-
for (int l = 0; l < d; l++)
|
188
|
+
for (int l = 0; l < d; l++) {
|
180
189
|
residual_2[l] = residual_1[l] - residual_2[l];
|
190
|
+
}
|
181
191
|
|
182
192
|
// 3rd level residual's approximation
|
183
193
|
idx_t id = invlists->get_single_id(list_no, ofs);
|
@@ -23,10 +23,14 @@ IndexIVFRaBitQ::IndexIVFRaBitQ(
|
|
23
23
|
Index* quantizer,
|
24
24
|
const size_t d,
|
25
25
|
const size_t nlist,
|
26
|
-
MetricType metric
|
27
|
-
|
26
|
+
MetricType metric,
|
27
|
+
bool own_invlists)
|
28
|
+
: IndexIVF(quantizer, d, nlist, 0, metric, own_invlists),
|
29
|
+
rabitq(d, metric) {
|
28
30
|
code_size = rabitq.code_size;
|
29
|
-
|
31
|
+
if (own_invlists) {
|
32
|
+
invlists->code_size = code_size;
|
33
|
+
}
|
30
34
|
is_trained = false;
|
31
35
|
|
32
36
|
by_residual = true;
|
@@ -76,6 +80,27 @@ void IndexIVFRaBitQ::encode_vectors(
|
|
76
80
|
}
|
77
81
|
}
|
78
82
|
|
83
|
+
void IndexIVFRaBitQ::decode_vectors(
|
84
|
+
idx_t n,
|
85
|
+
const uint8_t* codes,
|
86
|
+
const idx_t* listnos,
|
87
|
+
float* x) const {
|
88
|
+
#pragma omp parallel
|
89
|
+
{
|
90
|
+
std::vector<float> centroid(d);
|
91
|
+
|
92
|
+
#pragma omp for
|
93
|
+
for (idx_t i = 0; i < n; i++) {
|
94
|
+
const uint8_t* code = codes + i * code_size;
|
95
|
+
int64_t list_no = listnos[i];
|
96
|
+
float* xi = x + i * d;
|
97
|
+
|
98
|
+
quantizer->reconstruct(list_no, centroid.data());
|
99
|
+
rabitq.decode_core(code, xi, 1, centroid.data());
|
100
|
+
}
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
79
104
|
void IndexIVFRaBitQ::add_core(
|
80
105
|
idx_t n,
|
81
106
|
const float* x,
|
@@ -33,7 +33,8 @@ struct IndexIVFRaBitQ : IndexIVF {
|
|
33
33
|
Index* quantizer,
|
34
34
|
const size_t d,
|
35
35
|
const size_t nlist,
|
36
|
-
MetricType metric = METRIC_L2
|
36
|
+
MetricType metric = METRIC_L2,
|
37
|
+
bool own_invlists = true);
|
37
38
|
|
38
39
|
IndexIVFRaBitQ();
|
39
40
|
|
@@ -46,6 +47,12 @@ struct IndexIVFRaBitQ : IndexIVF {
|
|
46
47
|
uint8_t* codes,
|
47
48
|
bool include_listnos = false) const override;
|
48
49
|
|
50
|
+
void decode_vectors(
|
51
|
+
idx_t n,
|
52
|
+
const uint8_t* codes,
|
53
|
+
const idx_t* list_nos,
|
54
|
+
float* x) const override;
|
55
|
+
|
49
56
|
void add_core(
|
50
57
|
idx_t n,
|
51
58
|
const float* x,
|
@@ -27,8 +27,15 @@ IndexIVFSpectralHash::IndexIVFSpectralHash(
|
|
27
27
|
size_t d,
|
28
28
|
size_t nlist,
|
29
29
|
int nbit,
|
30
|
-
float period
|
31
|
-
|
30
|
+
float period,
|
31
|
+
bool own_invlists)
|
32
|
+
: IndexIVF(
|
33
|
+
quantizer,
|
34
|
+
d,
|
35
|
+
nlist,
|
36
|
+
(nbit + 7) / 8,
|
37
|
+
METRIC_L2,
|
38
|
+
own_invlists),
|
32
39
|
nbit(nbit),
|
33
40
|
period(period) {
|
34
41
|
RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
|
@@ -47,10 +47,12 @@ void IndexLattice::train(idx_t n, const float* x) {
|
|
47
47
|
for (idx_t i = 0; i < n; i++) {
|
48
48
|
for (int sq = 0; sq < nsq; sq++) {
|
49
49
|
float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
|
50
|
-
if (norm2 > maxs[sq])
|
50
|
+
if (norm2 > maxs[sq]) {
|
51
51
|
maxs[sq] = norm2;
|
52
|
-
|
52
|
+
}
|
53
|
+
if (norm2 < mins[sq]) {
|
53
54
|
mins[sq] = norm2;
|
55
|
+
}
|
54
56
|
}
|
55
57
|
}
|
56
58
|
|
@@ -79,10 +81,12 @@ void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
|
|
79
81
|
for (int j = 0; j < nsq; j++) {
|
80
82
|
float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
|
81
83
|
(maxs[j] - mins[j]);
|
82
|
-
if (nj < 0)
|
84
|
+
if (nj < 0) {
|
83
85
|
nj = 0;
|
84
|
-
|
86
|
+
}
|
87
|
+
if (nj >= sc) {
|
85
88
|
nj = sc - 1;
|
89
|
+
}
|
86
90
|
wr.write((int64_t)nj, scale_nbit);
|
87
91
|
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
|
88
92
|
xi += dsq;
|
@@ -9,24 +9,17 @@
|
|
9
9
|
|
10
10
|
#include <faiss/IndexNNDescent.h>
|
11
11
|
|
12
|
-
#include <omp.h>
|
13
|
-
|
14
12
|
#include <cinttypes>
|
15
13
|
#include <cstdio>
|
16
14
|
#include <cstdlib>
|
17
15
|
|
18
|
-
#include <queue>
|
19
|
-
#include <unordered_set>
|
20
|
-
|
21
16
|
#ifdef __SSE__
|
22
17
|
#endif
|
23
18
|
|
24
19
|
#include <faiss/IndexFlat.h>
|
25
20
|
#include <faiss/impl/AuxIndexStructures.h>
|
26
21
|
#include <faiss/impl/FaissAssert.h>
|
27
|
-
#include <faiss/utils/Heap.h>
|
28
22
|
#include <faiss/utils/distances.h>
|
29
|
-
#include <faiss/utils/random.h>
|
30
23
|
|
31
24
|
extern "C" {
|
32
25
|
|
@@ -101,7 +101,7 @@ void IndexNSG::search(
|
|
101
101
|
}
|
102
102
|
}
|
103
103
|
|
104
|
-
void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int
|
104
|
+
void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int gk) {
|
105
105
|
FAISS_THROW_IF_NOT_MSG(
|
106
106
|
storage,
|
107
107
|
"Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
|
@@ -112,9 +112,9 @@ void IndexNSG::build(idx_t n, const float* x, idx_t* knn_graph, int GK_2) {
|
|
112
112
|
ntotal = storage->ntotal;
|
113
113
|
|
114
114
|
// check the knn graph
|
115
|
-
check_knn_graph(knn_graph, n,
|
115
|
+
check_knn_graph(knn_graph, n, gk);
|
116
116
|
|
117
|
-
const nsg::Graph<idx_t> knng(knn_graph, n,
|
117
|
+
const nsg::Graph<idx_t> knng(knn_graph, n, gk);
|
118
118
|
nsg.build(storage, n, knng, verbose);
|
119
119
|
is_built = true;
|
120
120
|
}
|
@@ -48,8 +48,9 @@ void IndexPreTransform::prepend_transform(VectorTransform* ltrans) {
|
|
48
48
|
|
49
49
|
IndexPreTransform::~IndexPreTransform() {
|
50
50
|
if (own_fields) {
|
51
|
-
for (int i = 0; i < chain.size(); i++)
|
51
|
+
for (int i = 0; i < chain.size(); i++) {
|
52
52
|
delete chain[i];
|
53
|
+
}
|
53
54
|
delete index;
|
54
55
|
}
|
55
56
|
}
|
@@ -94,8 +95,9 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
94
95
|
}
|
95
96
|
index->train(n, prev_x);
|
96
97
|
}
|
97
|
-
if (i == last_untrained)
|
98
|
+
if (i == last_untrained) {
|
98
99
|
break;
|
100
|
+
}
|
99
101
|
if (verbose) {
|
100
102
|
printf(" Applying transform %d/%zd\n", i, chain.size());
|
101
103
|
}
|
@@ -129,10 +129,11 @@ void IndexRefine::search(
|
|
129
129
|
base_index->search(
|
130
130
|
n, x, k_base, base_distances, base_labels, base_index_params);
|
131
131
|
|
132
|
-
for (int i = 0; i < n * k_base; i++)
|
132
|
+
for (int i = 0; i < n * k_base; i++) {
|
133
133
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
134
|
+
}
|
134
135
|
|
135
|
-
|
136
|
+
// parallelize over queries
|
136
137
|
#pragma omp parallel if (n > 1)
|
137
138
|
{
|
138
139
|
std::unique_ptr<DistanceComputer> dc(
|
@@ -143,8 +144,9 @@ void IndexRefine::search(
|
|
143
144
|
idx_t ij = i * k_base;
|
144
145
|
for (idx_t j = 0; j < k_base; j++) {
|
145
146
|
idx_t idx = base_labels[ij];
|
146
|
-
if (idx < 0)
|
147
|
+
if (idx < 0) {
|
147
148
|
break;
|
149
|
+
}
|
148
150
|
base_distances[ij] = (*dc)(idx);
|
149
151
|
ij++;
|
150
152
|
}
|
@@ -238,10 +240,12 @@ void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
238
240
|
}
|
239
241
|
|
240
242
|
IndexRefine::~IndexRefine() {
|
241
|
-
if (own_fields)
|
243
|
+
if (own_fields) {
|
242
244
|
delete base_index;
|
243
|
-
|
245
|
+
}
|
246
|
+
if (own_refine_index) {
|
244
247
|
delete refine_index;
|
248
|
+
}
|
245
249
|
}
|
246
250
|
|
247
251
|
/***************************************************
|
@@ -312,8 +316,9 @@ void IndexRefineFlat::search(
|
|
312
316
|
base_index->search(
|
313
317
|
n, x, k_base, base_distances, base_labels, base_index_params);
|
314
318
|
|
315
|
-
for (int i = 0; i < n * k_base; i++)
|
319
|
+
for (int i = 0; i < n * k_base; i++) {
|
316
320
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
321
|
+
}
|
317
322
|
|
318
323
|
// compute refined distances
|
319
324
|
auto rf = dynamic_cast<const IndexFlat*>(refine_index);
|
@@ -122,12 +122,15 @@ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
|
|
122
122
|
size_t nlist,
|
123
123
|
ScalarQuantizer::QuantizerType qtype,
|
124
124
|
MetricType metric,
|
125
|
-
bool by_residual
|
126
|
-
|
125
|
+
bool by_residual,
|
126
|
+
bool own_invlists)
|
127
|
+
: IndexIVF(quantizer, d, nlist, 0, metric, own_invlists), sq(d, qtype) {
|
127
128
|
code_size = sq.code_size;
|
128
129
|
this->by_residual = by_residual;
|
129
|
-
|
130
|
-
|
130
|
+
if (invlists) {
|
131
|
+
// was not known at construction time
|
132
|
+
invlists->code_size = code_size;
|
133
|
+
}
|
131
134
|
is_trained = false;
|
132
135
|
}
|
133
136
|
|
@@ -179,6 +182,15 @@ void IndexIVFScalarQuantizer::encode_vectors(
|
|
179
182
|
}
|
180
183
|
}
|
181
184
|
|
185
|
+
void IndexIVFScalarQuantizer::decode_vectors(
|
186
|
+
idx_t n,
|
187
|
+
const uint8_t* codes,
|
188
|
+
const idx_t*,
|
189
|
+
float* x) const {
|
190
|
+
FAISS_THROW_IF_NOT(is_trained);
|
191
|
+
return sq.decode(codes, x, n);
|
192
|
+
}
|
193
|
+
|
182
194
|
void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
|
183
195
|
const {
|
184
196
|
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
|
@@ -29,8 +29,8 @@ struct IndexScalarQuantizer : IndexFlatCodes {
|
|
29
29
|
/** Constructor.
|
30
30
|
*
|
31
31
|
* @param d dimensionality of the input vectors
|
32
|
-
* @param
|
33
|
-
* @param
|
32
|
+
* @param qtype type of scalar quantizer (e.g., QT_4bit)
|
33
|
+
* @param metric distance metric used for search (default: METRIC_L2)
|
34
34
|
*/
|
35
35
|
IndexScalarQuantizer(
|
36
36
|
int d,
|
@@ -72,7 +72,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
|
|
72
72
|
size_t nlist,
|
73
73
|
ScalarQuantizer::QuantizerType qtype,
|
74
74
|
MetricType metric = METRIC_L2,
|
75
|
-
bool by_residual = true
|
75
|
+
bool by_residual = true,
|
76
|
+
bool own_invlists = true);
|
76
77
|
|
77
78
|
IndexIVFScalarQuantizer();
|
78
79
|
|
@@ -87,6 +88,12 @@ struct IndexIVFScalarQuantizer : IndexIVF {
|
|
87
88
|
uint8_t* codes,
|
88
89
|
bool include_listnos = false) const override;
|
89
90
|
|
91
|
+
void decode_vectors(
|
92
|
+
idx_t n,
|
93
|
+
const uint8_t* codes,
|
94
|
+
const idx_t* list_nos,
|
95
|
+
float* x) const override;
|
96
|
+
|
90
97
|
void add_core(
|
91
98
|
idx_t n,
|
92
99
|
const float* x,
|
@@ -31,11 +31,13 @@ void sync_d(IndexBinary* index) {
|
|
31
31
|
|
32
32
|
// add translation to all valid labels
|
33
33
|
void translate_labels(int64_t n, idx_t* labels, int64_t translation) {
|
34
|
-
if (translation == 0)
|
34
|
+
if (translation == 0) {
|
35
35
|
return;
|
36
|
+
}
|
36
37
|
for (int64_t i = 0; i < n; i++) {
|
37
|
-
if (labels[i] < 0)
|
38
|
+
if (labels[i] < 0) {
|
38
39
|
continue;
|
40
|
+
}
|
39
41
|
labels[i] += translation;
|
40
42
|
}
|
41
43
|
}
|
@@ -199,8 +201,6 @@ void IndexShardsTemplate<IndexT>::search(
|
|
199
201
|
distance_t* distances,
|
200
202
|
idx_t* labels,
|
201
203
|
const SearchParameters* params) const {
|
202
|
-
FAISS_THROW_IF_NOT_MSG(
|
203
|
-
!params, "search params not supported for this index");
|
204
204
|
FAISS_THROW_IF_NOT(k > 0);
|
205
205
|
|
206
206
|
int64_t nshard = this->count();
|
@@ -219,7 +219,7 @@ void IndexShardsTemplate<IndexT>::search(
|
|
219
219
|
}
|
220
220
|
}
|
221
221
|
|
222
|
-
auto fn = [n, k, x, &all_distances, &all_labels, &translations](
|
222
|
+
auto fn = [n, k, x, params, &all_distances, &all_labels, &translations](
|
223
223
|
int no, const IndexT* index) {
|
224
224
|
if (index->verbose) {
|
225
225
|
printf("begin query shard %d on %" PRId64 " points\n", no, n);
|
@@ -230,7 +230,8 @@ void IndexShardsTemplate<IndexT>::search(
|
|
230
230
|
x,
|
231
231
|
k,
|
232
232
|
all_distances.data() + no * k * n,
|
233
|
-
all_labels.data() + no * k * n
|
233
|
+
all_labels.data() + no * k * n,
|
234
|
+
params);
|
234
235
|
|
235
236
|
translate_labels(
|
236
237
|
n * k, all_labels.data() + no * k * n, translations[no]);
|
@@ -32,12 +32,15 @@ void MatrixStats::PerDimStats::add(float x) {
|
|
32
32
|
n_inf++;
|
33
33
|
return;
|
34
34
|
}
|
35
|
-
if (x == 0)
|
35
|
+
if (x == 0) {
|
36
36
|
n0++;
|
37
|
-
|
37
|
+
}
|
38
|
+
if (x < min) {
|
38
39
|
min = x;
|
39
|
-
|
40
|
+
}
|
41
|
+
if (x > max) {
|
40
42
|
max = x;
|
43
|
+
}
|
41
44
|
sum += x;
|
42
45
|
sum2 += (double)x * (double)x;
|
43
46
|
}
|
@@ -46,8 +49,9 @@ void MatrixStats::PerDimStats::compute_mean_std() {
|
|
46
49
|
n_valid = n - n_nan - n_inf;
|
47
50
|
mean = sum / n_valid;
|
48
51
|
double var = sum2 / n_valid - mean * mean;
|
49
|
-
if (var < 0)
|
52
|
+
if (var < 0) {
|
50
53
|
var = 0;
|
54
|
+
}
|
51
55
|
stddev = sqrt(var);
|
52
56
|
}
|
53
57
|
|
@@ -95,10 +99,12 @@ MatrixStats::MatrixStats(size_t n, size_t d, const float* x) : n(n), d(d) {
|
|
95
99
|
if (sum2 == 0) {
|
96
100
|
n0++;
|
97
101
|
} else {
|
98
|
-
if (sum2 < min_norm2)
|
102
|
+
if (sum2 < min_norm2) {
|
99
103
|
min_norm2 = sum2;
|
100
|
-
|
104
|
+
}
|
105
|
+
if (sum2 > max_norm2) {
|
101
106
|
max_norm2 = sum2;
|
107
|
+
}
|
102
108
|
}
|
103
109
|
}
|
104
110
|
|
@@ -194,10 +200,12 @@ MatrixStats::MatrixStats(size_t n, size_t d, const float* x) : n(n), d(d) {
|
|
194
200
|
n_dangerous_range++;
|
195
201
|
}
|
196
202
|
|
197
|
-
if (st.stddev > max_std)
|
203
|
+
if (st.stddev > max_std) {
|
198
204
|
max_std = st.stddev;
|
199
|
-
|
205
|
+
}
|
206
|
+
if (st.stddev < min_std) {
|
200
207
|
min_std = st.stddev;
|
208
|
+
}
|
201
209
|
}
|
202
210
|
|
203
211
|
if (n0_2 == 0) {
|
@@ -36,8 +36,9 @@ void IndexSplitVectors::add_sub_index(Index* index) {
|
|
36
36
|
}
|
37
37
|
|
38
38
|
void IndexSplitVectors::sync_with_sub_indexes() {
|
39
|
-
if (sub_indexes.empty())
|
39
|
+
if (sub_indexes.empty()) {
|
40
40
|
return;
|
41
|
+
}
|
41
42
|
Index* index0 = sub_indexes[0];
|
42
43
|
sum_d = index0->d;
|
43
44
|
metric_type = index0->metric_type;
|
@@ -81,24 +82,28 @@ void IndexSplitVectors::search(
|
|
81
82
|
no == 0 ? distances : all_distances.get() + no * k * n;
|
82
83
|
idx_t* labels1 =
|
83
84
|
no == 0 ? labels : all_labels.get() + no * k * n;
|
84
|
-
if (index->verbose)
|
85
|
+
if (index->verbose) {
|
85
86
|
printf("begin query shard %d on %" PRId64 " points\n",
|
86
87
|
no,
|
87
88
|
n);
|
89
|
+
}
|
88
90
|
const Index* sub_index = index->sub_indexes[no];
|
89
91
|
int64_t sub_d = sub_index->d, d = index->d;
|
90
92
|
idx_t ofs = 0;
|
91
|
-
for (int i = 0; i < no; i++)
|
93
|
+
for (int i = 0; i < no; i++) {
|
92
94
|
ofs += index->sub_indexes[i]->d;
|
95
|
+
}
|
93
96
|
|
94
97
|
std::unique_ptr<float[]> sub_x(new float[sub_d * n]);
|
95
|
-
for (idx_t i = 0; i < n; i++)
|
98
|
+
for (idx_t i = 0; i < n; i++) {
|
96
99
|
memcpy(sub_x.get() + i * sub_d,
|
97
100
|
x + ofs + i * d,
|
98
101
|
sub_d * sizeof(float));
|
102
|
+
}
|
99
103
|
sub_index->search(n, sub_x.get(), k, distances1, labels1);
|
100
|
-
if (index->verbose)
|
104
|
+
if (index->verbose) {
|
101
105
|
printf("end query shard %d\n", no);
|
106
|
+
}
|
102
107
|
};
|
103
108
|
|
104
109
|
if (!threaded) {
|
@@ -150,8 +155,9 @@ void IndexSplitVectors::reset() {
|
|
150
155
|
|
151
156
|
IndexSplitVectors::~IndexSplitVectors() {
|
152
157
|
if (own_fields) {
|
153
|
-
for (int s = 0; s < sub_indexes.size(); s++)
|
158
|
+
for (int s = 0; s < sub_indexes.size(); s++) {
|
154
159
|
delete sub_indexes[s];
|
160
|
+
}
|
155
161
|
}
|
156
162
|
}
|
157
163
|
|
@@ -10,7 +10,8 @@
|
|
10
10
|
#ifndef FAISS_METRIC_TYPE_H
|
11
11
|
#define FAISS_METRIC_TYPE_H
|
12
12
|
|
13
|
-
#include <
|
13
|
+
#include <cstdint>
|
14
|
+
#include <cstdio>
|
14
15
|
|
15
16
|
namespace faiss {
|
16
17
|
|
@@ -36,8 +37,9 @@ enum MetricType {
|
|
36
37
|
METRIC_Jaccard,
|
37
38
|
/// Squared Eucliden distance, ignoring NaNs
|
38
39
|
METRIC_NaNEuclidean,
|
39
|
-
///
|
40
|
-
|
40
|
+
/// Gower's distance - numeric dimensions are in [0,1] and categorical
|
41
|
+
/// dimensions are negative integers
|
42
|
+
METRIC_GOWER,
|
41
43
|
};
|
42
44
|
|
43
45
|
/// all vector indices are this type
|
@@ -9,9 +9,6 @@
|
|
9
9
|
|
10
10
|
#include <faiss/clone_index.h>
|
11
11
|
|
12
|
-
#include <cstdio>
|
13
|
-
#include <cstdlib>
|
14
|
-
|
15
12
|
#include <faiss/impl/FaissAssert.h>
|
16
13
|
|
17
14
|
#include <faiss/Index2Layer.h>
|
@@ -315,8 +312,9 @@ Index* Cloner::clone_Index(const Index* index) {
|
|
315
312
|
res->metric_arg = ipt->metric_arg;
|
316
313
|
|
317
314
|
res->index = clone_Index(ipt->index);
|
318
|
-
for (int i = 0; i < ipt->chain.size(); i++)
|
315
|
+
for (int i = 0; i < ipt->chain.size(); i++) {
|
319
316
|
res->chain.push_back(clone_VectorTransform(ipt->chain[i]));
|
317
|
+
}
|
320
318
|
res->own_fields = true;
|
321
319
|
return res;
|
322
320
|
} else if (
|
@@ -18,10 +18,12 @@
|
|
18
18
|
#include <faiss/IndexIDMap.h>
|
19
19
|
#include <faiss/IndexIVFFlat.h>
|
20
20
|
#include <faiss/IndexIVFPQFastScan.h>
|
21
|
+
#include <faiss/IndexIVFRaBitQ.h>
|
21
22
|
#include <faiss/IndexLSH.h>
|
22
23
|
#include <faiss/IndexNSG.h>
|
23
24
|
#include <faiss/IndexPQFastScan.h>
|
24
25
|
#include <faiss/IndexPreTransform.h>
|
26
|
+
#include <faiss/IndexRaBitQ.h>
|
25
27
|
#include <faiss/IndexRefine.h>
|
26
28
|
|
27
29
|
namespace faiss {
|
@@ -103,6 +105,8 @@ std::string reverse_index_factory(const faiss::Index* index) {
|
|
103
105
|
ivf_index)) {
|
104
106
|
return prefix + ",PQ" + std::to_string(ivfpqfs_index->pq.M) + "x" +
|
105
107
|
std::to_string(ivfpqfs_index->pq.nbits) + "fs";
|
108
|
+
} else if (dynamic_cast<const faiss::IndexIVFRaBitQ*>(ivf_index)) {
|
109
|
+
return prefix + ",RaBitQ";
|
106
110
|
}
|
107
111
|
} else if (
|
108
112
|
const faiss::IndexPreTransform* pretransform_index =
|
@@ -175,6 +179,8 @@ std::string reverse_index_factory(const faiss::Index* index) {
|
|
175
179
|
const faiss::IndexIDMap* idmap =
|
176
180
|
dynamic_cast<const faiss::IndexIDMap*>(index)) {
|
177
181
|
return std::string("IDMap,") + reverse_index_factory(idmap->index);
|
182
|
+
} else if (dynamic_cast<const faiss::IndexRaBitQ*>(index)) {
|
183
|
+
return "RaBitQ";
|
178
184
|
}
|
179
185
|
// Avoid runtime error, just return empty string for logging.
|
180
186
|
return "";
|