ista-daslab-optimizers 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. ista_daslab_optimizers/__init__.py +6 -0
  2. ista_daslab_optimizers/acdc/__init__.py +5 -0
  3. ista_daslab_optimizers/acdc/acdc.py +387 -0
  4. ista_daslab_optimizers/acdc/wd_scheduler.py +31 -0
  5. ista_daslab_optimizers/dense_mfac/__init__.py +5 -0
  6. ista_daslab_optimizers/dense_mfac/dense_core_mfac.py +164 -0
  7. ista_daslab_optimizers/dense_mfac/dense_mfac.py +93 -0
  8. ista_daslab_optimizers/fft_low_rank/dct_adamw.py +351 -0
  9. ista_daslab_optimizers/fft_low_rank/fft_projector.py +192 -0
  10. ista_daslab_optimizers/fft_low_rank/trion.py +242 -0
  11. ista_daslab_optimizers/ista_optimizer/__init__.py +5 -0
  12. ista_daslab_optimizers/ista_optimizer/ista_optimizer.py +36 -0
  13. ista_daslab_optimizers/micro_adam/__init__.py +5 -0
  14. ista_daslab_optimizers/micro_adam/micro_adam.py +402 -0
  15. ista_daslab_optimizers/sparse_mfac/__init__.py +7 -0
  16. ista_daslab_optimizers/sparse_mfac/sparse_core_mfac_w_ef.py +226 -0
  17. ista_daslab_optimizers/sparse_mfac/sparse_mfac.py +87 -0
  18. ista_daslab_optimizers/tools.py +218 -0
  19. ista_daslab_optimizers/utils/dct.py +45 -0
  20. ista_daslab_optimizers/utils/global_cache.py +45 -0
  21. ista_daslab_optimizers/utils/matrix_storage.py +58 -0
  22. ista_daslab_optimizers/utils/newton_schulz_triton.py +374 -0
  23. ista_daslab_optimizers/utils/quantizers.py +71 -0
  24. ista_daslab_optimizers/utils/schedulers.py +41 -0
  25. ista_daslab_optimizers-1.1.8.dist-info/METADATA +333 -0
  26. ista_daslab_optimizers-1.1.8.dist-info/RECORD +29 -0
  27. ista_daslab_optimizers-1.1.8.dist-info/WHEEL +5 -0
  28. ista_daslab_optimizers-1.1.8.dist-info/licenses/LICENSE +201 -0
  29. ista_daslab_optimizers-1.1.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,374 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from torch import Tensor
