faiss 0.1.7 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +7 -7
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +8 -2
- data/ext/faiss/index.cpp +102 -69
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +0 -5
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +26 -12
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
|
@@ -5,49 +5,38 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
|
|
9
8
|
/*
|
|
10
9
|
* Structures that collect search results from distance computations
|
|
11
10
|
*/
|
|
12
11
|
|
|
13
12
|
#pragma once
|
|
14
13
|
|
|
15
|
-
|
|
14
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
16
15
|
#include <faiss/utils/Heap.h>
|
|
17
16
|
#include <faiss/utils/partitioning.h>
|
|
18
|
-
#include <faiss/impl/AuxIndexStructures.h>
|
|
19
|
-
|
|
20
17
|
|
|
21
18
|
namespace faiss {
|
|
22
19
|
|
|
23
|
-
|
|
24
|
-
|
|
25
20
|
/*****************************************************************
|
|
26
21
|
* Heap based result handler
|
|
27
22
|
*****************************************************************/
|
|
28
23
|
|
|
29
|
-
|
|
30
|
-
template<class C>
|
|
24
|
+
template <class C>
|
|
31
25
|
struct HeapResultHandler {
|
|
32
|
-
|
|
33
26
|
using T = typename C::T;
|
|
34
27
|
using TI = typename C::TI;
|
|
35
28
|
|
|
36
29
|
int nq;
|
|
37
|
-
T
|
|
38
|
-
TI
|
|
30
|
+
T* heap_dis_tab;
|
|
31
|
+
TI* heap_ids_tab;
|
|
39
32
|
|
|
40
|
-
int64_t k;
|
|
33
|
+
int64_t k; // number of results to keep
|
|
41
34
|
|
|
42
|
-
HeapResultHandler(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
|
|
48
|
-
{
|
|
49
|
-
|
|
50
|
-
}
|
|
35
|
+
HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
|
|
36
|
+
: nq(nq),
|
|
37
|
+
heap_dis_tab(heap_dis_tab),
|
|
38
|
+
heap_ids_tab(heap_ids_tab),
|
|
39
|
+
k(k) {}
|
|
51
40
|
|
|
52
41
|
/******************************************************
|
|
53
42
|
* API for 1 result at a time (each SingleResultHandler is
|
|
@@ -55,20 +44,20 @@ struct HeapResultHandler {
|
|
|
55
44
|
*/
|
|
56
45
|
|
|
57
46
|
struct SingleResultHandler {
|
|
58
|
-
HeapResultHandler
|
|
47
|
+
HeapResultHandler& hr;
|
|
59
48
|
size_t k;
|
|
60
49
|
|
|
61
|
-
T
|
|
62
|
-
TI
|
|
50
|
+
T* heap_dis;
|
|
51
|
+
TI* heap_ids;
|
|
63
52
|
T thresh;
|
|
64
53
|
|
|
65
|
-
SingleResultHandler(HeapResultHandler
|
|
54
|
+
SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
|
|
66
55
|
|
|
67
56
|
/// begin results for query # i
|
|
68
57
|
void begin(size_t i) {
|
|
69
58
|
heap_dis = hr.heap_dis_tab + i * k;
|
|
70
59
|
heap_ids = hr.heap_ids_tab + i * k;
|
|
71
|
-
heap_heapify<C>
|
|
60
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
72
61
|
thresh = heap_dis[0];
|
|
73
62
|
}
|
|
74
63
|
|
|
@@ -82,11 +71,10 @@ struct HeapResultHandler {
|
|
|
82
71
|
|
|
83
72
|
/// series of results for query i is done
|
|
84
73
|
void end() {
|
|
85
|
-
heap_reorder<C>
|
|
74
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
86
75
|
}
|
|
87
76
|
};
|
|
88
77
|
|
|
89
|
-
|
|
90
78
|
/******************************************************
|
|
91
79
|
* API for multiple results (called from 1 thread)
|
|
92
80
|
*/
|
|
@@ -97,20 +85,21 @@ struct HeapResultHandler {
|
|
|
97
85
|
void begin_multiple(size_t i0, size_t i1) {
|
|
98
86
|
this->i0 = i0;
|
|
99
87
|
this->i1 = i1;
|
|
100
|
-
for(size_t i = i0; i < i1; i++) {
|
|
101
|
-
heap_heapify<C>
|
|
88
|
+
for (size_t i = i0; i < i1; i++) {
|
|
89
|
+
heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
102
90
|
}
|
|
103
91
|
}
|
|
104
92
|
|
|
105
93
|
/// add results for query i0..i1 and j0..j1
|
|
106
|
-
void add_results(size_t j0, size_t j1, const T
|
|
107
|
-
|
|
108
|
-
for (
|
|
109
|
-
T
|
|
110
|
-
TI
|
|
94
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
95
|
+
#pragma omp parallel for
|
|
96
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
97
|
+
T* heap_dis = heap_dis_tab + i * k;
|
|
98
|
+
TI* heap_ids = heap_ids_tab + i * k;
|
|
99
|
+
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
111
100
|
T thresh = heap_dis[0];
|
|
112
101
|
for (size_t j = j0; j < j1; j++) {
|
|
113
|
-
T dis =
|
|
102
|
+
T dis = dis_tab_i[j];
|
|
114
103
|
if (C::cmp(thresh, dis)) {
|
|
115
104
|
heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
|
|
116
105
|
thresh = heap_dis[0];
|
|
@@ -122,11 +111,10 @@ struct HeapResultHandler {
|
|
|
122
111
|
/// series of results for queries i0..i1 is done
|
|
123
112
|
void end_multiple() {
|
|
124
113
|
// maybe parallel for
|
|
125
|
-
for(size_t i = i0; i < i1; i++) {
|
|
126
|
-
heap_reorder<C>
|
|
114
|
+
for (size_t i = i0; i < i1; i++) {
|
|
115
|
+
heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
127
116
|
}
|
|
128
117
|
}
|
|
129
|
-
|
|
130
118
|
};
|
|
131
119
|
|
|
132
120
|
/*****************************************************************
|
|
@@ -138,31 +126,25 @@ struct HeapResultHandler {
|
|
|
138
126
|
* distance array.
|
|
139
127
|
*****************************************************************/
|
|
140
128
|
|
|
141
|
-
|
|
142
|
-
|
|
143
129
|
/// Reservoir for a single query
|
|
144
|
-
template<class C>
|
|
130
|
+
template <class C>
|
|
145
131
|
struct ReservoirTopN {
|
|
146
132
|
using T = typename C::T;
|
|
147
133
|
using TI = typename C::TI;
|
|
148
134
|
|
|
149
|
-
T
|
|
150
|
-
TI
|
|
135
|
+
T* vals;
|
|
136
|
+
TI* ids;
|
|
151
137
|
|
|
152
|
-
size_t i;
|
|
153
|
-
size_t n;
|
|
154
|
-
size_t capacity;
|
|
138
|
+
size_t i; // number of stored elements
|
|
139
|
+
size_t n; // number of requested elements
|
|
140
|
+
size_t capacity; // size of storage
|
|
155
141
|
|
|
156
142
|
T threshold; // current threshold
|
|
157
143
|
|
|
158
144
|
ReservoirTopN() {}
|
|
159
145
|
|
|
160
|
-
ReservoirTopN(
|
|
161
|
-
|
|
162
|
-
T *vals, TI *ids
|
|
163
|
-
):
|
|
164
|
-
vals(vals), ids(ids),
|
|
165
|
-
i(0), n(n), capacity(capacity) {
|
|
146
|
+
ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
|
|
147
|
+
: vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
|
|
166
148
|
assert(n < capacity);
|
|
167
149
|
threshold = C::neutral();
|
|
168
150
|
}
|
|
@@ -184,55 +166,47 @@ struct ReservoirTopN {
|
|
|
184
166
|
assert(i == capacity);
|
|
185
167
|
|
|
186
168
|
threshold = partition_fuzzy<C>(
|
|
187
|
-
|
|
188
|
-
&i);
|
|
169
|
+
vals, ids, capacity, n, (capacity + n) / 2, &i);
|
|
189
170
|
}
|
|
190
171
|
|
|
191
|
-
void to_result(T
|
|
192
|
-
|
|
172
|
+
void to_result(T* heap_dis, TI* heap_ids) const {
|
|
193
173
|
for (int j = 0; j < std::min(i, n); j++) {
|
|
194
|
-
heap_push<C>(
|
|
195
|
-
j + 1, heap_dis, heap_ids,
|
|
196
|
-
vals[j], ids[j]
|
|
197
|
-
);
|
|
174
|
+
heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
|
|
198
175
|
}
|
|
199
176
|
|
|
200
177
|
if (i < n) {
|
|
201
|
-
heap_reorder<C>
|
|
178
|
+
heap_reorder<C>(i, heap_dis, heap_ids);
|
|
202
179
|
// add empty results
|
|
203
|
-
heap_heapify<C>
|
|
180
|
+
heap_heapify<C>(n - i, heap_dis + i, heap_ids + i);
|
|
204
181
|
} else {
|
|
205
182
|
// add remaining elements
|
|
206
|
-
heap_addn<C>
|
|
207
|
-
heap_reorder<C>
|
|
183
|
+
heap_addn<C>(n, heap_dis, heap_ids, vals + n, ids + n, i - n);
|
|
184
|
+
heap_reorder<C>(n, heap_dis, heap_ids);
|
|
208
185
|
}
|
|
209
|
-
|
|
210
186
|
}
|
|
211
|
-
|
|
212
187
|
};
|
|
213
188
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
template<class C>
|
|
189
|
+
template <class C>
|
|
217
190
|
struct ReservoirResultHandler {
|
|
218
|
-
|
|
219
191
|
using T = typename C::T;
|
|
220
192
|
using TI = typename C::TI;
|
|
221
193
|
|
|
222
194
|
int nq;
|
|
223
|
-
T
|
|
224
|
-
TI
|
|
195
|
+
T* heap_dis_tab;
|
|
196
|
+
TI* heap_ids_tab;
|
|
225
197
|
|
|
226
|
-
int64_t k;
|
|
198
|
+
int64_t k; // number of results to keep
|
|
227
199
|
size_t capacity; // capacity of the reservoirs
|
|
228
200
|
|
|
229
201
|
ReservoirResultHandler(
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
202
|
+
size_t nq,
|
|
203
|
+
T* heap_dis_tab,
|
|
204
|
+
TI* heap_ids_tab,
|
|
205
|
+
size_t k)
|
|
206
|
+
: nq(nq),
|
|
207
|
+
heap_dis_tab(heap_dis_tab),
|
|
208
|
+
heap_ids_tab(heap_ids_tab),
|
|
209
|
+
k(k) {
|
|
236
210
|
// double then round up to multiple of 16 (for SIMD alignment)
|
|
237
211
|
capacity = (2 * k + 15) & ~15;
|
|
238
212
|
}
|
|
@@ -243,23 +217,26 @@ struct ReservoirResultHandler {
|
|
|
243
217
|
*/
|
|
244
218
|
|
|
245
219
|
struct SingleResultHandler {
|
|
246
|
-
ReservoirResultHandler
|
|
220
|
+
ReservoirResultHandler& hr;
|
|
247
221
|
|
|
248
222
|
std::vector<T> reservoir_dis;
|
|
249
223
|
std::vector<TI> reservoir_ids;
|
|
250
224
|
ReservoirTopN<C> res1;
|
|
251
225
|
|
|
252
|
-
SingleResultHandler(ReservoirResultHandler
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
226
|
+
SingleResultHandler(ReservoirResultHandler& hr)
|
|
227
|
+
: hr(hr),
|
|
228
|
+
reservoir_dis(hr.capacity),
|
|
229
|
+
reservoir_ids(hr.capacity) {}
|
|
256
230
|
|
|
257
231
|
size_t i;
|
|
258
232
|
|
|
259
233
|
/// begin results for query # i
|
|
260
234
|
void begin(size_t i) {
|
|
261
235
|
res1 = ReservoirTopN<C>(
|
|
262
|
-
|
|
236
|
+
hr.k,
|
|
237
|
+
hr.capacity,
|
|
238
|
+
reservoir_dis.data(),
|
|
239
|
+
reservoir_ids.data());
|
|
263
240
|
this->i = i;
|
|
264
241
|
}
|
|
265
242
|
|
|
@@ -270,8 +247,8 @@ struct ReservoirResultHandler {
|
|
|
270
247
|
|
|
271
248
|
/// series of results for query i is done
|
|
272
249
|
void end() {
|
|
273
|
-
T
|
|
274
|
-
TI
|
|
250
|
+
T* heap_dis = hr.heap_dis_tab + i * hr.k;
|
|
251
|
+
TI* heap_ids = hr.heap_ids_tab + i * hr.k;
|
|
275
252
|
res1.to_result(heap_dis, heap_ids);
|
|
276
253
|
}
|
|
277
254
|
};
|
|
@@ -295,20 +272,22 @@ struct ReservoirResultHandler {
|
|
|
295
272
|
reservoirs.clear();
|
|
296
273
|
for (size_t i = i0; i < i1; i++) {
|
|
297
274
|
reservoirs.emplace_back(
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
275
|
+
k,
|
|
276
|
+
capacity,
|
|
277
|
+
reservoir_dis.data() + (i - i0) * capacity,
|
|
278
|
+
reservoir_ids.data() + (i - i0) * capacity);
|
|
302
279
|
}
|
|
303
280
|
}
|
|
304
281
|
|
|
305
282
|
/// add results for query i0..i1 and j0..j1
|
|
306
|
-
void add_results(size_t j0, size_t j1, const T
|
|
283
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
307
284
|
// maybe parallel for
|
|
308
|
-
|
|
309
|
-
|
|
285
|
+
#pragma omp parallel for
|
|
286
|
+
for (int64_t i = i0; i < i1; i++) {
|
|
287
|
+
ReservoirTopN<C>& reservoir = reservoirs[i - i0];
|
|
288
|
+
const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
|
|
310
289
|
for (size_t j = j0; j < j1; j++) {
|
|
311
|
-
T dis =
|
|
290
|
+
T dis = dis_tab_i[j];
|
|
312
291
|
reservoir.add(dis, j);
|
|
313
292
|
}
|
|
314
293
|
}
|
|
@@ -317,32 +296,27 @@ struct ReservoirResultHandler {
|
|
|
317
296
|
/// series of results for queries i0..i1 is done
|
|
318
297
|
void end_multiple() {
|
|
319
298
|
// maybe parallel for
|
|
320
|
-
for(size_t i = i0; i < i1; i++) {
|
|
299
|
+
for (size_t i = i0; i < i1; i++) {
|
|
321
300
|
reservoirs[i - i0].to_result(
|
|
322
|
-
|
|
301
|
+
heap_dis_tab + i * k, heap_ids_tab + i * k);
|
|
323
302
|
}
|
|
324
303
|
}
|
|
325
|
-
|
|
326
304
|
};
|
|
327
305
|
|
|
328
|
-
|
|
329
306
|
/*****************************************************************
|
|
330
307
|
* Result handler for range searches
|
|
331
308
|
*****************************************************************/
|
|
332
309
|
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
template<class C>
|
|
310
|
+
template <class C>
|
|
336
311
|
struct RangeSearchResultHandler {
|
|
337
312
|
using T = typename C::T;
|
|
338
313
|
using TI = typename C::TI;
|
|
339
314
|
|
|
340
|
-
RangeSearchResult
|
|
315
|
+
RangeSearchResult* res;
|
|
341
316
|
float radius;
|
|
342
317
|
|
|
343
|
-
RangeSearchResultHandler(RangeSearchResult
|
|
344
|
-
|
|
345
|
-
{}
|
|
318
|
+
RangeSearchResultHandler(RangeSearchResult* res, float radius)
|
|
319
|
+
: res(res), radius(radius) {}
|
|
346
320
|
|
|
347
321
|
/******************************************************
|
|
348
322
|
* API for 1 result at a time (each SingleResultHandler is
|
|
@@ -353,11 +327,10 @@ struct RangeSearchResultHandler {
|
|
|
353
327
|
// almost the same interface as RangeSearchResultHandler
|
|
354
328
|
RangeSearchPartialResult pres;
|
|
355
329
|
float radius;
|
|
356
|
-
RangeQueryResult
|
|
330
|
+
RangeQueryResult* qr = nullptr;
|
|
357
331
|
|
|
358
|
-
SingleResultHandler(RangeSearchResultHandler
|
|
359
|
-
|
|
360
|
-
{}
|
|
332
|
+
SingleResultHandler(RangeSearchResultHandler& rh)
|
|
333
|
+
: pres(rh.res), radius(rh.radius) {}
|
|
361
334
|
|
|
362
335
|
/// begin results for query # i
|
|
363
336
|
void begin(size_t i) {
|
|
@@ -366,15 +339,13 @@ struct RangeSearchResultHandler {
|
|
|
366
339
|
|
|
367
340
|
/// add one result for query i
|
|
368
341
|
void add_result(T dis, TI idx) {
|
|
369
|
-
|
|
370
342
|
if (C::cmp(radius, dis)) {
|
|
371
343
|
qr->add(dis, idx);
|
|
372
344
|
}
|
|
373
345
|
}
|
|
374
346
|
|
|
375
347
|
/// series of results for query i is done
|
|
376
|
-
void end() {
|
|
377
|
-
}
|
|
348
|
+
void end() {}
|
|
378
349
|
|
|
379
350
|
~SingleResultHandler() {
|
|
380
351
|
pres.finalize();
|
|
@@ -387,8 +358,8 @@ struct RangeSearchResultHandler {
|
|
|
387
358
|
|
|
388
359
|
size_t i0, i1;
|
|
389
360
|
|
|
390
|
-
std::vector
|
|
391
|
-
std::vector
|
|
361
|
+
std::vector<RangeSearchPartialResult*> partial_results;
|
|
362
|
+
std::vector<size_t> j0s;
|
|
392
363
|
int pr = 0;
|
|
393
364
|
|
|
394
365
|
/// begin
|
|
@@ -399,8 +370,8 @@ struct RangeSearchResultHandler {
|
|
|
399
370
|
|
|
400
371
|
/// add results for query i0..i1 and j0..j1
|
|
401
372
|
|
|
402
|
-
void add_results(size_t j0, size_t j1, const T
|
|
403
|
-
RangeSearchPartialResult
|
|
373
|
+
void add_results(size_t j0, size_t j1, const T* dis_tab) {
|
|
374
|
+
RangeSearchPartialResult* pres;
|
|
404
375
|
// there is one RangeSearchPartialResult structure per j0
|
|
405
376
|
// (= block of columns of the large distance matrix)
|
|
406
377
|
// it is a bit tricky to find the poper PartialResult structure
|
|
@@ -414,39 +385,32 @@ struct RangeSearchResultHandler {
|
|
|
414
385
|
pres = partial_results[pr];
|
|
415
386
|
pr++;
|
|
416
387
|
} else { // did not find this j0
|
|
417
|
-
pres = new RangeSearchPartialResult
|
|
388
|
+
pres = new RangeSearchPartialResult(res);
|
|
418
389
|
partial_results.push_back(pres);
|
|
419
390
|
j0s.push_back(j0);
|
|
420
391
|
pr = partial_results.size();
|
|
421
392
|
}
|
|
422
393
|
|
|
423
394
|
for (size_t i = i0; i < i1; i++) {
|
|
424
|
-
const float
|
|
425
|
-
RangeQueryResult
|
|
395
|
+
const float* ip_line = dis_tab + (i - i0) * (j1 - j0);
|
|
396
|
+
RangeQueryResult& qres = pres->new_result(i);
|
|
426
397
|
|
|
427
398
|
for (size_t j = j0; j < j1; j++) {
|
|
428
399
|
float dis = *ip_line++;
|
|
429
400
|
if (C::cmp(radius, dis)) {
|
|
430
|
-
qres.add
|
|
401
|
+
qres.add(dis, j);
|
|
431
402
|
}
|
|
432
403
|
}
|
|
433
404
|
}
|
|
434
405
|
}
|
|
435
406
|
|
|
436
|
-
void end_multiple() {
|
|
437
|
-
|
|
438
|
-
}
|
|
407
|
+
void end_multiple() {}
|
|
439
408
|
|
|
440
409
|
~RangeSearchResultHandler() {
|
|
441
410
|
if (partial_results.size() > 0) {
|
|
442
|
-
RangeSearchPartialResult::merge
|
|
411
|
+
RangeSearchPartialResult::merge(partial_results);
|
|
443
412
|
}
|
|
444
413
|
}
|
|
445
|
-
|
|
446
414
|
};
|
|
447
415
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
} // namespace faiss
|
|
452
|
-
|
|
416
|
+
} // namespace faiss
|