faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -21,49 +21,6 @@ namespace faiss {
|
|
21
21
|
|
22
22
|
struct IndexHNSW;
|
23
23
|
|
24
|
-
struct ReconstructFromNeighbors {
|
25
|
-
typedef HNSW::storage_idx_t storage_idx_t;
|
26
|
-
|
27
|
-
const IndexHNSW& index;
|
28
|
-
size_t M; // number of neighbors
|
29
|
-
size_t k; // number of codebook entries
|
30
|
-
size_t nsq; // number of subvectors
|
31
|
-
size_t code_size;
|
32
|
-
int k_reorder; // nb to reorder. -1 = all
|
33
|
-
|
34
|
-
std::vector<float> codebook; // size nsq * k * (M + 1)
|
35
|
-
|
36
|
-
std::vector<uint8_t> codes; // size ntotal * code_size
|
37
|
-
size_t ntotal;
|
38
|
-
size_t d, dsub; // derived values
|
39
|
-
|
40
|
-
explicit ReconstructFromNeighbors(
|
41
|
-
const IndexHNSW& index,
|
42
|
-
size_t k = 256,
|
43
|
-
size_t nsq = 1);
|
44
|
-
|
45
|
-
/// codes must be added in the correct order and the IndexHNSW
|
46
|
-
/// must be populated and sorted
|
47
|
-
void add_codes(size_t n, const float* x);
|
48
|
-
|
49
|
-
size_t compute_distances(
|
50
|
-
size_t n,
|
51
|
-
const idx_t* shortlist,
|
52
|
-
const float* query,
|
53
|
-
float* distances) const;
|
54
|
-
|
55
|
-
/// called by add_codes
|
56
|
-
void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
|
57
|
-
|
58
|
-
/// called by compute_distances
|
59
|
-
void reconstruct(storage_idx_t i, float* x, float* tmp) const;
|
60
|
-
|
61
|
-
void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
|
62
|
-
|
63
|
-
/// get the M+1 -by-d table for neighbor coordinates for vector i
|
64
|
-
void get_neighbor_table(storage_idx_t i, float* out) const;
|
65
|
-
};
|
66
|
-
|
67
24
|
/** The HNSW index is a normal random-access index with a HNSW
|
68
25
|
* link structure built on top */
|
69
26
|
|
@@ -74,10 +31,8 @@ struct IndexHNSW : Index {
|
|
74
31
|
HNSW hnsw;
|
75
32
|
|
76
33
|
// the sequential storage
|
77
|
-
bool own_fields;
|
78
|
-
Index* storage;
|
79
|
-
|
80
|
-
ReconstructFromNeighbors* reconstruct_from_neighbors;
|
34
|
+
bool own_fields = false;
|
35
|
+
Index* storage = nullptr;
|
81
36
|
|
82
37
|
explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
|
83
38
|
explicit IndexHNSW(Index* storage, int M = 32);
|
@@ -98,6 +53,13 @@ struct IndexHNSW : Index {
|
|
98
53
|
idx_t* labels,
|
99
54
|
const SearchParameters* params = nullptr) const override;
|
100
55
|
|
56
|
+
void range_search(
|
57
|
+
idx_t n,
|
58
|
+
const float* x,
|
59
|
+
float radius,
|
60
|
+
RangeSearchResult* result,
|
61
|
+
const SearchParameters* params = nullptr) const override;
|
62
|
+
|
101
63
|
void reconstruct(idx_t key, float* recons) const override;
|
102
64
|
|
103
65
|
void reset() override;
|
@@ -134,6 +96,8 @@ struct IndexHNSW : Index {
|
|
134
96
|
void reorder_links();
|
135
97
|
|
136
98
|
void link_singletons();
|
99
|
+
|
100
|
+
void permute_entries(const idx_t* perm);
|
137
101
|
};
|
138
102
|
|
139
103
|
/** Flat index topped with with a HNSW structure to access elements
|
@@ -150,7 +114,7 @@ struct IndexHNSWFlat : IndexHNSW {
|
|
150
114
|
*/
|
151
115
|
struct IndexHNSWPQ : IndexHNSW {
|
152
116
|
IndexHNSWPQ();
|
153
|
-
IndexHNSWPQ(int d, int pq_m, int M);
|
117
|
+
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
|
154
118
|
void train(idx_t n, const float* x) override;
|
155
119
|
};
|
156
120
|
|
@@ -9,31 +9,43 @@
|
|
9
9
|
|
10
10
|
#include <faiss/IndexIDMap.h>
|
11
11
|
|
12
|
-
#include <stdint.h>
|
13
12
|
#include <cinttypes>
|
13
|
+
#include <cstdint>
|
14
14
|
#include <cstdio>
|
15
15
|
#include <limits>
|
16
16
|
|
17
17
|
#include <faiss/impl/AuxIndexStructures.h>
|
18
18
|
#include <faiss/impl/FaissAssert.h>
|
19
|
-
#include <faiss/impl/IDSelector.h>
|
20
19
|
#include <faiss/utils/Heap.h>
|
21
20
|
#include <faiss/utils/WorkerThread.h>
|
22
21
|
|
23
22
|
namespace faiss {
|
24
23
|
|
24
|
+
namespace {
|
25
|
+
|
26
|
+
// IndexBinary needs to update the code_size when d is set...
|
27
|
+
|
28
|
+
void sync_d(Index* index) {}
|
29
|
+
|
30
|
+
void sync_d(IndexBinary* index) {
|
31
|
+
FAISS_THROW_IF_NOT(index->d % 8 == 0);
|
32
|
+
index->code_size = index->d / 8;
|
33
|
+
}
|
34
|
+
|
35
|
+
} // anonymous namespace
|
36
|
+
|
25
37
|
/*****************************************************
|
26
38
|
* IndexIDMap implementation
|
27
39
|
*******************************************************/
|
28
40
|
|
29
41
|
template <typename IndexT>
|
30
|
-
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
|
31
|
-
: index(index), own_fields(false) {
|
42
|
+
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
|
32
43
|
FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
|
33
44
|
this->is_trained = index->is_trained;
|
34
45
|
this->metric_type = index->metric_type;
|
35
46
|
this->verbose = index->verbose;
|
36
47
|
this->d = index->d;
|
48
|
+
sync_d(this);
|
37
49
|
}
|
38
50
|
|
39
51
|
template <typename IndexT>
|
@@ -71,6 +83,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
|
|
71
83
|
this->ntotal = index->ntotal;
|
72
84
|
}
|
73
85
|
|
86
|
+
namespace {
|
87
|
+
|
88
|
+
/// RAII object to reset the IDSelector in the params object
|
89
|
+
struct ScopedSelChange {
|
90
|
+
SearchParameters* params = nullptr;
|
91
|
+
IDSelector* old_sel = nullptr;
|
92
|
+
|
93
|
+
void set(SearchParameters* params_2, IDSelector* new_sel) {
|
94
|
+
this->params = params_2;
|
95
|
+
old_sel = params_2->sel;
|
96
|
+
params_2->sel = new_sel;
|
97
|
+
}
|
98
|
+
~ScopedSelChange() {
|
99
|
+
if (params) {
|
100
|
+
params->sel = old_sel;
|
101
|
+
}
|
102
|
+
}
|
103
|
+
};
|
104
|
+
|
105
|
+
} // namespace
|
106
|
+
|
74
107
|
template <typename IndexT>
|
75
108
|
void IndexIDMapTemplate<IndexT>::search(
|
76
109
|
idx_t n,
|
@@ -79,9 +112,26 @@ void IndexIDMapTemplate<IndexT>::search(
|
|
79
112
|
typename IndexT::distance_t* distances,
|
80
113
|
idx_t* labels,
|
81
114
|
const SearchParameters* params) const {
|
82
|
-
|
83
|
-
|
84
|
-
|
115
|
+
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
|
116
|
+
ScopedSelChange sel_change;
|
117
|
+
|
118
|
+
if (params && params->sel) {
|
119
|
+
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
|
120
|
+
|
121
|
+
if (!idtrans) {
|
122
|
+
/*
|
123
|
+
FAISS_THROW_IF_NOT_MSG(
|
124
|
+
idtrans,
|
125
|
+
"IndexIDMap requires an IDSelectorTranslated on input");
|
126
|
+
*/
|
127
|
+
// then make an idtrans and force it into the SearchParameters
|
128
|
+
// (hence the const_cast)
|
129
|
+
auto params_non_const = const_cast<SearchParameters*>(params);
|
130
|
+
this_idtrans.sel = params->sel;
|
131
|
+
sel_change.set(params_non_const, &this_idtrans);
|
132
|
+
}
|
133
|
+
}
|
134
|
+
index->search(n, x, k, distances, labels, params);
|
85
135
|
idx_t* li = labels;
|
86
136
|
#pragma omp parallel for
|
87
137
|
for (idx_t i = 0; i < n * k; i++) {
|
@@ -96,9 +146,16 @@ void IndexIDMapTemplate<IndexT>::range_search(
|
|
96
146
|
typename IndexT::distance_t radius,
|
97
147
|
RangeSearchResult* result,
|
98
148
|
const SearchParameters* params) const {
|
99
|
-
|
100
|
-
|
101
|
-
|
149
|
+
if (params) {
|
150
|
+
SearchParameters internal_search_parameters;
|
151
|
+
IDSelectorTranslated id_selector_translated(id_map, params->sel);
|
152
|
+
internal_search_parameters.sel = &id_selector_translated;
|
153
|
+
|
154
|
+
index->range_search(n, x, radius, result, &internal_search_parameters);
|
155
|
+
} else {
|
156
|
+
index->range_search(n, x, radius, result);
|
157
|
+
}
|
158
|
+
|
102
159
|
#pragma omp parallel for
|
103
160
|
for (idx_t i = 0; i < result->lims[result->nq]; i++) {
|
104
161
|
result->labels[i] = result->labels[i] < 0 ? result->labels[i]
|
@@ -106,26 +163,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
|
|
106
163
|
}
|
107
164
|
}
|
108
165
|
|
109
|
-
namespace {
|
110
|
-
|
111
|
-
struct IDTranslatedSelector : IDSelector {
|
112
|
-
const std::vector<int64_t>& id_map;
|
113
|
-
const IDSelector& sel;
|
114
|
-
IDTranslatedSelector(
|
115
|
-
const std::vector<int64_t>& id_map,
|
116
|
-
const IDSelector& sel)
|
117
|
-
: id_map(id_map), sel(sel) {}
|
118
|
-
bool is_member(idx_t id) const override {
|
119
|
-
return sel.is_member(id_map[id]);
|
120
|
-
}
|
121
|
-
};
|
122
|
-
|
123
|
-
} // namespace
|
124
|
-
|
125
166
|
template <typename IndexT>
|
126
167
|
size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
|
127
168
|
// remove in sub-index first
|
128
|
-
|
169
|
+
IDSelectorTranslated sel2(id_map, &sel);
|
129
170
|
size_t nremove = index->remove_ids(sel2);
|
130
171
|
|
131
172
|
int64_t j = 0;
|
@@ -232,7 +273,7 @@ void IndexIDMap2Template<IndexT>::reconstruct(
|
|
232
273
|
typename IndexT::component_t* recons) const {
|
233
274
|
try {
|
234
275
|
this->index->reconstruct(rev_map.at(key), recons);
|
235
|
-
} catch (const std::out_of_range&
|
276
|
+
} catch (const std::out_of_range&) {
|
236
277
|
FAISS_THROW_FMT("key %" PRId64 " not found", key);
|
237
278
|
}
|
238
279
|
}
|
@@ -9,6 +9,7 @@
|
|
9
9
|
|
10
10
|
#include <faiss/Index.h>
|
11
11
|
#include <faiss/IndexBinary.h>
|
12
|
+
#include <faiss/impl/IDSelector.h>
|
12
13
|
|
13
14
|
#include <unordered_map>
|
14
15
|
#include <vector>
|
@@ -21,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
|
|
21
22
|
using component_t = typename IndexT::component_t;
|
22
23
|
using distance_t = typename IndexT::distance_t;
|
23
24
|
|
24
|
-
IndexT* index;
|
25
|
-
bool own_fields; ///! whether pointers are deleted in destructo
|
25
|
+
IndexT* index = nullptr; ///! the sub-index
|
26
|
+
bool own_fields = false; ///! whether pointers are deleted in destructo
|
26
27
|
std::vector<idx_t> id_map;
|
27
28
|
|
28
29
|
explicit IndexIDMapTemplate(IndexT* index);
|
@@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
|
|
102
103
|
using IndexIDMap2 = IndexIDMap2Template<Index>;
|
103
104
|
using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
|
104
105
|
|
106
|
+
// IDSelector that translates the ids using an IDMap
|
107
|
+
struct IDSelectorTranslated : IDSelector {
|
108
|
+
const std::vector<int64_t>& id_map;
|
109
|
+
const IDSelector* sel;
|
110
|
+
|
111
|
+
IDSelectorTranslated(
|
112
|
+
const std::vector<int64_t>& id_map,
|
113
|
+
const IDSelector* sel)
|
114
|
+
: id_map(id_map), sel(sel) {}
|
115
|
+
|
116
|
+
IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
|
117
|
+
: id_map(index_idmap.id_map), sel(sel) {}
|
118
|
+
|
119
|
+
IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
|
120
|
+
: id_map(index_idmap.id_map), sel(sel) {}
|
121
|
+
|
122
|
+
bool is_member(idx_t id) const override {
|
123
|
+
return sel->is_member(id_map[id]);
|
124
|
+
}
|
125
|
+
};
|
126
|
+
|
105
127
|
} // namespace faiss
|
@@ -11,6 +11,7 @@
|
|
11
11
|
|
12
12
|
#include <omp.h>
|
13
13
|
#include <cstdint>
|
14
|
+
#include <memory>
|
14
15
|
#include <mutex>
|
15
16
|
|
16
17
|
#include <algorithm>
|
@@ -45,7 +46,7 @@ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
|
|
45
46
|
cp.niter = 10;
|
46
47
|
}
|
47
48
|
|
48
|
-
Level1Quantizer::Level1Quantizer()
|
49
|
+
Level1Quantizer::Level1Quantizer() = default;
|
49
50
|
|
50
51
|
Level1Quantizer::~Level1Quantizer() {
|
51
52
|
if (own_fields) {
|
@@ -172,7 +173,7 @@ IndexIVF::IndexIVF(
|
|
172
173
|
}
|
173
174
|
}
|
174
175
|
|
175
|
-
IndexIVF::IndexIVF()
|
176
|
+
IndexIVF::IndexIVF() = default;
|
176
177
|
|
177
178
|
void IndexIVF::add(idx_t n, const float* x) {
|
178
179
|
add_with_ids(n, x, nullptr);
|
@@ -202,7 +203,8 @@ void IndexIVF::add_core(
|
|
202
203
|
idx_t n,
|
203
204
|
const float* x,
|
204
205
|
const idx_t* xids,
|
205
|
-
const idx_t* coarse_idx
|
206
|
+
const idx_t* coarse_idx,
|
207
|
+
void* inverted_list_context) {
|
206
208
|
// do some blocking to avoid excessive allocs
|
207
209
|
idx_t bs = 65536;
|
208
210
|
if (n > bs) {
|
@@ -217,7 +219,8 @@ void IndexIVF::add_core(
|
|
217
219
|
i1 - i0,
|
218
220
|
x + i0 * d,
|
219
221
|
xids ? xids + i0 : nullptr,
|
220
|
-
coarse_idx + i0
|
222
|
+
coarse_idx + i0,
|
223
|
+
inverted_list_context);
|
221
224
|
}
|
222
225
|
return;
|
223
226
|
}
|
@@ -248,7 +251,10 @@ void IndexIVF::add_core(
|
|
248
251
|
if (list_no >= 0 && list_no % nt == rank) {
|
249
252
|
idx_t id = xids ? xids[i] : ntotal + i;
|
250
253
|
size_t ofs = invlists->add_entry(
|
251
|
-
list_no,
|
254
|
+
list_no,
|
255
|
+
id,
|
256
|
+
flat_codes.get() + i * code_size,
|
257
|
+
inverted_list_context);
|
252
258
|
|
253
259
|
dm_adder.add(i, list_no, ofs);
|
254
260
|
|
@@ -375,7 +381,7 @@ void IndexIVF::search(
|
|
375
381
|
indexIVF_stats.add(stats[slice]);
|
376
382
|
}
|
377
383
|
} else {
|
378
|
-
// handle
|
384
|
+
// handle parallelization at level below (or don't run in parallel at
|
379
385
|
// all)
|
380
386
|
sub_search_func(n, x, distances, labels, &indexIVF_stats);
|
381
387
|
}
|
@@ -444,11 +450,13 @@ void IndexIVF::search_preassigned(
|
|
444
450
|
: pmode == 1 ? nprobe > 1
|
445
451
|
: nprobe * n > 1);
|
446
452
|
|
453
|
+
void* inverted_list_context =
|
454
|
+
params ? params->inverted_list_context : nullptr;
|
455
|
+
|
447
456
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
|
448
457
|
{
|
449
|
-
InvertedListScanner
|
450
|
-
get_InvertedListScanner(store_pairs, sel);
|
451
|
-
ScopeDeleter1<InvertedListScanner> del(scanner);
|
458
|
+
std::unique_ptr<InvertedListScanner> scanner(
|
459
|
+
get_InvertedListScanner(store_pairs, sel));
|
452
460
|
|
453
461
|
/*****************************************************
|
454
462
|
* Depending on parallel_mode, there are two possible ways
|
@@ -507,7 +515,7 @@ void IndexIVF::search_preassigned(
|
|
507
515
|
nlist);
|
508
516
|
|
509
517
|
// don't waste time on empty lists
|
510
|
-
if (invlists->is_empty(key)) {
|
518
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
511
519
|
return (size_t)0;
|
512
520
|
}
|
513
521
|
|
@@ -520,7 +528,7 @@ void IndexIVF::search_preassigned(
|
|
520
528
|
size_t list_size = 0;
|
521
529
|
|
522
530
|
std::unique_ptr<InvertedListsIterator> it(
|
523
|
-
invlists->get_iterator(key));
|
531
|
+
invlists->get_iterator(key, inverted_list_context));
|
524
532
|
|
525
533
|
nheap += scanner->iterate_codes(
|
526
534
|
it.get(), simi, idxi, k, list_size);
|
@@ -539,7 +547,8 @@ void IndexIVF::search_preassigned(
|
|
539
547
|
const idx_t* ids = nullptr;
|
540
548
|
|
541
549
|
if (!store_pairs) {
|
542
|
-
sids
|
550
|
+
sids = std::make_unique<InvertedLists::ScopedIds>(
|
551
|
+
invlists, key);
|
543
552
|
ids = sids->get();
|
544
553
|
}
|
545
554
|
|
@@ -659,7 +668,6 @@ void IndexIVF::search_preassigned(
|
|
659
668
|
#pragma omp for schedule(dynamic)
|
660
669
|
for (int64_t ij = 0; ij < n * nprobe; ij++) {
|
661
670
|
size_t i = ij / nprobe;
|
662
|
-
size_t j = ij % nprobe;
|
663
671
|
|
664
672
|
scanner->set_query(x + i * d);
|
665
673
|
init_result(local_dis.data(), local_idx.data());
|
@@ -696,12 +704,13 @@ void IndexIVF::search_preassigned(
|
|
696
704
|
}
|
697
705
|
}
|
698
706
|
|
699
|
-
if (ivf_stats) {
|
700
|
-
ivf_stats
|
701
|
-
ivf_stats->nlist += nlistv;
|
702
|
-
ivf_stats->ndis += ndis;
|
703
|
-
ivf_stats->nheap_updates += nheap;
|
707
|
+
if (ivf_stats == nullptr) {
|
708
|
+
ivf_stats = &indexIVF_stats;
|
704
709
|
}
|
710
|
+
ivf_stats->nq += n;
|
711
|
+
ivf_stats->nlist += nlistv;
|
712
|
+
ivf_stats->ndis += ndis;
|
713
|
+
ivf_stats->nheap_updates += nheap;
|
705
714
|
}
|
706
715
|
|
707
716
|
void IndexIVF::range_search(
|
@@ -781,6 +790,9 @@ void IndexIVF::range_search_preassigned(
|
|
781
790
|
: pmode == 1 ? nprobe > 1
|
782
791
|
: nprobe * nx > 1);
|
783
792
|
|
793
|
+
void* inverted_list_context =
|
794
|
+
params ? params->inverted_list_context : nullptr;
|
795
|
+
|
784
796
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
|
785
797
|
{
|
786
798
|
RangeSearchPartialResult pres(result);
|
@@ -802,7 +814,7 @@ void IndexIVF::range_search_preassigned(
|
|
802
814
|
ik,
|
803
815
|
nlist);
|
804
816
|
|
805
|
-
if (invlists->is_empty(key)) {
|
817
|
+
if (invlists->is_empty(key, inverted_list_context)) {
|
806
818
|
return;
|
807
819
|
}
|
808
820
|
|
@@ -811,7 +823,7 @@ void IndexIVF::range_search_preassigned(
|
|
811
823
|
scanner->set_list(key, coarse_dis[i * nprobe + ik]);
|
812
824
|
if (invlists->use_iterator) {
|
813
825
|
std::unique_ptr<InvertedListsIterator> it(
|
814
|
-
invlists->get_iterator(key));
|
826
|
+
invlists->get_iterator(key, inverted_list_context));
|
815
827
|
|
816
828
|
scanner->iterate_codes_range(
|
817
829
|
it.get(), radius, qres, list_size);
|
@@ -891,17 +903,18 @@ void IndexIVF::range_search_preassigned(
|
|
891
903
|
}
|
892
904
|
}
|
893
905
|
|
894
|
-
if (stats) {
|
895
|
-
stats
|
896
|
-
stats->nlist += nlistv;
|
897
|
-
stats->ndis += ndis;
|
906
|
+
if (stats == nullptr) {
|
907
|
+
stats = &indexIVF_stats;
|
898
908
|
}
|
909
|
+
stats->nq += nx;
|
910
|
+
stats->nlist += nlistv;
|
911
|
+
stats->ndis += ndis;
|
899
912
|
}
|
900
913
|
|
901
914
|
InvertedListScanner* IndexIVF::get_InvertedListScanner(
|
902
915
|
bool /*store_pairs*/,
|
903
916
|
const IDSelector* /* sel */) const {
|
904
|
-
|
917
|
+
FAISS_THROW_MSG("get_InvertedListScanner not implemented");
|
905
918
|
}
|
906
919
|
|
907
920
|
void IndexIVF::reconstruct(idx_t key, float* recons) const {
|
@@ -973,14 +986,12 @@ void IndexIVF::search_and_reconstruct(
|
|
973
986
|
std::min(nlist, params ? params->nprobe : this->nprobe);
|
974
987
|
FAISS_THROW_IF_NOT(nprobe > 0);
|
975
988
|
|
976
|
-
idx_t
|
977
|
-
|
978
|
-
float* coarse_dis = new float[n * nprobe];
|
979
|
-
ScopeDeleter<float> del2(coarse_dis);
|
989
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
990
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
980
991
|
|
981
|
-
quantizer->search(n, x, nprobe, coarse_dis, idx);
|
992
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
982
993
|
|
983
|
-
invlists->prefetch_lists(idx, n * nprobe);
|
994
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
984
995
|
|
985
996
|
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
986
997
|
// and offset into `codes` for reconstruction
|
@@ -988,29 +999,94 @@ void IndexIVF::search_and_reconstruct(
|
|
988
999
|
n,
|
989
1000
|
x,
|
990
1001
|
k,
|
991
|
-
idx,
|
992
|
-
coarse_dis,
|
1002
|
+
idx.get(),
|
1003
|
+
coarse_dis.get(),
|
993
1004
|
distances,
|
994
1005
|
labels,
|
995
1006
|
true /* store_pairs */,
|
996
1007
|
params);
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
+
#pragma omp parallel for if (n * k > 1000)
|
1009
|
+
for (idx_t ij = 0; ij < n * k; ij++) {
|
1010
|
+
idx_t key = labels[ij];
|
1011
|
+
float* reconstructed = recons + ij * d;
|
1012
|
+
if (key < 0) {
|
1013
|
+
// Fill with NaNs
|
1014
|
+
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
1015
|
+
} else {
|
1016
|
+
int list_no = lo_listno(key);
|
1017
|
+
int offset = lo_offset(key);
|
1018
|
+
|
1019
|
+
// Update label to the actual id
|
1020
|
+
labels[ij] = invlists->get_single_id(list_no, offset);
|
1021
|
+
|
1022
|
+
reconstruct_from_offset(list_no, offset, reconstructed);
|
1023
|
+
}
|
1024
|
+
}
|
1025
|
+
}
|
1026
|
+
|
1027
|
+
void IndexIVF::search_and_return_codes(
|
1028
|
+
idx_t n,
|
1029
|
+
const float* x,
|
1030
|
+
idx_t k,
|
1031
|
+
float* distances,
|
1032
|
+
idx_t* labels,
|
1033
|
+
uint8_t* codes,
|
1034
|
+
bool include_listno,
|
1035
|
+
const SearchParameters* params_in) const {
|
1036
|
+
const IVFSearchParameters* params = nullptr;
|
1037
|
+
if (params_in) {
|
1038
|
+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
1039
|
+
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
|
1040
|
+
}
|
1041
|
+
const size_t nprobe =
|
1042
|
+
std::min(nlist, params ? params->nprobe : this->nprobe);
|
1043
|
+
FAISS_THROW_IF_NOT(nprobe > 0);
|
1044
|
+
|
1045
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
1046
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
1047
|
+
|
1048
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
|
1049
|
+
|
1050
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
1051
|
+
|
1052
|
+
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
1053
|
+
// and offset into `codes` for reconstruction
|
1054
|
+
search_preassigned(
|
1055
|
+
n,
|
1056
|
+
x,
|
1057
|
+
k,
|
1058
|
+
idx.get(),
|
1059
|
+
coarse_dis.get(),
|
1060
|
+
distances,
|
1061
|
+
labels,
|
1062
|
+
true /* store_pairs */,
|
1063
|
+
params);
|
1064
|
+
|
1065
|
+
size_t code_size_1 = code_size;
|
1066
|
+
if (include_listno) {
|
1067
|
+
code_size_1 += coarse_code_size();
|
1068
|
+
}
|
1008
1069
|
|
1009
|
-
|
1010
|
-
|
1070
|
+
#pragma omp parallel for if (n * k > 1000)
|
1071
|
+
for (idx_t ij = 0; ij < n * k; ij++) {
|
1072
|
+
idx_t key = labels[ij];
|
1073
|
+
uint8_t* code1 = codes + ij * code_size_1;
|
1011
1074
|
|
1012
|
-
|
1075
|
+
if (key < 0) {
|
1076
|
+
// Fill with 0xff
|
1077
|
+
memset(code1, -1, code_size_1);
|
1078
|
+
} else {
|
1079
|
+
int list_no = lo_listno(key);
|
1080
|
+
int offset = lo_offset(key);
|
1081
|
+
const uint8_t* cc = invlists->get_single_code(list_no, offset);
|
1082
|
+
|
1083
|
+
labels[ij] = invlists->get_single_id(list_no, offset);
|
1084
|
+
|
1085
|
+
if (include_listno) {
|
1086
|
+
encode_listno(list_no, code1);
|
1087
|
+
code1 += code_size_1 - code_size;
|
1013
1088
|
}
|
1089
|
+
memcpy(code1, cc, code_size);
|
1014
1090
|
}
|
1015
1091
|
}
|
1016
1092
|
}
|
@@ -1061,22 +1137,52 @@ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
|
|
1061
1137
|
}
|
1062
1138
|
|
1063
1139
|
void IndexIVF::train(idx_t n, const float* x) {
|
1064
|
-
if (verbose)
|
1140
|
+
if (verbose) {
|
1065
1141
|
printf("Training level-1 quantizer\n");
|
1142
|
+
}
|
1066
1143
|
|
1067
1144
|
train_q1(n, x, verbose, metric_type);
|
1068
1145
|
|
1069
|
-
if (verbose)
|
1146
|
+
if (verbose) {
|
1070
1147
|
printf("Training IVF residual\n");
|
1148
|
+
}
|
1149
|
+
|
1150
|
+
// optional subsampling
|
1151
|
+
idx_t max_nt = train_encoder_num_vectors();
|
1152
|
+
if (max_nt <= 0) {
|
1153
|
+
max_nt = (size_t)1 << 35;
|
1154
|
+
}
|
1155
|
+
|
1156
|
+
TransformedVectors tv(
|
1157
|
+
x, fvecs_maybe_subsample(d, (size_t*)&n, max_nt, x, verbose));
|
1158
|
+
|
1159
|
+
if (by_residual) {
|
1160
|
+
std::vector<idx_t> assign(n);
|
1161
|
+
quantizer->assign(n, tv.x, assign.data());
|
1162
|
+
|
1163
|
+
std::vector<float> residuals(n * d);
|
1164
|
+
quantizer->compute_residual_n(n, tv.x, residuals.data(), assign.data());
|
1165
|
+
|
1166
|
+
train_encoder(n, residuals.data(), assign.data());
|
1167
|
+
} else {
|
1168
|
+
train_encoder(n, tv.x, nullptr);
|
1169
|
+
}
|
1071
1170
|
|
1072
|
-
train_residual(n, x);
|
1073
1171
|
is_trained = true;
|
1074
1172
|
}
|
1075
1173
|
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1174
|
+
idx_t IndexIVF::train_encoder_num_vectors() const {
|
1175
|
+
return 0;
|
1176
|
+
}
|
1177
|
+
|
1178
|
+
void IndexIVF::train_encoder(
|
1179
|
+
idx_t /*n*/,
|
1180
|
+
const float* /*x*/,
|
1181
|
+
const idx_t* assign) {
|
1079
1182
|
// does nothing by default
|
1183
|
+
if (verbose) {
|
1184
|
+
printf("IndexIVF: no residual training\n");
|
1185
|
+
}
|
1080
1186
|
}
|
1081
1187
|
|
1082
1188
|
bool check_compatible_for_merge_expensive_check = true;
|