faiss 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,559 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <vector>
|
11
|
+
#include <algorithm>
|
12
|
+
#include <type_traits>
|
13
|
+
|
14
|
+
#include <faiss/utils/Heap.h>
|
15
|
+
#include <faiss/utils/simdlib.h>
|
16
|
+
|
17
|
+
#include <faiss/utils/AlignedTable.h>
|
18
|
+
#include <faiss/utils/partitioning.h>
|
19
|
+
#include <faiss/impl/platform_macros.h>
|
20
|
+
|
21
|
+
/** This file contains callbacks for kernels that compute distances.
|
22
|
+
*
|
23
|
+
* The SIMDResultHandler object is intended to be templated and inlined.
|
24
|
+
* Methods:
|
25
|
+
* - handle(): called when 32 distances are computed and provided in two
|
26
|
+
* simd16uint16. (q, b) indicate which entry it is in the block.
|
27
|
+
* - set_block_origin(): set the sub-matrix that is being computed
|
28
|
+
*/
|
29
|
+
|
30
|
+
namespace faiss {
|
31
|
+
|
32
|
+
namespace simd_result_handlers {
|
33
|
+
|
34
|
+
|
35
|
+
/** Dummy structure that just computes a checksum on results
|
36
|
+
* (to avoid the computation to be optimized away) */
|
37
|
+
struct DummyResultHandler {
|
38
|
+
size_t cs = 0;
|
39
|
+
|
40
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
41
|
+
cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
|
42
|
+
}
|
43
|
+
|
44
|
+
void set_block_origin(size_t, size_t) {
|
45
|
+
}
|
46
|
+
};
|
47
|
+
|
48
|
+
/** memorize results in a nq-by-nb matrix.
|
49
|
+
*
|
50
|
+
* j0 is the current upper-left block of the matrix
|
51
|
+
*/
|
52
|
+
struct StoreResultHandler {
|
53
|
+
uint16_t *data;
|
54
|
+
size_t ld; // total number of columns
|
55
|
+
size_t i0 = 0;
|
56
|
+
size_t j0 = 0;
|
57
|
+
|
58
|
+
StoreResultHandler(uint16_t *data, size_t ld):
|
59
|
+
data(data), ld(ld) {
|
60
|
+
}
|
61
|
+
|
62
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
63
|
+
size_t ofs = (q + i0) * ld + j0 + b * 32;
|
64
|
+
d0.store(data + ofs);
|
65
|
+
d1.store(data + ofs + 16);
|
66
|
+
}
|
67
|
+
|
68
|
+
void set_block_origin(size_t i0, size_t j0) {
|
69
|
+
this->i0 = i0;
|
70
|
+
this->j0 = j0;
|
71
|
+
}
|
72
|
+
};
|
73
|
+
|
74
|
+
|
75
|
+
/** stores results in fixed-size matrix. */
|
76
|
+
template<int NQ, int BB>
|
77
|
+
struct FixedStorageHandler {
|
78
|
+
simd16uint16 dis[NQ][BB];
|
79
|
+
int i0 = 0;
|
80
|
+
|
81
|
+
void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
|
82
|
+
dis[q + i0][2 * b] = d0;
|
83
|
+
dis[q + i0][2 * b + 1] = d1;
|
84
|
+
}
|
85
|
+
|
86
|
+
void set_block_origin(size_t i0, size_t j0) {
|
87
|
+
this->i0 = i0;
|
88
|
+
assert(j0 == 0);
|
89
|
+
}
|
90
|
+
|
91
|
+
template<class OtherResultHandler>
|
92
|
+
void to_other_handler(OtherResultHandler & other) const {
|
93
|
+
for (int q = 0; q < NQ; q++) {
|
94
|
+
for(int b = 0; b < BB; b += 2) {
|
95
|
+
other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
|
96
|
+
}
|
97
|
+
}
|
98
|
+
}
|
99
|
+
|
100
|
+
};
|
101
|
+
|
102
|
+
|
103
|
+
/** Record origin of current block */
|
104
|
+
template<class C, bool with_id_map>
|
105
|
+
struct SIMDResultHandler {
|
106
|
+
using TI = typename C::TI;
|
107
|
+
|
108
|
+
bool disable = false;
|
109
|
+
|
110
|
+
int64_t i0 = 0; // query origin
|
111
|
+
int64_t j0 = 0; // db origin
|
112
|
+
size_t ntotal; // ignore excess elements after ntotal
|
113
|
+
|
114
|
+
/// these fields are used mainly for the IVF variants (with_id_map=true)
|
115
|
+
const TI *id_map; // map offset in invlist to vector id
|
116
|
+
const int *q_map; // map q to global query
|
117
|
+
const uint16_t *dbias; // table of biases to add to each query
|
118
|
+
|
119
|
+
explicit SIMDResultHandler(size_t ntotal):
|
120
|
+
ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
|
121
|
+
{}
|
122
|
+
|
123
|
+
void set_block_origin(size_t i0, size_t j0) {
|
124
|
+
this->i0 = i0;
|
125
|
+
this->j0 = j0;
|
126
|
+
}
|
127
|
+
|
128
|
+
|
129
|
+
// adjust handler data for IVF.
|
130
|
+
void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
|
131
|
+
{
|
132
|
+
q += i0;
|
133
|
+
|
134
|
+
if (dbias) {
|
135
|
+
simd16uint16 dbias16(dbias[q]);
|
136
|
+
d0 += dbias16;
|
137
|
+
d1 += dbias16;
|
138
|
+
}
|
139
|
+
|
140
|
+
if (with_id_map) { // FIXME test on q_map instead
|
141
|
+
q = q_map[q];
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
// compute and adjust idx
|
146
|
+
int64_t adjust_id(size_t b, size_t j) {
|
147
|
+
int64_t idx = j0 + 32 * b + j;
|
148
|
+
if (with_id_map) {
|
149
|
+
idx = id_map[idx];
|
150
|
+
}
|
151
|
+
return idx;
|
152
|
+
}
|
153
|
+
|
154
|
+
/// return binary mask of elements below thr in (d0, d1)
|
155
|
+
/// inverse_test returns elements above
|
156
|
+
uint32_t get_lt_mask(
|
157
|
+
uint16_t thr, size_t b,
|
158
|
+
simd16uint16 d0, simd16uint16 d1
|
159
|
+
) {
|
160
|
+
simd16uint16 thr16(thr);
|
161
|
+
uint32_t lt_mask;
|
162
|
+
|
163
|
+
constexpr bool keep_min = C::is_max;
|
164
|
+
if (keep_min) {
|
165
|
+
lt_mask = ~cmp_ge32(d0, d1, thr16);
|
166
|
+
} else {
|
167
|
+
lt_mask = ~cmp_le32(d0, d1, thr16);
|
168
|
+
}
|
169
|
+
|
170
|
+
if (lt_mask == 0) {
|
171
|
+
return 0;
|
172
|
+
}
|
173
|
+
uint64_t idx = j0 + b * 32;
|
174
|
+
if (idx + 32 > ntotal) {
|
175
|
+
if (idx >= ntotal) {
|
176
|
+
return 0;
|
177
|
+
}
|
178
|
+
int nbit = (ntotal - idx);
|
179
|
+
lt_mask &= (uint32_t(1) << nbit) - 1;
|
180
|
+
}
|
181
|
+
return lt_mask;
|
182
|
+
}
|
183
|
+
|
184
|
+
virtual void to_flat_arrays(
|
185
|
+
float *distances, int64_t *labels,
|
186
|
+
const float *normalizers = nullptr
|
187
|
+
) = 0;
|
188
|
+
|
189
|
+
virtual ~SIMDResultHandler() {}
|
190
|
+
|
191
|
+
};
|
192
|
+
|
193
|
+
|
194
|
+
/** Special version for k=1 */
|
195
|
+
template<class C, bool with_id_map = false>
|
196
|
+
struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
|
197
|
+
using T = typename C::T;
|
198
|
+
using TI = typename C::TI;
|
199
|
+
|
200
|
+
struct Result {
|
201
|
+
T val;
|
202
|
+
TI id;
|
203
|
+
};
|
204
|
+
std::vector<Result> results;
|
205
|
+
|
206
|
+
SingleResultHandler(size_t nq, size_t ntotal):
|
207
|
+
SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
|
208
|
+
{
|
209
|
+
for (int i = 0; i < nq; i++) {
|
210
|
+
Result res = {C::neutral(), -1};
|
211
|
+
results[i] = res;
|
212
|
+
}
|
213
|
+
}
|
214
|
+
|
215
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
216
|
+
if(this->disable) {
|
217
|
+
return;
|
218
|
+
}
|
219
|
+
|
220
|
+
this->adjust_with_origin(q, d0, d1);
|
221
|
+
|
222
|
+
Result & res = results[q];
|
223
|
+
uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
|
224
|
+
if (!lt_mask) {
|
225
|
+
return;
|
226
|
+
}
|
227
|
+
|
228
|
+
ALIGNED(32) uint16_t d32tab[32];
|
229
|
+
d0.store(d32tab);
|
230
|
+
d1.store(d32tab + 16);
|
231
|
+
|
232
|
+
while (lt_mask) {
|
233
|
+
// find first non-zero
|
234
|
+
int j = __builtin_ctz(lt_mask);
|
235
|
+
lt_mask -= 1 << j;
|
236
|
+
T dis = d32tab[j];
|
237
|
+
if (C::cmp(res.val, dis)) {
|
238
|
+
res.val = dis;
|
239
|
+
res.id = this->adjust_id(b, j);
|
240
|
+
}
|
241
|
+
}
|
242
|
+
}
|
243
|
+
|
244
|
+
void to_flat_arrays(
|
245
|
+
float *distances, int64_t *labels,
|
246
|
+
const float *normalizers = nullptr
|
247
|
+
) override {
|
248
|
+
for (int q = 0; q < results.size(); q++) {
|
249
|
+
if (!normalizers) {
|
250
|
+
distances[q] = results[q].val;
|
251
|
+
} else {
|
252
|
+
float one_a = 1 / normalizers[2 * q];
|
253
|
+
float b = normalizers[2 * q + 1];
|
254
|
+
distances[q] = b + results[q].val * one_a;
|
255
|
+
}
|
256
|
+
labels[q] = results[q].id;
|
257
|
+
}
|
258
|
+
}
|
259
|
+
|
260
|
+
};
|
261
|
+
|
262
|
+
/** Structure that collects results in a min- or max-heap */
|
263
|
+
template<class C, bool with_id_map = false>
|
264
|
+
struct HeapHandler: SIMDResultHandler<C, with_id_map> {
|
265
|
+
using T = typename C::T;
|
266
|
+
using TI = typename C::TI;
|
267
|
+
|
268
|
+
int nq;
|
269
|
+
T *heap_dis_tab;
|
270
|
+
TI *heap_ids_tab;
|
271
|
+
|
272
|
+
int64_t k; // number of results to keep
|
273
|
+
|
274
|
+
HeapHandler(
|
275
|
+
int nq,
|
276
|
+
T * heap_dis_tab, TI * heap_ids_tab,
|
277
|
+
size_t k, size_t ntotal
|
278
|
+
):
|
279
|
+
SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
|
280
|
+
heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
|
281
|
+
{
|
282
|
+
for (int q = 0; q < nq; q++) {
|
283
|
+
T *heap_dis_in = heap_dis_tab + q * k;
|
284
|
+
TI *heap_ids_in = heap_ids_tab + q * k;
|
285
|
+
heap_heapify<C> (k, heap_dis_in, heap_ids_in);
|
286
|
+
}
|
287
|
+
}
|
288
|
+
|
289
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
290
|
+
if(this->disable) {
|
291
|
+
return;
|
292
|
+
}
|
293
|
+
|
294
|
+
this->adjust_with_origin(q, d0, d1);
|
295
|
+
|
296
|
+
T *heap_dis = heap_dis_tab + q * k;
|
297
|
+
TI *heap_ids = heap_ids_tab + q * k;
|
298
|
+
|
299
|
+
uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
|
300
|
+
0xffff;
|
301
|
+
|
302
|
+
// here we handle the reverse comparison case as well
|
303
|
+
uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
|
304
|
+
|
305
|
+
if (!lt_mask) {
|
306
|
+
return;
|
307
|
+
}
|
308
|
+
|
309
|
+
ALIGNED(32) uint16_t d32tab[32] ;
|
310
|
+
d0.store(d32tab);
|
311
|
+
d1.store(d32tab + 16);
|
312
|
+
|
313
|
+
while (lt_mask) {
|
314
|
+
// find first non-zero
|
315
|
+
int j = __builtin_ctz(lt_mask);
|
316
|
+
lt_mask -= 1 << j;
|
317
|
+
T dis = d32tab[j];
|
318
|
+
if (C::cmp(heap_dis[0], dis)) {
|
319
|
+
int64_t idx = this->adjust_id(b, j);
|
320
|
+
heap_pop<C>(k, heap_dis, heap_ids);
|
321
|
+
heap_push<C>(k, heap_dis, heap_ids, dis, idx);
|
322
|
+
}
|
323
|
+
}
|
324
|
+
|
325
|
+
}
|
326
|
+
|
327
|
+
void to_flat_arrays(
|
328
|
+
float *distances, int64_t *labels,
|
329
|
+
const float *normalizers = nullptr
|
330
|
+
) override {
|
331
|
+
|
332
|
+
for (int q = 0; q < nq; q++) {
|
333
|
+
T *heap_dis_in = heap_dis_tab + q * k;
|
334
|
+
TI *heap_ids_in = heap_ids_tab + q * k;
|
335
|
+
heap_reorder<C> (k, heap_dis_in, heap_ids_in);
|
336
|
+
int64_t *heap_ids = labels + q * k;
|
337
|
+
float *heap_dis = distances + q * k;
|
338
|
+
|
339
|
+
float one_a = 1.0, b = 0.0;
|
340
|
+
if (normalizers) {
|
341
|
+
one_a = 1 / normalizers[2 * q];
|
342
|
+
b = normalizers[2 * q + 1];
|
343
|
+
}
|
344
|
+
for (int j = 0; j < k; j++) {
|
345
|
+
heap_ids[j] = heap_ids_in[j];
|
346
|
+
heap_dis[j] = heap_dis_in[j] * one_a + b;
|
347
|
+
}
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
};
|
352
|
+
|
353
|
+
|
354
|
+
/** Simple top-N implementation using a reservoir.
|
355
|
+
*
|
356
|
+
* Results are stored when they are below the threshold until the capacity is
|
357
|
+
* reached. Then a partition sort is used to update the threshold. */
|
358
|
+
|
359
|
+
namespace {
|
360
|
+
|
361
|
+
uint64_t get_cy () {
|
362
|
+
#ifdef MICRO_BENCHMARK
|
363
|
+
uint32_t high, low;
|
364
|
+
asm volatile("rdtsc \n\t"
|
365
|
+
: "=a" (low),
|
366
|
+
"=d" (high));
|
367
|
+
return ((uint64_t)high << 32) | (low);
|
368
|
+
#else
|
369
|
+
return 0;
|
370
|
+
#endif
|
371
|
+
}
|
372
|
+
|
373
|
+
} // anonymous namespace
|
374
|
+
|
375
|
+
template<class C>
|
376
|
+
struct ReservoirTopN {
|
377
|
+
using T = typename C::T;
|
378
|
+
using TI = typename C::TI;
|
379
|
+
|
380
|
+
T *vals;
|
381
|
+
TI *ids;
|
382
|
+
|
383
|
+
size_t i; // number of stored elements
|
384
|
+
size_t n; // number of requested elements
|
385
|
+
size_t capacity; // size of storage
|
386
|
+
size_t cycles = 0;
|
387
|
+
|
388
|
+
T threshold; // current threshold
|
389
|
+
|
390
|
+
ReservoirTopN(
|
391
|
+
size_t n, size_t capacity,
|
392
|
+
T *vals, TI *ids
|
393
|
+
):
|
394
|
+
vals(vals), ids(ids),
|
395
|
+
i(0), n(n), capacity(capacity) {
|
396
|
+
assert(n < capacity);
|
397
|
+
threshold = C::neutral();
|
398
|
+
}
|
399
|
+
|
400
|
+
void add(T val, TI id) {
|
401
|
+
if (C::cmp(threshold, val)) {
|
402
|
+
if (i == capacity) {
|
403
|
+
shrink_fuzzy();
|
404
|
+
}
|
405
|
+
vals[i] = val;
|
406
|
+
ids[i] = id;
|
407
|
+
i++;
|
408
|
+
}
|
409
|
+
}
|
410
|
+
|
411
|
+
/// shrink number of stored elements to n
|
412
|
+
void shrink_xx() {
|
413
|
+
uint64_t t0 = get_cy();
|
414
|
+
qselect (vals, ids, i, n);
|
415
|
+
i = n; // forget all elements above i = n
|
416
|
+
threshold = C::Crev::neutral();
|
417
|
+
for(size_t j = 0; j < n; j++) {
|
418
|
+
if(C::cmp(vals[j], threshold)) {
|
419
|
+
threshold = vals[j];
|
420
|
+
}
|
421
|
+
}
|
422
|
+
cycles += get_cy() - t0;
|
423
|
+
}
|
424
|
+
|
425
|
+
void shrink() {
|
426
|
+
uint64_t t0 = get_cy();
|
427
|
+
threshold = partition<C>(vals, ids, i, n);
|
428
|
+
i = n;
|
429
|
+
cycles += get_cy() - t0;
|
430
|
+
}
|
431
|
+
|
432
|
+
void shrink_fuzzy() {
|
433
|
+
uint64_t t0 = get_cy();
|
434
|
+
assert(i == capacity);
|
435
|
+
threshold = partition_fuzzy<C>(
|
436
|
+
vals, ids, capacity, n, (capacity + n) / 2,
|
437
|
+
&i);
|
438
|
+
cycles += get_cy() - t0;
|
439
|
+
}
|
440
|
+
};
|
441
|
+
|
442
|
+
|
443
|
+
/** Handler built from several ReservoirTopN (one per query) */
|
444
|
+
template<class C, bool with_id_map = false>
|
445
|
+
struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
|
446
|
+
using T = typename C::T;
|
447
|
+
using TI = typename C::TI;
|
448
|
+
|
449
|
+
size_t capacity; // rounded up to multiple of 16
|
450
|
+
std::vector<TI> all_ids;
|
451
|
+
AlignedTable<T> all_vals;
|
452
|
+
|
453
|
+
std::vector<ReservoirTopN<C>> reservoirs;
|
454
|
+
|
455
|
+
uint64_t times[4];
|
456
|
+
|
457
|
+
ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
|
458
|
+
SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
|
459
|
+
all_ids(nq * capacity), all_vals(nq * capacity)
|
460
|
+
{
|
461
|
+
assert(capacity % 16 == 0);
|
462
|
+
for (size_t i = 0; i < nq; i++) {
|
463
|
+
reservoirs.emplace_back(
|
464
|
+
n, capacity,
|
465
|
+
all_vals.get() + i * capacity,
|
466
|
+
all_ids.data() + i * capacity
|
467
|
+
);
|
468
|
+
}
|
469
|
+
times[0] = times[1] = times[2] = times[3] = 0;
|
470
|
+
}
|
471
|
+
|
472
|
+
|
473
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
|
474
|
+
uint64_t t0 = get_cy();
|
475
|
+
if(this->disable) {
|
476
|
+
return;
|
477
|
+
}
|
478
|
+
this->adjust_with_origin(q, d0, d1);
|
479
|
+
|
480
|
+
ReservoirTopN<C> & res = reservoirs[q];
|
481
|
+
uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
|
482
|
+
uint64_t t1 = get_cy();
|
483
|
+
times[0] += t1 - t0;
|
484
|
+
|
485
|
+
if (!lt_mask) {
|
486
|
+
return;
|
487
|
+
}
|
488
|
+
ALIGNED(32) uint16_t d32tab[32];
|
489
|
+
d0.store(d32tab);
|
490
|
+
d1.store(d32tab + 16);
|
491
|
+
|
492
|
+
while (lt_mask) {
|
493
|
+
// find first non-zero
|
494
|
+
int j = __builtin_ctz(lt_mask);
|
495
|
+
lt_mask -= 1 << j;
|
496
|
+
T dis = d32tab[j];
|
497
|
+
res.add(dis, this->adjust_id(b, j));
|
498
|
+
}
|
499
|
+
times[1] += get_cy() - t1;
|
500
|
+
}
|
501
|
+
|
502
|
+
|
503
|
+
void to_flat_arrays(
|
504
|
+
float *distances, int64_t *labels,
|
505
|
+
const float *normalizers = nullptr
|
506
|
+
) override {
|
507
|
+
using Cf = typename std::conditional<
|
508
|
+
C::is_max,
|
509
|
+
CMax<float, int64_t>, CMin<float, int64_t>>::type;
|
510
|
+
|
511
|
+
uint64_t t0 = get_cy();
|
512
|
+
uint64_t t3 = 0;
|
513
|
+
std::vector<int> perm(reservoirs[0].n);
|
514
|
+
for (int q = 0; q < reservoirs.size(); q++) {
|
515
|
+
ReservoirTopN<C> & res = reservoirs[q];
|
516
|
+
size_t n = res.n;
|
517
|
+
|
518
|
+
if (res.i > res.n) {
|
519
|
+
res.shrink();
|
520
|
+
}
|
521
|
+
int64_t *heap_ids = labels + q * n;
|
522
|
+
float *heap_dis = distances + q * n;
|
523
|
+
|
524
|
+
float one_a = 1.0, b = 0.0;
|
525
|
+
if (normalizers) {
|
526
|
+
one_a = 1 / normalizers[2 * q];
|
527
|
+
b = normalizers[2 * q + 1];
|
528
|
+
}
|
529
|
+
for (int i = 0; i < res.i; i++) {
|
530
|
+
perm[i] = i;
|
531
|
+
}
|
532
|
+
// indirect sort of result arrays
|
533
|
+
std::sort(
|
534
|
+
perm.begin(), perm.begin() + res.i,
|
535
|
+
[&res](int i, int j) {
|
536
|
+
return C::cmp(res.vals[j], res.vals[i]);
|
537
|
+
}
|
538
|
+
);
|
539
|
+
for (int i = 0; i < res.i; i++) {
|
540
|
+
heap_dis[i] = res.vals[perm[i]] * one_a + b;
|
541
|
+
heap_ids[i] = res.ids[perm[i]];
|
542
|
+
}
|
543
|
+
|
544
|
+
// possibly add empty results
|
545
|
+
heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
|
546
|
+
|
547
|
+
t3 += res.cycles;
|
548
|
+
}
|
549
|
+
times[2] += get_cy() - t0;
|
550
|
+
times[3] += t3;
|
551
|
+
}
|
552
|
+
|
553
|
+
};
|
554
|
+
|
555
|
+
|
556
|
+
} // namespace simd_result_handlers
|
557
|
+
|
558
|
+
|
559
|
+
} // namespace faiss
|