ista-daslab-optimizers 1.1.3__tar.gz → 1.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {ista_daslab_optimizers-1.1.3/ista_daslab_optimizers.egg-info → ista_daslab_optimizers-1.1.5}/PKG-INFO +7 -3
  2. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/README.md +5 -1
  3. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/dense_mfac/dense_mfac.py +10 -6
  4. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5/ista_daslab_optimizers.egg-info}/PKG-INFO +7 -3
  5. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers.egg-info/SOURCES.txt +2 -0
  6. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/dense_mfac/dense_mfac_kernel.cu +6 -6
  7. ista_daslab_optimizers-1.1.5/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cpp +57 -0
  8. ista_daslab_optimizers-1.1.5/kernels/sparse_mfac_pruner/sparse_mfac_pruner.cu +235 -0
  9. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/pyproject.toml +1 -1
  10. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/LICENSE +0 -0
  11. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/MANIFEST.in +0 -0
  12. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/__init__.py +0 -0
  13. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/acdc/__init__.py +0 -0
  14. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/acdc/acdc.py +0 -0
  15. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/acdc/wd_scheduler.py +0 -0
  16. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/dense_mfac/__init__.py +0 -0
  17. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +0 -0
  18. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/micro_adam/__init__.py +0 -0
  19. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/micro_adam/micro_adam.py +0 -0
  20. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/sparse_mfac/__init__.py +0 -0
  21. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +0 -0
  22. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +0 -0
  23. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers/tools.py +0 -0
  24. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers.egg-info/dependency_links.txt +0 -0
  25. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers.egg-info/requires.txt +0 -0
  26. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/ista_daslab_optimizers.egg-info/top_level.txt +0 -0
  27. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/dense_mfac/dense_mfac.cpp +0 -0
  28. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/micro_adam/micro_adam.cpp +0 -0
  29. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/micro_adam/micro_adam_asymm_block_quant.cu +0 -0
  30. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/micro_adam/micro_adam_asymm_block_quant_inv.cu +0 -0
  31. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/micro_adam/micro_adam_update.cu +0 -0
  32. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/sparse_mfac/sparse_mfac.cpp +0 -0
  33. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu +0 -0
  34. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/sparse_mfac/sparse_mfac_SP_kernel.cu +0 -0
  35. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/tools/tools.cpp +0 -0
  36. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/tools/tools_kernel.cu +0 -0
  37. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/kernels/utils.h +0 -0
  38. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/setup.cfg +0 -0
  39. {ista_daslab_optimizers-1.1.3 → ista_daslab_optimizers-1.1.5}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.1.3
3
+ Version: 1.1.5
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -242,12 +242,14 @@ The repository contains code for the following optimizers published by DASLab @
242
242
  - official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
243
243
 
244
244
  ### Installation
245
- To use the latest stable version of the repository, you can install via pip:
245
+ To use the latest stable version of this repository, you can install via pip:
246
246
 
247
247
  ```shell
248
248
  pip3 install ista-daslab-optimizers
249
249
  ```
250
250
 
