faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -278,13 +278,15 @@ double kmeans1d(const float* x, size_t n, size_t nclusters, float* centroids) {
|
|
|
278
278
|
****************************************************/
|
|
279
279
|
|
|
280
280
|
// for imbalance factor
|
|
281
|
-
double tot = 0.0
|
|
281
|
+
double tot = 0.0;
|
|
282
|
+
double uf = 0.0;
|
|
282
283
|
|
|
283
284
|
idx_t end = n;
|
|
284
285
|
for (idx_t k = nclusters - 1; k >= 0; k--) {
|
|
285
|
-
idx_t start = T.at(k, end - 1);
|
|
286
|
-
float sum =
|
|
287
|
-
|
|
286
|
+
const idx_t start = T.at(k, end - 1);
|
|
287
|
+
const float sum =
|
|
288
|
+
std::accumulate(arr.data() + start, arr.data() + end, 0.0f);
|
|
289
|
+
const idx_t size = end - start;
|
|
288
290
|
FAISS_THROW_IF_NOT_FMT(
|
|
289
291
|
size > 0, "Cluster %d: size %d", int(k), int(size));
|
|
290
292
|
centroids[k] = sum / size;
|
|
@@ -122,30 +122,70 @@ void pq4_pack_codes_range(
|
|
|
122
122
|
}
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
+
namespace {
|
|
126
|
+
|
|
127
|
+
// get the specific address of the vector inside a block
|
|
128
|
+
// shift is used for determine the if the saved in bits 0..3 (false) or
|
|
129
|
+
// bits 4..7 (true)
|
|
130
|
+
uint8_t get_vector_specific_address(
|
|
131
|
+
size_t bbs,
|
|
132
|
+
size_t vector_id,
|
|
133
|
+
size_t sq,
|
|
134
|
+
bool& shift) {
|
|
135
|
+
// get the vector_id inside the block
|
|
136
|
+
vector_id = vector_id % bbs;
|
|
137
|
+
shift = vector_id > 15;
|
|
138
|
+
vector_id = vector_id & 15;
|
|
139
|
+
|
|
140
|
+
// get the address of the vector in sq
|
|
141
|
+
size_t address;
|
|
142
|
+
if (vector_id < 8) {
|
|
143
|
+
address = vector_id << 1;
|
|
144
|
+
} else {
|
|
145
|
+
address = ((vector_id - 8) << 1) + 1;
|
|
146
|
+
}
|
|
147
|
+
if (sq & 1) {
|
|
148
|
+
address += 16;
|
|
149
|
+
}
|
|
150
|
+
return (sq >> 1) * bbs + address;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
} // anonymous namespace
|
|
154
|
+
|
|
125
155
|
uint8_t pq4_get_packed_element(
|
|
126
156
|
const uint8_t* data,
|
|
127
157
|
size_t bbs,
|
|
128
158
|
size_t nsq,
|
|
129
|
-
size_t
|
|
159
|
+
size_t vector_id,
|
|
130
160
|
size_t sq) {
|
|
131
161
|
// move to correct bbs-sized block
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
if (sq == 1) {
|
|
141
|
-
data += 16;
|
|
162
|
+
// number of blocks * block size
|
|
163
|
+
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
|
164
|
+
bool shift;
|
|
165
|
+
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
|
166
|
+
if (shift) {
|
|
167
|
+
return data[address] >> 4;
|
|
168
|
+
} else {
|
|
169
|
+
return data[address] & 15;
|
|
142
170
|
}
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
void pq4_set_packed_element(
|
|
174
|
+
uint8_t* data,
|
|
175
|
+
uint8_t code,
|
|
176
|
+
size_t bbs,
|
|
177
|
+
size_t nsq,
|
|
178
|
+
size_t vector_id,
|
|
179
|
+
size_t sq) {
|
|
180
|
+
// move to correct bbs-sized block
|
|
181
|
+
// number of blocks * block size
|
|
182
|
+
data += (vector_id / bbs) * (((nsq + 1) / 2) * bbs);
|
|
183
|
+
bool shift;
|
|
184
|
+
size_t address = get_vector_specific_address(bbs, vector_id, sq, shift);
|
|
185
|
+
if (shift) {
|
|
186
|
+
data[address] = (code << 4) | (data[address] & 15);
|
|
147
187
|
} else {
|
|
148
|
-
|
|
188
|
+
data[address] = code | (data[address] & ~15);
|
|
149
189
|
}
|
|
150
190
|
}
|
|
151
191
|
|
|
@@ -26,7 +26,7 @@ namespace faiss {
|
|
|
26
26
|
* The unused bytes are set to 0.
|
|
27
27
|
*
|
|
28
28
|
* @param codes input codes, size (ntotal, ceil(M / 2))
|
|
29
|
-
* @param
|
|
29
|
+
* @param ntotal number of input codes
|
|
30
30
|
* @param nb output number of codes (ntotal rounded up to a multiple of
|
|
31
31
|
* bbs)
|
|
32
32
|
* @param M2 number of sub-quantizers (=M rounded up to a muliple of 2)
|
|
@@ -61,14 +61,27 @@ void pq4_pack_codes_range(
|
|
|
61
61
|
|
|
62
62
|
/** get a single element from a packed codes table
|
|
63
63
|
*
|
|
64
|
-
* @param
|
|
64
|
+
* @param vector_id vector id
|
|
65
65
|
* @param sq subquantizer (< nsq)
|
|
66
66
|
*/
|
|
67
67
|
uint8_t pq4_get_packed_element(
|
|
68
68
|
const uint8_t* data,
|
|
69
69
|
size_t bbs,
|
|
70
70
|
size_t nsq,
|
|
71
|
-
size_t
|
|
71
|
+
size_t vector_id,
|
|
72
|
+
size_t sq);
|
|
73
|
+
|
|
74
|
+
/** set a single element "code" into a packed codes table
|
|
75
|
+
*
|
|
76
|
+
* @param vector_id vector id
|
|
77
|
+
* @param sq subquantizer (< nsq)
|
|
78
|
+
*/
|
|
79
|
+
void pq4_set_packed_element(
|
|
80
|
+
uint8_t* data,
|
|
81
|
+
uint8_t code,
|
|
82
|
+
size_t bbs,
|
|
83
|
+
size_t nsq,
|
|
84
|
+
size_t vector_id,
|
|
72
85
|
size_t sq);
|
|
73
86
|
|
|
74
87
|
/** Pack Look-up table for consumption by the kernel.
|
|
@@ -88,8 +101,9 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
|
|
|
88
101
|
* @param nsq number of sub-quantizers (muliple of 2)
|
|
89
102
|
* @param codes packed codes array
|
|
90
103
|
* @param LUT packed look-up table
|
|
104
|
+
* @param scaler scaler to scale the encoded norm
|
|
91
105
|
*/
|
|
92
|
-
template <class ResultHandler>
|
|
106
|
+
template <class ResultHandler, class Scaler>
|
|
93
107
|
void pq4_accumulate_loop(
|
|
94
108
|
int nq,
|
|
95
109
|
size_t nb,
|
|
@@ -97,7 +111,8 @@ void pq4_accumulate_loop(
|
|
|
97
111
|
int nsq,
|
|
98
112
|
const uint8_t* codes,
|
|
99
113
|
const uint8_t* LUT,
|
|
100
|
-
ResultHandler& res
|
|
114
|
+
ResultHandler& res,
|
|
115
|
+
const Scaler& scaler);
|
|
101
116
|
|
|
102
117
|
/* qbs versions, supported only for bbs=32.
|
|
103
118
|
*
|
|
@@ -141,20 +156,22 @@ int pq4_pack_LUT_qbs_q_map(
|
|
|
141
156
|
|
|
142
157
|
/** Run accumulation loop.
|
|
143
158
|
*
|
|
144
|
-
* @param qbs 4-bit
|
|
159
|
+
* @param qbs 4-bit encoded number of queries
|
|
145
160
|
* @param nb number of database codes (mutliple of bbs)
|
|
146
161
|
* @param nsq number of sub-quantizers
|
|
147
162
|
* @param codes encoded database vectors (packed)
|
|
148
163
|
* @param LUT look-up table (packed)
|
|
149
164
|
* @param res call-back for the resutls
|
|
165
|
+
* @param scaler scaler to scale the encoded norm
|
|
150
166
|
*/
|
|
151
|
-
template <class ResultHandler>
|
|
167
|
+
template <class ResultHandler, class Scaler>
|
|
152
168
|
void pq4_accumulate_loop_qbs(
|
|
153
169
|
int qbs,
|
|
154
170
|
size_t nb,
|
|
155
171
|
int nsq,
|
|
156
172
|
const uint8_t* codes,
|
|
157
173
|
const uint8_t* LUT,
|
|
158
|
-
ResultHandler& res
|
|
174
|
+
ResultHandler& res,
|
|
175
|
+
const Scaler& scaler);
|
|
159
176
|
|
|
160
177
|
} // namespace faiss
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/impl/LookupTableScaler.h>
|
|
11
12
|
#include <faiss/impl/simd_result_handlers.h>
|
|
12
13
|
|
|
13
14
|
namespace faiss {
|
|
@@ -26,12 +27,13 @@ namespace {
|
|
|
26
27
|
* writes results in a ResultHandler
|
|
27
28
|
*/
|
|
28
29
|
|
|
29
|
-
template <int NQ, int BB, class ResultHandler>
|
|
30
|
+
template <int NQ, int BB, class ResultHandler, class Scaler>
|
|
30
31
|
void kernel_accumulate_block(
|
|
31
32
|
int nsq,
|
|
32
33
|
const uint8_t* codes,
|
|
33
34
|
const uint8_t* LUT,
|
|
34
|
-
ResultHandler& res
|
|
35
|
+
ResultHandler& res,
|
|
36
|
+
const Scaler& scaler) {
|
|
35
37
|
// distance accumulators
|
|
36
38
|
simd16uint16 accu[NQ][BB][4];
|
|
37
39
|
|
|
@@ -44,7 +46,7 @@ void kernel_accumulate_block(
|
|
|
44
46
|
}
|
|
45
47
|
}
|
|
46
48
|
|
|
47
|
-
for (int sq = 0; sq < nsq; sq += 2) {
|
|
49
|
+
for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
|
|
48
50
|
simd32uint8 lut_cache[NQ];
|
|
49
51
|
for (int q = 0; q < NQ; q++) {
|
|
50
52
|
lut_cache[q] = simd32uint8(LUT);
|
|
@@ -72,6 +74,35 @@ void kernel_accumulate_block(
|
|
|
72
74
|
}
|
|
73
75
|
}
|
|
74
76
|
|
|
77
|
+
for (int sq = 0; sq < scaler.nscale; sq += 2) {
|
|
78
|
+
simd32uint8 lut_cache[NQ];
|
|
79
|
+
for (int q = 0; q < NQ; q++) {
|
|
80
|
+
lut_cache[q] = simd32uint8(LUT);
|
|
81
|
+
LUT += 32;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
for (int b = 0; b < BB; b++) {
|
|
85
|
+
simd32uint8 c = simd32uint8(codes);
|
|
86
|
+
codes += 32;
|
|
87
|
+
simd32uint8 mask(15);
|
|
88
|
+
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
|
|
89
|
+
simd32uint8 clo = c & mask;
|
|
90
|
+
|
|
91
|
+
for (int q = 0; q < NQ; q++) {
|
|
92
|
+
simd32uint8 lut = lut_cache[q];
|
|
93
|
+
|
|
94
|
+
simd32uint8 res0 = scaler.lookup(lut, clo);
|
|
95
|
+
accu[q][b][0] += scaler.scale_lo(res0); // handle vectors 0..7
|
|
96
|
+
accu[q][b][1] += scaler.scale_hi(res0); // handle vectors 8..15
|
|
97
|
+
|
|
98
|
+
simd32uint8 res1 = scaler.lookup(lut, chi);
|
|
99
|
+
accu[q][b][2] += scaler.scale_lo(res1); // handle vectors 16..23
|
|
100
|
+
accu[q][b][3] +=
|
|
101
|
+
scaler.scale_hi(res1); // handle vectors 24..31
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
75
106
|
for (int q = 0; q < NQ; q++) {
|
|
76
107
|
for (int b = 0; b < BB; b++) {
|
|
77
108
|
accu[q][b][0] -= accu[q][b][1] << 8;
|
|
@@ -85,17 +116,18 @@ void kernel_accumulate_block(
|
|
|
85
116
|
}
|
|
86
117
|
}
|
|
87
118
|
|
|
88
|
-
template <int NQ, int BB, class ResultHandler>
|
|
119
|
+
template <int NQ, int BB, class ResultHandler, class Scaler>
|
|
89
120
|
void accumulate_fixed_blocks(
|
|
90
121
|
size_t nb,
|
|
91
122
|
int nsq,
|
|
92
123
|
const uint8_t* codes,
|
|
93
124
|
const uint8_t* LUT,
|
|
94
|
-
ResultHandler& res
|
|
125
|
+
ResultHandler& res,
|
|
126
|
+
const Scaler& scaler) {
|
|
95
127
|
constexpr int bbs = 32 * BB;
|
|
96
128
|
for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
|
|
97
129
|
FixedStorageHandler<NQ, 2 * BB> res2;
|
|
98
|
-
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
|
|
130
|
+
kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
|
|
99
131
|
res.set_block_origin(0, j0);
|
|
100
132
|
res2.to_other_handler(res);
|
|
101
133
|
codes += bbs * nsq / 2;
|
|
@@ -104,7 +136,7 @@ void accumulate_fixed_blocks(
|
|
|
104
136
|
|
|
105
137
|
} // anonymous namespace
|
|
106
138
|
|
|
107
|
-
template <class ResultHandler>
|
|
139
|
+
template <class ResultHandler, class Scaler>
|
|
108
140
|
void pq4_accumulate_loop(
|
|
109
141
|
int nq,
|
|
110
142
|
size_t nb,
|
|
@@ -112,15 +144,16 @@ void pq4_accumulate_loop(
|
|
|
112
144
|
int nsq,
|
|
113
145
|
const uint8_t* codes,
|
|
114
146
|
const uint8_t* LUT,
|
|
115
|
-
ResultHandler& res
|
|
147
|
+
ResultHandler& res,
|
|
148
|
+
const Scaler& scaler) {
|
|
116
149
|
FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
|
|
117
150
|
FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
|
|
118
151
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
119
152
|
FAISS_THROW_IF_NOT(nb % bbs == 0);
|
|
120
153
|
|
|
121
|
-
#define DISPATCH(NQ, BB)
|
|
122
|
-
case NQ * 1000 + BB:
|
|
123
|
-
accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
|
|
154
|
+
#define DISPATCH(NQ, BB) \
|
|
155
|
+
case NQ * 1000 + BB: \
|
|
156
|
+
accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
|
|
124
157
|
break
|
|
125
158
|
|
|
126
159
|
switch (nq * 1000 + bbs / 32) {
|
|
@@ -141,20 +174,28 @@ void pq4_accumulate_loop(
|
|
|
141
174
|
|
|
142
175
|
// explicit template instantiations
|
|
143
176
|
|
|
144
|
-
#define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
|
|
145
|
-
template void pq4_accumulate_loop<TH<C, with_id_map
|
|
146
|
-
int,
|
|
147
|
-
size_t,
|
|
148
|
-
int,
|
|
149
|
-
int,
|
|
150
|
-
const uint8_t*,
|
|
151
|
-
const uint8_t*,
|
|
152
|
-
TH<C, with_id_map
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
177
|
+
#define INSTANTIATE_ACCUMULATE(TH, C, with_id_map, S) \
|
|
178
|
+
template void pq4_accumulate_loop<TH<C, with_id_map>, S>( \
|
|
179
|
+
int, \
|
|
180
|
+
size_t, \
|
|
181
|
+
int, \
|
|
182
|
+
int, \
|
|
183
|
+
const uint8_t*, \
|
|
184
|
+
const uint8_t*, \
|
|
185
|
+
TH<C, with_id_map>&, \
|
|
186
|
+
const S&);
|
|
187
|
+
|
|
188
|
+
using DS = DummyScaler;
|
|
189
|
+
using NS = NormTableScaler;
|
|
190
|
+
|
|
191
|
+
#define INSTANTIATE_3(C, with_id_map) \
|
|
192
|
+
INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, DS) \
|
|
193
|
+
INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, DS) \
|
|
194
|
+
INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, DS) \
|
|
195
|
+
\
|
|
196
|
+
INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map, NS) \
|
|
197
|
+
INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map, NS) \
|
|
198
|
+
INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map, NS)
|
|
158
199
|
|
|
159
200
|
using Csi = CMax<uint16_t, int>;
|
|
160
201
|
INSTANTIATE_3(Csi, false);
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/impl/LookupTableScaler.h>
|
|
11
12
|
#include <faiss/impl/simd_result_handlers.h>
|
|
12
13
|
#include <faiss/utils/simdlib.h>
|
|
13
14
|
|
|
@@ -27,15 +28,17 @@ namespace {
|
|
|
27
28
|
* writes results in a ResultHandler
|
|
28
29
|
*/
|
|
29
30
|
|
|
30
|
-
template <int NQ, class ResultHandler>
|
|
31
|
+
template <int NQ, class ResultHandler, class Scaler>
|
|
31
32
|
void kernel_accumulate_block(
|
|
32
33
|
int nsq,
|
|
33
34
|
const uint8_t* codes,
|
|
34
35
|
const uint8_t* LUT,
|
|
35
|
-
ResultHandler& res
|
|
36
|
+
ResultHandler& res,
|
|
37
|
+
const Scaler& scaler) {
|
|
36
38
|
// dummy alloc to keep the windows compiler happy
|
|
37
39
|
constexpr int NQA = NQ > 0 ? NQ : 1;
|
|
38
40
|
// distance accumulators
|
|
41
|
+
// layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
|
|
39
42
|
simd16uint16 accu[NQA][4];
|
|
40
43
|
|
|
41
44
|
for (int q = 0; q < NQ; q++) {
|
|
@@ -45,7 +48,7 @@ void kernel_accumulate_block(
|
|
|
45
48
|
}
|
|
46
49
|
|
|
47
50
|
// _mm_prefetch(codes + 768, 0);
|
|
48
|
-
for (int sq = 0; sq < nsq; sq += 2) {
|
|
51
|
+
for (int sq = 0; sq < nsq - scaler.nscale; sq += 2) {
|
|
49
52
|
// prefetch
|
|
50
53
|
simd32uint8 c(codes);
|
|
51
54
|
codes += 32;
|
|
@@ -71,6 +74,31 @@ void kernel_accumulate_block(
|
|
|
71
74
|
}
|
|
72
75
|
}
|
|
73
76
|
|
|
77
|
+
for (int sq = 0; sq < scaler.nscale; sq += 2) {
|
|
78
|
+
// prefetch
|
|
79
|
+
simd32uint8 c(codes);
|
|
80
|
+
codes += 32;
|
|
81
|
+
|
|
82
|
+
simd32uint8 mask(0xf);
|
|
83
|
+
// shift op does not exist for int8...
|
|
84
|
+
simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
|
|
85
|
+
simd32uint8 clo = c & mask;
|
|
86
|
+
|
|
87
|
+
for (int q = 0; q < NQ; q++) {
|
|
88
|
+
// load LUTs for 2 quantizers
|
|
89
|
+
simd32uint8 lut(LUT);
|
|
90
|
+
LUT += 32;
|
|
91
|
+
|
|
92
|
+
simd32uint8 res0 = scaler.lookup(lut, clo);
|
|
93
|
+
accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
|
|
94
|
+
accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
|
|
95
|
+
|
|
96
|
+
simd32uint8 res1 = scaler.lookup(lut, chi);
|
|
97
|
+
accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
|
|
98
|
+
accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
74
102
|
for (int q = 0; q < NQ; q++) {
|
|
75
103
|
accu[q][0] -= accu[q][1] << 8;
|
|
76
104
|
simd16uint16 dis0 = combine2x2(accu[q][0], accu[q][1]);
|
|
@@ -81,13 +109,14 @@ void kernel_accumulate_block(
|
|
|
81
109
|
}
|
|
82
110
|
|
|
83
111
|
// handle at most 4 blocks of queries
|
|
84
|
-
template <int QBS, class ResultHandler>
|
|
112
|
+
template <int QBS, class ResultHandler, class Scaler>
|
|
85
113
|
void accumulate_q_4step(
|
|
86
114
|
size_t ntotal2,
|
|
87
115
|
int nsq,
|
|
88
116
|
const uint8_t* codes,
|
|
89
117
|
const uint8_t* LUT0,
|
|
90
|
-
ResultHandler& res
|
|
118
|
+
ResultHandler& res,
|
|
119
|
+
const Scaler& scaler) {
|
|
91
120
|
constexpr int Q1 = QBS & 15;
|
|
92
121
|
constexpr int Q2 = (QBS >> 4) & 15;
|
|
93
122
|
constexpr int Q3 = (QBS >> 8) & 15;
|
|
@@ -97,21 +126,21 @@ void accumulate_q_4step(
|
|
|
97
126
|
for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
|
|
98
127
|
FixedStorageHandler<SQ, 2> res2;
|
|
99
128
|
const uint8_t* LUT = LUT0;
|
|
100
|
-
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2);
|
|
129
|
+
kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
|
|
101
130
|
LUT += Q1 * nsq * 16;
|
|
102
131
|
if (Q2 > 0) {
|
|
103
132
|
res2.set_block_origin(Q1, 0);
|
|
104
|
-
kernel_accumulate_block<Q2>(nsq, codes, LUT, res2);
|
|
133
|
+
kernel_accumulate_block<Q2>(nsq, codes, LUT, res2, scaler);
|
|
105
134
|
LUT += Q2 * nsq * 16;
|
|
106
135
|
}
|
|
107
136
|
if (Q3 > 0) {
|
|
108
137
|
res2.set_block_origin(Q1 + Q2, 0);
|
|
109
|
-
kernel_accumulate_block<Q3>(nsq, codes, LUT, res2);
|
|
138
|
+
kernel_accumulate_block<Q3>(nsq, codes, LUT, res2, scaler);
|
|
110
139
|
LUT += Q3 * nsq * 16;
|
|
111
140
|
}
|
|
112
141
|
if (Q4 > 0) {
|
|
113
142
|
res2.set_block_origin(Q1 + Q2 + Q3, 0);
|
|
114
|
-
kernel_accumulate_block<Q4>(nsq, codes, LUT, res2);
|
|
143
|
+
kernel_accumulate_block<Q4>(nsq, codes, LUT, res2, scaler);
|
|
115
144
|
}
|
|
116
145
|
res.set_block_origin(0, j0);
|
|
117
146
|
res2.to_other_handler(res);
|
|
@@ -119,29 +148,31 @@ void accumulate_q_4step(
|
|
|
119
148
|
}
|
|
120
149
|
}
|
|
121
150
|
|
|
122
|
-
template <int NQ, class ResultHandler>
|
|
151
|
+
template <int NQ, class ResultHandler, class Scaler>
|
|
123
152
|
void kernel_accumulate_block_loop(
|
|
124
153
|
size_t ntotal2,
|
|
125
154
|
int nsq,
|
|
126
155
|
const uint8_t* codes,
|
|
127
156
|
const uint8_t* LUT,
|
|
128
|
-
ResultHandler& res
|
|
157
|
+
ResultHandler& res,
|
|
158
|
+
const Scaler& scaler) {
|
|
129
159
|
for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
|
|
130
160
|
res.set_block_origin(0, j0);
|
|
131
161
|
kernel_accumulate_block<NQ, ResultHandler>(
|
|
132
|
-
nsq, codes + j0 * nsq / 2, LUT, res);
|
|
162
|
+
nsq, codes + j0 * nsq / 2, LUT, res, scaler);
|
|
133
163
|
}
|
|
134
164
|
}
|
|
135
165
|
|
|
136
166
|
// non-template version of accumulate kernel -- dispatches dynamically
|
|
137
|
-
template <class ResultHandler>
|
|
167
|
+
template <class ResultHandler, class Scaler>
|
|
138
168
|
void accumulate(
|
|
139
169
|
int nq,
|
|
140
170
|
size_t ntotal2,
|
|
141
171
|
int nsq,
|
|
142
172
|
const uint8_t* codes,
|
|
143
173
|
const uint8_t* LUT,
|
|
144
|
-
ResultHandler& res
|
|
174
|
+
ResultHandler& res,
|
|
175
|
+
const Scaler& scaler) {
|
|
145
176
|
assert(nsq % 2 == 0);
|
|
146
177
|
assert(is_aligned_pointer(codes));
|
|
147
178
|
assert(is_aligned_pointer(LUT));
|
|
@@ -149,7 +180,7 @@ void accumulate(
|
|
|
149
180
|
#define DISPATCH(NQ) \
|
|
150
181
|
case NQ: \
|
|
151
182
|
kernel_accumulate_block_loop<NQ, ResultHandler>( \
|
|
152
|
-
ntotal2, nsq, codes, LUT, res);
|
|
183
|
+
ntotal2, nsq, codes, LUT, res, scaler); \
|
|
153
184
|
return
|
|
154
185
|
|
|
155
186
|
switch (nq) {
|
|
@@ -165,23 +196,24 @@ void accumulate(
|
|
|
165
196
|
|
|
166
197
|
} // namespace
|
|
167
198
|
|
|
168
|
-
template <class ResultHandler>
|
|
199
|
+
template <class ResultHandler, class Scaler>
|
|
169
200
|
void pq4_accumulate_loop_qbs(
|
|
170
201
|
int qbs,
|
|
171
202
|
size_t ntotal2,
|
|
172
203
|
int nsq,
|
|
173
204
|
const uint8_t* codes,
|
|
174
205
|
const uint8_t* LUT0,
|
|
175
|
-
ResultHandler& res
|
|
206
|
+
ResultHandler& res,
|
|
207
|
+
const Scaler& scaler) {
|
|
176
208
|
assert(nsq % 2 == 0);
|
|
177
209
|
assert(is_aligned_pointer(codes));
|
|
178
210
|
assert(is_aligned_pointer(LUT0));
|
|
179
211
|
|
|
180
212
|
// try out optimized versions
|
|
181
213
|
switch (qbs) {
|
|
182
|
-
#define DISPATCH(QBS)
|
|
183
|
-
case QBS:
|
|
184
|
-
accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res); \
|
|
214
|
+
#define DISPATCH(QBS) \
|
|
215
|
+
case QBS: \
|
|
216
|
+
accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res, scaler); \
|
|
185
217
|
return;
|
|
186
218
|
DISPATCH(0x3333); // 12
|
|
187
219
|
DISPATCH(0x2333); // 11
|
|
@@ -219,9 +251,10 @@ void pq4_accumulate_loop_qbs(
|
|
|
219
251
|
int nq = qi & 15;
|
|
220
252
|
qi >>= 4;
|
|
221
253
|
res.set_block_origin(i0, j0);
|
|
222
|
-
#define DISPATCH(NQ)
|
|
223
|
-
case NQ:
|
|
224
|
-
kernel_accumulate_block<NQ, ResultHandler>(
|
|
254
|
+
#define DISPATCH(NQ) \
|
|
255
|
+
case NQ: \
|
|
256
|
+
kernel_accumulate_block<NQ, ResultHandler>( \
|
|
257
|
+
nsq, codes, LUT, res, scaler); \
|
|
225
258
|
break
|
|
226
259
|
switch (nq) {
|
|
227
260
|
DISPATCH(1);
|
|
@@ -241,9 +274,23 @@ void pq4_accumulate_loop_qbs(
|
|
|
241
274
|
|
|
242
275
|
// explicit template instantiations
|
|
243
276
|
|
|
244
|
-
#define INSTANTIATE_ACCUMULATE_Q(RH)
|
|
245
|
-
template void pq4_accumulate_loop_qbs<RH>(
|
|
246
|
-
int,
|
|
277
|
+
#define INSTANTIATE_ACCUMULATE_Q(RH) \
|
|
278
|
+
template void pq4_accumulate_loop_qbs<RH, DummyScaler>( \
|
|
279
|
+
int, \
|
|
280
|
+
size_t, \
|
|
281
|
+
int, \
|
|
282
|
+
const uint8_t*, \
|
|
283
|
+
const uint8_t*, \
|
|
284
|
+
RH&, \
|
|
285
|
+
const DummyScaler&); \
|
|
286
|
+
template void pq4_accumulate_loop_qbs<RH, NormTableScaler>( \
|
|
287
|
+
int, \
|
|
288
|
+
size_t, \
|
|
289
|
+
int, \
|
|
290
|
+
const uint8_t*, \
|
|
291
|
+
const uint8_t*, \
|
|
292
|
+
RH&, \
|
|
293
|
+
const NormTableScaler&);
|
|
247
294
|
|
|
248
295
|
using Csi = CMax<uint16_t, int>;
|
|
249
296
|
INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi>)
|
|
@@ -293,7 +340,8 @@ void accumulate_to_mem(
|
|
|
293
340
|
uint16_t* accu) {
|
|
294
341
|
FAISS_THROW_IF_NOT(ntotal2 % 32 == 0);
|
|
295
342
|
StoreResultHandler handler(accu, ntotal2);
|
|
296
|
-
|
|
343
|
+
DummyScaler scaler;
|
|
344
|
+
accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler);
|
|
297
345
|
}
|
|
298
346
|
|
|
299
347
|
int pq4_preferred_qbs(int n) {
|