uni-quant-cuda 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Jakub Grula
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
@@ -0,0 +1,3 @@
1
+ include README.md
2
+ include LICENSE
3
+ include *.cpp
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: uni-quant-cuda
3
+ Version: 0.2.0
4
+ Summary: Uni-Quant: CUDA-accelerated quantization/dequantization for Keras and XGBoost models
5
+ Author-email: Jakub Grula <ramsters110@gmail.com>
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: POSIX :: Linux
10
+ Requires-Python: >=3.11
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Dynamic: license-file
14
+
15
+ Uni-Quant
16
+ ========
17
+
18
+ Small library to quantize/dequantize Keras and XGBoost models using PyTorch CUDA kernels.
19
+
20
+ Notes
21
+ - This package compiles CUDA kernels at runtime using `torch.utils.cpp_extension.load_inline`.
22
+ - Installing and using the CUDA compilation requires a compatible PyTorch build and CUDA toolkit on the target machine.
23
+
24
+ Dependencies are listed in `requirements.txt` and synchronized with `pyproject.toml`.
25
+
26
+ Quick publish test
27
+
28
+ Build a source/wheel and check locally:
29
+
30
+ ```
31
+ python -m pip install --upgrade build twine
32
+ python -m build
33
+ python -m twine check dist/*
34
+ ```
35
+
36
+ Upload (example):
37
+
38
+ ```
39
+ python -m twine upload dist/*
40
+ ```
@@ -0,0 +1,26 @@
1
+ Uni-Quant
2
+ ========
3
+
4
+ Small library to quantize/dequantize Keras and XGBoost models using PyTorch CUDA kernels.
5
+
6
+ Notes
7
+ - This package compiles CUDA kernels at runtime using `torch.utils.cpp_extension.load_inline`.
8
+ - Installing and using the CUDA compilation requires a compatible PyTorch build and CUDA toolkit on the target machine.
9
+
10
+ Dependencies are listed in `requirements.txt` and synchronized with `pyproject.toml`.
11
+
12
+ Quick publish test
13
+
14
+ Build a source/wheel and check locally:
15
+
16
+ ```
17
+ python -m pip install --upgrade build twine
18
+ python -m build
19
+ python -m twine check dist/*
20
+ ```
21
+
22
+ Upload (example):
23
+
24
+ ```
25
+ python -m twine upload dist/*
26
+ ```
@@ -0,0 +1,357 @@
1
+ #include <torch/types.h>
2
+ #include <c10/cuda/CUDAException.h>
3
+ #include <stdio.h>
4
+ #include <cuda.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_U8(x) TORCH_CHECK(x.scalar_type() == torch::kUInt8, #x " must be uint8")
10
+ #define CHECK_INPUT_CPU_HEX(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x); CHECK_U8(x)
11
+
12
+ static inline int64_t cdiv_i64(int64_t a, int64_t b) { return (a + b - 1) / b; }
13
+
14
+ __device__ __forceinline__ uint8_t hex_nibble(uint8_t c) {
15
+ if (c >= '0' && c <= '9') return (uint8_t)(c - '0');
16
+ if (c >= 'a' && c <= 'f') return (uint8_t)(10 + (c - 'a'));
17
+ return (uint8_t)(10 + (c - 'A'));
18
+ }
19
+
20
+ __device__ __forceinline__ uint8_t hex_byte(const uint8_t* s2) {
21
+ return (uint8_t)((hex_nibble(s2[0]) << 4) | hex_nibble(s2[1]));
22
+ }
23
+
24
+ __device__ __forceinline__ float read_be_f32_from_hex8(const uint8_t* s8) {
25
+ uint32_t b0 = (uint32_t)hex_byte(s8 + 0);
26
+ uint32_t b1 = (uint32_t)hex_byte(s8 + 2);
27
+ uint32_t b2 = (uint32_t)hex_byte(s8 + 4);
28
+ uint32_t b3 = (uint32_t)hex_byte(s8 + 6);
29
+ uint32_t u = (b0 << 24) | (b1 << 16) | (b2 << 8) | b3;
30
+ return __uint_as_float(u);
31
+ }
32
+
33
+ __global__ void dequant_hex_kernel(
34
+ int mode,
35
+ const uint8_t* hex,
36
+ float* out_a,
37
+ float* out_b,
38
+ int64_t d1, int64_t d2, int64_t d3, int64_t d4,
39
+ int64_t pack_size,
40
+ int quant_size,
41
+ int half_point,
42
+ int balanced,
43
+ int literal,
44
+ int64_t total_elems
45
+ ) {
46
+ int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
47
+ if (idx >= total_elems) return;
48
+
49
+ int64_t batches = 0;
50
+ int64_t rem = 0;
51
+
52
+ int64_t batch_shift = (int64_t)(pack_size * (quant_size / 4.0f));
53
+ int64_t rec_full = 8 + batch_shift;
54
+
55
+ int64_t rem_shift = 0;
56
+ int64_t rec_rem = 0;
57
+
58
+ int64_t elems_weights = 0;
59
+
60
+ int64_t rows = 0;
61
+
62
+ if (mode == 0 || mode == 4) {
63
+ batches = d2 / pack_size;
64
+ rem = d2 % pack_size;
65
+ rem_shift = (quant_size == 4) ? (int64_t)((rem + (rem & 1)) * (quant_size / 4)) : (int64_t)(rem * (quant_size / 4));
66
+ rec_rem = rem ? (8 + rem_shift) : 0;
67
+ rows = d1;
68
+ elems_weights = d1 * d2;
69
+ } else if (mode == 1) {
70
+ batches = d3 / pack_size;
71
+ rem = d3 % pack_size;
72
+ rem_shift = (quant_size == 4) ? (int64_t)((rem + (rem & 1)) * (quant_size / 4)) : (int64_t)(rem * (quant_size / 4));
73
+ rec_rem = rem ? (8 + rem_shift) : 0;
74
+ rows = d1 * d2;
75
+ elems_weights = d1 * d2 * d3;
76
+ } else if (mode == 2) {
77
+ batches = d4 / pack_size;
78
+ rem = d4 % pack_size;
79
+ rem_shift = (quant_size == 4) ? (int64_t)((rem + (rem & 1)) * (quant_size / 4)) : (int64_t)(rem * (quant_size / 4));
80
+ rec_rem = rem ? (8 + rem_shift) : 0;
81
+ rows = d1 * d2 * d3;
82
+ elems_weights = d1 * d2 * d3 * d4;
83
+ } else {
84
+ batches = d1 / pack_size;
85
+ rem = d1 % pack_size;
86
+ rem_shift = (quant_size == 4) ? (int64_t)((rem + (rem & 1)) * (quant_size / 4)) : (int64_t)(rem * (quant_size / 4));
87
+ rec_rem = rem ? (8 + rem_shift) : 0;
88
+ rows = 1;
89
+
90
+ elems_weights = d1;
91
+ }
92
+
93
+ int is_second = 0;
94
+ int is_bias = 0;
95
+ int64_t pos = idx;
96
+
97
+ int seg = 0;
98
+ if (mode == 5) {
99
+ seg = (int)(idx / d1);
100
+ pos = idx - (int64_t)seg * d1;
101
+ is_second = 0;
102
+ is_bias = 0;
103
+ } else if (mode == 4) {
104
+ is_bias = 0;
105
+ } else if (mode == 3) {
106
+ if (pos >= elems_weights) {
107
+ is_second = 1;
108
+ pos -= elems_weights;
109
+ }
110
+ is_bias = is_second;
111
+ } else {
112
+ if (pos >= elems_weights) {
113
+ is_bias = 1;
114
+ pos -= elems_weights;
115
+ }
116
+ }
117
+
118
+ int64_t row = 0;
119
+ int64_t col = 0;
120
+
121
+ if (mode == 0 || mode == 4) {
122
+ if (!is_bias) {
123
+ row = pos / d2;
124
+ col = pos - row * d2;
125
+ } else {
126
+ row = rows;
127
+ col = pos;
128
+ }
129
+ } else if (mode == 1) {
130
+ if (!is_bias) {
131
+ row = pos / d3;
132
+ col = pos - row * d3;
133
+ } else {
134
+ row = rows;
135
+ col = pos;
136
+ }
137
+ } else if (mode == 2) {
138
+ if (!is_bias) {
139
+ row = pos / d4;
140
+ col = pos - row * d4;
141
+ } else {
142
+ row = rows;
143
+ col = pos;
144
+ }
145
+ } else {
146
+ row = 0;
147
+ col = pos;
148
+ }
149
+
150
+ int64_t row_hex = batches * rec_full + (rem ? rec_rem : 0);
151
+ int64_t base_hex = 0;
152
+
153
+ if (mode == 5) {
154
+ base_hex = (int64_t)seg * row_hex;
155
+ } else if (mode == 3) {
156
+ base_hex = is_second ? (row_hex) : 0;
157
+ } else {
158
+ base_hex = row * row_hex;
159
+ if (is_bias) base_hex = rows * row_hex;
160
+ }
161
+
162
+ int64_t blk = col / pack_size;
163
+ int64_t within = col - blk * pack_size;
164
+
165
+ int64_t rec_hex_off = 0;
166
+ int64_t payload_hex_off = 0;
167
+ float scale = 0.0f;
168
+
169
+ if (blk < batches) {
170
+ rec_hex_off = base_hex + blk * rec_full;
171
+ scale = read_be_f32_from_hex8(hex + rec_hex_off);
172
+ payload_hex_off = rec_hex_off + 8;
173
+ if (quant_size == 8) {
174
+ int q = (int)hex_byte(hex + payload_hex_off + within * 2);
175
+ int n = q - (balanced ? half_point : 0);
176
+ float v = literal ? (float)n : ((float)n * scale);
177
+ out_a[idx] = v;
178
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = v;
179
+ } else {
180
+ int64_t byte_i = within >> 1;
181
+ uint8_t b = hex_byte(hex + payload_hex_off + byte_i * 2);
182
+ int q = (within & 1) ? (int)(b & 0x0F) : (int)((b >> 4) & 0x0F);
183
+ if ((within & 1) && q == 0) {
184
+ out_a[idx] = 0.0f;
185
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = 0.0f;
186
+ } else {
187
+ int n = q - (balanced ? half_point : 0);
188
+ float v = literal ? (float)n : ((float)n * scale);
189
+ out_a[idx] = v;
190
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = v;
191
+ }
192
+ }
193
+ } else {
194
+ if (!rem) {
195
+ out_a[idx] = 0.0f;
196
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = 0.0f;
197
+ return;
198
+ }
199
+ int64_t rblk = blk - batches;
200
+ rec_hex_off = base_hex + batches * rec_full + rblk * rec_rem;
201
+ scale = read_be_f32_from_hex8(hex + rec_hex_off);
202
+ payload_hex_off = rec_hex_off + 8;
203
+
204
+ if (quant_size == 8) {
205
+ int q = (int)hex_byte(hex + payload_hex_off + within * 2);
206
+ int n = q - (balanced ? half_point : 0);
207
+ float v = literal ? (float)n : ((float)n * scale);
208
+ out_a[idx] = v;
209
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = v;
210
+ } else {
211
+ int64_t byte_i = within >> 1;
212
+ uint8_t b = hex_byte(hex + payload_hex_off + byte_i * 2);
213
+ int q = (within & 1) ? (int)(b & 0x0F) : (int)((b >> 4) & 0x0F);
214
+ if ((within & 1) && q == 0) {
215
+ out_a[idx] = 0.0f;
216
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = 0.0f;
217
+ } else {
218
+ int n = q - (balanced ? half_point : 0);
219
+ float v = literal ? (float)n : ((float)n * scale);
220
+ out_a[idx] = v;
221
+ if ((mode == 3 || is_bias) && mode != 5) out_b[pos] = v;
222
+ }
223
+ }
224
+ }
225
+ }
226
+
227
+ static std::vector<torch::Tensor> dequantize_hex_impl(
228
+ int mode,
229
+ torch::Tensor hex_cpu,
230
+ int64_t d1, int64_t d2, int64_t d3, int64_t d4,
231
+ int64_t pack_size,
232
+ int quant_size,
233
+ bool balanced,
234
+ bool literal
235
+ ) {
236
+ CHECK_INPUT_CPU_HEX(hex_cpu);
237
+ TORCH_CHECK(pack_size > 0, "pack_size must be > 0");
238
+ TORCH_CHECK(quant_size == 4 || quant_size == 8, "quant_size must be 4 or 8");
239
+ int half_point = (1 << quant_size) / 2;
240
+
241
+ int64_t total = 0;
242
+ if (mode == 0) {
243
+ TORCH_CHECK(d2 >= pack_size, "d2 must be >= pack_size");
244
+ total = d1 * d2 + d2;
245
+ } else if (mode == 1) {
246
+ TORCH_CHECK(d3 >= pack_size, "d3 must be >= pack_size");
247
+ total = d1 * d2 * d3 + d3;
248
+ } else if (mode == 2) {
249
+ TORCH_CHECK(d4 >= pack_size, "d4 must be >= pack_size");
250
+ total = d1 * d2 * d3 * d4 + d4;
251
+ } else if (mode == 4) {
252
+ TORCH_CHECK(d2 >= pack_size, "d2 must be >= pack_size");
253
+ total = d1 * d2;
254
+ } else if (mode == 5) {
255
+ TORCH_CHECK(d1 >= pack_size, "d1 must be >= pack_size");
256
+ total = 4 * d1;
257
+ } else {
258
+ TORCH_CHECK(d1 >= pack_size, "d1 must be >= pack_size");
259
+ total = d1 + d1;
260
+ }
261
+
262
+ auto hex_gpu = hex_cpu.to(torch::kCUDA);
263
+ auto out_a_gpu = torch::empty({total}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
264
+ torch::Tensor out_b_gpu;
265
+
266
+ if (mode == 0) out_b_gpu = torch::empty({d2}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
267
+ else if (mode == 1) out_b_gpu = torch::empty({d3}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
268
+ else if (mode == 2) out_b_gpu = torch::empty({d4}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
269
+ else if (mode == 3) out_b_gpu = torch::empty({d1}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
270
+ else out_b_gpu = torch::empty({1}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));
271
+
272
+ int threads = 256;
273
+ int grid = (int)cdiv_i64(total, threads);
274
+
275
+ dequant_hex_kernel<<<grid, threads>>>(
276
+ mode,
277
+ (const uint8_t*)hex_gpu.data_ptr<uint8_t>(),
278
+ (float*)out_a_gpu.data_ptr<float>(),
279
+ (float*)out_b_gpu.data_ptr<float>(),
280
+ d1, d2, d3, d4,
281
+ pack_size,
282
+ quant_size,
283
+ half_point,
284
+ (int)balanced,
285
+ (int)literal,
286
+ total
287
+ );
288
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
289
+
290
+ if (mode == 0) {
291
+ auto w = out_a_gpu.narrow(0, 0, d1 * d2).view({d1, d2}).to(torch::kCPU);
292
+ auto b = out_b_gpu.to(torch::kCPU);
293
+ return { w, b };
294
+ }
295
+ if (mode == 1) {
296
+ auto w = out_a_gpu.narrow(0, 0, d1 * d2 * d3).view({d1, d2, d3}).to(torch::kCPU);
297
+ auto b = out_b_gpu.to(torch::kCPU);
298
+ return { w, b };
299
+ }
300
+ if (mode == 2) {
301
+ auto w = out_a_gpu.narrow(0, 0, d1 * d2 * d3 * d4).view({d1, d2, d3, d4}).to(torch::kCPU);
302
+ auto b = out_b_gpu.to(torch::kCPU);
303
+ return { w, b };
304
+ }
305
+ if (mode == 4) {
306
+ return { out_a_gpu.to(torch::kCPU) };
307
+ }
308
+ if (mode == 5) {
309
+ auto flat = out_a_gpu.to(torch::kCPU);
310
+ auto w = flat.narrow(0, 0 * d1, d1).contiguous();
311
+ auto b = flat.narrow(0, 1 * d1, d1).contiguous();
312
+ auto rm = flat.narrow(0, 2 * d1, d1).contiguous();
313
+ auto rv = flat.narrow(0, 3 * d1, d1).contiguous();
314
+ return { w, b, rm, rv };
315
+ }
316
+ auto a = out_a_gpu.narrow(0, 0, d1).to(torch::kCPU);
317
+ auto b = out_b_gpu.to(torch::kCPU);
318
+ return { a, b };
319
+ }
320
+
321
+ std::vector<torch::Tensor> dequantize_dense_hex(torch::Tensor hex_cpu, int64_t d1, int64_t d2, int64_t pack_size, int quant_size, bool balanced, bool literal) {
322
+ return dequantize_hex_impl(0, hex_cpu, d1, d2, 0, 0, pack_size, quant_size, balanced, literal);
323
+ }
324
+
325
+ std::vector<torch::Tensor> dequantize_conv1d_hex(torch::Tensor hex_cpu, int64_t d1, int64_t d2, int64_t d3, int64_t pack_size, int quant_size, bool balanced, bool literal) {
326
+ return dequantize_hex_impl(1, hex_cpu, d1, d2, d3, 0, pack_size, quant_size, balanced, literal);
327
+ }
328
+
329
+ std::vector<torch::Tensor> dequantize_conv2d_hex(torch::Tensor hex_cpu, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t pack_size, int quant_size, bool balanced, bool literal) {
330
+ return dequantize_hex_impl(2, hex_cpu, d1, d2, d3, d4, pack_size, quant_size, balanced, literal);
331
+ }
332
+
333
+ std::vector<torch::Tensor> dequantize_gru_hex(torch::Tensor hex_cpu, int64_t d1, int64_t units, int64_t biases, int64_t pack_size, int quant_size, bool balanced, bool literal) {
334
+ int64_t d2 = 3 * units;
335
+ int64_t rows = d1 + units + biases;
336
+
337
+ auto tmp = dequantize_hex_impl(4, hex_cpu, rows, d2, 0, 0, pack_size, quant_size, balanced, literal);
338
+ auto flat = tmp[0];
339
+
340
+ auto full = flat.view({rows, d2});
341
+ auto w_in = full.narrow(0, 0, d1).contiguous();
342
+ auto w_rec = full.narrow(0, d1, units).contiguous();
343
+ auto b2d = full.narrow(0, d1 + units, biases).contiguous();
344
+ if (biases == 1) {
345
+ auto b1d = b2d.view({d2}).contiguous();
346
+ return { w_in, w_rec, b1d };
347
+ }
348
+ return { w_in, w_rec, b2d };
349
+ }
350
+
351
+ std::vector<torch::Tensor> dequantize_layernorm_hex(torch::Tensor hex_cpu, int64_t d1, int64_t pack_size, int quant_size, bool balanced, bool literal) {
352
+ return dequantize_hex_impl(3, hex_cpu, d1, 0, 0, 0, pack_size, quant_size, balanced, literal);
353
+ }
354
+
355
+ std::vector<torch::Tensor> dequantize_batchnorm_hex(torch::Tensor hex_cpu, int64_t d1, int64_t pack_size, int quant_size, bool balanced, bool literal) {
356
+ return dequantize_hex_impl(5, hex_cpu, d1, 0, 0, 0, pack_size, quant_size, balanced, literal);
357
+ }
@@ -0,0 +1,21 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "uni-quant-cuda"
7
+ version = "0.2.0"
8
+ description = "Uni-Quant: CUDA-accelerated quantization/dequantization for Keras and XGBoost models"
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ authors = [ { name = "Jakub Grula", email = "ramsters110@gmail.com" } ]
12
+ requires-python = ">=3.11"
13
+ dynamic = ["dependencies"]
14
+
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Operating System :: POSIX :: Linux"
19
+ ]
20
+
21
+ ## No [tool.setuptools] section needed for flat layout
@@ -0,0 +1,238 @@
1
+ #include <torch/types.h>
2
+ #include <c10/cuda/CUDAException.h>
3
+ #include <stdio.h>
4
+ #include <cuda.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_F32(x) TORCH_CHECK(x.scalar_type() == torch::kFloat32, #x " must be float32")
10
+ #define CHECK_INPUT_CPU_F32(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x); CHECK_F32(x)
11
+
12
+ static inline int64_t cdiv_i64(int64_t a, int64_t b) { return (a + b - 1) / b; }
13
+
14
+ __device__ __forceinline__ uint32_t bswap32(uint32_t v) {
15
+ return ((v & 0x000000FFu) << 24) | ((v & 0x0000FF00u) << 8) | ((v & 0x00FF0000u) >> 8) | ((v & 0xFF000000u) >> 24);
16
+ }
17
+
18
+ __device__ __forceinline__ void write_be_f32(uint8_t* dst, float v) {
19
+ uint32_t u = __float_as_uint(v);
20
+ u = bswap32(u);
21
+ dst[0] = (uint8_t)(u & 0xFF);
22
+ dst[1] = (uint8_t)((u >> 8) & 0xFF);
23
+ dst[2] = (uint8_t)((u >> 16) & 0xFF);
24
+ dst[3] = (uint8_t)((u >> 24) & 0xFF);
25
+ }
26
+
27
+ __device__ __forceinline__ int clamp_int(int v, int lo, int hi) {
28
+ return v < lo ? lo : (v > hi ? hi : v);
29
+ }
30
+
31
+ __global__ void quant_pack_kernel(
32
+ int mode,
33
+ const float* w,
34
+ int64_t d0, int64_t d1, int64_t d2, int64_t d3,
35
+ int64_t pack_size,
36
+ int quant_size,
37
+ int half_point,
38
+ int64_t blocks_per_inner,
39
+ int64_t total_blocks,
40
+ int64_t stride,
41
+ int64_t rem,
42
+ int64_t rem_payload,
43
+ int64_t rem_stride,
44
+ int64_t row_bytes,
45
+ uint8_t* out
46
+ ) {
47
+ int64_t bid = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
48
+ if (bid >= total_blocks) return;
49
+
50
+ int64_t row_idx = bid / blocks_per_inner;
51
+ int64_t blk_in_row = bid % blocks_per_inner;
52
+ int is_rem_block = (rem > 0 && blk_in_row == (blocks_per_inner - (rem > 0 ? 1 : 0))) ? 1 : 0;
53
+ int64_t actual_pack = is_rem_block ? rem : pack_size;
54
+
55
+ int64_t base = 0;
56
+
57
+ if (mode == 0) {
58
+ base = bid * pack_size;
59
+ } else if (mode == 1) {
60
+ int64_t i = bid / blocks_per_inner;
61
+ int64_t jb = bid - i * blocks_per_inner;
62
+ int64_t start = jb * pack_size;
63
+ base = i * d1 + start;
64
+ } else if (mode == 2) {
65
+ int64_t tmp = bid;
66
+ int64_t kb = tmp % blocks_per_inner;
67
+ tmp /= blocks_per_inner;
68
+ int64_t j = tmp % d1;
69
+ int64_t i = tmp / d1;
70
+ int64_t start = kb * pack_size;
71
+ base = (i * d1 + j) * d2 + start;
72
+ } else {
73
+ int64_t tmp = bid;
74
+ int64_t lb = tmp % blocks_per_inner;
75
+ tmp /= blocks_per_inner;
76
+ int64_t k = tmp % d2;
77
+ tmp /= d2;
78
+ int64_t j = tmp % d1;
79
+ int64_t i = tmp / d1;
80
+ int64_t start = lb * pack_size;
81
+ base = (((i * d1 + j) * d2 + k) * d3) + start;
82
+ }
83
+
84
+ const float* src = w + base;
85
+ int64_t out_off = row_idx * row_bytes;
86
+ if (is_rem_block) {
87
+ out_off += (blocks_per_inner - 1) * stride;
88
+ } else {
89
+ out_off += blk_in_row * stride;
90
+ }
91
+ uint8_t* dst = out + out_off;
92
+
93
+ float max_abs = 0.0f;
94
+ for (int64_t t = 0; t < actual_pack; ++t) {
95
+ float a = fabsf(src[t]);
96
+ if (a > max_abs) max_abs = a;
97
+ }
98
+
99
+ float scale = max_abs / (float)(half_point - 1);
100
+ write_be_f32(dst, scale);
101
+ dst += 4;
102
+
103
+ int lo = -(half_point - 1);
104
+ int hi = (half_point - 1);
105
+
106
+ if (quant_size == 8) {
107
+ for (int64_t t = 0; t < actual_pack; ++t) {
108
+ float x = src[t];
109
+ float qf;
110
+ if (scale == 0.0f) qf = (x > 0.0f) ? (float)hi : ((x < 0.0f) ? (float)lo : 0.0f);
111
+ else qf = nearbyintf(x / scale);
112
+ int q = clamp_int((int)qf, lo, hi) + half_point;
113
+ dst[t] = (uint8_t)q;
114
+ }
115
+ } else {
116
+ int64_t o = 0;
117
+ int64_t padded = actual_pack + (actual_pack & 1);
118
+ for (int64_t t = 0; t < padded; t += 2) {
119
+ float x0 = (t < actual_pack) ? src[t] : 0.0f;
120
+ float qf0;
121
+ if (scale == 0.0f) qf0 = (x0 > 0.0f) ? (float)hi : ((x0 < 0.0f) ? (float)lo : 0.0f);
122
+ else qf0 = nearbyintf(x0 / scale);
123
+ int q0 = clamp_int((int)qf0, lo, hi) + half_point;
124
+
125
+ float x1 = (t + 1 < actual_pack) ? src[t + 1] : 0.0f;
126
+ float qf1;
127
+ if (scale == 0.0f) qf1 = (x1 > 0.0f) ? (float)hi : ((x1 < 0.0f) ? (float)lo : 0.0f);
128
+ else qf1 = nearbyintf(x1 / scale);
129
+ int q1 = clamp_int((int)qf1, lo, hi) + half_point;
130
+
131
+ dst[o++] = (uint8_t)(((q0 & 0x0F) << 4) | (q1 & 0x0F));
132
+ }
133
+ }
134
+ }
135
+
136
+ static torch::Tensor quantize_pack_impl(int mode, torch::Tensor w_cpu, int64_t pack_size, int quant_size) {
137
+ CHECK_INPUT_CPU_F32(w_cpu);
138
+ TORCH_CHECK(pack_size > 0, "pack_size must be > 0");
139
+ TORCH_CHECK(quant_size == 4 || quant_size == 8, "quant_size must be 4 or 8");
140
+ if (quant_size == 4) TORCH_CHECK((pack_size & 1) == 0, "pack_size must be even for quant_size=4");
141
+
142
+ int half_point = (1 << quant_size) / 2;
143
+
144
+ int64_t d0 = 0, d1 = 0, d2 = 0, d3 = 0;
145
+ int64_t blocks_per_inner = 0;
146
+ int64_t total_blocks = 0;
147
+
148
+ if (mode == 0) {
149
+ TORCH_CHECK(w_cpu.dim() == 1, "w_cpu must be 1D");
150
+ d0 = w_cpu.size(0);
151
+ blocks_per_inner = d0 / pack_size;
152
+ if (d0 % pack_size != 0) blocks_per_inner += 1;
153
+ total_blocks = blocks_per_inner;
154
+ } else if (mode == 1) {
155
+ TORCH_CHECK(w_cpu.dim() == 2, "w_cpu must be 2D");
156
+ d0 = w_cpu.size(0);
157
+ d1 = w_cpu.size(1);
158
+ blocks_per_inner = d1 / pack_size;
159
+ if (d1 % pack_size != 0) blocks_per_inner += 1;
160
+ total_blocks = d0 * blocks_per_inner;
161
+ } else if (mode == 2) {
162
+ TORCH_CHECK(w_cpu.dim() == 3, "w_cpu must be 3D");
163
+ d0 = w_cpu.size(0);
164
+ d1 = w_cpu.size(1);
165
+ d2 = w_cpu.size(2);
166
+ blocks_per_inner = d2 / pack_size;
167
+ if (d2 % pack_size != 0) blocks_per_inner += 1;
168
+ total_blocks = d0 * d1 * blocks_per_inner;
169
+ } else {
170
+ TORCH_CHECK(w_cpu.dim() == 4, "w_cpu must be 4D");
171
+ d0 = w_cpu.size(0);
172
+ d1 = w_cpu.size(1);
173
+ d2 = w_cpu.size(2);
174
+ d3 = w_cpu.size(3);
175
+ blocks_per_inner = d3 / pack_size;
176
+ if (d3 % pack_size != 0) blocks_per_inner += 1;
177
+ total_blocks = d0 * d1 * d2 * blocks_per_inner;
178
+ }
179
+
180
+ int64_t payload = (quant_size == 4) ? (pack_size / 2) : pack_size;
181
+ int64_t stride = 4 + payload;
182
+
183
+ int64_t inner_dim = (mode == 0) ? d0 : ((mode == 1) ? d1 : ((mode == 2) ? d2 : d3));
184
+ int64_t rem = inner_dim % pack_size;
185
+ int64_t rem_payload = 0, rem_stride = 0;
186
+ if (rem > 0) {
187
+ int64_t rem_padded = rem + (rem & 1);
188
+ rem_payload = (quant_size == 4) ? (rem_padded / 2) : rem;
189
+ rem_stride = 4 + rem_payload;
190
+ }
191
+
192
+ int64_t num_rows = total_blocks / blocks_per_inner;
193
+ int64_t full_blocks = inner_dim / pack_size;
194
+ int64_t row_bytes = full_blocks * stride + (rem > 0 ? rem_stride : 0);
195
+ int64_t total_bytes = num_rows * row_bytes;
196
+
197
+ auto w = w_cpu.to(torch::kCUDA);
198
+ auto out = torch::zeros({total_bytes}, torch::TensorOptions().device(torch::kCUDA).dtype(torch::kUInt8));
199
+
200
+ int threads = 256;
201
+ int grid = (int)cdiv_i64(total_blocks, threads);
202
+
203
+ quant_pack_kernel<<<grid, threads>>>(
204
+ mode,
205
+ (const float*)w.data_ptr<float>(),
206
+ d0, d1, d2, d3,
207
+ pack_size,
208
+ quant_size,
209
+ half_point,
210
+ blocks_per_inner,
211
+ total_blocks,
212
+ stride,
213
+ rem,
214
+ rem_payload,
215
+ rem_stride,
216
+ row_bytes,
217
+ (uint8_t*)out.data_ptr<uint8_t>()
218
+ );
219
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
220
+
221
+ return out.to(torch::kCPU);
222
+ }
223
+
224
+ torch::Tensor quantize_pack_1d(torch::Tensor w_cpu, int64_t pack_size, int quant_size) {
225
+ return quantize_pack_impl(0, w_cpu, pack_size, quant_size);
226
+ }
227
+
228
+ torch::Tensor quantize_pack_2d(torch::Tensor w_cpu, int64_t pack_size, int quant_size) {
229
+ return quantize_pack_impl(1, w_cpu, pack_size, quant_size);
230
+ }
231
+
232
+ torch::Tensor quantize_pack_3d(torch::Tensor w_cpu, int64_t pack_size, int quant_size) {
233
+ return quantize_pack_impl(2, w_cpu, pack_size, quant_size);
234
+ }
235
+
236
+ torch::Tensor quantize_pack_4d(torch::Tensor w_cpu, int64_t pack_size, int quant_size) {
237
+ return quantize_pack_impl(3, w_cpu, pack_size, quant_size);
238
+ }
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: uni-quant-cuda
3
+ Version: 0.2.0
4
+ Summary: Uni-Quant: CUDA-accelerated quantization/dequantization for Keras and XGBoost models
5
+ Author-email: Jakub Grula <ramsters110@gmail.com>
6
+ License: MIT
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: POSIX :: Linux
10
+ Requires-Python: >=3.11
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Dynamic: license-file
14
+
15
+ Uni-Quant
16
+ ========
17
+
18
+ Small library to quantize/dequantize Keras and XGBoost models using PyTorch CUDA kernels.
19
+
20
+ Notes
21
+ - This package compiles CUDA kernels at runtime using `torch.utils.cpp_extension.load_inline`.
22
+ - Installing and using the CUDA compilation requires a compatible PyTorch build and CUDA toolkit on the target machine.
23
+
24
+ Dependencies are listed in `requirements.txt` and synchronized with `pyproject.toml`.
25
+
26
+ Quick publish test
27
+
28
+ Build a source/wheel and check locally:
29
+
30
+ ```
31
+ python -m pip install --upgrade build twine
32
+ python -m build
33
+ python -m twine check dist/*
34
+ ```
35
+
36
+ Upload (example):
37
+
38
+ ```
39
+ python -m twine upload dist/*
40
+ ```
@@ -0,0 +1,10 @@
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ dkernel.cpp
5
+ pyproject.toml
6
+ qkernel.cpp
7
+ uni_quant_cuda.egg-info/PKG-INFO
8
+ uni_quant_cuda.egg-info/SOURCES.txt
9
+ uni_quant_cuda.egg-info/dependency_links.txt
10
+ uni_quant_cuda.egg-info/top_level.txt