ista-daslab-optimizers-cuda 1.0.0__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/PKG-INFO +11 -2
  2. ista_daslab_optimizers_cuda-1.1.0/README.md +9 -0
  3. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/ista_daslab_optimizers_cuda.egg-info/PKG-INFO +11 -2
  4. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/ista_daslab_optimizers_cuda.egg-info/SOURCES.txt +9 -0
  5. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/ista_daslab_optimizers_cuda.egg-info/top_level.txt +1 -0
  6. ista_daslab_optimizers_cuda-1.1.0/kernels/sparse_mfac_pruner/mfac_pruner_cpp.cpp +150 -0
  7. ista_daslab_optimizers_cuda-1.1.0/kernels/sparse_mfac_pruner/mfac_pruner_dense.cu +132 -0
  8. ista_daslab_optimizers_cuda-1.1.0/kernels/sparse_mfac_pruner/mfac_pruner_initial.cu +244 -0
  9. ista_daslab_optimizers_cuda-1.1.0/kernels/sparse_mfac_pruner/mfac_pruner_sparse.cu +156 -0
  10. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/utils.h +46 -0
  11. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/pyproject.toml +41 -42
  12. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/setup.py +9 -0
  13. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/LICENSE +0 -0
  14. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/MANIFEST.in +0 -0
  15. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/ista_daslab_optimizers_cuda.egg-info/dependency_links.txt +0 -0
  16. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/ista_daslab_optimizers_cuda.egg-info/requires.txt +0 -0
  17. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/dense_mfac/dense_mfac.cpp +0 -0
  18. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/dense_mfac/dense_mfac_kernel.cu +0 -0
  19. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/micro_adam/micro_adam.cpp +0 -0
  20. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -0
  21. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +0 -0
  22. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/micro_adam/micro_adam_update.cu +0 -0
  23. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/sparse_mfac/sparse_mfac.cpp +0 -0
  24. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +0 -0
  25. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -0
  26. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/tools/tools.cpp +0 -0
  27. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/kernels/tools/tools_kernel.cu +0 -0
  28. {ista_daslab_optimizers_cuda-1.0.0 → ista_daslab_optimizers_cuda-1.1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ista_daslab_optimizers_cuda
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -29,7 +29,6 @@ License: MIT License
29
29
  Project-URL: Repository, https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA
30
30
  Keywords: adaptive optimization,deep learning,low memory optimization
31
31
  Classifier: Programming Language :: Python :: 3.8
32
- Classifier: License :: OSI Approved :: Apache Software License
33
32
  Requires-Python: >=3.8
34
33
  Description-Content-Type: text/markdown
35
34
  License-File: LICENSE
@@ -38,3 +37,13 @@ Requires-Dist: torchaudio
38
37
  Requires-Dist: torchvision
39
38
  Requires-Dist: numpy
40
39
  Dynamic: license-file
40
+
41
+ # Core dependency of ISTA DAS Lab Optimization Package containing CUDA kernels
42
+ This project contains CUDA kernels designed for [ISTA-DASLab-Optimizers](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers) as a
43
+ dependency.
44
+
45
+ # Versions summary:
46
+ - **1.1.0** @ February 5th, 2026:
47
+ - added kernels for the Sparse M-FAC Pruner
48
+ - **1.0.0** @ February 5th, 2026:
49
+ - created this repository to decouple the CUDA kernels from the mai **ISTA-DASLab-Optimizers** project
@@ -0,0 +1,9 @@
1
+ # Core dependency of ISTA DAS Lab Optimization Package containing CUDA kernels
2
+ This project contains CUDA kernels designed for [ISTA-DASLab-Optimizers](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers) as a
3
+ dependency.
4
+
5
+ # Versions summary:
6
+ - **1.1.0** @ February 5th, 2026:
7
+ - added kernels for the Sparse M-FAC Pruner
8
+ - **1.0.0** @ February 5th, 2026:
9
+ - created this repository to decouple the CUDA kernels from the mai **ISTA-DASLab-Optimizers** project
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ista_daslab_optimizers_cuda
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -29,7 +29,6 @@ License: MIT License
29
29
  Project-URL: Repository, https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA
30
30
  Keywords: adaptive optimization,deep learning,low memory optimization
31
31
  Classifier: Programming Language :: Python :: 3.8
32
- Classifier: License :: OSI Approved :: Apache Software License
33
32
  Requires-Python: >=3.8
34
33
  Description-Content-Type: text/markdown
35
34
  License-File: LICENSE
@@ -38,3 +37,13 @@ Requires-Dist: torchaudio
38
37
  Requires-Dist: torchvision
39
38
  Requires-Dist: numpy
40
39
  Dynamic: license-file
40
+
41
+ # Core dependency of ISTA DAS Lab Optimization Package containing CUDA kernels
42
+ This project contains CUDA kernels designed for [ISTA-DASLab-Optimizers](https://github.com/IST-DASLab/ISTA-DASLab-Optimizers) as a
43
+ dependency.
44
+
45
+ # Versions summary:
46
+ - **1.1.0** @ February 5th, 2026:
47
+ - added kernels for the Sparse M-FAC Pruner
48
+ - **1.0.0** @ February 5th, 2026:
49
+ - created this repository to decouple the CUDA kernels from the mai **ISTA-DASLab-Optimizers** project
@@ -1,5 +1,6 @@
1
1
  LICENSE
2
2
  MANIFEST.in
3
+ README.md
3
4
  pyproject.toml
4
5
  setup.py
5
6
  ./kernels/dense_mfac/dense_mfac.cpp
@@ -11,6 +12,10 @@ setup.py
11
12
  ./kernels/sparse_mfac/sparse_mfac.cpp
12
13
  ./kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
13
14
  ./kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
15
+ ./kernels/sparse_mfac_pruner/mfac_pruner_cpp.cpp
16
+ ./kernels/sparse_mfac_pruner/mfac_pruner_dense.cu
17
+ ./kernels/sparse_mfac_pruner/mfac_pruner_initial.cu
18
+ ./kernels/sparse_mfac_pruner/mfac_pruner_sparse.cu
14
19
  ./kernels/tools/tools.cpp
15
20
  ./kernels/tools/tools_kernel.cu
16
21
  ista_daslab_optimizers_cuda.egg-info/PKG-INFO
@@ -28,5 +33,9 @@ kernels/micro_adam/micro_adam_update.cu
28
33
  kernels/sparse_mfac/sparse_mfac.cpp
29
34
  kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
30
35
  kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
36
+ kernels/sparse_mfac_pruner/mfac_pruner_cpp.cpp
37
+ kernels/sparse_mfac_pruner/mfac_pruner_dense.cu
38
+ kernels/sparse_mfac_pruner/mfac_pruner_initial.cu
39
+ kernels/sparse_mfac_pruner/mfac_pruner_sparse.cu
31
40
  kernels/tools/tools.cpp
32
41
  kernels/tools/tools_kernel.cu
@@ -1,4 +1,5 @@
1
1
  ista_daslab_cuda_dense_mfac
2
2
  ista_daslab_cuda_micro_adam
3
3
  ista_daslab_cuda_sparse_mfac
4
+ ista_daslab_cuda_sparse_mfac_pruner
4
5
  ista_daslab_cuda_tools
@@ -0,0 +1,150 @@
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include "../utils.h"
4
+ //#include "parallel_reduce.h"
5
+
6
+ __global__ void compute_row_initial_kernel (float *global_V, float *global_g, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int N, int B, int nbits, int use_kahan, int grad_const, int do_init, int do_debug);
7
+ void compute_row_initial_cuda (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int nbits, int use_kahan, int grad_const, int do_init, int do_debug);
8
+ void compute_row_initial (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int nbits, int use_kahan, int grad_const, int do_init, int do_debug)
9
+ {
10
+ assert((nbits == 32) || (nbits == 64));
11
+ assert((use_kahan == 0) || (use_kahan == 1));
12
+ assert((grad_const == 0) || (grad_const == 512));
13
+ assert((0 <= row_start) && (row_start < row_end) && (row_end <= m));
14
+ CHECK_INPUT(V);
15
+ CHECK_INPUT(g);
16
+ CHECK_INPUT(q);
17
+ CHECK_INPUT(out);
18
+
19
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(V));
20
+ compute_row_initial_cuda(V, g, q, out, row_start, row_end, m, damp, N, B, nblocks, nthreads, nbits, use_kahan, grad_const, do_init, do_debug);
21
+ }
22
+
23
+ __global__ void compute_row_dense_kernel(float *global_V, float *global_g, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int N, int B, int grad_const);
24
+ void compute_row_dense_cuda (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int grad_const);
25
+ void compute_row_dense (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int grad_const)
26
+ {
27
+ //assert((grad_const == 0) || (grad_const == 512));
28
+ assert((0 <= row_start) && (row_start < row_end) && (row_end <= m));
29
+ CHECK_INPUT(V);
30
+ CHECK_INPUT(g);
31
+ CHECK_INPUT(q);
32
+ CHECK_INPUT(out);
33
+
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(V));
35
+ compute_row_dense_cuda(V, g, q, out, row_start, row_end, m, damp, N, B, nblocks, nthreads, grad_const);
36
+ }
37
+
38
+ __global__ void compute_row_sparse_kernel(float *global_V, int16 *global_gi, float *global_gv, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int density, int N, int B, int do_init);
39
+ void compute_row_sparse_cuda (TT V, TT gi, TT gv, TT q, TT out, int row_start, int row_end, int m, float damp, int density, int N, int B, int nblocks, int nthreads, int do_init);
40
+ void compute_row_sparse (TT V, TT gi, TT gv, TT q, TT out, int row_start, int row_end, int m, float damp, int density, int N, int B, int nblocks, int nthreads, int do_init)
41
+ {
42
+ assert((0 <= row_start) && (row_start < row_end) && (row_end <= m));
43
+ CHECK_INPUT(V);
44
+ CHECK_INPUT(gi);
45
+ CHECK_INPUT(gv);
46
+ CHECK_INPUT(q);
47
+ CHECK_INPUT(out);
48
+
49
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(V));
50
+ compute_row_sparse_cuda(V, gi, gv, q, out, row_start, row_end, m, damp, density, N, B, nblocks, nthreads, do_init);
51
+ }
52
+ /*
53
+ void pipeline_copy_compute(TT Vcpu, TT Vgpu0, TT Vgpu1,
54
+ TT Qcpu, TT Qgpu0, TT Qgpu1,
55
+ TT Vtmp, TT grad, TT gi, TT gv,
56
+ int start_copy_cpu, int end_copy_cpu,
57
+ int start_copy_gpu, int end_copy_gpu,
58
+ int start_comp_gpu, int end_comp_gpu,
59
+ int half_copy, int half_compute,
60
+ int m, int N, int B, float damp,
61
+ int grad_const, int kernel_call_count, int topk_type,
62
+ int nblocks, int nthreads)
63
+ {
64
+ float *pVcpu = (float*) Vcpu.data_ptr();
65
+ float *pVgpu0 = (float*) Vgpu0.data_ptr();
66
+ float *pVgpu1 = (float*) Vgpu1.data_ptr();
67
+ float *pQcpu = (float*) Qcpu.data_ptr();
68
+ float *pQgpu0 = (float*) Qgpu0.data_ptr();
69
+ float *pQgpu1 = (float*) Qgpu1.data_ptr();
70
+ float *pVtmp = (float*) Vtmp.data_ptr();
71
+ float *pgrad = (float*) grad.data_ptr();
72
+ int16 *pgi = (int16*) gi.data_ptr();
73
+ float *pgv = (float*) gv.data_ptr();
74
+
75
+ dim3 blocks(nblocks, 1, 1);
76
+ dim3 threads(nthreads, 1, 1);
77
+
78
+ cudaStream_t stream_copy_V, stream_copy_Q, stream_compute;
79
+ cudaStreamCreate(&stream_copy_V);
80
+ cudaStreamCreate(&stream_copy_Q);
81
+ cudaStreamCreate(&stream_compute);
82
+
83
+ /// START SECTION COPY
84
+ int NB = N * B;
85
+ int rows_copy = end_copy_cpu - start_copy_cpu;
86
+ int sizeVcpu = rows_copy * N * B;
87
+ int sizeQcpu = rows_copy * N;
88
+ int offsetVcpu = start_copy_cpu * NB;
89
+ int offsetQcpu = start_copy_cpu * N;
90
+
91
+ float *copyVgpu = (half_copy == 0) ? pVgpu0 : pVgpu1;
92
+ float *copyQgpu = (half_copy == 0) ? pQgpu0 : pQgpu1;
93
+ /// END SECTION COPY
94
+
95
+ /// START SECTION COMPUTE
96
+ int density = gi.sizes()[1];
97
+ long shmem_initial = 4 * B * sizeof(float);
98
+ long shmem_sparse = (2 * B + 2 * density) * sizeof(float) + density * sizeof(int16);
99
+
100
+ float *compVgpu = (half_compute == 0) ? pVgpu0 : pVgpu1;
101
+ float *compQgpu = (half_compute == 0) ? pQgpu0 : pQgpu1;
102
+
103
+ // kernel_call_count=-1 means kernel_call_count=None
104
+ int do_init = (kernel_call_count == -1) ? 0 : static_cast<int>(kernel_call_count == 1);
105
+ /// END SECTION COMPUTE
106
+
107
+ /// COPY V
108
+ cudaMemcpyAsync(
109
+ copyVgpu, // device pointer
110
+ pVcpu + offsetVcpu, // host pointer
111
+ sizeVcpu, // size
112
+ cudaMemcpyHostToDevice, // direction
113
+ stream_copy_V // stream
114
+ );
115
+
116
+ /// COPY Q
117
+ cudaMemcpyAsync(
118
+ copyQgpu, // device pointer
119
+ pQcpu + offsetQcpu, // host pointer
120
+ sizeQcpu, // size
121
+ cudaMemcpyHostToDevice, // direction
122
+ stream_copy_Q // stream
123
+ );
124
+
125
+ /// COMPUTE
126
+ if(topk_type == 0) { // global topk
127
+ if(shmem_initial > 48 * 1024) {
128
+ cudaFuncSetAttribute(compute_row_initial_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_initial);
129
+ }
130
+ compute_row_initial_kernel<<<blocks, threads, shmem_initial, stream_compute>>>(compVgpu, pgrad, compQgpu, pVtmp, start_comp_gpu, end_comp_gpu, m, damp, N, B, 32, 0, grad_const, do_init, 0);
131
+ } else { // row topk
132
+ if(shmem_sparse > 48 * 1024){
133
+ cudaFuncSetAttribute(shmem_sparse, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_sparse);
134
+ }
135
+ compute_row_sparse_kernel<<<blocks, threads, shmem_sparse, stream_compute>>>(compVgpu, pgi, pgv, compQgpu, pVtmp, start_comp_gpu, end_comp_gpu, m, damp, density, N, B, do_init);
136
+ }
137
+
138
+ GPU_ERROR_CHECK(cudaGetLastError());
139
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
140
+ GPU_ERROR_CHECK(cudaDeviceSynchronize());
141
+ }
142
+ */
143
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
144
+ m.def("compute_row_initial", &compute_row_initial, "Computes one row of matrix V used for pruning");
145
+ m.def("compute_row_dense", &compute_row_dense, "Computes one row of matrix V used for pruning using dense gradients");
146
+ m.def("compute_row_sparse", &compute_row_sparse, "Computes one row of matrix V used for pruning using sparse gradients");
147
+ //m.def("pipeline_copy_compute", &pipeline_copy_compute, "CPU-GPU transfer and GPU computation using streams");
148
+ }
149
+
150
+
@@ -0,0 +1,132 @@
1
+ #include "../utils.h"
2
+ #include <float.h>
3
+ #include <limits>
4
+ // #include "parallel_reduce.h"
5
+
6
+ __device__ inline void generic_parallel_reduce(float *mem, int N, const long THREADS, const long Tid) {
7
+ /*
8
+ Compute parallel reduce on a shared memory array of `n` elements using `T` threads, even when n > T.
9
+ */
10
+
11
+ // perform addition to the first T elements (when n > T)
12
+ for(int i = Tid + THREADS; i < N; i += THREADS) {
13
+ mem[Tid] += mem[i];
14
+ }
15
+ __syncthreads();
16
+ for(int stride = (THREADS >> 1); stride > 0; stride >>= 1) {
17
+ if(Tid < stride && Tid + stride < N) {
18
+ mem[Tid] += mem[Tid + stride];
19
+ }
20
+ __syncthreads();
21
+ }
22
+ }
23
+
24
+ __global__ void
25
+ compute_row_dense_kernel(float *global_V, float *global_g, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int N, int B, int grad_const)
26
+ {
27
+ const long Bid = blockIdx.x; // block id
28
+ const long THREADS = blockDim.x; // number of threads
29
+ const long Tid = threadIdx.x; // thread id
30
+
31
+ extern __shared__ float mem[];
32
+ float *V = mem; // size B, stores one row of V, e.g. V[i, Bid, :]
33
+ float *g = mem + B; // size B, stores one row of g, e.g. g[Bid, :]
34
+ float *prods = mem + 2 * B; // size B, stores products V*g before summing up.
35
+ float *Vout = mem + 3 * B; // size B, accumulates dot * V
36
+
37
+ // predefined constants to avoid computing the same quantities multiple times
38
+ long N_B = N * B;
39
+ long Bid_B = Bid * B;
40
+
41
+ int i, j, j_global;
42
+ float dot, q, delta;
43
+ long V_start;
44
+
45
+ // g = global_g[Bid, :]
46
+ copy_global_to_shmem(global_g, g, Bid_B, Bid_B + B, THREADS, Tid);
47
+ __syncthreads();
48
+
49
+ // copy_global_to_shmem(global_out, Vout, Bid_B, Bid_B + B, THREADS, Tid); // Vout = out[Bid, :]
50
+ if(row_end < m) { // we call the kernel to compute rows of V
51
+ // V_start = 0 * N_B + Bid_B;
52
+ // copy_global_to_shmem<T>(global_V, Vout, V_start, V_start + B, THREADS, Tid);
53
+ for(i = Tid; i < B; i += THREADS) {
54
+ if(should_skip(g[i], grad_const)) {
55
+ Vout[i] = static_cast<float>(0);
56
+ } else {
57
+ Vout[i] = static_cast<float>(damp) * static_cast<float>(g[i]);
58
+ }
59
+ }
60
+ } else if(row_end == m) { // we call the kernel to compute the final update that prunes the model
61
+ for(i = Tid; i < B; i += THREADS) {
62
+ Vout[i] = static_cast<float>(0);
63
+ }
64
+ }
65
+ __syncthreads();
66
+
67
+ for(j = row_start; j < row_end; ++j) {
68
+ // V = global_V[j, Bid, :]
69
+ V_start = j * N_B + Bid_B;
70
+ copy_global_to_shmem(global_V, V, V_start, V_start + B, THREADS, Tid);
71
+ __syncthreads();
72
+
73
+ // (1) compute dot products
74
+ for(i = Tid; i < B; i += THREADS) {
75
+ if(should_skip(g[i], grad_const)) {
76
+ prods[i] = static_cast<float>(0);
77
+ } else {
78
+ prods[i] = V[i] * g[i];
79
+ }
80
+ }
81
+
82
+ __syncthreads();
83
+
84
+ generic_parallel_reduce(prods, B, THREADS, Tid);
85
+ dot = prods[0];
86
+
87
+ // read q from global memory: q = global_q[j, Bid]
88
+ if(Tid == 0) {
89
+ prods[0] = global_q[j * N + Bid];
90
+ }
91
+ __syncthreads();
92
+ q = prods[0];
93
+ delta = dot / q;
94
+
95
+ for(i = Tid; i < B; i += THREADS) {
96
+ Vout[i] -= delta * V[i];
97
+ }
98
+ } // end for j < row
99
+
100
+ // out[Bid, :] = Vout
101
+ for(j_global = Bid_B + Tid, j = Tid;
102
+ j_global < Bid_B + B;
103
+ j_global += THREADS, j += THREADS)
104
+ {
105
+ global_out[j_global] = Vout[j];
106
+ }
107
+
108
+ // TODO: compute q here, based on Vout: q[row, Bid] = m + dot_product(Vout, g)
109
+ }
110
+
111
+ void
112
+ compute_row_dense_cuda (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int grad_const)
113
+ {
114
+ dim3 blocks(nblocks, 1, 1);
115
+ dim3 threads(nthreads, 1, 1);
116
+ long sh_mem_size_bytes = 4 * B * sizeof(float);
117
+
118
+ if(sh_mem_size_bytes > 48 * 1024) {
119
+ cudaFuncSetAttribute(compute_row_dense_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, sh_mem_size_bytes);
120
+ }
121
+
122
+ float* fpV = (float*) V.data_ptr();
123
+ float* fpg = (float*) g.data_ptr();
124
+ float* fpq = (float*) q.data_ptr();
125
+ float* fpout = (float*) out.data_ptr();
126
+
127
+ compute_row_dense_kernel<<<blocks, threads, sh_mem_size_bytes>>>(fpV, fpg, fpq, fpout, row_start, row_end, m, damp, N, B, grad_const);
128
+
129
+ GPU_ERROR_CHECK(cudaGetLastError());
130
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
131
+ // GPU_ERROR_CHECK(cudaDeviceSynchronize());
132
+ }
@@ -0,0 +1,244 @@
1
+ #include "../utils.h"
2
+ // #include "parallel_reduce.h"
3
+
4
+ __device__ inline void generic_parallel_reduce(float *mem, int N, const long THREADS, const long Tid) {
5
+ /*
6
+ Compute parallel reduce on a shared memory array of `n` elements using `T` threads, even when n > T.
7
+ */
8
+
9
+ // perform addition to the first T elements (when n > T)
10
+ for(int i = Tid + THREADS; i < N; i += THREADS) {
11
+ mem[Tid] += mem[i];
12
+ }
13
+ __syncthreads();
14
+ for(int stride = (THREADS >> 1); stride > 0; stride >>= 1) {
15
+ if(Tid < stride && Tid + stride < N) {
16
+ mem[Tid] += mem[Tid + stride];
17
+ }
18
+ __syncthreads();
19
+ }
20
+ }
21
+
22
+ __device__ inline void kahan_parallel_reduce(float* mem, int N, int THREADS, int Tid) {
23
+ // initially, we sum everything in interval [THREADS, N-1] to [0, THREADS-1]
24
+ // the already existing values in mem[Tid] serve as initial values, so sum=mem[Tid]
25
+ double sum = static_cast<double>(mem[Tid]);
26
+ double c = static_cast<double>(0);
27
+ double y, t;
28
+
29
+ // the following for-loop implements mem[Tid] += mem[i] using Kahan summation
30
+ // for the values at indices i > THREADS
31
+ for(int i = Tid + THREADS; i < N; i += THREADS) {
32
+ y = static_cast<double>(mem[i]) - c;
33
+ t = sum + y;
34
+ c = (t - sum) - y;
35
+ sum = t;
36
+ }
37
+ mem[Tid] = static_cast<float>(sum);
38
+ __syncthreads();
39
+
40
+ // the following for-loop implements mem[Tid] += mem[Tid + stride] using
41
+ // Kahan summation and parallel reduce in logarithmic time
42
+ for(int stride = (THREADS >> 1); stride > 0; stride >>= 1) {
43
+ if(Tid < stride && Tid + stride < N) {
44
+ y = static_cast<double>(mem[Tid + stride]) - c; // mem[Tid+stride] is the value to be summed up
45
+ t = sum + y; // mem[Tid] stores the sum
46
+ c = (t - sum) - y;
47
+ sum = t; // update sum
48
+ mem[Tid] = static_cast<float>(sum);
49
+ }
50
+ __syncthreads();
51
+ }
52
+ }
53
+
54
+ __global__ void compute_row_initial_kernel(float *global_V, float *global_g, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int N, int B, int nbits, int use_kahan, int grad_const, int do_init, int do_debug)
55
+ {
56
+ const long Bid = blockIdx.x; // block id
57
+ const long THREADS = blockDim.x; // number of threads
58
+ const long Tid = threadIdx.x; // thread id
59
+
60
+ if(do_debug) {
61
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d] THREADS=%ld, damp=%f\n", Bid, Tid, row_start, row_end, THREADS, damp);
62
+ }
63
+
64
+ extern __shared__ unsigned char shmem[];
65
+ float *mem = reinterpret_cast<float*>(shmem);
66
+ float *V = mem; // size B, stores one row of V, e.g. V[i, Bid, :]
67
+ float *g = mem + B; // size B, stores one row of g, e.g. g[Bid, :]
68
+ float *prods = mem + 2 * B; // size B, stores products V*g before summing up.
69
+ float *Vout = mem + 3 * B; // size B, accumulates dot * V
70
+ double *comps = 0;
71
+ if(use_kahan) { //
72
+ comps = reinterpret_cast<double*>(shmem); // Kahan compensations for each component in Vout
73
+ if(nbits == 32) {
74
+ comps += 2 * B;
75
+ } else if(nbits == 64) {
76
+ comps += 4 * B;
77
+ }
78
+ }
79
+
80
+ // predefined constants to avoid computing the same quantities multiple times
81
+ long N_B = N * B;
82
+ long Bid_B = Bid * B;
83
+
84
+ long g_start = Bid_B;
85
+ long V_start;
86
+
87
+ int i, j, j_global;
88
+ float dot, q, delta;
89
+
90
+ // g = global_g[Bid, :]
91
+ copy_global_to_shmem(global_g, g, g_start, g_start + B, THREADS, Tid);
92
+ __syncthreads();
93
+
94
+ if(do_debug) {
95
+ for(i = Tid; i < B; i += THREADS) {
96
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-1] g[%d]=%lf\n", Bid, Tid, row_start, row_end, i, g[i]);
97
+ }
98
+ }
99
+
100
+ if(row_end < m) { // we call the kernel to compute rows of V
101
+ // V_start = 0 * N_B + Bid_B;
102
+ // copy_global_to_shmem(global_V, Vout, V_start, V_start + B, THREADS, Tid);
103
+ for(i = Tid; i < B; i += THREADS) {
104
+ if(do_init) {
105
+ if(should_skip(g[i], grad_const)) {
106
+ Vout[i] = static_cast<float>(0);
107
+ } else {
108
+ Vout[i] = static_cast<float>(damp) * static_cast<float>(g[i]);
109
+ }
110
+ } else {
111
+ Vout[i] = static_cast<float>(0);
112
+ }
113
+ }
114
+
115
+ } else if(row_end == m) { // we call the kernel to compute the final update that prunes the model
116
+ for(i = Tid; i < B; i += THREADS) {
117
+ Vout[i] = static_cast<float>(0);
118
+ }
119
+ }
120
+ __syncthreads();
121
+
122
+ double y, t; // for Kahan
123
+
124
+ // initialize compensations to zero
125
+ if(use_kahan) {
126
+ for(i = Tid; i < B; i += THREADS) {
127
+ comps[i] = static_cast<float>(0);
128
+ }
129
+ }
130
+
131
+ for(j = row_start; j < row_end; ++j) {
132
+ // V = global_V[j, Bid, :]
133
+ V_start = j * N_B + Bid_B;
134
+ copy_global_to_shmem(global_V, V, V_start, V_start + B, THREADS, Tid);
135
+ __syncthreads();
136
+
137
+ if(do_debug) {
138
+ for(i = Tid; i < B; i += THREADS) {
139
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-2] v[%d, %ld, %d]=%lf\n", Bid, Tid, row_start, row_end, j, Bid, i, V[i]);
140
+ }
141
+ }
142
+
143
+ // (1) compute dot products
144
+ for(i = Tid; i < B; i += THREADS) {
145
+ if(should_skip(g[i], grad_const)) {
146
+ prods[i] = static_cast<float>(0);
147
+ } else {
148
+ prods[i] = V[i] * g[i];
149
+ }
150
+ }
151
+
152
+ if(do_debug) {
153
+ for(i = Tid; i < B; i += THREADS) {
154
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-3] prods[%d]=%lf (pre-reduce)\n", Bid, Tid, row_start, row_end, i, prods[i]);
155
+ }
156
+ }
157
+ __syncthreads();
158
+ if(use_kahan) {
159
+ kahan_parallel_reduce(prods, B, THREADS, Tid);
160
+ } else {
161
+ generic_parallel_reduce(prods, B, THREADS, Tid);
162
+ }
163
+ dot = prods[0];
164
+
165
+ if(do_debug) {
166
+ for(i = Tid; i < B; i += THREADS) {
167
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-4] prods[%d]=%lf (post-reduce)\n", Bid, Tid, row_start, row_end, i, prods[i]);
168
+ }
169
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-5] dot=%lf\n", Bid, Tid, row_start, row_end, dot);
170
+ }
171
+
172
+ // read q from global memory: q = global_q[j, Bid]
173
+ if(Tid == 0) {
174
+ prods[0] = static_cast<float>(global_q[j * N + Bid]);
175
+ }
176
+ __syncthreads();
177
+ q = prods[0];
178
+ delta = dot / q;
179
+
180
+ if(do_debug) {
181
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-6] q=%lf, delta=%lf\n", Bid, Tid, row_start, row_end, q, delta);
182
+ }
183
+
184
+ if(use_kahan) {
185
+ for(i = Tid; i < B; i += THREADS) {
186
+ y = static_cast<double>(-delta) * static_cast<double>(V[i]) - comps[i];
187
+ t = static_cast<double>(Vout[i]) + y;
188
+ comps[i] = (t - static_cast<double>(Vout[i])) - y;
189
+ Vout[i] = t;
190
+ }
191
+ } else {
192
+ for(i = Tid; i < B; i += THREADS) {
193
+ Vout[i] -= delta * V[i];
194
+ if(do_debug) {
195
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-7] delta*V[%d, %ld, %d]=%lf\n", Bid, Tid, row_start, row_end, j, Bid, i, delta * V[i]);
196
+ }
197
+ }
198
+ }
199
+ } // end for j < row
200
+
201
+ for(j_global = Bid_B + Tid, j = Tid;
202
+ j_global < Bid_B + B;
203
+ j_global += THREADS, j += THREADS)
204
+ {
205
+ global_out[j_global] = static_cast<float>(Vout[j]);
206
+ }
207
+ if(do_debug) {
208
+ for(i = Tid; i < B; i += THREADS) {
209
+ printf("[Bid=%ld][Tid=%ld][rows=%d-%d][step-8] vout[%d]=%lf[OUT]\n", Bid, Tid, row_start, row_end, i, Vout[i]);
210
+ }
211
+ }
212
+
213
+ // TODO: compute q here, based on Vout: q[row, Bid] = m + dot_product(Vout, g)
214
+ }
215
+
216
+ void
217
+ compute_row_initial_cuda (TT V, TT g, TT q, TT out, int row_start, int row_end, int m, float damp, int N, int B, int nblocks, int nthreads, int nbits, int use_kahan, int grad_const, int do_init, int do_debug) {
218
+ assert(nbits == 32);
219
+ dim3 blocks(nblocks, 1, 1);
220
+ dim3 threads(nthreads, 1, 1);
221
+ long sh_mem_size_bytes = 4 * B * ((nbits == 32) ? sizeof(float) : sizeof(double));
222
+ if(use_kahan){
223
+ sh_mem_size_bytes += B * sizeof(double); // add shared memory space for the Kahan compensations
224
+ }
225
+
226
+ // printf("row=%d, N=%d, B=%d, blocks=%d, threads=%d, sh_mem_size_bytes=%ld\n", row, N, B, nblocks, nthreads, sh_mem_size_bytes);
227
+
228
+ if(sh_mem_size_bytes > 48 * 1024) {
229
+ //// if we want to allocate more than 48KB, then we have to call this method
230
+ cudaFuncSetAttribute(compute_row_initial_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, sh_mem_size_bytes);
231
+ }
232
+
233
+ float* fpV = (float*) V.data_ptr();
234
+ float* fpg = (float*) g.data_ptr();
235
+ float* fpq = (float*) q.data_ptr();
236
+ float* fpout = (float*) out.data_ptr();
237
+
238
+ compute_row_initial_kernel<<<blocks, threads, sh_mem_size_bytes>>>(fpV, fpg, fpq, fpout, row_start, row_end, m, damp, N, B, nbits, use_kahan, grad_const, do_init, do_debug);
239
+
240
+
241
+ GPU_ERROR_CHECK(cudaGetLastError());
242
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
243
+ // GPU_ERROR_CHECK(cudaDeviceSynchronize());
244
+ }
@@ -0,0 +1,156 @@
1
+ #include "../utils.h"
2
+ // #include "parallel_reduce.h"
3
+
4
+ __device__ inline void generic_parallel_reduce(float *mem, int N, const long THREADS, const long Tid) {
5
+ /*
6
+ Compute parallel reduce on a shared memory array of `n` elements using `T` threads, even when n > T.
7
+ */
8
+
9
+ // perform addition to the first T elements (when n > T)
10
+ for(int i = Tid + THREADS; i < N; i += THREADS) {
11
+ mem[Tid] += mem[i];
12
+ }
13
+ __syncthreads();
14
+ for(int stride = (THREADS >> 1); stride > 0; stride >>= 1) {
15
+ if(Tid < stride && Tid + stride < N) {
16
+ mem[Tid] += mem[Tid + stride];
17
+ }
18
+ __syncthreads();
19
+ }
20
+ }
21
+
22
+ __device__ inline void unpack(float packed, int &i, float &f) { // ct stands for constant
23
+ /*
24
+ *i = static_cast<int>(packed);
25
+ *f = packed - (*i);
26
+ if((*i) < 0) {
27
+ (*i) = -(*i);
28
+ }
29
+ */
30
+ i = static_cast<int>(packed);
31
+ f = packed - i;
32
+ if(i < 0) {
33
+ i = -i;
34
+ }
35
+ }
36
+
37
+ __global__ void
38
+ compute_row_sparse_kernel(float *global_V, int16 *global_gi, float *global_gv, float *global_q, float *global_out, int row_start, int row_end, int m, float damp, int density, int N, int B, int do_init)
39
+ {
40
+ const long Bid = blockIdx.x; // block id
41
+ const long THREADS = blockDim.x; // number of threads
42
+ const long Tid = threadIdx.x; // thread id
43
+
44
+ extern __shared__ unsigned char shmem[];
45
+
46
+ // sh_mem_size_bytes += B * sizeof(float); // for a row of V
47
+ // sh_mem_size_bytes += B * sizeof(float); // for a row of Vout
48
+ // sh_mem_size_bytes += density * sizeof(float); // for prods
49
+ // sh_mem_size_bytes += density * sizeof(float); // for gv
50
+ // sh_mem_size_bytes += density * sizeof(int16); // for gi
51
+ float *V = (float*) shmem;
52
+ float *Vout = V + B;
53
+ float *prods = V + 2 * B;
54
+ float *gv = V + 2 * B + density;
55
+ int16 *gi = (int16*)(shmem + (2 * B + 2 * density) * sizeof(float));
56
+
57
+ // predefined constants to avoid computing the same quantities multiple times
58
+ long N_B = N * B;
59
+ long Bid_B = Bid * B;
60
+ long Bid_density = Bid * density;
61
+
62
+ long V_start;
63
+
64
+ int i, j, j_global;
65
+ float dot, q, delta;
66
+
67
+ copy_global_to_shmem(global_gv, gv, Bid_density, Bid_density + density, THREADS, Tid);
68
+ __syncthreads();
69
+
70
+ copy_global_to_shmem(global_gi, gi, Bid_density, Bid_density + density, THREADS, Tid);
71
+ __syncthreads();
72
+
73
+ // for(i = Tid; i < density; i += THREADS) {
74
+ // printf("[Bid=%ld][Tid=%ld][i=%d] gi=%d, gv=%.8f\n", Bid, Tid, i, gi[i], gv[i]);
75
+ // }
76
+
77
+ // initialize Vout with zeros in the first place
78
+ for(i = Tid; i < B; i += THREADS) {
79
+ Vout[i] = static_cast<float>(0);
80
+ }
81
+ __syncthreads();
82
+
83
+ if(do_init) {
84
+ // initialize with damp * grad
85
+ for(i = Tid; i < density; i += THREADS) {
86
+ Vout[gi[i]] = damp * gv[i];
87
+ }
88
+ }
89
+ __syncthreads();
90
+
91
+ for(j = row_start; j < row_end; ++j) {
92
+ // V = global_V[j, Bid, :]
93
+ V_start = j * N_B + Bid_B;
94
+ copy_global_to_shmem(global_V, V, V_start, V_start + B, THREADS, Tid);
95
+ __syncthreads();
96
+
97
+ // (1) compute dot products
98
+ for(i = Tid; i < density; i += THREADS) {
99
+ prods[i] = V[gi[i]] * gv[i];
100
+ }
101
+ __syncthreads();
102
+
103
+ generic_parallel_reduce(prods, density, THREADS, Tid);
104
+ dot = prods[0];
105
+
106
+ // read q from global memory: q = global_q[j, Bid]
107
+ if(Tid == 0) {
108
+ prods[0] = static_cast<float>(global_q[j * N + Bid]);
109
+ }
110
+ __syncthreads();
111
+ q = prods[0];
112
+ delta = dot / q;
113
+
114
+ for(i = Tid; i < B; i += THREADS) {
115
+ Vout[i] -= delta * V[i];
116
+ }
117
+ } // end for j < row
118
+
119
+ for(j_global = Bid_B + Tid, j = Tid;
120
+ j_global < Bid_B + B;
121
+ j_global += THREADS, j += THREADS)
122
+ {
123
+ global_out[j_global] = Vout[j];
124
+ }
125
+ // TODO: compute q here, based on Vout: q[row, Bid] = m + dot_product(Vout, g)
126
+ }
127
+
128
+ void compute_row_sparse_cuda (TT V, TT gi, TT gv, TT q, TT out, int row_start, int row_end, int m, float damp, int density, int N, int B, int nblocks, int nthreads, int do_init)
129
+ {
130
+ dim3 blocks(nblocks, 1, 1);
131
+ dim3 threads(nthreads, 1, 1);
132
+
133
+ long sh_mem_size_bytes = 0;
134
+ sh_mem_size_bytes += B * sizeof(float); // for a row of V
135
+ sh_mem_size_bytes += B * sizeof(float); // for a row of Vout
136
+ sh_mem_size_bytes += density * sizeof(float); // for prods
137
+ sh_mem_size_bytes += density * sizeof(float); // for gv
138
+ sh_mem_size_bytes += density * sizeof(int16); // for gi
139
+
140
+ if(sh_mem_size_bytes > 48 * 1024) {
141
+ //// if we want to allocate more than 48KB, then we have to call this method
142
+ cudaFuncSetAttribute(compute_row_sparse_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, sh_mem_size_bytes);
143
+ }
144
+
145
+ float* pV = (float*) V.data_ptr();
146
+ int16* pgi = (int16*) gi.data_ptr();
147
+ float* pgv = (float*) gv.data_ptr();
148
+ float* pq = (float*) q.data_ptr();
149
+ float* pout = (float*) out.data_ptr();
150
+
151
+ compute_row_sparse_kernel<<<blocks, threads, sh_mem_size_bytes>>>(pV, pgi, pgv, pq, pout, row_start, row_end, m, damp, density, N, B, do_init);
152
+
153
+ GPU_ERROR_CHECK(cudaGetLastError());
154
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
155
+ // GPU_ERROR_CHECK(cudaDeviceSynchronize());
156
+ }
@@ -18,6 +18,7 @@
18
18
  #include <limits> // for epsilon