251
+ and you can also visit the [PyPi page](https://pypi.org/project/ista-daslab-optimizers/).
252
+
251
253
  We also provide a script `install.sh` that creates a new environment, installs requirements
252
254
  and then installs the project as a Python package following these steps:
253
255
 
@@ -289,6 +291,8 @@ optimizer = MicroAdam(
289
291
  # Versions summary:
290
292
 
291
293
  ---
294
+ - **1.1.5** @ February 19th, 2025:
295
+ - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where we have one feature extractor block and a list of classification heads. The issue was related to the model size, which included the feature extractor backbone and all classification heads, but in practice only one classification head will be used for training and inference. This caused some size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size` to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created automatically in the `step` function.
292
296
  - **1.1.3** @ September 5th, 2024:
293
297
  - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
294
298
  - **1.1.2** @ August 1st, 2024:
@@ -17,12 +17,14 @@ The repository contains code for the following optimizers published by DASLab @
17
17
  - official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
18
18
 
19
19
  ### Installation
20
- To use the latest stable version of the repository, you can install via pip:
20
+ To use the latest stable version of this repository, you can install via pip:
21
21
 
22
22
  ```shell
23
23
  pip3 install ista-daslab-optimizers
24
24
  ```
25
25
 
26
+ and you can also visit the [PyPi page](https://pypi.org/project/ista-daslab-optimizers/).
27
+
26
28
  We also provide a script `install.sh` that creates a new environment, installs requirements
27
29
  and then installs the project as a Python package following these steps:
28
30
 
@@ -64,6 +66,8 @@ optimizer = MicroAdam(
64
66
  # Versions summary:
65
67
 
66
68
  ---
69
+ - **1.1.5** @ February 19th, 2025:
70
+ - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where we have one feature extractor block and a list of classification heads. The issue was related to the model size, which included the feature extractor backbone and all classification heads, but in practice only one classification head will be used for training and inference. This caused some size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size` to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created automatically in the `step` function.
67
71
  - **1.1.3** @ September 5th, 2024:
68
72
  - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
69
73
  - **1.1.2** @ August 1st, 2024:
@@ -23,22 +23,23 @@ class DenseMFAC(torch.optim.Optimizer):
23
23
  self.damp = damp
24
24
  self.weight_decay = weight_decay
25
25
  self.device = get_first_device()
26
+ self.create_G = create_G
26
27
 
27
28
  self.model_size = None
28
29
  self.steps = 0
29
30
  self.wandb_data = dict()
30
31
 
31
-
32
-
33
32
  self.model_size = sum([p.numel() for group in self.param_groups for p in group['params']])
34
- print(f'Model size: {self.model_size}')
35
33
 
34
+ self.hinv = None
35
+
36
+ def _create_hinv(self):
36
37
  self.hinv = DenseCoreMFAC(
37
- grads=torch.zeros((ngrads, self.model_size), dtype=torch.float),
38
+ grads=torch.zeros((self.m, self.model_size), dtype=torch.float),
38
39
  dev=self.device,
39
40
  gpus=get_gpus(),
40
- damp=damp,
41
- create_G=create_G)
41
+ damp=self.damp,
42
+ create_G=self.create_G)
42
43
 
43
44
  @torch.no_grad()
44
45
  def empty_buffer(self):
@@ -70,6 +71,9 @@ class DenseMFAC(torch.optim.Optimizer):
70
71
  def step(self, closure=None):
71
72
  self.steps += 1
72
73
 
74
+ if self.hinv is None:
75
+ self._create_hinv()
76
+
73
77
  g = get_weights_and_gradients(self.param_groups, get_weights=False)
74
78
  update = self.compute_update(g, g)
75
79
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: ista_daslab_optimizers
3
- Version: 1.1.3
3
+ Version: 1.1.5
4
4
  Summary: Deep Learning optimizers developed in the Distributed Algorithms and Systems group (DASLab) @ Institute of Science and Technology Austria (ISTA)
5
5
  Author-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
6
6
  Maintainer-email: Ionut-Vlad Modoranu <ionut-vlad.modoranu@ist.ac.at>
@@ -242,12 +242,14 @@ The repository contains code for the following optimizers published by DASLab @
242
242
  - official repository: [GitHub](https://github.com/IST-DASLab/MicroAdam)
243
243
 
244
244
  ### Installation
245
- To use the latest stable version of the repository, you can install via pip:
245
+ To use the latest stable version of this repository, you can install via pip:
246
246
 
247
247
  ```shell
248
248
  pip3 install ista-daslab-optimizers
249
249
  ```
250
250
 
251
+ and you can also visit the [PyPi page](https://pypi.org/project/ista-daslab-optimizers/).
252
+
251
253
  We also provide a script `install.sh` that creates a new environment, installs requirements
252
254
  and then installs the project as a Python package following these steps:
253
255
 
@@ -289,6 +291,8 @@ optimizer = MicroAdam(
289
291
  # Versions summary:
290
292
 
291
293
  ---
294
+ - **1.1.5** @ February 19th, 2025:
295
+ - adapted `DenseMFAC` for a model with multiple classification heads for Continual Learning where we have one feature extractor block and a list of classification heads. The issue was related to the model size, which included the feature extractor backbone and all classification heads, but in practice only one classification head will be used for training and inference. This caused some size mismatch errors at runtime in the `DenseCoreMFAC` module because the gradient at runtime had fewer entries than the entire model. When using `DenseMFAC` for such settings, set `optimizer.model_size` to the correct size after calling the constructor and the `DenseCoreMFAC` object will be created automatically in the `step` function.
292
296
  - **1.1.3** @ September 5th, 2024:
293
297
  - allow using `SparseCoreMFACwithEF` separately by importing it in `sparse_mfac.__init__.py`
294
298
  - **1.1.2** @ August 1st, 2024:
@@ -42,5 +42,7 @@ kernels/micro_adam/micro_adam_update.cu
42
42
  kernels/sparse_mfac/sparse_mfac.cpp
43
43
  kernels/sparse_mfac/sparse_mfac_LCG_kernel.cu
44
44
  kernels/sparse_mfac/sparse_mfac_SP_kernel.cu
45
+ kernels/sparse_mfac_pruner/sparse_mfac_pruner.cpp
46
+ kernels/sparse_mfac_pruner/sparse_mfac_pruner.cu
45
47
  kernels/tools/tools.cpp
46
48
  kernels/tools/tools_kernel.cu
@@ -45,16 +45,16 @@ torch::Tensor hinv_setup_cuda(torch::Tensor tmp, torch::Tensor coef) {
45
45
  const dim3 threads(SIZE, SIZE);
46
46
  const dim3 blocks(m / SIZE, m / SIZE);
47
47
 
48
- AT_DISPATCH_FLOATING_TYPES(tmp.type(), "hinv_setup_cuda", ([&] {
48
+ AT_DISPATCH_FLOATING_TYPES(tmp.scalar_type(), "hinv_setup_cuda", ([&] {
49
49
  HinvCoefKernelDiag<scalar_t><<<m / SIZE, threads>>>(
50
- m, tmp.data<scalar_t>(), coef.data<scalar_t>()
50
+ m, tmp.data_ptr<scalar_t>(), coef.data_ptr<scalar_t>()
51
51
  );
52
52
  })
53
53
  );
54
54
  for (int i = 0; i < m / SIZE - 1; i++) {
55
- AT_DISPATCH_FLOATING_TYPES(tmp.type(), "hinv_setup_cuda", ([&] {
55
+ AT_DISPATCH_FLOATING_TYPES(tmp.scalar_type(), "hinv_setup_cuda", ([&] {
56
56
  HinvCoefKernelMain<scalar_t><<<blocks, threads>>>(
57
- m, tmp.data<scalar_t>(), coef.data<scalar_t>(), i
57
+ m, tmp.data_ptr<scalar_t>(), coef.data_ptr<scalar_t>(), i
58
58
  );
59
59
  })
60
60
  );
@@ -178,9 +178,9 @@ __global__ void HinvMulKernel(
178
178
  // NOTE: currently only works for `m` <= 1024
179
179
  torch::Tensor hinv_mul_cuda(int rows, torch::Tensor giHig, torch::Tensor giHix) {
180
180
  const auto m = giHig.size(0);
181
- AT_DISPATCH_FLOATING_TYPES(giHig.type(), "hinv_mul_cuda", ([&] {
181
+ AT_DISPATCH_FLOATING_TYPES(giHig.scalar_type(), "hinv_mul_cuda", ([&] {
182
182
  HinvMulKernel<scalar_t><<<1, m>>>(
183
- rows, m, giHig.data<scalar_t>(), giHix.data<scalar_t>()
183
+ rows, m, giHig.data_ptr<scalar_t>(), giHix.data_ptr<scalar_t>()
184
184
  );
185
185
  })
186
186
  );
@@ -0,0 +1,57 @@
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include "../utils.h"
4
+
5
+ // CUDA methods
6
+ void compute_row_cuda(
7
+ torch::Tensor V,
8
+ torch::Tensor g,
9
+ torch::Tensor q,
10
+ torch::Tensor out,
11
+ int row,
12
+ int m,
13
+ float damp,
14
+ int N,
15
+ int B,
16
+ int nbits,
17
+ );
18
+
19
+ // C++ methods callable from Python
20
+ void compute_row(
21
+ torch::Tensor V,
22
+ torch::Tensor g,
23
+ torch::Tensor q,
24
+ torch::Tensor out,
25
+ int row,
26
+ int m,
27
+ float damp,
28
+ int N,
29
+ int B
30
+ ) {
31
+ CHECK_INPUT(V);
32
+ CHECK_INPUT(g);
33
+ CHECK_INPUT(q);
34
+ CHECK_INPUT(out);
35
+
36
+ int nbits;
37
+
38
+ if(IS_BF16(V)) {
39
+ ASSERT_BF16(g);
40
+ ASSERT_BF16(q);
41
+ ASSERT_BF16(out);
42
+ nbits = 16;
43
+ } else if(IS_FLOAT(V)) {
44
+ ASSERT_FLOAT(g);
45
+ ASSERT_FLOAT(q);
46
+ ASSERT_FLOAT(out);
47
+ nbits = 32;
48
+ }
49
+
50
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(V));
51
+ compute_row_cuda(V, g, q, out, row, m, damp, N, B, nbits);
52
+ }
53
+
54
+
55
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
56
+ m.def("compute_row", &compute_row, "Computes one row of matrix V used for pruning");
57
+ }
@@ -0,0 +1,235 @@
1
+ #include "../utils.h"
2
+
3
+ __device__ void parallel_reduce(float *mem, long T, long logT, long Tid, long offset=0, bool zerorize=false) {
4
+ /*
5
+ Perform parallel reduce in logarithmic time over the vector mem with T threads (mem has T components).
6
+ If zerorize=true, then set the components of mem to zero after accumulation.
7
+ Use offset > 0 to perform the parallel reduction over a different sequence of size T in mem
8
+ Tid is the current thread id and logT is log2(T).
9
+ Return the sum, which will be located at mem[offset]
10
+
11
+ Resources:
12
+ https://shreeraman-ak.medium.com/parallel-reduction-with-cuda-d0ae10c1ae2c
13
+ */
14
+ long mid = T >> 1; // half of number of threads
15
+ long offset_PLUS_Tid = offset + Tid;
16
+ for(long i = 0; i < logT; ++i) { // perform log2(T) rounds of accumulation
17
+ __syncthreads();
18
+ if(Tid < mid) { // left half accumulates, right half sends to left half
19
+ mem[offset_PLUS_Tid] += mem[offset_PLUS_Tid + mid];
20
+ if(zerorize) {
21
+ mem[offset_PLUS_Tid + mid] = 0.;
22
+ }
23
+ }
24
+ mid >>= 1;
25
+ }
26
+ }
27
+
28
+ __device__ inline void copy_global_to_shmem(void *global,
29
+ float *shmem,
30
+ int global_start,
31
+ int global_end,
32
+ int nbits) {
33
+ int T = blockDim.x; // number of threads
34
+ int j_global, j_shmem; // used in for-loops to read from global memory to shared memory
35
+ if(nbits == 16) {
36
+ bfloat16 *bf16_global = (bfloat16*) global;
37
+ // <------------init------------------> <------stop cond------> <-------next steps-------->
38
+ for(j_global = global_start, j_shmem = 0; j_global < global_end; j_global += T, j_shmem += T) { // coalesced memory access to global
39
+ shmem[j_shmem] = __bfloat162float(bf16_global[j_global]);
40
+ }
41
+ } else if(nbits == 32) {
42
+ float *float_global = (float*) global;
43
+ // <------------init------------------> <------stop cond------> <-------next steps-------->
44
+ for(j_global = global_start, j_shmem = 0; j_global < global_end; j_global += T, j_shmem += T) { // coalesced memory access to global
45
+ shmem[j_shmem] = float_global[j_global];
46
+ }
47
+ }
48
+ }
49
+
50
+ __global__ void compute_row_kernel(void *global_V,
51
+ void *global_g,
52
+ void *global_q,
53
+ void *global_out,
54
+ int row, /* current row in V to be computed */
55
+ int m, /* index of the row to compute now using all previous row-1 rows */
56
+ float damp,
57
+ int N, /* number of Fisher blocks*/
58
+ int B, /* size of Fisher Block*/
59
+ int nbits) {
60
+ /*
61
+ Computes one row of matrix V from the M-FAC pruner.
62
+ Parameters:
63
+ - global_V: matrix of size m x N x B
64
+ - global_g: matrix of size N x B
65
+ - global_q: matrix of size m x N
66
+ - global_out: matrix of size N x B
67
+ - row: compute global_V[j, :, :] given the previous row-1 rows
68
+ - m: total number of gradients (useful to compute q)
69
+ - N: number of Fisher blocks
70
+ - B: size of a Fisher block
71
+
72
+ This kernel loops through the first dimension of V (from 0 to m-1).
73
+ Each CUDA thread block processes row Bid from global_V[i] and global_g (e.g. global_V[i, Bid, :] and global_g[Bid, :], equivalent to one Fisher block)
74
+ Processing these rows can be done in parallel by multiple thread blocks without interfering with each other
75
+ To be efficient with the memory, we use shared memory in the following way:
76
+ - V stores the current row global_V[i, Bid, :] (global_V is read only once)
77
+ - g stores the current row global_g[Bid, :] (global_g is read only once)
78
+ - prods stores the element-wise products V_j * g_j
79
+ - Vout accumulates the row Bid of global_out, which is written only once
80
+
81
+ B
82
+ |----------|
83
+ N | |----------|
84
+ | | |----------|
85
+ |-----| | |
86
+ |-----| |
87
+ |----------|
88
+ */
89
+
90
+ const int Bid = blockIdx.x; // block id
91
+ const int T = blockDim.x; // number of threads
92
+ const int Tid = threadIdx.x; // thread id
93
+ int logT = log_threads(T);
94
+
95
+ extern __shared__ float mem[];
96
+ float *V = mem; // size B, stores one row of V
97
+ float *g = mem + B; // size B, stores one row of g
98
+ float *prods = mem + 2 * B; // size B, stores products V*g before summing up.
99
+ float *Vout = mem + 3 * B; // size B, accumulates dot * V
100
+
101
+ // predefined constants to avoid computing the same quantities multiple times
102
+ int N_B = N * B;
103
+ int Bid_B = Bid * B;
104
+
105
+ // variables
106
+ int i; // iterates through the first dimension of V, from 0 to m-1
107
+ int j, j_global; // universal index variables
108
+ int global_V_dim0_i_block_Bid_start, global_V_dim0_i_block_Bid_end; // start/end indices for V[i, Bid, :]
109
+ int global_g_block_Bid_start = Bid_B, global_g_block_Bid_end = Bid_B + B; // start/end indices for g[Bid, :]
110
+ int global_out_block_Bid_start = Bid_B, global_out_block_Bid_end = Bid_B + B;
111
+ float dot; // stores the dot product between V and g, which are V[i, Bid, :] and g[Bid, :] (e.g. dots[:, Bid])
112
+ float q; // stores the value q[i, Bid]
113
+ float m_float = (float) m;
114
+
115
+ // read g[Bid, :] from global memory to shared memory only once
116
+ copy_global_to_shmem(global_g, g, global_g_block_Bid_start, global_g_block_Bid_end, nbits);
117
+ __syncthreads();
118
+
119
+ // initialize Vout with damp * g:
120
+ for(j = 0; j < B; j += T) {
121
+ Vout[j] = damp * g[j];
122
+ }
123
+ __syncthreads();
124
+
125
+ // compute scalar products, stored in dots
126
+ for(i = 0; i < row; ++i) { // iterate through the first dimension of V
127
+
128
+ // read q[i, Bid] only in thread 0
129
+ if(Tid == 0) { // read q only on the first thread and save it in prods[0] (prods will be overwritten with V * g after that)
130
+ if(nbits == 16) {
131
+ bfloat16 *bf16_global_q = (bfloat16*) global_q;
132
+ prods[0] = __bfloat162float(bf16_global_q[i * N + Bid]);
133
+ } else if (nbits == 32) {
134
+ float *float_global_q = (float*) global_q;
135
+ prods[0] = float_global_q[i * N + Bid];
136
+ }
137
+ }
138
+ __syncthreads();
139
+
140
+ q = prods[0]; // this will be run on all threads (send q to all threads via prods[0])
141
+
142
+ __syncthreads();
143
+
144
+ // read V[i, Bid, :] from global memory to shared memory only once
145
+ global_V_dim0_i_block_Bid_start = i * N_B + // go to the beginning of row i
146
+ Bid_B; // go to the beginning of block Bid on row i
147
+ global_V_dim0_i_block_Bid_end = global_V_dim0_i_block_Bid_start + B;
148
+ copy_global_to_shmem(global_V, V, global_V_dim0_i_block_Bid_start, global_V_dim0_i_block_Bid_end, nbits);
149
+ __syncthreads();
150
+
151
+ // compute dot product between V and g (e.g. element-wise multiplication)
152
+ for(j = 0; j < B; j += T) {
153
+ prods[j] = V[j] * g[j];
154
+ }
155
+ __syncthreads();
156
+
157
+ // TODO: how to compute q[row] directly here?
158
+
159
+ // compute the sum of all elements in prods (result will be stored at prods[0])
160
+ parallel_reduce(prods, T, logT, Tid, 0, false);
161
+ dot = prods[0]; // all threads will have the dot product in this variable
162
+
163
+ // write m + dot in global_q[i, Bid] only if row < m
164
+ if(Tid == 0) {
165
+ if(row < m) { // computing rows of V is not finished
166
+ if(nbits == 16) {
167
+ bfloat16 *bf16_global_q = (bfloat16*) global_q;
168
+ bf16_global_q[row * N + Bid] = __float2bfloat16(m_float + dot);
169
+ } else if(nbits == 32) {
170
+ float *float_global_q = (float*) global_q;
171
+ float_global_q[row * N + Bid] = m_float + dot;
172
+ }
173
+ }
174
+ }
175
+ __syncthreads();
176
+
177
+ // store (dot/q) * V to Vout
178
+ for(j = 0; j < B; j += T) {
179
+ Vout[j] -= (dot / q) * V[j];
180
+ }
181
+ __syncthreads();
182
+ }
183
+
184
+ // write to out
185
+ if(nbits == 16) {
186
+ bfloat16 *bf16_global_out = (bfloat16*) global_out;
187
+ // <---------------init-----------------------> <------------stop cond-------------> <---next steps----->
188
+ for(j_global = global_out_block_Bid_start, j = 0; j_global < global_out_block_Bid_end; j_global+=T, j += T) {
189
+ bf16_global_out[j_global] = __float2bfloat16(Vout[j]);
190
+ }
191
+ } else if (nbits == 32) {
192
+ float *float_global_out = (bfloat16*) global_out;
193
+ // <---------------init-----------------------> <------------stop cond-------------> <---next steps----->
194
+ for(j_global = global_out_block_Bid_start, j = 0; j_global < global_out_block_Bid_end; j_global+=T, j += T) {
195
+ float_global_out[j_global] = Vout[j];
196
+ }
197
+ }
198
+ }
199
+
200
+ void compute_row_cuda(torch::Tensor V,
201
+ torch::Tensor g,
202
+ torch::Tensor q,
203
+ torch::Tensor out,
204
+ int row,
205
+ int m,
206
+ float damp,
207
+ int N,
208
+ int B,
209
+ int nbits) {
210
+
211
+ dim3 blocks(N, 1, 1);
212
+ dim3 threads(1024, 1, 1);
213
+ int shared_mem_size_bytes = (4 * B) * sizeof(float);
214
+
215
+ if(shared_mem_size_bytes > 48 * 1024) {
216
+ //// if we want to allocate more than 48KB, then we have to call this method
217
+ cudaFuncSetAttribute(compute_row_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size_bytes);
218
+ }
219
+
220
+ compute_row_kernel<<<blocks, threads, shared_mem_size_bytes>>>(
221
+ (void*) V.data_ptr(),
222
+ (void*) g.data_ptr(),
223
+ (void*) q.data_ptr(),
224
+ (void*) out.data_ptr(),
225
+ row,
226
+ m,
227
+ damp,
228
+ N,
229
+ B,
230
+ nbits);
231
+
232
+ GPU_ERROR_CHECK(cudaGetLastError());
233
+ GPU_ERROR_CHECK(cudaPeekAtLastError());
234
+ GPU_ERROR_CHECK(cudaDeviceSynchronize());
235
+ }
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name='ista_daslab_optimizers'
7
- version='1.1.3'
7
+ version='1.1.5'
8
8
  dependencies = [
9
9
  "torch", # >=2.3.1",
10
10
  "torchaudio", # >=2.3.1",