torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,208 @@
1
+ import warnings
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core.module import Module
7
+ from ...utils import tofloat
8
+
9
+
10
+ def _reset_except_self(optimizer, var, self: Module):
11
+ for m in optimizer.unrolled_modules:
12
+ if m is not self:
13
+ m.reset()
14
+
15
+ class SVRG(Module):
16
+ """Stochastic variance reduced gradient method (SVRG).
17
+
18
+ To use, put SVRG as the first module, it can be used with any other modules.
19
+ To reduce variance of a gradient estimator, put the gradient estimator before SVRG.
20
+
21
+ First it uses first ``accum_steps`` batches to compute full gradient at initial
22
+ parameters using gradient accumulation, the model will not be updated during this.
23
+
24
+ Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.
25
+
26
+ After ``svrg_steps``, it goes back to full gradient computation step step.
27
+
28
+ As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
29
+ which should compute full gradients, set them to ``.grad`` attributes of the parameters,
30
+ and return full loss.
31
+
32
+ Args:
33
+ svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
34
+ accum_steps (int | None, optional):
35
+ number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
36
+ reset_before_accum (bool, optional):
37
+ whether to reset all other modules when re-calculating full gradient. Defaults to True.
38
+ svrg_loss (bool, optional):
39
+ whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
40
+ alpha (float, optional):
41
+ multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6
42
+
43
+ ## Examples:
44
+ SVRG-LBFGS
45
+ ```python
46
+ opt = tz.Modular(
47
+ model.parameters(),
48
+ tz.m.SVRG(len(dataloader)),
49
+ tz.m.LBFGS(),
50
+ tz.m.Backtracking(),
51
+ )
52
+ ```
53
+
54
+ For extra variance reduction one can use Online versions of algorithms, although it won't always help.
55
+ ```python
56
+ opt = tz.Modular(
57
+ model.parameters(),
58
+ tz.m.SVRG(len(dataloader)),
59
+ tz.m.Online(tz.m.LBFGS()),
60
+ tz.m.Backtracking(),
61
+ )
62
+
63
+ Variance reduction can also be applied to gradient estimators.
64
+ ```python
65
+ opt = tz.Modular(
66
+ model.parameters(),
67
+ tz.m.SPSA(),
68
+ tz.m.SVRG(100),
69
+ tz.m.LR(1e-2),
70
+ )
71
+ ```
72
+ ## Notes
73
+
74
+ The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
75
+ - ``x`` is current parameters
76
+ - ``x_0`` is initial parameters, where full gradient was computed
77
+ - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
78
+ - ``g_f`` refers to full gradient at ``x_0``.
79
+
80
+ The SVRG loss is computed using the same formula.
81
+ """
82
+ def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
83
+ defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
84
+ super().__init__(defaults)
85
+
86
+ @torch.no_grad
87
+ def step(self, var):
88
+ params = var.params
89
+ closure = var.closure
90
+ assert closure is not None
91
+
92
+ if "full_grad" not in self.global_state:
93
+
94
+ # -------------------------- calculate full gradient ------------------------- #
95
+ if "full_closure" in var.storage:
96
+ full_closure = var.storage['full_closure']
97
+ with torch.enable_grad():
98
+ full_loss = full_closure()
99
+ if all(p.grad is None for p in params):
100
+ warnings.warn("all gradients are None after evaluating full_closure.")
101
+
102
+ full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
103
+ self.global_state["full_loss"] = full_loss
104
+ self.global_state["full_grad"] = full_grad
105
+ self.global_state['x_0'] = [p.clone() for p in params]
106
+
107
+ # current batch will be used for svrg update
108
+
109
+ else:
110
+ # accumulate gradients over n steps
111
+ accum_steps = self.defaults['accum_steps']
112
+ if accum_steps is None: accum_steps = self.defaults['svrg_steps']
113
+
114
+ current_accum_step = self.global_state.get('current_accum_step', 0) + 1
115
+ self.global_state['current_accum_step'] = current_accum_step
116
+
117
+ # accumulate grads
118
+ accumulator = self.get_state(params, 'accumulator')
119
+ grad = var.get_grad()
120
+ torch._foreach_add_(accumulator, grad)
121
+
122
+ # accumulate loss
123
+ loss_accumulator = self.global_state.get('loss_accumulator', 0)
124
+ loss_accumulator += tofloat(var.loss)
125
+ self.global_state['loss_accumulator'] = loss_accumulator
126
+
127
+ # on nth step, use the accumulated gradient
128
+ if current_accum_step >= accum_steps:
129
+ torch._foreach_div_(accumulator, accum_steps)
130
+ self.global_state["full_grad"] = accumulator
131
+ self.global_state["full_loss"] = loss_accumulator / accum_steps
132
+
133
+ self.global_state['x_0'] = [p.clone() for p in params]
134
+ self.clear_state_keys('accumulator')
135
+ del self.global_state['current_accum_step']
136
+
137
+ # otherwise skip update until enough grads are accumulated
138
+ else:
139
+ var.update = None
140
+ var.stop = True
141
+ var.skip_update = True
142
+ return var
143
+
144
+
145
+ svrg_steps = self.defaults['svrg_steps']
146
+ current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
147
+ self.global_state['current_svrg_step'] = current_svrg_step
148
+
149
+ # --------------------------- SVRG gradient closure -------------------------- #
150
+ x0 = self.global_state['x_0']
151
+ gf_x0 = self.global_state["full_grad"]
152
+ ff_x0 = self.global_state['full_loss']
153
+ use_svrg_loss = self.defaults['svrg_loss']
154
+ alpha = self.get_settings(params, 'alpha')
155
+ alpha_0 = alpha[0]
156
+ if all(a == 1 for a in alpha): alpha = None
157
+
158
+ def svrg_closure(backward=True):
159
+ # g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
160
+ with torch.no_grad():
161
+ x = [p.clone() for p in params]
162
+
163
+ if backward:
164
+ # f and g at x
165
+ with torch.enable_grad(): fb_x = closure()
166
+ gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
167
+
168
+ # f and g at x_0
169
+ torch._foreach_copy_(params, x0)
170
+ with torch.enable_grad(): fb_x0 = closure()
171
+ gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
172
+ torch._foreach_copy_(params, x)
173
+
174
+ # g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
175
+ correction = torch._foreach_sub(gb_x0, gf_x0)
176
+ if alpha is not None: torch._foreach_mul_(correction, alpha)
177
+ g_svrg = torch._foreach_sub(gb_x, correction)
178
+
179
+ f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
180
+ for p, g in zip(params, g_svrg):
181
+ p.grad = g
182
+
183
+ if use_svrg_loss: return f_svrg
184
+ return fb_x
185
+
186
+ # no backward
187
+ if use_svrg_loss:
188
+ fb_x = closure(False)
189
+ torch._foreach_copy_(params, x0)
190
+ fb_x0 = closure(False)
191
+ torch._foreach_copy_(params, x)
192
+ f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
193
+ return f_svrg
194
+
195
+ return closure(False)
196
+
197
+ var.closure = svrg_closure
198
+
199
+ # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
200
+ if current_svrg_step >= svrg_steps:
201
+ del self.global_state['current_svrg_step']
202
+ del self.global_state['full_grad']
203
+ del self.global_state['full_loss']
204
+ del self.global_state['x_0']
205
+ if self.defaults['reset_before_accum']:
206
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
207
+
208
+ return var
@@ -1 +1 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, NormalizedWeightDecay
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
@@ -4,7 +4,7 @@ from typing import Literal
4
4
  import torch
