faiss 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +29 -3
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
|
|
16
16
|
namespace faiss {
|
|
17
17
|
|
|
18
|
+
struct IDSelector;
|
|
19
|
+
|
|
18
20
|
// When offsets list id + offset are encoded in an uint64
|
|
19
21
|
// we call this LO = list-offset
|
|
20
22
|
|
|
@@ -34,8 +36,6 @@ inline uint64_t lo_offset(uint64_t lo) {
|
|
|
34
36
|
* Direct map: a way to map back from ids to inverted lists
|
|
35
37
|
*/
|
|
36
38
|
struct DirectMap {
|
|
37
|
-
typedef Index::idx_t idx_t;
|
|
38
|
-
|
|
39
39
|
enum Type {
|
|
40
40
|
NoMap = 0, // default
|
|
41
41
|
Array = 1, // sequential ids (only for add, no add_with_ids)
|
|
@@ -91,8 +91,6 @@ struct DirectMap {
|
|
|
91
91
|
|
|
92
92
|
/// Thread-safe way of updating the direct_map
|
|
93
93
|
struct DirectMapAdd {
|
|
94
|
-
typedef Index::idx_t idx_t;
|
|
95
|
-
|
|
96
94
|
using Type = DirectMap::Type;
|
|
97
95
|
|
|
98
96
|
DirectMap& direct_map;
|
|
@@ -10,23 +10,32 @@
|
|
|
10
10
|
#include <faiss/invlists/InvertedLists.h>
|
|
11
11
|
|
|
12
12
|
#include <cstdio>
|
|
13
|
+
#include <memory>
|
|
13
14
|
|
|
14
15
|
#include <faiss/impl/FaissAssert.h>
|
|
15
16
|
#include <faiss/utils/utils.h>
|
|
16
17
|
|
|
17
18
|
namespace faiss {
|
|
18
19
|
|
|
20
|
+
InvertedListsIterator::~InvertedListsIterator() {}
|
|
21
|
+
|
|
19
22
|
/*****************************************
|
|
20
23
|
* InvertedLists implementation
|
|
21
24
|
******************************************/
|
|
22
25
|
|
|
23
26
|
InvertedLists::InvertedLists(size_t nlist, size_t code_size)
|
|
24
|
-
: nlist(nlist), code_size(code_size) {}
|
|
27
|
+
: nlist(nlist), code_size(code_size), use_iterator(false) {}
|
|
25
28
|
|
|
26
29
|
InvertedLists::~InvertedLists() {}
|
|
27
30
|
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
bool InvertedLists::is_empty(size_t list_no) const {
|
|
32
|
+
return use_iterator
|
|
33
|
+
? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
|
|
34
|
+
->is_available()
|
|
35
|
+
: list_size(list_no) == 0;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
|
|
30
39
|
assert(offset < list_size(list_no));
|
|
31
40
|
const idx_t* ids = get_ids(list_no);
|
|
32
41
|
idx_t id = ids[offset];
|
|
@@ -67,6 +76,10 @@ void InvertedLists::reset() {
|
|
|
67
76
|
}
|
|
68
77
|
}
|
|
69
78
|
|
|
79
|
+
InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
|
|
80
|
+
FAISS_THROW_MSG("get_iterator is not supported");
|
|
81
|
+
}
|
|
82
|
+
|
|
70
83
|
void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
|
|
71
84
|
#pragma omp parallel for
|
|
72
85
|
for (idx_t i = 0; i < nlist; i++) {
|
|
@@ -87,6 +100,98 @@ void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
|
|
|
87
100
|
}
|
|
88
101
|
}
|
|
89
102
|
|
|
103
|
+
size_t InvertedLists::copy_subset_to(
|
|
104
|
+
InvertedLists& oivf,
|
|
105
|
+
subset_type_t subset_type,
|
|
106
|
+
idx_t a1,
|
|
107
|
+
idx_t a2) const {
|
|
108
|
+
FAISS_THROW_IF_NOT(nlist == oivf.nlist);
|
|
109
|
+
FAISS_THROW_IF_NOT(code_size == oivf.code_size);
|
|
110
|
+
FAISS_THROW_IF_NOT_FMT(
|
|
111
|
+
subset_type >= 0 && subset_type <= 4,
|
|
112
|
+
"subset type %d not implemented",
|
|
113
|
+
subset_type);
|
|
114
|
+
size_t accu_n = 0;
|
|
115
|
+
size_t accu_a1 = 0;
|
|
116
|
+
size_t accu_a2 = 0;
|
|
117
|
+
size_t n_added = 0;
|
|
118
|
+
|
|
119
|
+
size_t ntotal = 0;
|
|
120
|
+
if (subset_type == 2) {
|
|
121
|
+
ntotal = compute_ntotal();
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
|
125
|
+
size_t n = list_size(list_no);
|
|
126
|
+
ScopedIds ids_in(this, list_no);
|
|
127
|
+
|
|
128
|
+
if (subset_type == SUBSET_TYPE_ID_RANGE) {
|
|
129
|
+
for (idx_t i = 0; i < n; i++) {
|
|
130
|
+
idx_t id = ids_in[i];
|
|
131
|
+
if (a1 <= id && id < a2) {
|
|
132
|
+
oivf.add_entry(
|
|
133
|
+
list_no,
|
|
134
|
+
get_single_id(list_no, i),
|
|
135
|
+
ScopedCodes(this, list_no, i).get());
|
|
136
|
+
n_added++;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
} else if (subset_type == SUBSET_TYPE_ID_MOD) {
|
|
140
|
+
for (idx_t i = 0; i < n; i++) {
|
|
141
|
+
idx_t id = ids_in[i];
|
|
142
|
+
if (id % a1 == a2) {
|
|
143
|
+
oivf.add_entry(
|
|
144
|
+
list_no,
|
|
145
|
+
get_single_id(list_no, i),
|
|
146
|
+
ScopedCodes(this, list_no, i).get());
|
|
147
|
+
n_added++;
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
} else if (subset_type == SUBSET_TYPE_ELEMENT_RANGE) {
|
|
151
|
+
// see what is allocated to a1 and to a2
|
|
152
|
+
size_t next_accu_n = accu_n + n;
|
|
153
|
+
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
|
|
154
|
+
size_t i1 = next_accu_a1 - accu_a1;
|
|
155
|
+
size_t next_accu_a2 = next_accu_n * a2 / ntotal;
|
|
156
|
+
size_t i2 = next_accu_a2 - accu_a2;
|
|
157
|
+
|
|
158
|
+
for (idx_t i = i1; i < i2; i++) {
|
|
159
|
+
oivf.add_entry(
|
|
160
|
+
list_no,
|
|
161
|
+
get_single_id(list_no, i),
|
|
162
|
+
ScopedCodes(this, list_no, i).get());
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
n_added += i2 - i1;
|
|
166
|
+
accu_a1 = next_accu_a1;
|
|
167
|
+
accu_a2 = next_accu_a2;
|
|
168
|
+
} else if (subset_type == SUBSET_TYPE_INVLIST_FRACTION) {
|
|
169
|
+
size_t i1 = n * a2 / a1;
|
|
170
|
+
size_t i2 = n * (a2 + 1) / a1;
|
|
171
|
+
|
|
172
|
+
for (idx_t i = i1; i < i2; i++) {
|
|
173
|
+
oivf.add_entry(
|
|
174
|
+
list_no,
|
|
175
|
+
get_single_id(list_no, i),
|
|
176
|
+
ScopedCodes(this, list_no, i).get());
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
n_added += i2 - i1;
|
|
180
|
+
} else if (subset_type == SUBSET_TYPE_INVLIST) {
|
|
181
|
+
if (list_no >= a1 && list_no < a2) {
|
|
182
|
+
oivf.add_entries(
|
|
183
|
+
list_no,
|
|
184
|
+
n,
|
|
185
|
+
ScopedIds(this, list_no).get(),
|
|
186
|
+
ScopedCodes(this, list_no).get());
|
|
187
|
+
n_added += n;
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
accu_n += n;
|
|
191
|
+
}
|
|
192
|
+
return n_added;
|
|
193
|
+
}
|
|
194
|
+
|
|
90
195
|
double InvertedLists::imbalance_factor() const {
|
|
91
196
|
std::vector<int> hist(nlist);
|
|
92
197
|
|
|
@@ -109,7 +214,9 @@ void InvertedLists::print_stats() const {
|
|
|
109
214
|
}
|
|
110
215
|
for (size_t i = 0; i < sizes.size(); i++) {
|
|
111
216
|
if (sizes[i]) {
|
|
112
|
-
printf("list size in < %
|
|
217
|
+
printf("list size in < %zu: %d instances\n",
|
|
218
|
+
static_cast<size_t>(1) << i,
|
|
219
|
+
sizes[i]);
|
|
113
220
|
}
|
|
114
221
|
}
|
|
115
222
|
}
|
|
@@ -158,7 +265,7 @@ const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const {
|
|
|
158
265
|
return codes[list_no].data();
|
|
159
266
|
}
|
|
160
267
|
|
|
161
|
-
const
|
|
268
|
+
const idx_t* ArrayInvertedLists::get_ids(size_t list_no) const {
|
|
162
269
|
assert(list_no < nlist);
|
|
163
270
|
return ids[list_no].data();
|
|
164
271
|
}
|
|
@@ -267,7 +374,7 @@ void HStackInvertedLists::release_codes(size_t, const uint8_t* codes) const {
|
|
|
267
374
|
delete[] codes;
|
|
268
375
|
}
|
|
269
376
|
|
|
270
|
-
const
|
|
377
|
+
const idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
|
|
271
378
|
idx_t *ids = new idx_t[list_size(list_no)], *c = ids;
|
|
272
379
|
|
|
273
380
|
for (int i = 0; i < ils.size(); i++) {
|
|
@@ -281,8 +388,7 @@ const Index::idx_t* HStackInvertedLists::get_ids(size_t list_no) const {
|
|
|
281
388
|
return ids;
|
|
282
389
|
}
|
|
283
390
|
|
|
284
|
-
|
|
285
|
-
const {
|
|
391
|
+
idx_t HStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
|
|
286
392
|
for (int i = 0; i < ils.size(); i++) {
|
|
287
393
|
const InvertedLists* il = ils[i];
|
|
288
394
|
size_t sz = il->list_size(list_no);
|
|
@@ -312,8 +418,6 @@ void HStackInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
|
|
|
312
418
|
|
|
313
419
|
namespace {
|
|
314
420
|
|
|
315
|
-
using idx_t = InvertedLists::idx_t;
|
|
316
|
-
|
|
317
421
|
idx_t translate_list_no(const SliceInvertedLists* sil, idx_t list_no) {
|
|
318
422
|
FAISS_THROW_IF_NOT(list_no >= 0 && list_no < sil->nlist);
|
|
319
423
|
return list_no + sil->i0;
|
|
@@ -349,12 +453,11 @@ void SliceInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
|
|
|
349
453
|
return il->release_codes(translate_list_no(this, list_no), codes);
|
|
350
454
|
}
|
|
351
455
|
|
|
352
|
-
const
|
|
456
|
+
const idx_t* SliceInvertedLists::get_ids(size_t list_no) const {
|
|
353
457
|
return il->get_ids(translate_list_no(this, list_no));
|
|
354
458
|
}
|
|
355
459
|
|
|
356
|
-
|
|
357
|
-
const {
|
|
460
|
+
idx_t SliceInvertedLists::get_single_id(size_t list_no, size_t offset) const {
|
|
358
461
|
return il->get_single_id(translate_list_no(this, list_no), offset);
|
|
359
462
|
}
|
|
360
463
|
|
|
@@ -380,8 +483,6 @@ void SliceInvertedLists::prefetch_lists(const idx_t* list_nos, int nlist)
|
|
|
380
483
|
|
|
381
484
|
namespace {
|
|
382
485
|
|
|
383
|
-
using idx_t = InvertedLists::idx_t;
|
|
384
|
-
|
|
385
486
|
// find the invlist this number belongs to
|
|
386
487
|
int translate_list_no(const VStackInvertedLists* vil, idx_t list_no) {
|
|
387
488
|
FAISS_THROW_IF_NOT(list_no >= 0 && list_no < vil->nlist);
|
|
@@ -449,14 +550,13 @@ void VStackInvertedLists::release_codes(size_t list_no, const uint8_t* codes)
|
|
|
449
550
|
return ils[i]->release_codes(list_no, codes);
|
|
450
551
|
}
|
|
451
552
|
|
|
452
|
-
const
|
|
553
|
+
const idx_t* VStackInvertedLists::get_ids(size_t list_no) const {
|
|
453
554
|
int i = translate_list_no(this, list_no);
|
|
454
555
|
list_no -= cumsz[i];
|
|
455
556
|
return ils[i]->get_ids(list_no);
|
|
456
557
|
}
|
|
457
558
|
|
|
458
|
-
|
|
459
|
-
const {
|
|
559
|
+
idx_t VStackInvertedLists::get_single_id(size_t list_no, size_t offset) const {
|
|
460
560
|
int i = translate_list_no(this, list_no);
|
|
461
561
|
list_no -= cumsz[i];
|
|
462
562
|
return ils[i]->get_single_id(list_no, offset);
|
|
@@ -15,11 +15,18 @@
|
|
|
15
15
|
* the interface.
|
|
16
16
|
*/
|
|
17
17
|
|
|
18
|
-
#include <faiss/
|
|
18
|
+
#include <faiss/MetricType.h>
|
|
19
19
|
#include <vector>
|
|
20
20
|
|
|
21
21
|
namespace faiss {
|
|
22
22
|
|
|
23
|
+
struct InvertedListsIterator {
|
|
24
|
+
virtual ~InvertedListsIterator();
|
|
25
|
+
virtual bool is_available() const = 0;
|
|
26
|
+
virtual void next() = 0;
|
|
27
|
+
virtual std::pair<idx_t, const uint8_t*> get_id_and_codes() = 0;
|
|
28
|
+
};
|
|
29
|
+
|
|
23
30
|
/** Table of inverted lists
|
|
24
31
|
* multithreading rules:
|
|
25
32
|
* - concurrent read accesses are allowed
|
|
@@ -28,13 +35,14 @@ namespace faiss {
|
|
|
28
35
|
* are allowed
|
|
29
36
|
*/
|
|
30
37
|
struct InvertedLists {
|
|
31
|
-
typedef Index::idx_t idx_t;
|
|
32
|
-
|
|
33
38
|
size_t nlist; ///< number of possible key values
|
|
34
39
|
size_t code_size; ///< code size per vector in bytes
|
|
40
|
+
bool use_iterator;
|
|
35
41
|
|
|
36
42
|
InvertedLists(size_t nlist, size_t code_size);
|
|
37
43
|
|
|
44
|
+
virtual ~InvertedLists();
|
|
45
|
+
|
|
38
46
|
/// used for BlockInvertedLists, where the codes are packed into groups
|
|
39
47
|
/// and the individual code size is meaningless
|
|
40
48
|
static const size_t INVALID_CODE_SIZE = static_cast<size_t>(-1);
|
|
@@ -42,9 +50,15 @@ struct InvertedLists {
|
|
|
42
50
|
/*************************
|
|
43
51
|
* Read only functions */
|
|
44
52
|
|
|
53
|
+
// check if the list is empty
|
|
54
|
+
bool is_empty(size_t list_no) const;
|
|
55
|
+
|
|
45
56
|
/// get the size of a list
|
|
46
57
|
virtual size_t list_size(size_t list_no) const = 0;
|
|
47
58
|
|
|
59
|
+
/// get iterable for lists that use_iterator
|
|
60
|
+
virtual InvertedListsIterator* get_iterator(size_t list_no) const;
|
|
61
|
+
|
|
48
62
|
/** get the codes for an inverted list
|
|
49
63
|
* must be released by release_codes
|
|
50
64
|
*
|
|
@@ -105,10 +119,36 @@ struct InvertedLists {
|
|
|
105
119
|
|
|
106
120
|
virtual void reset();
|
|
107
121
|
|
|
122
|
+
/*************************
|
|
123
|
+
* high level functions */
|
|
124
|
+
|
|
108
125
|
/// move all entries from oivf (empty on output)
|
|
109
126
|
void merge_from(InvertedLists* oivf, size_t add_id);
|
|
110
127
|
|
|
111
|
-
|
|
128
|
+
// how to copy a subset of elements from the inverted lists
|
|
129
|
+
// This depends on two integers, a1 and a2.
|
|
130
|
+
enum subset_type_t : int {
|
|
131
|
+
// depends on IDs
|
|
132
|
+
SUBSET_TYPE_ID_RANGE = 0, // copies ids in [a1, a2)
|
|
133
|
+
SUBSET_TYPE_ID_MOD = 1, // copies ids if id % a1 == a2
|
|
134
|
+
// depends on order within invlists
|
|
135
|
+
SUBSET_TYPE_ELEMENT_RANGE =
|
|
136
|
+
2, // copies fractions of invlists so that a1 elements are left
|
|
137
|
+
// before and a2 after
|
|
138
|
+
SUBSET_TYPE_INVLIST_FRACTION =
|
|
139
|
+
3, // take fraction a2 out of a1 from each invlist, 0 <= a2 < a1
|
|
140
|
+
// copy only inverted lists a1:a2
|
|
141
|
+
SUBSET_TYPE_INVLIST = 4
|
|
142
|
+
};
|
|
143
|
+
|
|
144
|
+
/** copy a subset of the entries index to the other index
|
|
145
|
+
* @return number of entries copied
|
|
146
|
+
*/
|
|
147
|
+
size_t copy_subset_to(
|
|
148
|
+
InvertedLists& other,
|
|
149
|
+
subset_type_t subset_type,
|
|
150
|
+
idx_t a1,
|
|
151
|
+
idx_t a2) const;
|
|
112
152
|
|
|
113
153
|
/*************************
|
|
114
154
|
* statistics */
|
|
@@ -154,7 +154,7 @@ struct OnDiskInvertedLists::OngoingPrefetch {
|
|
|
154
154
|
const OnDiskInvertedLists* od = pf->od;
|
|
155
155
|
od->locks->lock_1(list_no);
|
|
156
156
|
size_t n = od->list_size(list_no);
|
|
157
|
-
const
|
|
157
|
+
const idx_t* idx = od->get_ids(list_no);
|
|
158
158
|
const uint8_t* codes = od->get_codes(list_no);
|
|
159
159
|
int cs = 0;
|
|
160
160
|
for (size_t i = 0; i < n; i++) {
|
|
@@ -389,7 +389,7 @@ const uint8_t* OnDiskInvertedLists::get_codes(size_t list_no) const {
|
|
|
389
389
|
return ptr + lists[list_no].offset;
|
|
390
390
|
}
|
|
391
391
|
|
|
392
|
-
const
|
|
392
|
+
const idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
|
|
393
393
|
if (lists[list_no].offset == INVALID_OFFSET) {
|
|
394
394
|
return nullptr;
|
|
395
395
|
}
|
|
@@ -781,7 +781,7 @@ InvertedLists* OnDiskInvertedListsIOHook::read_ArrayInvertedLists(
|
|
|
781
781
|
OnDiskInvertedLists::List& l = ails->lists[i];
|
|
782
782
|
l.size = l.capacity = sizes[i];
|
|
783
783
|
l.offset = o;
|
|
784
|
-
o += l.size * (sizeof(
|
|
784
|
+
o += l.size * (sizeof(idx_t) + ails->code_size);
|
|
785
785
|
}
|
|
786
786
|
// resume normal reading of file
|
|
787
787
|
fseek(fdesc, o, SEEK_SET);
|
|
@@ -31,7 +31,7 @@ struct OnDiskOneList {
|
|
|
31
31
|
|
|
32
32
|
/** On-disk storage of inverted lists.
|
|
33
33
|
*
|
|
34
|
-
* The data is stored in a mmapped chunk of memory (base
|
|
34
|
+
* The data is stored in a mmapped chunk of memory (base pointer ptr,
|
|
35
35
|
* size totsize). Each list is a range of memory that contains (object
|
|
36
36
|
* List) that contains:
|
|
37
37
|
*
|
|
@@ -118,7 +118,7 @@ PyCallbackIDSelector::PyCallbackIDSelector(PyObject* callback)
|
|
|
118
118
|
Py_INCREF(callback);
|
|
119
119
|
}
|
|
120
120
|
|
|
121
|
-
bool PyCallbackIDSelector::is_member(idx_t id) const {
|
|
121
|
+
bool PyCallbackIDSelector::is_member(faiss::idx_t id) const {
|
|
122
122
|
FAISS_THROW_IF_NOT((id >> 32) == 0);
|
|
123
123
|
PyThreadLock gil;
|
|
124
124
|
PyObject* result = PyObject_CallFunction(callback, "(n)", int(id));
|
|
@@ -98,7 +98,9 @@ struct AlignedTableTightAlloc {
|
|
|
98
98
|
AlignedTableTightAlloc<T, A>& operator=(
|
|
99
99
|
const AlignedTableTightAlloc<T, A>& other) {
|
|
100
100
|
resize(other.numel);
|
|
101
|
-
|
|
101
|
+
if (numel > 0) {
|
|
102
|
+
memcpy(ptr, other.ptr, sizeof(T) * numel);
|
|
103
|
+
}
|
|
102
104
|
return *this;
|
|
103
105
|
}
|
|
104
106
|
|
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
/* Function for soft heap */
|
|
11
11
|
|
|
12
|
+
#include <faiss/impl/FaissAssert.h>
|
|
12
13
|
#include <faiss/utils/Heap.h>
|
|
13
14
|
|
|
14
15
|
namespace faiss {
|
|
@@ -32,7 +33,7 @@ void HeapArray<C>::addn(size_t nj, const T* vin, TI j0, size_t i0, int64_t ni) {
|
|
|
32
33
|
if (ni == -1)
|
|
33
34
|
ni = nh;
|
|
34
35
|
assert(i0 >= 0 && i0 + ni <= nh);
|
|
35
|
-
#pragma omp parallel for
|
|
36
|
+
#pragma omp parallel for if (ni * nj > 100000)
|
|
36
37
|
for (int64_t i = i0; i < i0 + ni; i++) {
|
|
37
38
|
T* __restrict simi = get_val(i);
|
|
38
39
|
TI* __restrict idxi = get_ids(i);
|
|
@@ -62,7 +63,7 @@ void HeapArray<C>::addn_with_ids(
|
|
|
62
63
|
if (ni == -1)
|
|
63
64
|
ni = nh;
|
|
64
65
|
assert(i0 >= 0 && i0 + ni <= nh);
|
|
65
|
-
#pragma omp parallel for
|
|
66
|
+
#pragma omp parallel for if (ni * nj > 100000)
|
|
66
67
|
for (int64_t i = i0; i < i0 + ni; i++) {
|
|
67
68
|
T* __restrict simi = get_val(i);
|
|
68
69
|
TI* __restrict idxi = get_ids(i);
|
|
@@ -78,9 +79,38 @@ void HeapArray<C>::addn_with_ids(
|
|
|
78
79
|
}
|
|
79
80
|
}
|
|
80
81
|
|
|
82
|
+
template <typename C>
|
|
83
|
+
void HeapArray<C>::addn_query_subset_with_ids(
|
|
84
|
+
size_t nsubset,
|
|
85
|
+
const TI* subset,
|
|
86
|
+
size_t nj,
|
|
87
|
+
const T* vin,
|
|
88
|
+
const TI* id_in,
|
|
89
|
+
int64_t id_stride) {
|
|
90
|
+
FAISS_THROW_IF_NOT_MSG(id_in, "anonymous ids not supported");
|
|
91
|
+
if (id_stride < 0) {
|
|
92
|
+
id_stride = nj;
|
|
93
|
+
}
|
|
94
|
+
#pragma omp parallel for if (nsubset * nj > 100000)
|
|
95
|
+
for (int64_t si = 0; si < nsubset; si++) {
|
|
96
|
+
T i = subset[si];
|
|
97
|
+
T* __restrict simi = get_val(i);
|
|
98
|
+
TI* __restrict idxi = get_ids(i);
|
|
99
|
+
const T* ip_line = vin + si * nj;
|
|
100
|
+
const TI* id_line = id_in + si * id_stride;
|
|
101
|
+
|
|
102
|
+
for (size_t j = 0; j < nj; j++) {
|
|
103
|
+
T ip = ip_line[j];
|
|
104
|
+
if (C::cmp(simi[0], ip)) {
|
|
105
|
+
heap_replace_top<C>(k, simi, idxi, ip, id_line[j]);
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
81
111
|
template <typename C>
|
|
82
112
|
void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
|
|
83
|
-
#pragma omp parallel for
|
|
113
|
+
#pragma omp parallel for if (nh * k > 100000)
|
|
84
114
|
for (int64_t j = 0; j < nh; j++) {
|
|
85
115
|
int64_t imin = -1;
|
|
86
116
|
typename C::T xval = C::Crev::neutral();
|
|
@@ -109,4 +139,110 @@ template struct HeapArray<CMax<float, int64_t>>;
|
|
|
109
139
|
template struct HeapArray<CMin<int, int64_t>>;
|
|
110
140
|
template struct HeapArray<CMax<int, int64_t>>;
|
|
111
141
|
|
|
142
|
+
/**********************************************************
|
|
143
|
+
* merge knn search results
|
|
144
|
+
**********************************************************/
|
|
145
|
+
|
|
146
|
+
/** Merge result tables from several shards. The per-shard results are assumed
|
|
147
|
+
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
|
|
148
|
+
* element heap because we want the best (ie. lowest for L2) result to be on
|
|
149
|
+
* top, not the worst.
|
|
150
|
+
*
|
|
151
|
+
* @param all_distances size (nshard, n, k)
|
|
152
|
+
* @param all_labels size (nshard, n, k)
|
|
153
|
+
* @param distances output distances, size (n, k)
|
|
154
|
+
* @param labels output labels, size (n, k)
|
|
155
|
+
*/
|
|
156
|
+
template <class idx_t, class C>
|
|
157
|
+
void merge_knn_results(
|
|
158
|
+
size_t n,
|
|
159
|
+
size_t k,
|
|
160
|
+
typename C::TI nshard,
|
|
161
|
+
const typename C::T* all_distances,
|
|
162
|
+
const idx_t* all_labels,
|
|
163
|
+
typename C::T* distances,
|
|
164
|
+
idx_t* labels) {
|
|
165
|
+
using distance_t = typename C::T;
|
|
166
|
+
if (k == 0) {
|
|
167
|
+
return;
|
|
168
|
+
}
|
|
169
|
+
long stride = n * k;
|
|
170
|
+
#pragma omp parallel if (n * nshard * k > 100000)
|
|
171
|
+
{
|
|
172
|
+
std::vector<int> buf(2 * nshard);
|
|
173
|
+
// index in each shard's result list
|
|
174
|
+
int* pointer = buf.data();
|
|
175
|
+
// (shard_ids, heap_vals): heap that indexes
|
|
176
|
+
// shard -> current distance for this shard
|
|
177
|
+
int* shard_ids = pointer + nshard;
|
|
178
|
+
std::vector<distance_t> buf2(nshard);
|
|
179
|
+
distance_t* heap_vals = buf2.data();
|
|
180
|
+
#pragma omp for
|
|
181
|
+
for (long i = 0; i < n; i++) {
|
|
182
|
+
// the heap maps values to the shard where they are
|
|
183
|
+
// produced.
|
|
184
|
+
const distance_t* D_in = all_distances + i * k;
|
|
185
|
+
const idx_t* I_in = all_labels + i * k;
|
|
186
|
+
int heap_size = 0;
|
|
187
|
+
|
|
188
|
+
// push the first element of each shard (if not -1)
|
|
189
|
+
for (long s = 0; s < nshard; s++) {
|
|
190
|
+
pointer[s] = 0;
|
|
191
|
+
if (I_in[stride * s] >= 0) {
|
|
192
|
+
heap_push<C>(
|
|
193
|
+
++heap_size,
|
|
194
|
+
heap_vals,
|
|
195
|
+
shard_ids,
|
|
196
|
+
D_in[stride * s],
|
|
197
|
+
s);
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
distance_t* D = distances + i * k;
|
|
202
|
+
idx_t* I = labels + i * k;
|
|
203
|
+
|
|
204
|
+
int j;
|
|
205
|
+
for (j = 0; j < k && heap_size > 0; j++) {
|
|
206
|
+
// pop element from best shard
|
|
207
|
+
int s = shard_ids[0]; // top of heap
|
|
208
|
+
int& p = pointer[s];
|
|
209
|
+
D[j] = heap_vals[0];
|
|
210
|
+
I[j] = I_in[stride * s + p];
|
|
211
|
+
|
|
212
|
+
// pop from shard, advance pointer for this shard
|
|
213
|
+
heap_pop<C>(heap_size--, heap_vals, shard_ids);
|
|
214
|
+
p++;
|
|
215
|
+
if (p < k && I_in[stride * s + p] >= 0) {
|
|
216
|
+
heap_push<C>(
|
|
217
|
+
++heap_size,
|
|
218
|
+
heap_vals,
|
|
219
|
+
shard_ids,
|
|
220
|
+
D_in[stride * s + p],
|
|
221
|
+
s);
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
for (; j < k; j++) {
|
|
225
|
+
I[j] = -1;
|
|
226
|
+
D[j] = C::Crev::neutral();
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// explicit instanciations
|
|
233
|
+
#define INSTANTIATE(C, distance_t) \
|
|
234
|
+
template void merge_knn_results<int64_t, C<distance_t, int>>( \
|
|
235
|
+
size_t, \
|
|
236
|
+
size_t, \
|
|
237
|
+
int, \
|
|
238
|
+
const distance_t*, \
|
|
239
|
+
const int64_t*, \
|
|
240
|
+
distance_t*, \
|
|
241
|
+
int64_t*);
|
|
242
|
+
|
|
243
|
+
INSTANTIATE(CMin, float);
|
|
244
|
+
INSTANTIATE(CMax, float);
|
|
245
|
+
INSTANTIATE(CMin, int32_t);
|
|
246
|
+
INSTANTIATE(CMax, int32_t);
|
|
247
|
+
|
|
112
248
|
} // namespace faiss
|
|
@@ -413,6 +413,19 @@ struct HeapArray {
|
|
|
413
413
|
size_t i0 = 0,
|
|
414
414
|
int64_t ni = -1);
|
|
415
415
|
|
|
416
|
+
/** same as addn_with_ids, but for just a subset of queries
|
|
417
|
+
*
|
|
418
|
+
* @param nsubset number of query entries to update
|
|
419
|
+
* @param subset indexes of queries to update, in 0..nh-1, size nsubset
|
|
420
|
+
*/
|
|
421
|
+
void addn_query_subset_with_ids(
|
|
422
|
+
size_t nsubset,
|
|
423
|
+
const TI* subset,
|
|
424
|
+
size_t nj,
|
|
425
|
+
const T* vin,
|
|
426
|
+
const TI* id_in = nullptr,
|
|
427
|
+
int64_t id_stride = 0);
|
|
428
|
+
|
|
416
429
|
/// reorder all the heaps
|
|
417
430
|
void reorder();
|
|
418
431
|
|
|
@@ -431,7 +444,7 @@ typedef HeapArray<CMin<int, int64_t>> int_minheap_array_t;
|
|
|
431
444
|
typedef HeapArray<CMax<float, int64_t>> float_maxheap_array_t;
|
|
432
445
|
typedef HeapArray<CMax<int, int64_t>> int_maxheap_array_t;
|
|
433
446
|
|
|
434
|
-
// The heap templates are
|
|
447
|
+
// The heap templates are instantiated explicitly in Heap.cpp
|
|
435
448
|
|
|
436
449
|
/*********************************************************************
|
|
437
450
|
* Indirect heaps: instead of having
|
|
@@ -492,6 +505,27 @@ inline void indirect_heap_push(
|
|
|
492
505
|
bh_ids[i] = id;
|
|
493
506
|
}
|
|
494
507
|
|
|
508
|
+
/** Merge result tables from several shards. The per-shard results are assumed
|
|
509
|
+
* to be sorted. Note that the C comparator is reversed w.r.t. the usual top-k
|
|
510
|
+
* element heap because we want the best (ie. lowest for L2) result to be on
|
|
511
|
+
* top, not the worst. Also, it needs to hold an index of a shard id (ie.
|
|
512
|
+
* usually int32 is more than enough).
|
|
513
|
+
*
|
|
514
|
+
* @param all_distances size (nshard, n, k)
|
|
515
|
+
* @param all_labels size (nshard, n, k)
|
|
516
|
+
* @param distances output distances, size (n, k)
|
|
517
|
+
* @param labels output labels, size (n, k)
|
|
518
|
+
*/
|
|
519
|
+
template <class idx_t, class C>
|
|
520
|
+
void merge_knn_results(
|
|
521
|
+
size_t n,
|
|
522
|
+
size_t k,
|
|
523
|
+
typename C::TI nshard,
|
|
524
|
+
const typename C::T* all_distances,
|
|
525
|
+
const idx_t* all_labels,
|
|
526
|
+
typename C::T* distances,
|
|
527
|
+
idx_t* labels);
|
|
528
|
+
|
|
495
529
|
} // namespace faiss
|
|
496
530
|
|
|
497
531
|
#endif /* FAISS_Heap_h */
|