faiss 0.2.4 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
@@ -278,13 +278,15 @@ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
|
|
278
278
|
****************************************************/
|
279
279
|
|
280
280
|
// for imbalance factor
|
281
|
-
double tot = 0.0
|
281
|
+
double tot = 0.0;
|
282
|
+
double uf = 0.0;
|
282
283
|
|
283
284
|
idx_t end = n;
|
284
285
|
for (idx_t k = nclusters - 1; k >= 0; k--) {
|
285
|
-
idx_t start = T.at(k, end - 1);
|
286
|
-
float sum =
|
287
|
-
|
286
|
+
const idx_t start = T.at(k, end - 1);
|
287
|
+
const float sum =
|
288
|
+
std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
|
289
|
+
const idx_t size = end - start;
|
288
290
|
FAISS_THROW_IF_NOT_FMT(
|
289
291
|
size > 0, "Cluster %d: size %d", int(k), int(size));
|
290
292
|
centroids[k] = sum / size;
|
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
|
|
122
122
|
}
|
123
123
|
}
|
124
124
|
|
125
|
+
namespace {
|
126
|
+
|
127
|
+
// get the specific address of the vector inside a block
|
128
|
+
// shift is used for determine the if the saved in bits 0..3 (false) or
|
129
|
+
// bits 4..7 (true)
|
130
|
+
uint8_t get_vector_specific_address(
|
131
|
+
size_t bbs,
|
132
|
+
size_t vector_id,
|
133
|
+
size_t sq,
|
134
|
+
bool& shift) {
|
135
|
+
// get the vector_id inside the block
|
136
|
+
vector_id = vector_id % bbs;
|
137
|
+
shift = vector_id > 15;
|
138
|
+
vector_id = vector_id & 15;
|
139
|
+
|
140
|
+
// get the address of the vector in sq
|
141
|
+
size_t address;
|
142
|
+
if (vector_id < 8) {
|
143
|
+
address = vector_id << 1;
|
144
|
+
} else {
|
145
|
+
address = ((vector_id - 8) << 1) + 1;
|
146
|
+
}
|
147
|
+
if (sq & 1) {
|
148
|
+
address += 16;
|
149
|
+
}
|
150
|
+
return (sq >> 1) * bbs + address;
|
151
|
+
}
|
152
|
+
|
153
|
+
} // anonymous namespace
|
154
|
+
|
125
155
|
uint8_t pq4_get_packed_element(
|
126
156
|
const uint8_t* data,
|
127
157
|
size_t bbs,
|
128
158
|
size_t nsq,
|
129
|
-
size_t
|
159
|
+
size_t vector_id,
|
130
160
|
size_t sq) {
|
131
161
|
// move to correct bbs-sized block
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
if (sq == 1) {
|
141
|
-
data += 16;
|
162
|
+
// number of blocks * block size
|
163
|
+
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
164
|
+
bool shift;
|
165
|
+
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
166
|
+
if (shift) {
|
167
|
+
return data[address] >> 4;
|
168
|
+
} else {
|
169
|
+
return data[address] & 15;
|
142
170
|
}
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
171
|
+
}
|
172
|
+
|
173
|
+
void pq4_set_packed_element(
|
174
|
+
uint8_t* data,
|
175
|
+
uint8_t code,
|
176
|
+
size_t bbs,
|
177
|
+
size_t nsq,
|
178
|
+
size_t vector_id,
|
179
|
+
size_t sq) {
|
180
|
+
// move to correct bbs-sized block
|
181
|
+
// number of blocks * block size
|
182
|
+
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
183
|
+
bool shift;
|
184
|
+
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
185
|
+
if (shift) {
|
186
|
+
data[address] = (code << 4) | (data[address] & 15);
|
147
187
|
} else {
|
148
|
-
|
188
|
+
data[address] = code | (data[address] & ~15);
|
149
189
|
}
|
150
190
|
}
|
151
191
|
|
@@ -26,7 +26,7 @@ namespace faiss {
|
|
26
26
|
* The unused bytes are set to 0.
|
27
27
|
*
|
28
28
|
* @param codes input codes, size (ntotal, ceil(M / 2))
|
29
|
-
* @param
|
29
|
+
* @param ntotal number of input codes
|
30
30
|
* @param nb output number of codes (ntotal rounded up to a multiple of
|
31
31
|
* bbs)
|
32
32
|
* @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
|
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
|
|
61
61
|
|
62
62
|
/** get a single element from a packed codes table
|
63
63
|
*
|
64
|
-
* @param
|
64
|
+
* @param vector_id vector id
|
65
65
|
* @param sq subquantizer (< nsq)
|
66
66
|
*/
|
67
67
|
uint8_t pq4_get_packed_element(
|
68
68
|
const uint8_t* data,
|
69
69
|
size_t bbs,
|
70
70
|
size_t nsq,
|
71
|
-
size_t
|
71
|
+
size_t vector_id,
|
72
|
+
size_t sq);
|
73
|
+
|
74
|
+
/** set a single element "code" into a packed codes table
|
75
|
+
*
|
76
|
+
* @param vector_id vector id
|
77
|
+
* @param sq subquantizer (< nsq)
|
78
|
+
*/
|
79
|
+
void pq4_set_packed_element(
|
80
|
+
uint8_t* data,
|
81
|
+
uint8_t code,
|
82
|
+
size_t bbs,
|
83
|
+
size_t nsq,
|
84
|
+
size_t vector_id,
|
72
85
|
size_t sq);
|
73
86
|
|
74
87
|
/** Pack Look-up table for consumption by the kernel.
|
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
|
|
88
101
|
* @param nsq number of sub-quantizers (muliple of 2)
|
89
102
|
* @param codes packed codes array
|
90
103
|
* @param LUT packed look-up table
|
104
|
+
* @param scaler scaler to scale the encoded norm
|
91
105
|
*/
|
92
|
-
template <class ResultHandler>
|
106
|
+
template <class ResultHandler, class Scaler>
|
93
107
|
void pq4_accumulate_loop(
|
94
108
|
int nq,
|
95
109
|
size_t nb,
|
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
|
|
97
111
|
int nsq,
|
98
112
|
const uint8_t* codes,
|
99
113
|
const uint8_t* LUT,
|
100
|
-
ResultHandler& res
|
114
|
+
ResultHandler& res,
|
115
|
+
const Scaler& scaler);
|
101
116
|
|
102
117
|
/* qbs versions, supported only for bbs=32.
|
103
118
|
*
|
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
|
|
141
156
|
|
142
157
|
/** Run accumulation loop.
|
143
158
|
*
|
144
|
-
* @param qbs 4-bit
|
159
|
+
* @param qbs 4-bit encoded number of queries
|
145
160
|
* @param nb number of database codes (mutliple of bbs)
|
146
161
|
* @param nsq number of sub-quantizers
|
147
162
|
* @param codes encoded database vectors (packed)
|
148
163
|
* @param LUT look-up table (packed)
|
149
164
|
* @param res call-back for the resutls
|
165
|
+
* @param scaler scaler to scale the encoded norm
|
150
166
|
*/
|
151
|
-
template <class ResultHandler>
|
167
|
+
template <class ResultHandler, class Scaler>
|
152
168
|
void pq4_accumulate_loop_qbs(
|
153
169
|
int qbs,
|
154
170
|
size_t nb,
|
155
171
|
int nsq,
|
156
172
|
const uint8_t* codes,
|
157
173
|
const uint8_t* LUT,
|
158
|
-
ResultHandler& res
|
174
|
+
ResultHandler& res,
|
175
|
+
const Scaler& scaler);
|
159
176
|
|
160
177
|
} // namespace faiss
|
@@ -8,6 +8,7 @@
|
|
8
8
|
#include <faiss/impl/pq4_fast_scan.h>
|
9
9
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
11
|
+
#include <faiss/impl/LookupTableScaler.h>
|
11
12
|
#include <faiss/impl/simd_result_handlers.h>
|
12
13
|
|
13
14
|
namespace faiss {
|
@@ -26,12 +27,13 @@ namespace {
|
|
26
27
|
* writes results in a ResultHandler
|
27
28
|
*/
|
28
29
|
|
29
|
-
template <int NQ, int BB, class ResultHandler>
|
30
|
+
template <int NQ, int BB, class ResultHandler, class Scaler>
|
30
31
|
void kernel_accumulate_block(
|
31
32
|
int nsq,
|
32
33
|
const uint8_t* codes,
|
33
34
|
const uint8_t* LUT,
|
34
|
-
ResultHandler& res
|
35
|
+
ResultHandler& res,
|
36
|
+
const Scaler& scaler) {
|
35
37
|
// distance accumulators
|
36
38
|
simd16uint16 accu[NQ][BB][4];
|
37
39
|
|
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
|
|
44
46
|
}
|
45
47
|
}
|
46
48
|
|
47
|
-
for (int sq = 0; sq < nsq; sq += 2) {
|
49
|
+
for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
|
48
50
|
simd32uint8 lut_cache[NQ];
|
49
51
|
for (int q = 0; q < NQ; q++) {
|
50
52
|
lut_cache[q] = simd32uint8(LUT);
|
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
|
|
72
74
|
}
|
73
75
|
}
|
74
76
|
|
77
|
+
for (int sq = 0; sq < scaler.nscale; sq += 2) {
|
78
|
+
simd32uint8 lut_cache[NQ];
|
79
|
+
for (int q = 0; q < NQ; q++) {
|
80
|
+
lut_cache[q] = simd32uint8(LUT);
|
81
|
+
LUT += 32;
|
82
|
+
}
|
83
|
+
|
84
|
+
for (int b = 0; b < BB; b++) {
|
85
|
+
simd32uint8 c = simd32uint8(codes);
|
86
|
+
codes += 32;
|
87
|
+
simd32uint8 mask(15);
|
88
|
+
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
|
89
|
+
simd32uint8 clo = c & mask;
|
90
|
+
|
91
|
+
for (int q = 0; q < NQ; q++) {
|
92
|
+
simd32uint8 lut = lut_cache[q];
|
93
|
+
|
94
|
+
simd32uint8 res0 = scaler.lookup(lut, clo);
|
95
|
+
accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
|
96
|
+
accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
|
97
|
+
|
98
|
+
simd32uint8 res1 = scaler.lookup(lut, chi);
|
99
|
+
accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
|
100
|
+
accu[q][b][3] +=
|
101
|
+
scaler.scale_hi(res1); // handle vectors 24..31
|
102
|
+
}
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
75
106
|
for (int q = 0; q < NQ; q++) {
|
76
107
|
for (int b = 0; b < BB; b++) {
|
77
108
|
accu[q][b][0] -= accu[q][b][1] << 8;
|
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
|
|
85
116
|
}
|
86
117
|
}
|
87
118
|
|
88
|
-
template <int NQ, int BB, class ResultHandler>
|
119
|
+
template <int NQ, int BB, class ResultHandler, class Scaler>
|
89
120
|
void accumulate_fixed_blocks(
|
90
121
|
size_t nb,
|
91
122
|
int nsq,
|
92
123
|
const uint8_t* codes,
|
93
124
|
const uint8_t* LUT,
|
94
|
-
ResultHandler& res
|
125
|
+
ResultHandler& res,
|
126
|
+
const Scaler& scaler) {
|
95
127
|
constexpr int bbs = 32 * BB;
|
96
128
|
for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
|
97
129
|
FixedStorageHandler<NQ, 2 * BB> res2;
|
98
|
-
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
|
130
|
+
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
|
99
131
|
res.set_block_origin(0, j0);
|
100
132
|
res2.to_other_handler(res);
|
101
133
|
codes += bbs * nsq / 2;
|
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
|
|
104
136
|
|
105
137
|
} // anonymous namespace
|
106
138
|
|
107
|
-
template <class ResultHandler>
|
139
|
+
template <class ResultHandler, class Scaler>
|
108
140
|
void pq4_accumulate_loop(
|
109
141
|
int nq,
|
110
142
|
size_t nb,
|
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
|
|
112
144
|
int nsq,
|
113
145
|
const uint8_t* codes,
|
114
146
|
const uint8_t* LUT,
|
115
|
-
ResultHandler& res
|
147
|
+
ResultHandler& res,
|
148
|
+
const Scaler& scaler) {
|
116
149
|
FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
|
117
150
|
FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
|
118
151
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
119
152
|
FAISS_THROW_IF_NOT(nb % bbs == 0);
|
120
153
|
|
121
|
-
#define DISPATCH(NQ, BB)
|
122
|
-
case NQ * 1000 + BB:
|
123
|
-
accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
|
154
|
+
#define DISPATCH(NQ, BB) \
|
155
|
+
case NQ * 1000 + BB: \
|
156
|
+
accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
|
124
157
|
break
|
125
158
|
|
126
159
|
switch (nq * 1000 + bbs / 32) {
|
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
|
|
141
174
|
|
142
175
|
// explicit template instantiations
|
143
176
|
|
144
|
-
#define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
|
145
|
-
template void pq4_accumulate_loop<TH<C, with_id_map
|
146
|
-
int,
|
147
|
-
size_t,
|
148
|
-
int,
|
149
|
-
int,
|
150
|
-
const uint8_t*,
|
151
|
-
const uint8_t*,
|
152
|
-
TH<C, with_id_map
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
177
|
+
#define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
|
178
|
+
template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
|
179
|
+
int, \
|
180
|
+
size_t, \
|
181
|
+
int, \
|
182
|
+
int, \
|
183
|
+
const uint8_t*, \
|
184
|
+
const uint8_t*, \
|
185
|
+
TH<C, with_id_map>&, \
|
186
|
+
const S&);
|
187
|
+
|
188
|
+
using DS = DummyScaler;
|
189
|
+
using NS = NormTableScaler;
|
190
|
+
|
191
|
+
#define INSTANTIATE_3(C, with_id_map) \
|
192
|
+
INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
|
193
|
+
INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
|
194
|
+
INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
|
195
|
+
\
|
196
|
+
INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
|
197
|
+
INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
|
198
|
+
INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
|
158
199
|
|
159
200
|
using Csi = CMax<uint16_t, int>;
|
160
201
|
INSTANTIATE_3(Csi, false);
|
@@ -8,6 +8,7 @@
|
|
8
8
|
#include <faiss/impl/pq4_fast_scan.h>
|
9
9
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
11
|
+
#include <faiss/impl/LookupTableScaler.h>
|
11
12
|
#include <faiss/impl/simd_result_handlers.h>
|
12
13
|
#include <faiss/utils/simdlib.h>
|
13
14
|
|
@@ -27,15 +28,17 @@ namespace {
|
|
27
28
|
* writes results in a ResultHandler
|
28
29
|
*/
|
29
30
|
|
30
|
-
template <int NQ, class ResultHandler>
|
31
|
+
template <int NQ, class ResultHandler, class Scaler>
|
31
32
|
void kernel_accumulate_block(
|
32
33
|
int nsq,
|
33
34
|
const uint8_t* codes,
|
34
35
|
const uint8_t* LUT,
|
35
|
-
ResultHandler& res
|
36
|
+
ResultHandler& res,
|
37
|
+
const Scaler& scaler) {
|
36
38
|
// dummy alloc to keep the windows compiler happy
|
37
39
|
constexpr int NQA = NQ > 0 ? NQ : 1;
|
38
40
|
// distance accumulators
|
41
|
+
// layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
|
39
42
|
simd16uint16 accu[NQA][4];
|
40
43
|
|
41
44
|
for (int q = 0; q < NQ; q++) {
|
@@ -45,7 +48,7 @@ void kernel_accumulate_block(
|
|
45
48
|
}
|
46
49
|
|
47
50
|
// _mm_prefetch(codes + 768, 0);
|
48
|
-
for (int sq = 0; sq < nsq; sq += 2) {
|
51
|
+
for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
|
49
52
|
// prefetch
|
50
53
|
simd32uint8 c(codes);
|
51
54
|
codes += 32;
|
@@ -71,6 +74,31 @@ void kernel_accumulate_block(
|
|
71
74
|
}
|
72
75
|
}
|
73
76
|
|
77
|
+
for (int sq = 0; sq < scaler.nscale; sq += 2) {
|
78
|
+
// prefetch
|
79
|
+
simd32uint8 c(codes);
|
80
|
+
codes += 32;
|
81
|
+
|
82
|
+
simd32uint8 mask(0xf);
|
83
|
+
// shift op does not exist for int8...
|
84
|
+
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
|
85
|
+
simd32uint8 clo = c & mask;
|
86
|
+
|
87
|
+
for (int q = 0; q < NQ; q++) {
|
88
|
+
// load LUTs for 2 quantizers
|
89
|
+
simd32uint8 lut(LUT);
|
90
|
+
LUT += 32;
|
91
|
+
|
92
|
+
simd32uint8 res0 = scaler.lookup(lut, clo);
|
93
|
+
accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
|
94
|
+
accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
|
95
|
+
|
96
|
+
simd32uint8 res1 = scaler.lookup(lut, chi);
|
97
|
+
accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
|
98
|
+
accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
|
99
|
+
}
|
100
|
+
}
|
101
|
+
|
74
102
|
for (int q = 0; q < NQ; q++) {
|
75
103
|
accu[q][0] -= accu[q][1] << 8;
|
76
104
|
simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]);
|
@@ -81,13 +109,14 @@ void kernel_accumulate_block(
|
|
81
109
|
}
|
82
110
|
|
83
111
|
// handle at most 4 blocks of queries
|
84
|
-
template <int QBS, class ResultHandler>
|
112
|
+
template <int QBS, class ResultHandler, class Scaler>
|
85
113
|
void accumulate_q_4step(
|
86
114
|
size_t ntotal2,
|
87
115
|
int nsq,
|
88
116
|
const uint8_t* codes,
|
89
117
|
const uint8_t* LUT0,
|
90
|
-
ResultHandler& res
|
118
|
+
ResultHandler& res,
|
119
|
+
const Scaler& scaler) {
|
91
120
|
constexpr int Q1 = QBS & 15;
|
92
121
|
constexpr int Q2 = (QBS >> 4) & 15;
|
93
122
|
constexpr int Q3 = (QBS >> 8) & 15;
|
@@ -97,21 +126,21 @@ void accumulate_q_4step(
|
|
97
126
|
for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
|
98
127
|
FixedStorageHandler<SQ, 2> res2;
|
99
128
|
const uint8_t* LUT = LUT0;
|
100
|
-
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2);
|
129
|
+
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
|
101
130
|
LUT += Q1 * nsq * 16;
|
102
131
|
if (Q2 > 0) {
|
103
132
|
res2.set_block_origin(Q1, 0);
|
104
|
-
kernel_accumulate_block<Q2>(nsq, codes, LUT, res2);
|
133
|
+
kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler);
|
105
134
|
LUT += Q2 * nsq * 16;
|
106
135
|
}
|
107
136
|
if (Q3 > 0) {
|
108
137
|
res2.set_block_origin(Q1 + Q2, 0);
|
109
|
-
kernel_accumulate_block<Q3>(nsq, codes, LUT, res2);
|
138
|
+
kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler);
|
110
139
|
LUT += Q3 * nsq * 16;
|
111
140
|
}
|
112
141
|
if (Q4 > 0) {
|
113
142
|
res2.set_block_origin(Q1 + Q2 + Q3, 0);
|
114
|
-
kernel_accumulate_block<Q4>(nsq, codes, LUT, res2);
|
143
|
+
kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler);
|
115
144
|
}
|
116
145
|
res.set_block_origin(0, j0);
|
117
146
|
res2.to_other_handler(res);
|
@@ -119,29 +148,31 @@ void accumulate_q_4step(
|
|
119
148
|
}
|
120
149
|
}
|
121
150
|
|
122
|
-
template <int NQ, class ResultHandler>
|
151
|
+
template <int NQ, class ResultHandler, class Scaler>
|
123
152
|
void kernel_accumulate_block_loop(
|
124
153
|
size_t ntotal2,
|
125
154
|
int nsq,
|
126
155
|
const uint8_t* codes,
|
127
156
|
const uint8_t* LUT,
|
128
|
-
ResultHandler& res
|
157
|
+
ResultHandler& res,
|
158
|
+
const Scaler& scaler) {
|
129
159
|
for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
|
130
160
|
res.set_block_origin(0, j0);
|
131
161
|
kernel_accumulate_block<NQ, ResultHandler>(
|
132
|
-
nsq, codes + j0 * nsq / 2, LUT, res);
|
162
|
+
nsq, codes + j0 * nsq / 2, LUT, res, scaler);
|
133
163
|
}
|
134
164
|
}
|
135
165
|
|
136
166
|
// non-template version of accumulate kernel -- dispatches dynamically
|
137
|
-
template <class ResultHandler>
|
167
|
+
template <class ResultHandler, class Scaler>
|
138
168
|
void accumulate(
|
139
169
|
int nq,
|
140
170
|
size_t ntotal2,
|
141
171
|
int nsq,
|
142
172
|
const uint8_t* codes,
|
143
173
|
const uint8_t* LUT,
|
144
|
-
ResultHandler& res
|
174
|
+
ResultHandler& res,
|
175
|
+
const Scaler& scaler) {
|
145
176
|
assert(nsq % 2 == 0);
|
146
177
|
assert(is_aligned_pointer(codes));
|
147
178
|
assert(is_aligned_pointer(LUT));
|
@@ -149,7 +180,7 @@ void accumulate(
|
|
149
180
|
#define DISPATCH(NQ) \
|
150
181
|
case NQ: \
|
151
182
|
kernel_accumulate_block_loop<NQ, ResultHandler>( \
|
152
|
-
ntotal2, nsq, codes, LUT, res);
|
183
|
+
ntotal2, nsq, codes, LUT, res, scaler); \
|
153
184
|
return
|
154
185
|
|
155
186
|
switch (nq) {
|
@@ -165,23 +196,24 @@ void accumulate(
|
|
165
196
|
|
166
197
|
} // namespace
|
167
198
|
|
168
|
-
template <class ResultHandler>
|
199
|
+
template <class ResultHandler, class Scaler>
|
169
200
|
void pq4_accumulate_loop_qbs(
|
170
201
|
int qbs,
|
171
202
|
size_t ntotal2,
|
172
203
|
int nsq,
|
173
204
|
const uint8_t* codes,
|
174
205
|
const uint8_t* LUT0,
|
175
|
-
ResultHandler& res
|
206
|
+
ResultHandler& res,
|
207
|
+
const Scaler& scaler) {
|
176
208
|
assert(nsq % 2 == 0);
|
177
209
|
assert(is_aligned_pointer(codes));
|
178
210
|
assert(is_aligned_pointer(LUT0));
|
179
211
|
|
180
212
|
// try out optimized versions
|
181
213
|
switch (qbs) {
|
182
|
-
#define DISPATCH(QBS)
|
183
|
-
case QBS:
|
184
|
-
accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res); \
|
214
|
+
#define DISPATCH(QBS) \
|
215
|
+
case QBS: \
|
216
|
+
accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res, scaler); \
|
185
217
|
return;
|
186
218
|
DISPATCH(0x3333); // 12
|
187
219
|
DISPATCH(0x2333); // 11
|
@@ -219,9 +251,10 @@ void pq4_accumulate_loop_qbs(
|
|
219
251
|
int nq = qi & 15;
|
220
252
|
qi >>= 4;
|
221
253
|
res.set_block_origin(i0, j0);
|
222
|
-
#define DISPATCH(NQ)
|
223
|
-
case NQ:
|
224
|
-
kernel_accumulate_block<NQ, ResultHandler>(
|
254
|
+
#define DISPATCH(NQ) \
|
255
|
+
case NQ: \
|
256
|
+
kernel_accumulate_block<NQ, ResultHandler>( \
|
257
|
+
nsq, codes, LUT, res, scaler); \
|
225
258
|
break
|
226
259
|
switch (nq) {
|
227
260
|
DISPATCH(1);
|
@@ -241,9 +274,23 @@ void pq4_accumulate_loop_qbs(
|
|
241
274
|
|
242
275
|
// explicit template instantiations
|
243
276
|
|
244
|
-
#define INSTANTIATE_ACCUMULATE_Q(RH)
|
245
|
-
template void pq4_accumulate_loop_qbs<RH>(
|
246
|
-
int,
|
277
|
+
#define INSTANTIATE_ACCUMULATE_Q(RH) \
|
278
|
+
template void pq4_accumulate_loop_qbs<RH, DummyScaler>( \
|
279
|
+
int, \
|
280
|
+
size_t, \
|
281
|
+
int, \
|
282
|
+
const uint8_t*, \
|
283
|
+
const uint8_t*, \
|
284
|
+
RH&, \
|
285
|
+
const DummyScaler&); \
|
286
|
+
template void pq4_accumulate_loop_qbs<RH, NormTableScaler>( \
|
287
|
+
int, \
|
288
|
+
size_t, \
|
289
|
+
int, \
|
290
|
+
const uint8_t*, \
|
291
|
+
const uint8_t*, \
|
292
|
+
RH&, \
|
293
|
+
const NormTableScaler&);
|
247
294
|
|
248
295
|
using Csi = CMax<uint16_t, int>;
|
249
296
|
INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi>)
|
@@ -293,7 +340,8 @@ void accumulate_to_mem(
|
|
293
340
|
uint16_t* accu) {
|
294
341
|
FAISS_THROW_IF_NOT(ntotal2 % 32 == 0);
|
295
342
|
StoreResultHandler handler(accu, ntotal2);
|
296
|
-
|
343
|
+
DummyScaler scaler;
|
344
|
+
accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler);
|
297
345
|
}
|
298
346
|
|
299
347
|
int pq4_preferred_qbs(int n) {
|