5
5
 
6
6
  from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
8
8
 
9
9
 
10
10
  @torch.no_grad
@@ -14,7 +14,7 @@ def weight_decay_(
14
14
  weight_decay: float | NumberList,
15
15
  ord: int = 2
16
16
  ):
17
- """returns `grad_`."""
17
+ """modifies in-place and returns ``grad_``."""
18
18
  if ord == 1: return grad_.add_(params.sign().mul_(weight_decay))
19
19
  if ord == 2: return grad_.add_(params.mul(weight_decay))
20
20
  if ord - 1 % 2 != 0: return grad_.add_(params.pow(ord-1).mul_(weight_decay))
@@ -22,34 +22,113 @@ def weight_decay_(
22
22
 
23
23
 
24
24
  class WeightDecay(Transform):
25
+ """Weight decay.
26
+
27
+ Args:
28
+ weight_decay (float): weight decay scale.
29
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
30
+ target (Target, optional): what to set on var. Defaults to 'update'.
31
+
32
+ ### Examples:
33
+
34
+ Adam with non-decoupled weight decay
35
+ ```python
36
+ opt = tz.Modular(
37
+ model.parameters(),
38
+ tz.m.WeightDecay(1e-3),
39
+ tz.m.Adam(),
40
+ tz.m.LR(1e-3)
41
+ )
42
+ ```
43
+
44
+ Adam with decoupled weight decay that still scales with learning rate
45
+ ```python
46
+
47
+ opt = tz.Modular(
48
+ model.parameters(),
49
+ tz.m.Adam(),
50
+ tz.m.WeightDecay(1e-3),
51
+ tz.m.LR(1e-3)
52
+ )
53
+ ```
54
+
55
+ Adam with fully decoupled weight decay that doesn't scale with learning rate
56
+ ```python
57
+ opt = tz.Modular(
58
+ model.parameters(),
59
+ tz.m.Adam(),
60
+ tz.m.LR(1e-3),
61
+ tz.m.WeightDecay(1e-6)
62
+ )
63
+ ```
64
+
65
+ """
25
66
  def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
67
+
26
68
  defaults = dict(weight_decay=weight_decay, ord=ord)
27
69
  super().__init__(defaults, uses_grad=False, target=target)
28
70
 
29
71
  @torch.no_grad
30
- def apply(self, tensors, params, grads, loss, states, settings):
72
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
31
73
  weight_decay = NumberList(s['weight_decay'] for s in settings)
32
74
  ord = settings[0]['ord']
33
75
 
34
76
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
35
77
 
36
- class NormalizedWeightDecay(Transform):
78
+ class RelativeWeightDecay(Transform):
79
+ """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
80
+
81
+ Args:
82
+ weight_decay (float): relative weight decay scale.
83
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
84
+ norm_input (str, optional):
85
+ determines what should weight decay be relative to. "update", "grad" or "params".
86
+ Defaults to "update".
87
+ metric (Ords, optional):
88
+ metric (norm, etc) that weight decay should be relative to.
89
+ defaults to 'mad' (mean absolute deviation).
90
+ target (Target, optional): what to set on var. Defaults to 'update'.
91
+
92
+ ### Examples:
93
+
94
+ Adam with non-decoupled relative weight decay
95
+ ```python
96
+ opt = tz.Modular(
97
+ model.parameters(),
98
+ tz.m.RelativeWeightDecay(1e-1),
99
+ tz.m.Adam(),
100
+ tz.m.LR(1e-3)
101
+ )
102
+ ```
103
+
104
+ Adam with decoupled relative weight decay
105
+ ```python
106
+ opt = tz.Modular(
107
+ model.parameters(),
108
+ tz.m.Adam(),
109
+ tz.m.RelativeWeightDecay(1e-1),
110
+ tz.m.LR(1e-3)
111
+ )
112
+ ```
113
+ """
37
114
  def __init__(
38
115
  self,
39
116
  weight_decay: float = 0.1,
40
- ord: int = 2,
117
+ ord: int = 2,
41
118
  norm_input: Literal["update", "grad", "params"] = "update",
119
+ metric: Metrics = 'mad',
42
120
  target: Target = "update",
43
121
  ):
44
- defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
122
+ defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
45
123
  super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
46
124
 
47
125
  @torch.no_grad
48
- def apply(self, tensors, params, grads, loss, states, settings):
126
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
49
127
  weight_decay = NumberList(s['weight_decay'] for s in settings)
50
128
 
51
129
  ord = settings[0]['ord']
52
130
  norm_input = settings[0]['norm_input']
131
+ metric = settings[0]['metric']
53
132
 
54
133
  if norm_input == 'update': src = TensorList(tensors)
55
134
  elif norm_input == 'grad':
@@ -60,8 +139,7 @@ class NormalizedWeightDecay(Transform):
60
139
  else:
61
140
  raise ValueError(norm_input)
62
141
 
63
- norm = src.global_vector_norm(ord)
64
-
142
+ norm = src.global_metric(metric)
65
143
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
66
144
 
67
145
 
@@ -72,7 +150,12 @@ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberL
72
150
  weight_decay_(params, params, -weight_decay, ord)
73
151
 
74
152
  class DirectWeightDecay(Module):
75
- """directly decays weights in-place"""
153
+ """Directly applies weight decay to parameters.
154
+
155
+ Args:
156
+ weight_decay (float): weight decay scale.
157
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
158
+ """
76
159
  def __init__(self, weight_decay: float, ord: int = 2,):
77
160
  defaults = dict(weight_decay=weight_decay, ord=ord)
78
161
  super().__init__(defaults)
@@ -80,7 +163,7 @@ class DirectWeightDecay(Module):
80
163
  @torch.no_grad
81
164
  def step(self, var):
82
165
  weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
83
- ord = self.settings[var.params[0]]['ord']
166
+ ord = self.defaults['ord']
84
167
 
85
168
  decay_weights_(var.params, weight_decay, ord)
86
169
  return var
@@ -7,7 +7,35 @@ from ...utils import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
10
- """Custom param groups are supported only by `set_param_groups`. Settings passed to Modular will be ignored."""
10
+ """
11
+ Wraps a pytorch optimizer to use it as a module.
12
+
13
+ .. note::
14
+ Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
15
+
16
+ Args:
17
+ opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
+ function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
+ or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
20
+ *args:
21
+ **kwargs:
22
+ Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
23
+
24
+ Example:
25
+ wrapping pytorch_optimizer.StableAdamW
26
+
27
+ .. code-block:: py
28
+
29
+ from pytorch_optimizer import StableAdamW
30
+ opt = tz.Modular(
31
+ model.parameters(),
32
+ tz.m.Wrap(StableAdamW, lr=1),
33
+ tz.m.Cautious(),
34
+ tz.m.LR(1e-2)
35
+ )
36
+
37
+
38
+ """
11
39
  def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
12
40
  super().__init__()
13
41
  self._opt_fn = opt_fn
@@ -0,0 +1 @@
1
+ from .cd import CD, CCD, CCDLS