5
+
6
+
7
+ def _get_autotune_configs():
8
+ return [
9
+ triton.Config(
10
+ {
11
+ "BLOCK_SIZE_M": bm,
12
+ "BLOCK_SIZE_N": bn,
13
+ "BLOCK_SIZE_K": bk,
14
+ "GROUP_SIZE_M": 8,
15
+ "LOWER_UPPER": 1,
16
+ },
17
+ num_stages=stages,
18
+ num_warps=warps,
19
+ )
20
+ for bm in [64, 128]
21
+ for bn in [64, 128, 256]
22
+ for bk in [64, 128]
23
+ for stages, warps in [(3, 4), (3, 8), (4, 4)]
24
+ if bm // bn <= 2 and bn // bm <= 2
25
+ ]
26
+
27
+
28
+ @triton.jit
29
+ def _pid_to_block(
30
+ pid,
31
+ M,
32
+ BLOCK_SIZE_M: tl.constexpr,
33
+ BLOCK_SIZE_N: tl.constexpr,
34
+ GROUP_SIZE_M: tl.constexpr,
35
+ ):
36
+ """
37
+ Helper function to map Triton program ID to (batch, row, col) of the output matrix.
38
+ """
39
+ # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N)
40
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
41
+ num_pid_n = tl.cdiv(M, BLOCK_SIZE_N)
42
+
43
+ # Map PID to a single matrix in batch
44
+ batch_idx = pid // (num_pid_m * num_pid_n)
45
+ pid = pid % (num_pid_m * num_pid_n)
46
+
47
+ # Map PID to 2D grid of blocks
48
+ pid_m = pid // num_pid_n
49
+ pid_n = pid % num_pid_n
50
+ pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
51
+
52
+ m_idx = pid_m * BLOCK_SIZE_M
53
+ n_idx = pid_n * BLOCK_SIZE_N
54
+
55
+ return batch_idx, m_idx, n_idx
56
+
57
+
58
+ @triton.autotune(
59
+ configs=_get_autotune_configs(),
60
+ key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
61
+ )
62
+ @triton.jit
63
+ def ns_line_1_kernel(
64
+ A_ptr,
65
+ C_ptr,
66
+ M,
67
+ K,
68
+ a_stride_b,
69
+ a_stride_r,
70
+ a_stride_c,
71
+ c_stride_b,
72
+ c_stride_r,
73
+ c_stride_c,
74
+ BLOCK_SIZE_M: tl.constexpr,
75
+ BLOCK_SIZE_N: tl.constexpr,
76
+ BLOCK_SIZE_K: tl.constexpr,
77
+ GROUP_SIZE_M: tl.constexpr,
78
+ LOWER_UPPER: tl.constexpr,
79
+ ):
80
+ """
81
+ Input A has shape (M, K)
82
+ Output C has shape (M, M)
83
+ Compute C = A @ A.T
84
+ """
85
+
86
+ pid = tl.program_id(axis=0)
87
+ batch_idx, m_idx, n_idx = _pid_to_block(
88
+ pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
89
+ )
90
+
91
+ # Skip blocks that don't need to be computed
92
+ skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
93
+ skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
94
+ if skip_block_below_diag or skip_block_above_diag:
95
+ return
96
+
97
+ # Index into one matrix of batch
98
+ A_ptr += batch_idx * a_stride_b
99
+ C_ptr += batch_idx * c_stride_b
100
+
101
+ # Create pointer arrays for A and A.T
102
+ offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
103
+ offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
104
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
105
+ a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
106
+ at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
107
+
108
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
109
+
110
+ # Accumulate over blocks of K
111
+ for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
112
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
113
+ at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
114
+ accumulator = tl.dot(a, at, accumulator)
115
+ a_ptrs += BLOCK_SIZE_K * a_stride_c
116
+ at_ptrs += BLOCK_SIZE_K * a_stride_c
117
+
118
+ out_dtype = C_ptr.dtype.element_ty
119
+ output = accumulator.to(out_dtype)
120
+
121
+ # Store block of C
122
+ offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
123
+ offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
124
+ c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
125
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
126
+ tl.store(c_ptrs, output, mask=c_mask)
127
+
128
+ # Store block of C mirrored across the diagonal
129
+ c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
130
+ c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
131
+ tl.store(c_ptrs_t, output.T, mask=c_mask_t)
132
+
133
+
134
+ def ns_line_1(A: Tensor, *, out: Tensor = None):
135
+ """
136
+ Launch Triton kernel to compute C = A @ A.T
137
+ """
138
+ if A.ndim > 3 or A.ndim < 2:
139
+ raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.")
140
+
141
+ M, K = A.shape[-2:]
142
+
143
+ if out is None:
144
+ out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype)
145
+ assert out.size(-2) == M, "Output matrix has incorrect shape"
146
+ assert out.size(-1) == M, "Output matrix has incorrect shape"
147
+
148
+ batch_size = A.size(0) if A.ndim == 3 else 1
149
+ input_batch_stride = A.stride(0) if A.ndim == 3 else 0
150
+ output_batch_stride = out.stride(0) if out.ndim == 3 else 0
151
+
152
+ grid = lambda meta: (
153
+ batch_size
154
+ * triton.cdiv(M, meta["BLOCK_SIZE_M"])
155
+ * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
156
+ )
157
+ ns_line_1_kernel[grid](
158
+ A_ptr=A,
159
+ C_ptr=out,
160
+ M=M,
161
+ K=K,
162
+ a_stride_b=input_batch_stride,
163
+ a_stride_r=A.stride(-2),
164
+ a_stride_c=A.stride(-1),
165
+ c_stride_b=output_batch_stride,
166
+ c_stride_r=out.stride(-2),
167
+ c_stride_c=out.stride(-1),
168
+ )
169
+
170
+ return out
171
+
172
+
173
+ @triton.autotune(
174
+ configs=_get_autotune_configs(),
175
+ key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
176
+ )
177
+ @triton.jit
178
+ def ns_line_2_kernel(
179
+ A_ptr,
180
+ C_ptr,
181
+ M,
182
+ a_stride_b,
183
+ a_stride_r,
184
+ a_stride_c,
185
+ c_stride_b,
186
+ c_stride_r,
187
+ c_stride_c,
188
+ alpha,
189
+ beta,
190
+ BLOCK_SIZE_M: tl.constexpr,
191
+ BLOCK_SIZE_N: tl.constexpr,
192
+ BLOCK_SIZE_K: tl.constexpr,
193
+ GROUP_SIZE_M: tl.constexpr,
194
+ LOWER_UPPER: tl.constexpr,
195
+ ):
196
+ """
197
+ Input A is square matrix with shape (M, M)
198
+ Output C has shape (M, M)
199
+ Compute C = alpha * A @ A.T + beta * A
200
+ """
201
+
202
+ pid = tl.program_id(axis=0)
203
+ batch_idx, m_idx, n_idx = _pid_to_block(
204
+ pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
205
+ )
206
+
207
+ # Skip blocks that don't need to be computed
208
+ skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
209
+ skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
210
+ if skip_block_below_diag or skip_block_above_diag:
211
+ return
212
+
213
+ # Index into one matrix of batch
214
+ A_ptr += batch_idx * a_stride_b
215
+ C_ptr += batch_idx * c_stride_b
216
+
217
+ # Create pointer arrays for A and A.T
218
+ offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
219
+ offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
220
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
221
+ a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
222
+ at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
223
+
224
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
225
+
226
+ # Accumulate over blocks of K
227
+ for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)):
228
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0)
229
+ at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0)
230
+ accumulator = tl.dot(a, at, accumulator)
231
+ a_ptrs += BLOCK_SIZE_K * a_stride_c
232
+ at_ptrs += BLOCK_SIZE_K * a_stride_c
233
+
234
+ # Load block of A to add (corresponds to the current block of C)
235
+ offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M)
236
+ offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N)
237
+ a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c)
238
+ a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M)
239
+ a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32)
240
+
241
+ # Apply alpha and beta
242
+ accumulator *= alpha
243
+ accumulator += a_add * beta
244
+
245
+ out_dtype = C_ptr.dtype.element_ty
246
+ output = accumulator.to(out_dtype)
247
+
248
+ # Store block of C
249
+ offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
250
+ offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
251
+ c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
252
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
253
+ tl.store(c_ptrs, output, mask=c_mask)
254
+
255
+ # Store block of C mirrored across the diagonal
256
+ c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
257
+ c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
258
+ tl.store(c_ptrs_t, output.T, mask=c_mask_t)
259
+
260
+
261
+ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None):
262
+ """
263
+ Launch Triton kernel to compute C = alpha * A @ A.T + beta * A
264
+ """
265
+ if A.ndim > 3 or A.ndim < 2:
266
+ raise ValueError(f"Input tensor must be 2D or 3D, but got {A.ndim}D tensor.")
267
+
268
+ M, K = A.shape[-2:]
269
+ if M != K:
270
+ raise ValueError(
271
+ f"Input must be symmetric square matrix, but got shape {A.shape}"
272
+ )
273
+
274
+ if out is None:
275
+ out = torch.empty((*A.shape[:-1], M), device=A.device, dtype=A.dtype)
276
+ assert out.size(-2) == M, "Output matrix has incorrect shape"
277
+ assert out.size(-1) == M, "Output matrix has incorrect shape"
278
+
279
+ batch_size = A.size(0) if A.ndim == 3 else 1
280
+ input_batch_stride = A.stride(0) if A.ndim == 3 else 0
281
+ output_batch_stride = out.stride(0) if out.ndim == 3 else 0
282
+
283
+ grid = lambda meta: (
284
+ batch_size
285
+ * triton.cdiv(M, meta["BLOCK_SIZE_M"])
286
+ * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
287
+ )
288
+ ns_line_2_kernel[grid](
289
+ A_ptr=A,
290
+ C_ptr=out,
291
+ M=M,
292
+ a_stride_b=input_batch_stride,
293
+ a_stride_r=A.stride(-2),
294
+ a_stride_c=A.stride(-1),
295
+ c_stride_b=output_batch_stride,
296
+ c_stride_r=out.stride(-2),
297
+ c_stride_c=out.stride(-1),
298
+ alpha=alpha,
299
+ beta=beta,
300
+ )
301
+
302
+ return out
303
+
304
+
305
+ @torch.compile(dynamic=False, fullgraph=True)
306
+ def zeropower_via_newtonschulz5(G: Tensor, epsilon: float = 1e-7):
307
+ """
308
+ Reference implementation of Newton-Schulz without Triton.
309
+ """
310
+ # Newton-Schulz constants
311
+ ns_consts = [
312
+ (4.0848, -6.8946, 2.9270),
313
+ (3.9505, -6.3029, 2.6377),
314
+ (3.7418, -5.5913, 2.3037),
315
+ (2.8769, -3.1427, 1.2046),
316
+ (2.8366, -3.0525, 1.2012),
317
+ ]
318
+
319
+ X = G.to(dtype=torch.bfloat16)
320
+ if G.size(-2) > G.size(-1):
321
+ X = X.mT
322
+
323
+ # Ensure spectral norm is at most 1
324
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)
325
+
326
+ for a, b, c in ns_consts:
327
+ A = X @ X.mT
328
+ B = b * A + c * (A @ A)
329
+ X = a * X + B @ X
330
+
331
+ if G.size(-2) > G.size(-1):
332
+ X = X.mT
333
+ return X
334
+
335
+
336
+ @torch.compile(dynamic=False, fullgraph=True)
337
+ def newton_schulz_triton(G: Tensor, epsilon: float = 1e-7):
338
+ """
339
+ Triton implementation of Newton-Schulz iteration
340
+ """
341
+ # Newton-Schulz constants
342
+ ns_consts = [
343
+ (4.0848, -6.8946, 2.9270),
344
+ (3.9505, -6.3029, 2.6377),
345
+ (3.7418, -5.5913, 2.3037),
346
+ (2.8769, -3.1427, 1.2046),
347
+ (2.8366, -3.0525, 1.2012),
348
+ ]
349
+
350
+ X = G.to(dtype=torch.bfloat16)
351
+ if G.size(-2) > G.size(-1):
352
+ X = X.mT
353
+
354
+ # Ensure spectral norm is at most 1
355
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + epsilon)
356
+
357
+ # Allocate buffers
358
+ X = X.contiguous()
359
+ A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
360
+ B = torch.empty_like(A)
361
+ C = torch.empty_like(X)
362
+
363
+ ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm
364
+
365
+ # Perform the NS iterations
366
+ for a, b, c in ns_consts:
367
+ ns_line_1(X, out=A) # A = X @ X.mT
368
+ ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A
369
+ ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X
370
+ X, C = C, X # Swap references to avoid unnecessary copies
371
+
372
+ if G.size(-2) > G.size(-1):
373
+ X = X.mT
374
+ return X
@@ -0,0 +1,71 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ class Quantizer4bit:
5
+ def __init__(self, shape, device, dtype, bucket_size):
6
+ assert np.prod(shape) % bucket_size == 0
7
+ self.shape = shape
8
+ self.device = device
9
+ self.bucket_size = bucket_size
10
+ self.numel = np.prod(shape)
11
+
12
+ self.n_buckets = self.numel // self.bucket_size
13
+
14
+ self.xq = torch.zeros(self.numel // 2, dtype=torch.uint8, device=self.device)
15
+ self.x_min = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
16
+ self.x_max = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
17
+
18
+ def quantize(self, x):
19
+ N, B = self.n_buckets, self.bucket_size
20
+ N = self.numel // B
21
+ self.x_min.copy_(x.view(N, B).min(dim=1).values.view(-1, 1))
22
+ self.x_max.copy_(x.view(N, B).max(dim=1).values.view(-1, 1))
23
+ u = (self.x_max - self.x_min) / 15.
24
+ xq = ((x.view(N, B) - self.x_min) / u + 0.5).floor().to(torch.uint8).view(-1, 2)
25
+ byte_left = xq[:, 0] << 4
26
+ byte_right = xq[:, 1]
27
+ self.xq.copy_(byte_left | byte_right)
28
+
29
+ def quantize_inv(self):
30
+ N, B = self.n_buckets, self.bucket_size
31
+ u = (self.x_max - self.x_min) / 15.
32
+ byte_left = (self.xq & 0xF0) >> 4
33
+ byte_right = self.xq & 0x0F
34
+ xq = torch.hstack(
35
+ (
36
+ byte_left.view(-1),
37
+ byte_right.view(-1)
38
+ )
39
+ ).view(N, B) # intercalate byte_left and byte_right
40
+ x = xq * u + self.x_min
41
+ return x.view(*self.shape)
42
+
43
+ class Quantizer8bit:
44
+ def __init__(self, shape, device, dtype, bucket_size):
45
+ assert np.prod(shape) % bucket_size == 0
46
+ self.shape = shape
47
+ self.device = device
48
+ self.bucket_size = bucket_size
49
+ self.numel = np.prod(shape)
50
+
51
+ self.n_buckets = self.numel // self.bucket_size
52
+
53
+ self.xq = torch.zeros(self.numel, dtype=torch.uint8, device=self.device)
54
+ self.x_min = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
55
+ self.x_max = torch.zeros(self.n_buckets, 1, dtype=dtype, device=self.device)
56
+
57
+ def quantize(self, x):
58
+ N, B = self.n_buckets, self.bucket_size
59
+ N = self.numel // B
60
+ self.x_min.copy_(x.view(N, B).min(dim=1).values.view(-1, 1))
61
+ self.x_max.copy_(x.view(N, B).max(dim=1).values.view(-1, 1))
62
+ u = (self.x_max - self.x_min) / 15.
63
+ xq = ((x.view(N, B) - self.x_min) / u + 0.5).floor().to(torch.uint8)
64
+ self.xq.copy_(xq.view(-1))
65
+ del xq, u
66
+
67
+ def quantize_inv(self):
68
+ N, B = self.n_buckets, self.bucket_size
69
+ u = (self.x_max - self.x_min) / 15.
70
+ x = self.xq.view(N, B) * u + self.x_min
71
+ return x.view(*self.shape)
@@ -0,0 +1,41 @@
1
+ import torch
2
+
3
+ def ema_standard_schedule(m, g, beta):
4
+ """
5
+ Implements the standard EMA: m_new = beta * m_old + (1 - beta) * g
6
+ :param m: momentum buffer
7
+ :param g: gradient
8
+ :param beta: EMA coefficient
9
+ """
10
+ m.lerp_(g, 1 - beta)
11
+
12
+ def ema_delayed_decay_schedule(m, g, beta, beta_prev, t, T_decay, alpha):
13
+ """
14
+ This version is proposed by Mher Safaryan in June 2025 while a postdoc @ ISTA:
15
+
16
+ beta_0 = 1 (tracks largest weight)
17
+ alpha >= 0 (sub-schedule slope)
18
+
19
+ if t == 1 or t % T_decay == 0:
20
+ m_t = beta * m_t-1 + (1 - beta) * g
21
+ beta_t = 1 - beta
22
+ else:
23
+ m_t = (1 / (1 + alpha + beta_t-1)) * m_t-1 + (alpha + beta_t-1) / (1 + alpha + beta_t-1) * g
24
+ beta_t = (alpha + beta_t-1) / (1 + alpha + beta_t-1)
25
+
26
+ :param m: momentum buffer
27
+ :param g: gradient buffer
28
+ :param beta: EMA coefficient
29
+ :param beta_prev: previous EMA coefficient
30
+ :param alpha: slope (use values between 0.001 and 0.007)
31
+ :param T_decay: decay interval
32
+ :return: returns beta_t
33
+ """
34
+
35
+ if t == 1 or t % T_decay == 0:
36
+ ema_standard_schedule(m, g, beta)
37
+ return 1 - beta
38
+ else:
39
+ beta_t = (alpha + beta_prev) / (1 + alpha + beta_prev)
40
+ ema_standard_schedule(m, g, 1-beta_t)
41
+ return beta_t