ista-daslab-optimizers-cuda 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. ista_daslab_optimizers_cuda-1.0.0/LICENSE +21 -0
  2. ista_daslab_optimizers_cuda-1.0.0/MANIFEST.in +1 -0
  3. ista_daslab_optimizers_cuda-1.0.0/PKG-INFO +40 -0
  4. ista_daslab_optimizers_cuda-1.0.0/ista_daslab_optimizers_cuda.egg-info/PKG-INFO +40 -0
  5. ista_daslab_optimizers_cuda-1.0.0/ista_daslab_optimizers_cuda.egg-info/SOURCES.txt +32 -0
  6. ista_daslab_optimizers_cuda-1.0.0/ista_daslab_optimizers_cuda.egg-info/dependency_links.txt +1 -0
  7. ista_daslab_optimizers_cuda-1.0.0/ista_daslab_optimizers_cuda.egg-info/requires.txt +4 -0
  8. ista_daslab_optimizers_cuda-1.0.0/ista_daslab_optimizers_cuda.egg-info/top_level.txt +4 -0
  9. ista_daslab_optimizers_cuda-1.0.0/kernels/dense_mfac/dense_mfac.cpp +20 -0
  10. ista_daslab_optimizers_cuda-1.0.0/kernels/dense_mfac/dense_mfac_kernel.cu +216 -0
  11. ista_daslab_optimizers_cuda-1.0.0/kernels/micro_adam/micro_adam.cpp +62 -0
  12. ista_daslab_optimizers_cuda-1.0.0/kernels/micro_adam/micro_adam_asymm_block_quant.cu +64 -0
  13. ista_daslab_optimizers_cuda-1.0.0/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +83 -0
  14. ista_daslab_optimizers_cuda-1.0.0/kernels/micro_adam/micro_adam_update.cu +165 -0
  15. ista_daslab_optimizers_cuda-1.0.0/kernels/sparse_mfac/sparse_mfac.cpp +84 -0
  16. ista_daslab_optimizers_cuda-1.0.0/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +246 -0
  17. ista_daslab_optimizers_cuda-1.0.0/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +251 -0
  18. ista_daslab_optimizers_cuda-1.0.0/kernels/tools/tools.cpp +127 -0
  19. ista_daslab_optimizers_cuda-1.0.0/kernels/tools/tools_kernel.cu +315 -0
  20. ista_daslab_optimizers_cuda-1.0.0/kernels/utils.h +125 -0
  21. ista_daslab_optimizers_cuda-1.0.0/pyproject.toml +42 -0
  22. ista_daslab_optimizers_cuda-1.0.0/setup.cfg +4 -0
  23. ista_daslab_optimizers_cuda-1.0.0/setup.py +56 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 IST Austria Distributed Algorithms and Systems Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ graft ./kernels
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: ista_daslab_optimizers_cuda
3
+ Version: 1.0.0
4
+ Summary: CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
+ Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
+ Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
7
+ License: MIT License
8
+
9
+ Copyright (c) 2026 IST Austria Distributed Algorithms and Systems Lab
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+
29
+ Project-URL: Repository, https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA
30
+ Keywords: adaptive optimization,deep learning,low memory optimization
31
+ Classifier: Programming Language :: Python :: 3.8
32
+ Classifier: License :: OSI Approved :: Apache Software License
33
+ Requires-Python: >=3.8
34
+ Description-Content-Type: text/markdown
35
+ License-File: LICENSE
36
+ Requires-Dist: torch
37
+ Requires-Dist: torchaudio
38
+ Requires-Dist: torchvision
39
+ Requires-Dist: numpy
40
+ Dynamic: license-file
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.4
2
+ Name: ista_daslab_optimizers_cuda
3
+ Version: 1.0.0
4
+ Summary: CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
+ Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
+ Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
7
+ License: MIT License
8
+
9
+ Copyright (c) 2026 IST Austria Distributed Algorithms and Systems Lab
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+
29
+ Project-URL: Repository, https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA
30
+ Keywords: adaptive optimization,deep learning,low memory optimization
31
+ Classifier: Programming Language :: Python :: 3.8
32
+ Classifier: License :: OSI Approved :: Apache Software License
33
+ Requires-Python: >=3.8
34
+ Description-Content-Type: text/markdown
35
+ License-File: LICENSE
36
+ Requires-Dist: torch
37
+ Requires-Dist: torchaudio
38
+ Requires-Dist: torchvision
39
+ Requires-Dist: numpy
40
+ Dynamic: license-file
@@ -0,0 +1,32 @@
1
+ LICENSE
2
+ MANIFEST.in
3
+ pyproject.toml
4
+ setup.py
5
+ ./kernels/dense_mfac/dense_mfac.cpp
6
+ ./kernels/dense_mfac/dense_mfac_kernel.cu
7
+ ./kernels/micro_adam/micro_adam.cpp
8
+ ./kernels/micro_adam/micro_adam_asymm_block_quant.cu
9
+ ./kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu
10
+ ./kernels/micro_adam/micro_adam_update.cu
11
+ ./kernels/sparse_mfac/sparse_mfac.cpp
12
+ ./kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
13
+ ./kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
14
+ ./kernels/tools/tools.cpp
15
+ ./kernels/tools/tools_kernel.cu
16
+ ista_daslab_optimizers_cuda.egg-info/PKG-INFO
17
+ ista_daslab_optimizers_cuda.egg-info/SOURCES.txt
18
+ ista_daslab_optimizers_cuda.egg-info/dependency_links.txt
19
+ ista_daslab_optimizers_cuda.egg-info/requires.txt
20
+ ista_daslab_optimizers_cuda.egg-info/top_level.txt
21
+ kernels/utils.h
22
+ kernels/dense_mfac/dense_mfac.cpp
23
+ kernels/dense_mfac/dense_mfac_kernel.cu
24
+ kernels/micro_adam/micro_adam.cpp
25
+ kernels/micro_adam/micro_adam_asymm_block_quant.cu
26
+ kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu
27
+ kernels/micro_adam/micro_adam_update.cu
28
+ kernels/sparse_mfac/sparse_mfac.cpp
29
+ kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
30
+ kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
31
+ kernels/tools/tools.cpp
32
+ kernels/tools/tools_kernel.cu
@@ -0,0 +1,4 @@
1
+ torch
2
+ torchaudio
3
+ torchvision
4
+ numpy
@@ -0,0 +1,4 @@
1
+ ista_daslab_cuda_dense_mfac
2
+ ista_daslab_cuda_micro_adam
3
+ ista_daslab_cuda_sparse_mfac
4
+ ista_daslab_cuda_tools
@@ -0,0 +1,20 @@
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+
4
+ torch::Tensor hinv_setup_cuda(torch::Tensor tmp, torch::Tensor coef);
5
+ torch::Tensor hinv_mul_cuda(int rows, torch::Tensor giHig, torch::Tensor giHix);
6
+
7
+ torch::Tensor hinv_setup(torch::Tensor tmp, torch::Tensor coef) {
8
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(tmp));
9
+ return hinv_setup_cuda(tmp, coef);
10
+ }
11
+
12
+ torch::Tensor hinv_mul(int rows, torch::Tensor giHig, torch::Tensor giHix) {
13
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(giHig));
14
+ return hinv_mul_cuda(rows, giHig, giHix);
15
+ }
16
+
17
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18
+ m.def("hinv_setup", &hinv_setup, "Hinv setup (CUDA)");
19
+ m.def("hinv_mul", &hinv_mul, "Hinv mul (CUDA)");
20
+ }
@@ -0,0 +1,216 @@
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ const int SIZE = 32;
6
+ const int MAX = 1024;
7
+
8
+ template <typename scalar_t>
9
+ __device__ __forceinline__ scalar_t GetElement(
10
+ const scalar_t* __restrict__ A, int m,
11
+ int row, int col
12
+ ) {
13
+ return A[row * m + col];
14
+ }
15
+
16
+ template <typename scalar_t>
17
+ __device__ __forceinline__ void SetElement(
18
+ scalar_t* __restrict__ A, int m,
19
+ int row, int col,
20
+ scalar_t value
21
+ ) {
22
+ A[row * m + col] = value;
23
+ }
24
+
25
+ /* Kernel for computing `coef` (required setup & update operations) */
26
+
27
+ template <typename scalar_t>
28
+ __global__ void HinvCoefKernelDiag(
29
+ int m,
30
+ const scalar_t* __restrict__ tmp,
31
+ scalar_t* __restrict__ coef
32
+ );
33
+
34
+ template <typename scalar_t>
35
+ __global__ void HinvCoefKernelMain(
36
+ int m,
37
+ const scalar_t* __restrict__ tmp,
38
+ scalar_t* __restrict__ coef,
39
+ int iter
40
+ );
41
+
42
+ // NOTE: for simplicity, we assume that `m` is always divisible by `SIZE`
43
+ torch::Tensor hinv_setup_cuda(torch::Tensor tmp, torch::Tensor coef) {
44
+ const auto m = tmp.size(0);
45
+ const dim3 threads(SIZE, SIZE);
46
+ const dim3 blocks(m / SIZE, m / SIZE);
47
+
48
+ AT_DISPATCH_FLOATING_TYPES(tmp.scalar_type(), "hinv_setup_cuda", ([&] {
49
+ HinvCoefKernelDiag<scalar_t><<<m / SIZE, threads>>>(
50
+ m, tmp.data_ptr<scalar_t>(), coef.data_ptr<scalar_t>()
51
+ );
52
+ })
53
+ );
54
+ for (int i = 0; i < m / SIZE - 1; i++) {
55
+ AT_DISPATCH_FLOATING_TYPES(tmp.scalar_type(), "hinv_setup_cuda", ([&] {
56
+ HinvCoefKernelMain<scalar_t><<<blocks, threads>>>(
57
+ m, tmp.data_ptr<scalar_t>(), coef.data_ptr<scalar_t>(), i
58
+ );
59
+ })
60
+ );
61
+ }
62
+
63
+ return coef;
64
+ }
65
+
66
+ template <typename scalar_t>
67
+ __global__ void HinvCoefKernelMain(
68
+ int m,
69
+ const scalar_t* __restrict__ tmp,
70
+ scalar_t* __restrict__ coef,
71
+ int iter
72
+ ) {
73
+ // one thread per block element
74
+
75
+ // top-left of target block
76
+ int toi = blockIdx.x * SIZE;
77
+ int toj = blockIdx.y * SIZE;
78
+ // top-left of source block
79
+ int fromi = (blockIdx.y + iter) * SIZE;
80
+ int fromj = toj;
81
+
82
+ // only compute below (current) diagonal
83
+ if (fromi >= toi)
84
+ return;
85
+
86
+ // current relative position
87
+ int x = threadIdx.x;
88
+ int y = threadIdx.y;
89
+ // current absolute position
90
+ int i = toi + x;
91
+ int j = toj + y;
92
+
93
+ // parallel load relevant blocks of `coef` and `tmp`
94
+ __shared__ scalar_t from_coef[SIZE][SIZE];
95
+ __shared__ scalar_t to_tmp[SIZE][SIZE];
96
+ from_coef[x][y] = GetElement(coef, m, fromi + x, fromj + y);
97
+ to_tmp[x][y] = GetElement(tmp, m, i, fromi + y);
98
+ __syncthreads();
99
+
100
+ // parallel matmul
101
+ scalar_t res = GetElement(coef, m, i, j);
102
+ for (int k = 0; k < SIZE; k++)
103
+ res += to_tmp[x][k] * from_coef[k][y];
104
+ SetElement(coef, m, i, j, res);
105
+
106
+ // keep only next sequential blocks
107
+ if (toi != fromi + SIZE)
108
+ return;
109
+ __syncthreads();
110
+
111
+ // parallel load relevant blocks of `coef` and `tmp`
112
+ from_coef[x][y] = GetElement(coef, m, i, j);
113
+ to_tmp[x][y] = GetElement(tmp, m, i, toi + y);
114
+ __syncthreads();
115
+
116
+ // parallel sequential vector-matrix multiplies
117
+ res = from_coef[x][y];
118
+ for (int k = 0; k < SIZE; k++) {
119
+ if (k < x)
120
+ res += to_tmp[x][k] * from_coef[k][y];
121
+ if (k == x - 1) {
122
+ // parallel write block row
123
+ from_coef[x][y] = res;
124
+ SetElement(coef, m, i, j, res);
125
+ }
126
+ __syncthreads();
127
+ }
128
+ }
129
+
130
+ template <typename scalar_t>
131
+ __global__ void HinvCoefKernelDiag(
132
+ int m,
133
+ const scalar_t* __restrict__ tmp,
134
+ scalar_t* __restrict__ coef
135
+ ) {
136
+ // one thread per block element
137
+
138
+ // current relative position
139
+ int x = threadIdx.x;
140
+ int y = threadIdx.y;
141
+ // current absolute position
142
+ int i = blockIdx.x * SIZE + x;
143
+ int j = blockIdx.x * SIZE + y;
144
+
145
+ // parallel load relevant blocks of `coef` and `tmp`
146
+ __shared__ scalar_t from_coef[SIZE][SIZE];
147
+ __shared__ scalar_t to_tmp[SIZE][SIZE];
148
+ from_coef[x][y] = GetElement(coef, m, i, j);
149
+ to_tmp[x][y] = GetElement(tmp, m, i, j);
150
+ __syncthreads();
151
+
152
+ // parallel sequential vector-matrix multiplies
153
+ scalar_t res = 0;
154
+ for (int k = 0; k < SIZE; k++) {
155
+ if (k < x)
156
+ res += to_tmp[x][k] * from_coef[k][y];
157
+ // don't write diagonal elements
158
+ if (k == x - 1 && x != y) {
159
+ // parallel write block row
160
+ from_coef[x][y] = res;
161
+ SetElement(coef, m, i, j, res);
162
+ }
163
+ __syncthreads();
164
+ }
165
+ }
166
+
167
+
168
+ /* Kernel for computing `giHix` (required for multiplication) */
169
+
170
+ template <typename scalar_t>
171
+ __global__ void HinvMulKernel(
172
+ int rows,
173
+ int m,
174
+ const scalar_t* __restrict__ giHig,
175
+ scalar_t* __restrict__ giHix
176
+ );
177
+
178
+ // NOTE: currently only works for `m` <= 1024
179
+ torch::Tensor hinv_mul_cuda(int rows, torch::Tensor giHig, torch::Tensor giHix) {
180
+ const auto m = giHig.size(0);
181
+ AT_DISPATCH_FLOATING_TYPES(giHig.scalar_type(), "hinv_mul_cuda", ([&] {
182
+ HinvMulKernel<scalar_t><<<1, m>>>(
183
+ rows, m, giHig.data_ptr<scalar_t>(), giHix.data_ptr<scalar_t>()
184
+ );
185
+ })
186
+ );
187
+ return giHix;
188
+ }
189
+
190
+ template <typename scalar_t>
191
+ __global__ void HinvMulKernel(
192
+ int rows,
193
+ int m,
194
+ const scalar_t* __restrict__ giHig,
195
+ scalar_t* __restrict__ giHix
196
+ ) {
197
+ int i = threadIdx.x;
198
+
199
+ // parallel load relevant coefficients from `giHix` and `giHig`
200
+ __shared__ scalar_t denom[MAX];
201
+ __shared__ scalar_t tmp[MAX];
202
+ denom[i] = GetElement(giHig, m, i, i) + rows; // changed by ionut: fix_scaling
203
+ tmp[i] = GetElement(giHix, m, 0, i);
204
+ __syncthreads();
205
+
206
+ // compute parallel sequential linear combination
207
+ for (int j = 1; j < m; j++) {
208
+ if (j <= i) {
209
+ scalar_t sub = GetElement(giHig, m, j - 1, i) * tmp[j - 1] / denom[j - 1];
210
+ tmp[i] -= sub;
211
+ }
212
+ __syncthreads();
213
+ }
214
+
215
+ SetElement(giHix, m, 0, i, tmp[i]);
216
+ }
@@ -0,0 +1,62 @@
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <cuda_bf16.h>
4
+ #include "../utils.h"
5
+
6
+ typedef long long LL;
7
+
8
+ // CUDA methods
9
+ void compute_microadam_update_cuda(int blocks, int threads, int carveout,
10
+ LL t, float beta1, float beta2, float eps,
11
+ LL d_block_size, LL k_block_size,
12
+ LL d, LL m, LL k,
13
+ torch::Tensor indices, torch::Tensor values, torch::Tensor out);
14
+
15
+ void asymm_block_quant_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x);
16
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha);
17
+
18
+ // C++ methods
19
+ void compute_microadam_update(int blocks, int threads, int carveout,
20
+ LL t, float beta1, float beta2, float eps,
21
+ LL d_block_size, LL k_block_size,
22
+ LL d, LL m, LL k,
23
+ torch::Tensor indices, torch::Tensor values, torch::Tensor out) {
24
+ CHECK_INPUT(indices);
25
+ CHECK_INPUT(values);
26
+ CHECK_INPUT(out);
27
+ CHECK_THREADS(threads);
28
+
29
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(indices));
30
+ compute_microadam_update_cuda(blocks, threads, carveout,
31
+ t, beta1, beta2, eps,
32
+ d_block_size, k_block_size,
33
+ d, m, k,
34
+ indices, values, out);
35
+ }
36
+
37
+ void asymm_block_quant(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
38
+ CHECK_INPUT(xq);
39
+ CHECK_INPUT(xmin);
40
+ CHECK_INPUT(xmax);
41
+ CHECK_INPUT(x);
42
+
43
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
44
+ asymm_block_quant_cuda(d, q_block_size, xq, xmin, xmax, x);
45
+ }
46
+
47
+ void asymm_block_quant_inv(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
48
+ CHECK_INPUT(xq);
49
+ CHECK_INPUT(xmin);
50
+ CHECK_INPUT(xmax);
51
+ CHECK_INPUT(x);
52
+
53
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
54
+ asymm_block_quant_inv_cuda(d, q_block_size, xq, xmin, xmax, x, alpha);
55
+ }
56
+
57
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
58
+ m.def("compute_microadam_update", &compute_microadam_update, "Computes the update from Compressed Adam.");
59
+
60
+ m.def("asymm_block_quant", &asymm_block_quant, "Asymmetrically quantizes a vector to 4bits in blocks.");
61
+ m.def("asymm_block_quant_inv", &asymm_block_quant_inv, "Asymmetrically dequantizes a vector to 4bits in blocks.");
62
+ }
@@ -0,0 +1,64 @@
1
+ #include "../utils.h"
2
+
3
+ __global__ void asymm_block_quant_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x) {
4
+ /*
5
+ This kernel computes xq = Q(x, x_min, x_max) for 4 bits (implements point 4 from PhD notebook page 55)
6
+ In contrast to "globally" kernel, this kernel works with a single block
7
+ Make sure block_size is always divisible by 2
8
+
9
+ We have to read:
10
+ - q_block_size values from x
11
+ - one value from ranges
12
+ - q_block_size / 2 values from xq
13
+ */
14
+ bfloat162 *x2 = reinterpret_cast<bfloat162*>(x); // we will read two values from x at once
15
+
16
+ const LL B = gridDim.x; // number of blocks
17
+ const LL Bid = blockIdx.x; // block id
18
+ const LL T = blockDim.x; // number of threads
19
+ const LL Tid = threadIdx.x; // thread id
20
+
21
+ LL half_d = (d >> 1);
22
+ LL half_q_block_size = (q_block_size >> 1); // block size in xq
23
+ LL half_start_index = Bid * half_q_block_size; // start index in vector x
24
+ LL half_end_index = min(half_start_index + half_q_block_size, half_d); // end index in vector x
25
+ float m = __bfloat162float(xmin[Bid]);
26
+ float M = __bfloat162float(xmax[Bid]);
27
+ float u = (M - m) / 15.0f; // 15 = 16 - 1 = 2^4 - 1
28
+
29
+ bfloat162 vx2; // the value that will store x2[index]
30
+ uint8_t msb; // the MSB of a xq component
31
+ uint8_t lsb; // the LSB of a xq component
32
+
33
+ for(LL half_index = half_start_index + Tid; half_index < half_end_index; half_index += T) {
34
+ vx2 = x2[half_index];
35
+ msb = (uint8_t) floorf((__bfloat162float(vx2.x) - m) / u + 0.5f);
36
+ lsb = (uint8_t) floorf((__bfloat162float(vx2.y) - m) / u + 0.5f);
37
+ xq[half_index] = (msb << 4) | lsb;
38
+ }
39
+
40
+ if((d & 1) && (Bid == B-1) && (Tid == T-1)) {
41
+ msb = (uint8_t) floorf((__bfloat162float(x[d - 1]) - m) / u + 0.5f);
42
+ xq[half_d] = (msb << 4);
43
+ }
44
+ }
45
+
46
+ void asymm_block_quant_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x) {
47
+ torch::ScalarType bf16 = torch::ScalarType::BFloat16;
48
+ assert(xmin.scalar_type() == bf16 && xmax.scalar_type() == bf16 && x.scalar_type() == torch::ScalarType::BFloat16);
49
+
50
+ LL blocks = 1 + (LL)(d / q_block_size);
51
+ uint8_t *ptr_xq = (uint8_t*) xq.data_ptr();
52
+
53
+ asymm_block_quant_kernel_bf16_bf16<<<blocks, 1024>>>(d,
54
+ q_block_size,
55
+ ptr_xq,
56
+ (bfloat16*) xmin.data_ptr(),
57
+ (bfloat16*) xmax.data_ptr(),
58
+ (bfloat16*) x.data_ptr());
59
+
60
+ // error checks
61
+ GPU_ERROR_CHECK(cudaGetLastError());
62
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
63
+ // GPU_ERROR_CHECK(cudaDeviceSynchronize());
64
+ }
@@ -0,0 +1,83 @@
1
+ #include "../utils.h"
2
+
3
+ __global__ void asymm_block_quant_inv_kernel_bf16_bf16(LL d, LL q_block_size, uint8_t *xq, bfloat16 *xmin, bfloat16 *xmax, bfloat16 *x, float alpha) {
4
+ /*
5
+ This kernel computes x += alpha * Q_inv(xq, xmin, xmax) for 4 bits (implements point 1 from PhD #9 notebook page 55)
6
+ Here, x is the output buffer and will already contain the dense gradient
7
+ The output buffer x has d components and xq has d/2 components because one uint8_t stores two 4-bit values
8
+ In contrast to "globally" kernel, this kernel works with a single block
9
+ Make sure block_size is always divisible by 2
10
+
11
+ We have to read:
12
+ - q_block_size values from x
13
+ - one value from ranges
14
+ - q_block_size / 2 values from xq
15
+ */
16
+ bfloat162 *x2 = reinterpret_cast<bfloat162*>(x); // we will read two values from x at once
17
+
18
+ const LL B = (LL) gridDim.x; // number of blocks
19
+ const LL Bid = (LL) blockIdx.x; // block id
20
+ const LL T = (LL) blockDim.x; // number of threads
21
+ const LL Tid = (LL) threadIdx.x; // thread id
22
+
23
+ LL half_d = (d >> 1);
24
+ LL half_q_block_size = (q_block_size >> 1); // block size in xq
25
+ LL half_start_index = Bid * half_q_block_size; // start index in vector x
26
+ LL half_end_index = minLL(half_start_index + half_q_block_size, half_d); // end index in vector x
27
+ // if (Bid == 0 && Tid == 0) {
28
+ // printf("\n\n\n\t\t\t&&&&&&&&&& half_d=%lld, half_q_block_size=%lld, half_start_index=%lld, half_end_index=%lld\n\n\n");
29
+ // }
30
+ float m = __bfloat162float(xmin[Bid]);
31
+ float M = __bfloat162float(xmax[Bid]);
32
+ float u = (M - m) / 15.0f; // 15 = 16 - 1 = 2^4 - 1
33
+ bfloat162 vx2; // the value that will store x2[index]
34
+ uint8_t vq; // the value that will store xq[index]
35
+ uint8_t msb; // the MSB of a xq component
36
+ uint8_t lsb; // the LSB of a xq component
37
+
38
+ for(LL half_index = half_start_index + Tid; half_index < half_end_index; half_index += T) {
39
+ vx2 = x2[half_index];
40
+ vq = xq[half_index];
41
+
42
+ msb = (vq & 0xF0) >> 4;
43
+ lsb = (vq & 0x0F);
44
+
45
+ // += operation happens here
46
+ vx2.x += __float2bfloat16((msb * u + m) * alpha);
47
+ vx2.y += __float2bfloat16((lsb * u + m) * alpha);
48
+ x2[half_index] = vx2;
49
+ }
50
+ if((d & 1) && (Bid == B-1) && (Tid == T-1)) {
51
+ bfloat16 vx = x[d - 1];
52
+ vq = xq[half_d];
53
+ msb = (vq & 0xF0) >> 4;
54
+ // += operation happens here
55
+ vx += __float2bfloat16((msb * u + m) * alpha);
56
+ x[d - 1] = vx;
57
+ }
58
+ }
59
+
60
+ void asymm_block_quant_inv_cuda(LL d, LL q_block_size, torch::Tensor xq, torch::Tensor xmin, torch::Tensor xmax, torch::Tensor x, float alpha) {
61
+ ASSERT_BF16(xmin);
62
+ ASSERT_BF16(xmax);
63
+ ASSERT_BF16(x);
64
+
65
+ LL blocks = 1 + (LL)(d / q_block_size);
66
+ dim3 B(blocks, 1, 1);
67
+ dim3 T(1024, 1, 1);
68
+
69
+ uint8_t *ptr_xq = (uint8_t*) xq.data_ptr();
70
+
71
+ asymm_block_quant_inv_kernel_bf16_bf16<<<B, T>>>(d,
72
+ q_block_size,
73
+ (uint8_t*) xq.data_ptr(),
74
+ (bfloat16*) xmin.data_ptr(),
75
+ (bfloat16*) xmax.data_ptr(),
76
+ (bfloat16*) x.data_ptr(),
77
+ alpha);
78
+
79
+ // error checks
80
+ GPU_ERROR_CHECK(cudaGetLastError());
81
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
82
+ // GPU_ERROR_CHECK(cudaDeviceSynchronize());
83
+ }