faiss 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +36 -33
- data/vendor/faiss/faiss/AutoTune.h +6 -3
- data/vendor/faiss/faiss/Clustering.cpp +16 -12
- data/vendor/faiss/faiss/Index.cpp +3 -4
- data/vendor/faiss/faiss/Index.h +3 -3
- data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
- data/vendor/faiss/faiss/IndexBinary.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
- data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
- data/vendor/faiss/faiss/IndexFlat.h +0 -51
- data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
- data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
- data/vendor/faiss/faiss/IndexIVF.h +22 -15
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
- data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
- data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
- data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
- data/vendor/faiss/faiss/IndexRefine.h +73 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
- data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
- data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
- data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
- data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
- data/vendor/faiss/faiss/impl/io.cpp +33 -2
- data/vendor/faiss/faiss/impl/io.h +7 -2
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
- data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
- data/vendor/faiss/faiss/index_factory.cpp +112 -7
- data/vendor/faiss/faiss/index_io.h +1 -48
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
- data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
- data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
- data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
- data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
- data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
- data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
- data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
- data/vendor/faiss/faiss/utils/Heap.h +61 -50
- data/vendor/faiss/faiss/utils/distances.cpp +164 -319
- data/vendor/faiss/faiss/utils/distances.h +28 -20
- data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
- data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
- data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
- data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
- data/vendor/faiss/faiss/utils/hamming.h +2 -7
- data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
- data/vendor/faiss/faiss/utils/partitioning.h +69 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
- data/vendor/faiss/faiss/utils/simdlib.h +31 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
- metadata +43 -141
- data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
- data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
- data/vendor/faiss/c_api/AutoTune_c.h +0 -66
- data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
- data/vendor/faiss/c_api/Clustering_c.h +0 -123
- data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
- data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
- data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
- data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
- data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
- data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
- data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
- data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
- data/vendor/faiss/c_api/IndexShards_c.h +0 -39
- data/vendor/faiss/c_api/Index_c.cpp +0 -105
- data/vendor/faiss/c_api/Index_c.h +0 -183
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
- data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
- data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
- data/vendor/faiss/c_api/clone_index_c.h +0 -32
- data/vendor/faiss/c_api/error_c.h +0 -42
- data/vendor/faiss/c_api/error_impl.cpp +0 -27
- data/vendor/faiss/c_api/error_impl.h +0 -16
- data/vendor/faiss/c_api/faiss_c.h +0 -58
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
- data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
- data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
- data/vendor/faiss/c_api/index_factory_c.h +0 -30
- data/vendor/faiss/c_api/index_io_c.cpp +0 -42
- data/vendor/faiss/c_api/index_io_c.h +0 -50
- data/vendor/faiss/c_api/macros_impl.h +0 -110
- data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
- data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
- data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
- data/vendor/faiss/misc/test_blas.cpp +0 -87
- data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
- data/vendor/faiss/tests/test_merge.cpp +0 -260
- data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
- data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
- data/vendor/faiss/tests/test_params_override.cpp +0 -236
- data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
- data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
- data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
- data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,166 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <memory>
|
11
|
+
|
12
|
+
#include <faiss/IndexIVFPQ.h>
|
13
|
+
#include <faiss/impl/ProductQuantizer.h>
|
14
|
+
#include <faiss/utils/AlignedTable.h>
|
15
|
+
|
16
|
+
namespace faiss {
|
17
|
+
|
18
|
+
|
19
|
+
/** Fast scan version of IVFPQ. Works for 4-bit PQ for now.
|
20
|
+
*
|
21
|
+
* The codes in the inverted lists are not stored sequentially but
|
22
|
+
* grouped in blocks of size bbs. This makes it possible to very quickly
|
23
|
+
* compute distances with SIMD instructions.
|
24
|
+
*
|
25
|
+
* Implementations (implem):
|
26
|
+
* 0: auto-select implementation (default)
|
27
|
+
* 1: orig's search, re-implemented
|
28
|
+
* 2: orig's search, re-ordered by invlist
|
29
|
+
* 10: optimizer int16 search, collect results in heap, no qbs
|
30
|
+
* 11: idem, collect results in reservoir
|
31
|
+
* 12: optimizer int16 search, collect results in heap, uses qbs
|
32
|
+
* 13: idem, collect results in reservoir
|
33
|
+
*/
|
34
|
+
|
35
|
+
struct IndexIVFPQFastScan: IndexIVF {
|
36
|
+
|
37
|
+
bool by_residual; ///< Encode residual or plain vector?
|
38
|
+
ProductQuantizer pq; ///< produces the codes
|
39
|
+
|
40
|
+
// size of the kernel
|
41
|
+
int bbs; // set at build time
|
42
|
+
|
43
|
+
// M rounded up to a multiple of 2
|
44
|
+
size_t M2;
|
45
|
+
|
46
|
+
/// precomputed tables management
|
47
|
+
int use_precomputed_table = 0;
|
48
|
+
/// if use_precompute_table size (nlist, pq.M, pq.ksub)
|
49
|
+
AlignedTable<float> precomputed_table;
|
50
|
+
|
51
|
+
// search-time implementation
|
52
|
+
int implem = 0;
|
53
|
+
// skip some parts of the computation (for timing)
|
54
|
+
int skip = 0;
|
55
|
+
|
56
|
+
// batching factors at search time (0 = default)
|
57
|
+
int qbs = 0;
|
58
|
+
size_t qbs2 = 0;
|
59
|
+
|
60
|
+
IndexIVFPQFastScan (
|
61
|
+
Index * quantizer, size_t d, size_t nlist,
|
62
|
+
size_t M, size_t nbits_per_idx,
|
63
|
+
MetricType metric = METRIC_L2, int bbs = 32);
|
64
|
+
|
65
|
+
IndexIVFPQFastScan ();
|
66
|
+
|
67
|
+
// built from an IndexIVFPQ
|
68
|
+
explicit IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs = 32);
|
69
|
+
|
70
|
+
/// orig's inverted lists (for debugging)
|
71
|
+
InvertedLists * orig_invlists = nullptr;
|
72
|
+
|
73
|
+
void train_residual (idx_t n, const float *x) override;
|
74
|
+
|
75
|
+
/// build precomputed table, possibly updating use_precomputed_table
|
76
|
+
void precompute_table ();
|
77
|
+
|
78
|
+
/// same as the regular IVFPQ encoder. The codes are not reorganized by
|
79
|
+
/// blocks a that point
|
80
|
+
void encode_vectors(
|
81
|
+
idx_t n, const float* x,
|
82
|
+
const idx_t *list_nos, uint8_t * codes,
|
83
|
+
bool include_listno = false) const override;
|
84
|
+
|
85
|
+
void add_with_ids (
|
86
|
+
idx_t n, const float * x, const idx_t *xids) override;
|
87
|
+
|
88
|
+
void search(
|
89
|
+
idx_t n, const float* x, idx_t k,
|
90
|
+
float* distances, idx_t* labels) const override;
|
91
|
+
|
92
|
+
// prepare look-up tables
|
93
|
+
|
94
|
+
void compute_LUT(
|
95
|
+
size_t n, const float *x,
|
96
|
+
const idx_t *coarse_ids, const float *coarse_dis,
|
97
|
+
AlignedTable<float> & dis_tables,
|
98
|
+
AlignedTable<float> & biases
|
99
|
+
) const;
|
100
|
+
|
101
|
+
void compute_LUT_uint8(
|
102
|
+
size_t n, const float *x,
|
103
|
+
const idx_t *coarse_ids, const float *coarse_dis,
|
104
|
+
AlignedTable<uint8_t> & dis_tables,
|
105
|
+
AlignedTable<uint16_t> & biases,
|
106
|
+
float * normalizers
|
107
|
+
) const;
|
108
|
+
|
109
|
+
// internal search funcs
|
110
|
+
|
111
|
+
template<bool is_max>
|
112
|
+
void search_dispatch_implem(
|
113
|
+
idx_t n, const float* x, idx_t k,
|
114
|
+
float* distances, idx_t* labels) const;
|
115
|
+
|
116
|
+
template<class C>
|
117
|
+
void search_implem_1(
|
118
|
+
idx_t n, const float* x, idx_t k,
|
119
|
+
float* distances, idx_t* labels) const;
|
120
|
+
|
121
|
+
template<class C>
|
122
|
+
void search_implem_2(
|
123
|
+
idx_t n, const float* x, idx_t k,
|
124
|
+
float* distances, idx_t* labels) const;
|
125
|
+
|
126
|
+
// implem 10 and 12 are not multithreaded internally, so
|
127
|
+
// export search stats
|
128
|
+
template<class C>
|
129
|
+
void search_implem_10(
|
130
|
+
idx_t n, const float* x, idx_t k,
|
131
|
+
float* distances, idx_t* labels,
|
132
|
+
int impl, size_t *ndis_out, size_t *nlist_out) const;
|
133
|
+
|
134
|
+
template<class C>
|
135
|
+
void search_implem_12(
|
136
|
+
idx_t n, const float* x, idx_t k,
|
137
|
+
float* distances, idx_t* labels,
|
138
|
+
int impl, size_t *ndis_out, size_t *nlist_out) const;
|
139
|
+
|
140
|
+
|
141
|
+
|
142
|
+
};
|
143
|
+
|
144
|
+
struct IVFFastScanStats {
|
145
|
+
uint64_t times[10];
|
146
|
+
uint64_t t_compute_distance_tables, t_round;
|
147
|
+
uint64_t t_copy_pack, t_scan, t_to_flat;
|
148
|
+
uint64_t reservoir_times[4];
|
149
|
+
|
150
|
+
double Mcy_at(int i) {
|
151
|
+
return times[i] / (1000*1000.0);
|
152
|
+
}
|
153
|
+
|
154
|
+
double Mcy_reservoir_at(int i) {
|
155
|
+
return reservoir_times[i] / (1000*1000.0);
|
156
|
+
}
|
157
|
+
IVFFastScanStats() {reset();}
|
158
|
+
void reset() {
|
159
|
+
memset(this, 0, sizeof(*this));
|
160
|
+
}
|
161
|
+
};
|
162
|
+
|
163
|
+
FAISS_API extern IVFFastScanStats IVFFastScan_stats;
|
164
|
+
|
165
|
+
|
166
|
+
} // namespace faiss
|
@@ -97,13 +97,13 @@ void IndexIVFPQR::add_core (idx_t n, const float *x, const idx_t *xids,
|
|
97
97
|
#define TOC get_cycles () - t0
|
98
98
|
|
99
99
|
|
100
|
-
void IndexIVFPQR::search_preassigned (
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
100
|
+
void IndexIVFPQR::search_preassigned (
|
101
|
+
idx_t n, const float *x, idx_t k,
|
102
|
+
const idx_t *idx, const float *L1_dis,
|
103
|
+
float *distances, idx_t *labels,
|
104
|
+
bool store_pairs,
|
105
|
+
const IVFSearchParameters *params, IndexIVFStats *stats
|
106
|
+
) const
|
107
107
|
{
|
108
108
|
uint64_t t0;
|
109
109
|
TIC;
|
@@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
172
172
|
float dis = fvec_L2sqr (residual_1, residual_2, d);
|
173
173
|
|
174
174
|
if (dis < heap_sim[0]) {
|
175
|
-
maxheap_pop (k, heap_sim, heap_ids);
|
176
175
|
idx_t id_or_pair = store_pairs ? sl : id;
|
177
|
-
|
176
|
+
maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair);
|
178
177
|
}
|
179
178
|
n_refine ++;
|
180
179
|
}
|
@@ -55,7 +55,8 @@ struct IndexIVFPQR: IndexIVFPQ {
|
|
55
55
|
const float *centroid_dis,
|
56
56
|
float *distances, idx_t *labels,
|
57
57
|
bool store_pairs,
|
58
|
-
const IVFSearchParameters *params=nullptr
|
58
|
+
const IVFSearchParameters *params=nullptr,
|
59
|
+
IndexIVFStats *stats=nullptr
|
59
60
|
) const override;
|
60
61
|
|
61
62
|
IndexIVFPQR();
|
@@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner {
|
|
269
269
|
float dis = hc.hamming (codes);
|
270
270
|
|
271
271
|
if (dis < simi [0]) {
|
272
|
-
maxheap_pop (k, simi, idxi);
|
273
272
|
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
274
|
-
|
273
|
+
maxheap_replace_top (k, simi, idxi, dis, id);
|
275
274
|
nup++;
|
276
275
|
}
|
277
276
|
codes += code_size;
|
@@ -129,9 +129,10 @@ void IndexPQ::reconstruct (idx_t key, float * recons) const
|
|
129
129
|
|
130
130
|
namespace {
|
131
131
|
|
132
|
-
|
133
|
-
struct
|
132
|
+
template<class PQDecoder>
|
133
|
+
struct PQDistanceComputer: DistanceComputer {
|
134
134
|
size_t d;
|
135
|
+
MetricType metric;
|
135
136
|
Index::idx_t nb;
|
136
137
|
const uint8_t *codes;
|
137
138
|
size_t code_size;
|
@@ -144,10 +145,11 @@ struct PQDis: DistanceComputer {
|
|
144
145
|
{
|
145
146
|
const uint8_t *code = codes + i * code_size;
|
146
147
|
const float *dt = precomputed_table.data();
|
148
|
+
PQDecoder decoder(code, pq.nbits);
|
147
149
|
float accu = 0;
|
148
150
|
for (int j = 0; j < pq.M; j++) {
|
149
|
-
accu += dt[
|
150
|
-
dt +=
|
151
|
+
accu += dt[decoder.decode()];
|
152
|
+
dt += 1 << decoder.nbits;
|
151
153
|
}
|
152
154
|
ndis++;
|
153
155
|
return accu;
|
@@ -155,33 +157,43 @@ struct PQDis: DistanceComputer {
|
|
155
157
|
|
156
158
|
float symmetric_dis(idx_t i, idx_t j) override
|
157
159
|
{
|
160
|
+
FAISS_THROW_IF_NOT(sdc);
|
158
161
|
const float * sdci = sdc;
|
159
162
|
float accu = 0;
|
160
|
-
|
161
|
-
|
163
|
+
PQDecoder codei (codes + i * code_size, pq.nbits);
|
164
|
+
PQDecoder codej (codes + j * code_size, pq.nbits);
|
162
165
|
|
163
166
|
for (int l = 0; l < pq.M; l++) {
|
164
|
-
accu += sdci[(
|
165
|
-
sdci +=
|
167
|
+
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
|
168
|
+
sdci += uint64_t(1) << (2 * codei.nbits);
|
166
169
|
}
|
170
|
+
ndis++;
|
167
171
|
return accu;
|
168
172
|
}
|
169
173
|
|
170
|
-
explicit
|
171
|
-
|
174
|
+
explicit PQDistanceComputer(const IndexPQ& storage)
|
175
|
+
: pq(storage.pq) {
|
172
176
|
precomputed_table.resize(pq.M * pq.ksub);
|
173
177
|
nb = storage.ntotal;
|
174
178
|
d = storage.d;
|
179
|
+
metric = storage.metric_type;
|
175
180
|
codes = storage.codes.data();
|
176
181
|
code_size = pq.code_size;
|
177
|
-
|
178
|
-
|
179
|
-
|
182
|
+
if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
|
183
|
+
sdc = pq.sdc_table.data();
|
184
|
+
} else {
|
185
|
+
sdc = nullptr;
|
186
|
+
}
|
180
187
|
ndis = 0;
|
181
188
|
}
|
182
189
|
|
183
190
|
void set_query(const float *x) override {
|
184
|
-
|
191
|
+
if (metric == METRIC_L2) {
|
192
|
+
pq.compute_distance_table(x, precomputed_table.data());
|
193
|
+
} else {
|
194
|
+
pq.compute_inner_prod_table(x, precomputed_table.data());
|
195
|
+
}
|
196
|
+
|
185
197
|
}
|
186
198
|
};
|
187
199
|
|
@@ -190,8 +202,13 @@ struct PQDis: DistanceComputer {
|
|
190
202
|
|
191
203
|
|
192
204
|
DistanceComputer * IndexPQ::get_distance_computer() const {
|
193
|
-
|
194
|
-
|
205
|
+
if (pq.nbits == 8) {
|
206
|
+
return new PQDistanceComputer<PQDecoder8>(*this);
|
207
|
+
} else if (pq.nbits == 16) {
|
208
|
+
return new PQDistanceComputer<PQDecoder16>(*this);
|
209
|
+
} else {
|
210
|
+
return new PQDistanceComputer<PQDecoderGeneric>(*this);
|
211
|
+
}
|
195
212
|
}
|
196
213
|
|
197
214
|
|
@@ -329,8 +346,7 @@ static size_t polysemous_inner_loop (
|
|
329
346
|
}
|
330
347
|
|
331
348
|
if (dis < heap_dis[0]) {
|
332
|
-
|
333
|
-
maxheap_push (k, heap_dis, heap_ids, dis, bi);
|
349
|
+
maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
|
334
350
|
}
|
335
351
|
}
|
336
352
|
b_code += code_size;
|
@@ -0,0 +1,536 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
|
9
|
+
#include <faiss/IndexPQFastScan.h>
|
10
|
+
|
11
|
+
#include <cassert>
|
12
|
+
#include <memory>
|
13
|
+
#include <limits.h>
|
14
|
+
|
15
|
+
#include <omp.h>
|
16
|
+
|
17
|
+
|
18
|
+
#include <faiss/impl/FaissAssert.h>
|
19
|
+
#include <faiss/utils/utils.h>
|
20
|
+
#include <faiss/utils/random.h>
|
21
|
+
|
22
|
+
#include <faiss/impl/simd_result_handlers.h>
|
23
|
+
#include <faiss/utils/quantize_lut.h>
|
24
|
+
#include <faiss/impl/pq4_fast_scan.h>
|
25
|
+
|
26
|
+
|
27
|
+
namespace faiss {
|
28
|
+
|
29
|
+
using namespace simd_result_handlers;
|
30
|
+
|
31
|
+
inline size_t roundup(size_t a, size_t b) {
|
32
|
+
return (a + b - 1) / b * b;
|
33
|
+
}
|
34
|
+
|
35
|
+
IndexPQFastScan::IndexPQFastScan(
|
36
|
+
int d, size_t M, size_t nbits,
|
37
|
+
MetricType metric,
|
38
|
+
int bbs):
|
39
|
+
Index(d, metric), pq(d, M, nbits),
|
40
|
+
bbs(bbs), ntotal2(0), M2(roundup(M, 2))
|
41
|
+
{
|
42
|
+
FAISS_THROW_IF_NOT(nbits == 4);
|
43
|
+
is_trained = false;
|
44
|
+
}
|
45
|
+
|
46
|
+
IndexPQFastScan::IndexPQFastScan():
|
47
|
+
bbs(0), ntotal2(0), M2(0)
|
48
|
+
{}
|
49
|
+
|
50
|
+
IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
|
51
|
+
Index(orig.d, orig.metric_type),
|
52
|
+
pq(orig.pq),
|
53
|
+
bbs(bbs)
|
54
|
+
{
|
55
|
+
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
|
56
|
+
ntotal = orig.ntotal;
|
57
|
+
is_trained = orig.is_trained;
|
58
|
+
orig_codes = orig.codes.data();
|
59
|
+
|
60
|
+
qbs = 0; // means use default
|
61
|
+
|
62
|
+
// pack the codes
|
63
|
+
|
64
|
+
size_t M = pq.M;
|
65
|
+
|
66
|
+
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
67
|
+
M2 = roundup(M, 2);
|
68
|
+
ntotal2 = roundup(ntotal, bbs);
|
69
|
+
|
70
|
+
codes.resize(ntotal2 * M2 / 2);
|
71
|
+
|
72
|
+
// printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
|
73
|
+
pq4_pack_codes(
|
74
|
+
orig.codes.data(),
|
75
|
+
ntotal, M,
|
76
|
+
ntotal2, bbs, M2,
|
77
|
+
codes.get()
|
78
|
+
);
|
79
|
+
}
|
80
|
+
|
81
|
+
void IndexPQFastScan::train (idx_t n, const float *x)
|
82
|
+
{
|
83
|
+
if (is_trained) {
|
84
|
+
return;
|
85
|
+
}
|
86
|
+
pq.train(n, x);
|
87
|
+
is_trained = true;
|
88
|
+
}
|
89
|
+
|
90
|
+
|
91
|
+
void IndexPQFastScan::add (idx_t n, const float *x) {
|
92
|
+
FAISS_THROW_IF_NOT (is_trained);
|
93
|
+
AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
|
94
|
+
pq.compute_codes (x, tmp_codes.get(), n);
|
95
|
+
ntotal2 = roundup(ntotal + n, bbs);
|
96
|
+
size_t new_size = ntotal2 * M2 / 2;
|
97
|
+
size_t old_size = codes.size();
|
98
|
+
if (new_size > old_size) {
|
99
|
+
codes.resize(new_size);
|
100
|
+
memset(codes.get() + old_size, 0, new_size - old_size);
|
101
|
+
}
|
102
|
+
pq4_pack_codes_range(
|
103
|
+
tmp_codes.get(), pq.M, ntotal, ntotal + n,
|
104
|
+
bbs, M2, codes.get()
|
105
|
+
);
|
106
|
+
ntotal += n;
|
107
|
+
}
|
108
|
+
|
109
|
+
void IndexPQFastScan::reset()
|
110
|
+
{
|
111
|
+
codes.resize(0);
|
112
|
+
ntotal = 0;
|
113
|
+
}
|
114
|
+
|
115
|
+
|
116
|
+
|
117
|
+
namespace {
|
118
|
+
|
119
|
+
// from impl/ProductQuantizer.cpp
|
120
|
+
template <class C, typename dis_t>
|
121
|
+
void pq_estimators_from_tables_generic(
|
122
|
+
const ProductQuantizer& pq, size_t nbits,
|
123
|
+
const uint8_t *codes, size_t ncodes,
|
124
|
+
const dis_t *dis_table, size_t k,
|
125
|
+
typename C::T *heap_dis, int64_t *heap_ids)
|
126
|
+
{
|
127
|
+
using accu_t = typename C::T;
|
128
|
+
const size_t M = pq.M;
|
129
|
+
const size_t ksub = pq.ksub;
|
130
|
+
for (size_t j = 0; j < ncodes; ++j) {
|
131
|
+
PQDecoderGeneric decoder(
|
132
|
+
codes + j * pq.code_size, nbits
|
133
|
+
);
|
134
|
+
accu_t dis = 0;
|
135
|
+
const dis_t * __restrict dt = dis_table;
|
136
|
+
for (size_t m = 0; m < M; m++) {
|
137
|
+
uint64_t c = decoder.decode();
|
138
|
+
dis += dt[c];
|
139
|
+
dt += ksub;
|
140
|
+
}
|
141
|
+
|
142
|
+
if (C::cmp(heap_dis[0], dis)) {
|
143
|
+
heap_pop<C>(k, heap_dis, heap_ids);
|
144
|
+
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
145
|
+
}
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
149
|
+
|
150
|
+
} // anonymous namespace
|
151
|
+
|
152
|
+
|
153
|
+
using namespace quantize_lut;
|
154
|
+
|
155
|
+
void IndexPQFastScan::compute_quantized_LUT(
|
156
|
+
idx_t n, const float* x,
|
157
|
+
uint8_t *lut, float *normalizers) const
|
158
|
+
{
|
159
|
+
size_t dim12 = pq.ksub * pq.M;
|
160
|
+
std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
|
161
|
+
if (metric_type == METRIC_L2) {
|
162
|
+
pq.compute_distance_tables (n, x, dis_tables.get());
|
163
|
+
} else {
|
164
|
+
pq.compute_inner_prod_tables (n, x, dis_tables.get());
|
165
|
+
}
|
166
|
+
|
167
|
+
for(uint64_t i = 0; i < n; i++) {
|
168
|
+
round_uint8_per_column(
|
169
|
+
dis_tables.get() + i * dim12, pq.M, pq.ksub,
|
170
|
+
&normalizers[2 * i], &normalizers[2 * i + 1]
|
171
|
+
);
|
172
|
+
}
|
173
|
+
|
174
|
+
for(uint64_t i = 0; i < n; i++) {
|
175
|
+
const float *t_in = dis_tables.get() + i * dim12;
|
176
|
+
uint8_t *t_out = lut + i * M2 * pq.ksub;
|
177
|
+
|
178
|
+
for(int j = 0; j < dim12; j++) {
|
179
|
+
t_out[j] = int(t_in[j]);
|
180
|
+
}
|
181
|
+
memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
|
182
|
+
}
|
183
|
+
}
|
184
|
+
|
185
|
+
|
186
|
+
|
187
|
+
/******************************************************************************
|
188
|
+
* Search driver routine
|
189
|
+
******************************************************************************/
|
190
|
+
|
191
|
+
|
192
|
+
void IndexPQFastScan::search(
|
193
|
+
idx_t n, const float* x, idx_t k,
|
194
|
+
float* distances, idx_t* labels) const
|
195
|
+
{
|
196
|
+
if (metric_type == METRIC_L2) {
|
197
|
+
search_dispatch_implem<true>(n, x, k, distances, labels);
|
198
|
+
} else {
|
199
|
+
search_dispatch_implem<false>(n, x, k, distances, labels);
|
200
|
+
}
|
201
|
+
}
|
202
|
+
|
203
|
+
|
204
|
+
template<bool is_max>
|
205
|
+
void IndexPQFastScan::search_dispatch_implem(
|
206
|
+
idx_t n,
|
207
|
+
const float* x,
|
208
|
+
idx_t k,
|
209
|
+
float* distances,
|
210
|
+
idx_t* labels) const
|
211
|
+
{
|
212
|
+
using Cfloat = typename std::conditional<is_max,
|
213
|
+
CMax<float, int64_t>, CMin<float, int64_t> >::type;
|
214
|
+
|
215
|
+
using C = typename std::conditional<is_max,
|
216
|
+
CMax<uint16_t, int>, CMin<uint16_t, int> >::type;
|
217
|
+
|
218
|
+
if (n == 0) {
|
219
|
+
return;
|
220
|
+
}
|
221
|
+
|
222
|
+
// actual implementation used
|
223
|
+
int impl = implem;
|
224
|
+
|
225
|
+
if (impl == 0) {
|
226
|
+
if (bbs == 32) {
|
227
|
+
impl = 12;
|
228
|
+
} else {
|
229
|
+
impl = 14;
|
230
|
+
}
|
231
|
+
if (k > 20) {
|
232
|
+
impl ++;
|
233
|
+
}
|
234
|
+
}
|
235
|
+
|
236
|
+
|
237
|
+
if (implem == 1) {
|
238
|
+
FAISS_THROW_IF_NOT(orig_codes);
|
239
|
+
FAISS_THROW_IF_NOT(is_max);
|
240
|
+
float_maxheap_array_t res = {
|
241
|
+
size_t(n), size_t(k), labels, distances };
|
242
|
+
pq.search (x, n, orig_codes, ntotal, &res, true);
|
243
|
+
} else if (implem == 2 || implem == 3 || implem == 4) {
|
244
|
+
FAISS_THROW_IF_NOT(orig_codes);
|
245
|
+
|
246
|
+
size_t dim12 = pq.ksub * pq.M;
|
247
|
+
std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
|
248
|
+
if (is_max) {
|
249
|
+
pq.compute_distance_tables (n, x, dis_tables.get());
|
250
|
+
} else {
|
251
|
+
pq.compute_inner_prod_tables (n, x, dis_tables.get());
|
252
|
+
}
|
253
|
+
|
254
|
+
std::vector<float> normalizers(n * 2);
|
255
|
+
|
256
|
+
if (implem == 2) {
|
257
|
+
// default float
|
258
|
+
} else if (implem == 3 || implem == 4) {
|
259
|
+
for(uint64_t i = 0; i < n; i++) {
|
260
|
+
round_uint8_per_column(
|
261
|
+
dis_tables.get() + i * dim12, pq.M,
|
262
|
+
pq.ksub,
|
263
|
+
&normalizers[2 * i], &normalizers[2 * i + 1]
|
264
|
+
);
|
265
|
+
}
|
266
|
+
}
|
267
|
+
|
268
|
+
for (int64_t i = 0; i < n; i++) {
|
269
|
+
int64_t *heap_ids = labels + i * k;
|
270
|
+
float *heap_dis = distances + i * k;
|
271
|
+
|
272
|
+
heap_heapify<Cfloat> (k, heap_dis, heap_ids);
|
273
|
+
|
274
|
+
pq_estimators_from_tables_generic<Cfloat>(
|
275
|
+
pq, pq.nbits, orig_codes, ntotal,
|
276
|
+
dis_tables.get() + i * dim12,
|
277
|
+
k, heap_dis, heap_ids
|
278
|
+
);
|
279
|
+
|
280
|
+
heap_reorder<Cfloat> (k, heap_dis, heap_ids);
|
281
|
+
|
282
|
+
if (implem == 4) {
|
283
|
+
float a = normalizers[2 * i];
|
284
|
+
float b = normalizers[2 * i + 1];
|
285
|
+
|
286
|
+
for(int j = 0; j < k; j++) {
|
287
|
+
heap_dis[j] = heap_dis[j] / a + b;
|
288
|
+
}
|
289
|
+
}
|
290
|
+
}
|
291
|
+
} else if (impl >= 12 && impl <= 15) {
|
292
|
+
FAISS_THROW_IF_NOT(ntotal < INT_MAX);
|
293
|
+
int nt = std::min(omp_get_max_threads(), int(n));
|
294
|
+
if (nt < 2) {
|
295
|
+
if (impl == 12 || impl == 13) {
|
296
|
+
search_implem_12<C>(n, x, k, distances, labels, impl);
|
297
|
+
} else {
|
298
|
+
search_implem_14<C>(n, x, k, distances, labels, impl);
|
299
|
+
}
|
300
|
+
} else {
|
301
|
+
// explicitly slice over threads
|
302
|
+
#pragma omp parallel for num_threads(nt)
|
303
|
+
for (int slice = 0; slice < nt; slice++) {
|
304
|
+
idx_t i0 = n * slice / nt;
|
305
|
+
idx_t i1 = n * (slice + 1) / nt;
|
306
|
+
float *dis_i = distances + i0 * k;
|
307
|
+
idx_t *lab_i = labels + i0 * k;
|
308
|
+
if (impl == 12 || impl == 13) {
|
309
|
+
search_implem_12<C>(
|
310
|
+
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
|
311
|
+
} else {
|
312
|
+
search_implem_14<C>(
|
313
|
+
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
|
314
|
+
}
|
315
|
+
}
|
316
|
+
}
|
317
|
+
} else {
|
318
|
+
FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
|
319
|
+
}
|
320
|
+
|
321
|
+
}
|
322
|
+
|
323
|
+
template<class C>
|
324
|
+
void IndexPQFastScan::search_implem_12(
|
325
|
+
idx_t n, const float* x, idx_t k,
|
326
|
+
float* distances, idx_t* labels,
|
327
|
+
int impl) const
|
328
|
+
{
|
329
|
+
|
330
|
+
FAISS_THROW_IF_NOT(bbs == 32);
|
331
|
+
|
332
|
+
// handle qbs2 blocking by recursive call
|
333
|
+
int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
|
334
|
+
if (n > qbs2) {
|
335
|
+
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
|
336
|
+
int64_t i1 = std::min(i0 + qbs2, n);
|
337
|
+
search_implem_12<C>(
|
338
|
+
i1 - i0, x + d * i0, k,
|
339
|
+
distances + i0 * k, labels + i0 * k, impl
|
340
|
+
);
|
341
|
+
}
|
342
|
+
return;
|
343
|
+
}
|
344
|
+
|
345
|
+
size_t dim12 = pq.ksub * M2;
|
346
|
+
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
|
347
|
+
std::unique_ptr<float []> normalizers(new float[2 * n]);
|
348
|
+
|
349
|
+
if (skip & 1) {
|
350
|
+
quantized_dis_tables.clear();
|
351
|
+
} else {
|
352
|
+
compute_quantized_LUT(
|
353
|
+
n, x, quantized_dis_tables.get(), normalizers.get()
|
354
|
+
);
|
355
|
+
}
|
356
|
+
|
357
|
+
AlignedTable<uint8_t> LUT(n * dim12);
|
358
|
+
|
359
|
+
// block sizes are encoded in qbs, 4 bits at a time
|
360
|
+
|
361
|
+
// caution: we override an object field
|
362
|
+
int qbs = this->qbs;
|
363
|
+
|
364
|
+
if (n != pq4_qbs_to_nq(qbs)) {
|
365
|
+
qbs = pq4_preferred_qbs(n);
|
366
|
+
}
|
367
|
+
|
368
|
+
int LUT_nq = pq4_pack_LUT_qbs(
|
369
|
+
qbs, M2, quantized_dis_tables.get(), LUT.get()
|
370
|
+
);
|
371
|
+
FAISS_THROW_IF_NOT(LUT_nq == n);
|
372
|
+
|
373
|
+
if (k == 1) {
|
374
|
+
SingleResultHandler<C> handler(n, ntotal);
|
375
|
+
if (skip & 4) {
|
376
|
+
// pass
|
377
|
+
} else {
|
378
|
+
handler.disable = bool(skip & 2);
|
379
|
+
pq4_accumulate_loop_qbs(
|
380
|
+
qbs, ntotal2, M2,
|
381
|
+
codes.get(), LUT.get(),
|
382
|
+
handler
|
383
|
+
);
|
384
|
+
}
|
385
|
+
|
386
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
387
|
+
|
388
|
+
} else if (impl == 12) {
|
389
|
+
|
390
|
+
std::vector<uint16_t> tmp_dis(n * k);
|
391
|
+
std::vector<int32_t> tmp_ids(n * k);
|
392
|
+
|
393
|
+
if (skip & 4) {
|
394
|
+
// skip
|
395
|
+
} else {
|
396
|
+
HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
397
|
+
handler.disable = bool(skip & 2);
|
398
|
+
|
399
|
+
pq4_accumulate_loop_qbs(
|
400
|
+
qbs, ntotal2, M2,
|
401
|
+
codes.get(), LUT.get(),
|
402
|
+
handler
|
403
|
+
);
|
404
|
+
|
405
|
+
if (!(skip & 8)) {
|
406
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
407
|
+
}
|
408
|
+
}
|
409
|
+
|
410
|
+
|
411
|
+
} else { // impl == 13
|
412
|
+
|
413
|
+
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
414
|
+
handler.disable = bool(skip & 2);
|
415
|
+
|
416
|
+
if (skip & 4) {
|
417
|
+
// skip
|
418
|
+
} else {
|
419
|
+
pq4_accumulate_loop_qbs(
|
420
|
+
qbs, ntotal2, M2,
|
421
|
+
codes.get(), LUT.get(),
|
422
|
+
handler
|
423
|
+
);
|
424
|
+
}
|
425
|
+
|
426
|
+
if (!(skip & 8)) {
|
427
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
428
|
+
}
|
429
|
+
|
430
|
+
FastScan_stats.t0 += handler.times[0];
|
431
|
+
FastScan_stats.t1 += handler.times[1];
|
432
|
+
FastScan_stats.t2 += handler.times[2];
|
433
|
+
FastScan_stats.t3 += handler.times[3];
|
434
|
+
|
435
|
+
}
|
436
|
+
}
|
437
|
+
|
438
|
+
FastScanStats FastScan_stats;
|
439
|
+
|
440
|
+
template<class C>
|
441
|
+
void IndexPQFastScan::search_implem_14(
|
442
|
+
idx_t n, const float* x, idx_t k,
|
443
|
+
float* distances, idx_t* labels, int impl) const
|
444
|
+
{
|
445
|
+
|
446
|
+
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
447
|
+
|
448
|
+
int qbs2 = qbs == 0 ? 4 : qbs;
|
449
|
+
|
450
|
+
// handle qbs2 blocking by recursive call
|
451
|
+
if (n > qbs2) {
|
452
|
+
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
|
453
|
+
int64_t i1 = std::min(i0 + qbs2, n);
|
454
|
+
search_implem_14<C>(
|
455
|
+
i1 - i0, x + d * i0, k,
|
456
|
+
distances + i0 * k, labels + i0 * k, impl
|
457
|
+
);
|
458
|
+
}
|
459
|
+
return;
|
460
|
+
}
|
461
|
+
|
462
|
+
size_t dim12 = pq.ksub * M2;
|
463
|
+
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
|
464
|
+
std::unique_ptr<float []> normalizers(new float[2 * n]);
|
465
|
+
|
466
|
+
if (skip & 1) {
|
467
|
+
quantized_dis_tables.clear();
|
468
|
+
} else {
|
469
|
+
compute_quantized_LUT(
|
470
|
+
n, x, quantized_dis_tables.get(), normalizers.get()
|
471
|
+
);
|
472
|
+
}
|
473
|
+
|
474
|
+
AlignedTable<uint8_t> LUT(n * dim12);
|
475
|
+
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
|
476
|
+
|
477
|
+
if (k == 1) {
|
478
|
+
SingleResultHandler<C> handler(n, ntotal);
|
479
|
+
if (skip & 4) {
|
480
|
+
// pass
|
481
|
+
} else {
|
482
|
+
handler.disable = bool(skip & 2);
|
483
|
+
pq4_accumulate_loop (
|
484
|
+
n, ntotal2, bbs, M2,
|
485
|
+
codes.get(), LUT.get(),
|
486
|
+
handler
|
487
|
+
);
|
488
|
+
}
|
489
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
490
|
+
|
491
|
+
} else if (impl == 14) {
|
492
|
+
|
493
|
+
std::vector<uint16_t> tmp_dis(n * k);
|
494
|
+
std::vector<int32_t> tmp_ids(n * k);
|
495
|
+
|
496
|
+
if (skip & 4) {
|
497
|
+
// skip
|
498
|
+
} else if (k > 1) {
|
499
|
+
HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
500
|
+
handler.disable = bool(skip & 2);
|
501
|
+
|
502
|
+
pq4_accumulate_loop (
|
503
|
+
n, ntotal2, bbs, M2,
|
504
|
+
codes.get(), LUT.get(),
|
505
|
+
handler
|
506
|
+
);
|
507
|
+
|
508
|
+
if (!(skip & 8)) {
|
509
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
510
|
+
}
|
511
|
+
}
|
512
|
+
|
513
|
+
|
514
|
+
} else { // impl == 15
|
515
|
+
|
516
|
+
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
517
|
+
handler.disable = bool(skip & 2);
|
518
|
+
|
519
|
+
if (skip & 4) {
|
520
|
+
// skip
|
521
|
+
} else {
|
522
|
+
pq4_accumulate_loop (
|
523
|
+
n, ntotal2, bbs, M2,
|
524
|
+
codes.get(), LUT.get(),
|
525
|
+
handler
|
526
|
+
);
|
527
|
+
}
|
528
|
+
|
529
|
+
if (!(skip & 8)) {
|
530
|
+
handler.to_flat_arrays(distances, labels, normalizers.get());
|
531
|
+
}
|
532
|
+
}
|
533
|
+
}
|
534
|
+
|
535
|
+
|
536
|
+
} // namespace faiss
|