19
19
 
20
20
  using namespace std;
21
+ using TT = torch::Tensor;
21
22
 
22
23
  typedef __nv_bfloat16 bfloat16;
23
24
  typedef __nv_bfloat162 bfloat162;
@@ -73,6 +74,7 @@ __device__ inline long log_threads(long T) {
73
74
  if(T == 256) return 8;
74
75
  if(T == 512) return 9;
75
76
  if(T == 1024) return 10;
77
+ return 1; // default to avoid compilation errors
76
78
  }
77
79
 
78
80
  inline LL get_threads(LL max_threads) {
@@ -86,6 +88,19 @@ inline LL get_threads(LL max_threads) {
86
88
  return threads;
87
89
  }
88
90
 
91
+ __device__ inline void dynamically_assign_float(void *out, int out_index, float value, int out_bits) {
92
+ /*
93
+ This function assigns out[out_index] = value.
94
+ If nbits=16, then it means out is bfloat16 and we need to convert value to bfloat16.
95
+ If nbits=32, then it means out is float and no conversion is needed
96
+ */
97
+ if(out_bits == 16) {
98
+ ((bfloat16*) out)[out_index] = __float2bfloat16(value);
99
+ } else {
100
+ ((float*) out)[out_index] = value;
101
+ }
102
+ }
103
+
89
104
  __device__ inline void dynamically_assign(void *out, void *inp, int out_index, int inp_index, int out_bits, int inp_bits) {
90
105
  /*
91
106
  This function assigns out[out_index] = inp[inp_index] based on the types and performs the conversions when needed:
@@ -108,6 +123,34 @@ __device__ inline void dynamically_assign(void *out, void *inp, int out_index, i
108
123
  }
109
124
  }
110
125
 
126
+ __device__ inline bool should_skip(float x, int ct) { // ct stands for constant
127
+ if(ct == 0) {
128
+ return false;
129
+ }
130
+ int x_int = static_cast<int>(x);
131
+ if(((ct - 10) <= x_int) && (x_int <= (ct + 10))) {
132
+ return true;
133
+ }
134
+ return false;
135
+ }
136
+
137
+ template<typename T>
138
+ __device__ inline void copy_global_to_shmem(T *global,
139
+ T *shmem,
140
+ long global_start,
141
+ long global_end,
142
+ const long THREADS,
143
+ const long Tid) {
144
+ long j_global, j_shmem; // used in for-loops to read from global memory to shared memory
145
+
146
+ for(j_global = global_start + Tid, j_shmem = Tid;
147
+ j_global < global_end;
148
+ j_global += THREADS, j_shmem += THREADS)
149
+ {
150
+ shmem[j_shmem] = global[j_global];
151
+ }
152
+ }
153
+
111
154
  #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
112
155
  #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
113
156
  #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
@@ -116,7 +159,10 @@ __device__ inline void dynamically_assign(void *out, void *inp, int out_index, i
116
159
  #define FLOAT_EPS std::numeric_limits<float>::epsilon()
117
160
  #define DOUBLE_EPS std::numeric_limits<double>::epsilon()
118
161
  #define GPU_ERROR_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
162
+ #define IS_BF16(x) torch::ScalarType::BFloat16 == x.scalar_type()
163
+ #define IS_FLOAT(x) torch::ScalarType::Float == x.scalar_type()
119
164
  #define ASSERT_BF16(x) { assert(torch::ScalarType::BFloat16 == x.scalar_type()); }
165
+ #define ASSERT_FLOAT(x) { assert(IS_FLOAT(x)); }
120
166
  #define ASSERT_FLOAT_16_OR_32(x) { assert(torch::ScalarType::BFloat16 == x.scalar_type() || torch::ScalarType::Float == x.scalar_type()); }
121
167
 
122
168
  #define COPY_DIRECTION_k2d 0
@@ -1,42 +1,41 @@
1
- [build-system]
2
- requires = ["setuptools", "wheel", "torch"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name='ista_daslab_optimizers_cuda'
7
- version='1.0.0'
8
- dependencies = [
9
- "torch", # >=2.3.1",
10
- "torchaudio", # >=2.3.1",
11
- "torchvision", #>=0.18.1",
12
- "numpy", # >=1.24.1",
13
- # "wandb",#>=0.17.1",
14
- # "gpustat",#>=1.1.1",
15
- # "timm", # >=1.0.3",
16
- # "einops", # >=0.7.0",
17
- # "psutil", # >=5.9.8",
18
- # "fast-hadamard-transform",
19
- # "fast-hadamard-transform @ git+https://github.com/Dao-AILab/fast-hadamard-transform.git",
20
- ]
21
- requires-python = '>= 3.8'
22
- authors = [
23
- {name = "Ionut-Vlad Modoranu", email = "ionut-vlad.modoranu@ist.ac.at"}
24
- ]
25
- maintainers = [
26
- {name = "Ionut-Vlad Modoranu", email = "ionut-vlad.modoranu@ist.ac.at"},
27
- ]
28
- description = 'CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)'
29
- readme = "README.md"
30
- license = {file = "LICENSE"}
31
- keywords = [
32
- "adaptive optimization",
33
- "deep learning",
34
- "low memory optimization",
35
- ]
36
- classifiers = [
37
- "Programming Language :: Python :: 3.8",
38
- "License :: OSI Approved :: Apache Software License",
39
- ]
40
-
41
- [project.urls]
42
- Repository = 'https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA'
1
+ [build-system]
2
+ requires = ["setuptools", "wheel", "torch"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name='ista_daslab_optimizers_cuda'
7
+ version='1.1.0'
8
+ dependencies = [
9
+ "torch", # >=2.3.1",
10
+ "torchaudio", # >=2.3.1",
11
+ "torchvision", #>=0.18.1",
12
+ "numpy", # >=1.24.1",
13
+ # "wandb",#>=0.17.1",
14
+ # "gpustat",#>=1.1.1",
15
+ # "timm", # >=1.0.3",
16
+ # "einops", # >=0.7.0",
17
+ # "psutil", # >=5.9.8",
18
+ # "fast-hadamard-transform",
19
+ # "fast-hadamard-transform @ git+https://github.com/Dao-AILab/fast-hadamard-transform.git",
20
+ ]
21
+ requires-python = '>= 3.8'
22
+ authors = [
23
+ {name = "Ionut-Vlad Modoranu", email = "ionut-vlad.modoranu@ist.ac.at"}
24
+ ]
25
+ maintainers = [
26
+ {name = "Ionut-Vlad Modoranu", email = "ionut-vlad.modoranu@ist.ac.at"},
27
+ ]
28
+ description = 'CUDA kernels for ISTA-DASLab-Optimizers project developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)'
29
+ readme = "README.md"
30
+ license = {file = "LICENSE"}
31
+ keywords = [
32
+ "adaptive optimization",
33
+ "deep learning",
34
+ "low memory optimization",
35
+ ]
36
+ classifiers = [
37
+ "Programming Language :: Python :: 3.8",
38
+ ]
39
+
40
+ [project.urls]
41
+ Repository = 'https://github.com/IST-DASLab/ISTA-DASLab-Optimizers-CUDA'
@@ -51,6 +51,15 @@ setup(
51
51
  './kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu',
52
52
  ],
53
53
  ),
54
+ get_cuda_extension(
55
+ name=f'ista_daslab_cuda_sparse_mfac_pruner',
56
+ sources=[
57
+ './kernels/sparse_mfac_pruner/mfac_pruner_cpp.cpp',
58
+ './kernels/sparse_mfac_pruner/mfac_pruner_dense.cu',
59
+ './kernels/sparse_mfac_pruner/mfac_pruner_initial.cu',
60
+ './kernels/sparse_mfac_pruner/mfac_pruner_sparse.cu',
61
+ ],
62
+ ),
54
63
  ],
55
64
  cmdclass={'build_ext': BuildExtension.with_options(verbose=True)},
56
65
  )