torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
+
11
+
12
+ class Dropout(Transform):
13
+ """Applies dropout to the update.
14
+
15
+ For each weight the update to that weight has :code:`p` probability to be set to 0.
16
+ This can be used to implement gradient dropout or update dropout depending on placement.
17
+
18
+ Args:
19
+ p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
20
+ graft (bool, optional):
21
+ if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
22
+ target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
23
+
24
+
25
+ Examples:
26
+ Gradient dropout.
27
+
28
+ .. code-block:: python
29
+
30
+ opt = tz.Modular(
31
+ model.parameters(),
32
+ tz.m.Dropout(0.5),
33
+ tz.m.Adam(),
34
+ tz.m.LR(1e-3)
35
+ )
36
+
37
+ Update dropout.
38
+
39
+ .. code-block:: python
40
+
41
+ opt = tz.Modular(
42
+ model.parameters(),
43
+ tz.m.Adam(),
44
+ tz.m.Dropout(0.5),
45
+ tz.m.LR(1e-3)
46
+ )
47
+
48
+ """
49
+ def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
50
+ defaults = dict(p=p, graft=graft)
51
+ super().__init__(defaults, uses_grad=False, target=target)
52
+
53
+ @torch.no_grad
54
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
55
+ tensors = TensorList(tensors)
56
+ p = NumberList(s['p'] for s in settings)
57
+ graft = settings[0]['graft']
58
+
59
+ if graft:
60
+ target_norm = tensors.global_vector_norm()
61
+ tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
62
+ return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
63
+
64
+ return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
65
+
66
+ def _bernoulli_like(tensor, p = 0.5, generator = None):
67
+ """p is probability of a 1, other values will be 0."""
68
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
69
+
70
+ class WeightDropout(Module):
71
+ """
72
+ Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
73
+
74
+ Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
75
+
76
+ Args:
77
+ p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
78
+ graft (bool, optional):
79
+ if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
80
+ """
81
+ def __init__(self, p: float = 0.5, graft: bool = True):
82
+ defaults = dict(p=p, graft=graft, use_dropout=True)
83
+ super().__init__(defaults)
84
+
85
+ @torch.no_grad
86
+ def step(self, var):
87
+ closure = var.closure
88
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
89
+ params = TensorList(var.params)
90
+ p = NumberList(self.settings[p]['p'] for p in params)
91
+
92
+ # create masks
93
+ mask = []
94
+ for p, m in zip(params, mask):
95
+ prob = self.settings[p]['p']
96
+ use_dropout = self.settings[p]['use_dropout']
97
+ if use_dropout: mask.append(_bernoulli_like(p, prob))
98
+ else: mask.append(torch.ones_like(p))
99
+
100
+ @torch.no_grad
101
+ def dropout_closure(backward=True):
102
+ orig_params = params.clone()
103
+ params.mul_(mask)
104
+ if backward:
105
+ with torch.enable_grad(): loss = closure()
106
+ else:
107
+ loss = closure(False)
108
+ params.copy_(orig_params)
109
+ return loss
110
+
111
+ var.closure = dropout_closure
112
+ return var
113
+
114
+
115
+ class PerturbWeights(Module):
116
+ """
117
+ Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
118
+
119
+ Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
120
+
121
+ Args:
122
+ alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
123
+ relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
124
+ graft (bool, optional):
125
+ if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
126
+ """
127
+ def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
128
+ defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
129
+ super().__init__(defaults)
130
+
131
+ @torch.no_grad
132
+ def step(self, var):
133
+ closure = var.closure
134
+ if closure is None: raise RuntimeError('WeightDropout requires closure')
135
+ params = TensorList(var.params)
136
+
137
+ # create perturbations
138
+ perts = []
139
+ for p in params:
140
+ settings = self.settings[p]
141
+ if not settings['perturb']:
142
+ perts.append(torch.zeros_like(p))
143
+ continue
144
+
145
+ alpha = settings['alpha']
146
+ if settings['relative']:
147
+ alpha *= p.abs().mean()
148
+
149
+ distribution = self.settings[p]['distribution'].lower()
150
+ if distribution in ('normal', 'gaussian'):
151
+ perts.append(torch.randn_like(p).mul_(alpha))
152
+ elif distribution == 'uniform':
153
+ perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
154
+ elif distribution == 'sphere':
155
+ r = torch.randn_like(p)
156
+ perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
157
+ else:
158
+ raise ValueError(distribution)
159
+
160
+ @torch.no_grad
161
+ def perturbed_closure(backward=True):
162
+ params.add_(perts)
163
+ if backward:
164
+ with torch.enable_grad(): loss = closure()
165
+ else:
166
+ loss = closure(False)
167
+ params.sub_(perts)
168
+ return loss
169
+
170
+ var.closure = perturbed_closure
171
+ return var
@@ -0,0 +1,103 @@
1
+ from collections.abc import Callable
2
+ from typing import cast
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, Var
7
+
8
+
9
+ def _split(
10
+ module: Module,
11
+ idxs,
12
+ params,
13
+ var: Var,
14
+ ):
15
+ split_params = [p for i,p in enumerate(params) if i in idxs]
16
+
17
+ split_grad = None
18
+ if var.grad is not None:
19
+ split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
20
+
21
+ split_update = None
22
+ if var.update is not None:
23
+ split_update = [u for i,u in enumerate(var.update) if i in idxs]
24
+
25
+ split_var = var.clone(clone_update=False)
26
+ split_var.params = split_params
27
+ split_var.grad = split_grad
28
+ split_var.update = split_update
29
+
30
+ split_var = module.step(split_var)
31
+
32
+ if (var.grad is None) and (split_var.grad is not None):
33
+ var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
34
+
35
+ if split_var.update is not None:
36
+
37
+ if var.update is None:
38
+ if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
39
+ else: var.update = [g.clone() for g in var.grad]
40
+
41
+ for idx, u in zip(idxs, split_var.update):
42
+ var.update[idx] = u
43
+
44
+ var.update_attrs_from_clone_(split_var)
45
+ return var
46
+
47
+ class Split(Module):
48
+ """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
49
+
50
+ Args:
51
+ filter (Callable[[torch.Tensor], bool]): a function that takes in a parameter tensor and returns a boolean value.
52
+ true (Chainable | None): modules that are applied to tensors where :code:`filter` returned True.
53
+ false (Chainable | None): modules that are applied to tensors where :code:`filter` returned False.
54
+
55
+ Examples:
56
+ standard Muon with Adam fallback
57
+
58
+ .. code-block:: python
59
+
60
+ opt = tz.Modular(
61
+ model.head.parameters(),
62
+ tz.m.Split(
63
+ # apply muon only to 2D+ parameters
64
+ filter = lambda t: t.ndim >= 2,
65
+ true = [
66
+ tz.m.HeavyBall(),
67
+ tz.m.Orthogonalize(),
68
+ tz.m.LR(1e-2),
69
+ ],
70
+ false = tz.m.Adam()
71
+ ),
72
+ tz.m.LR(1e-2)
73
+ )
74
+
75
+
76
+ """
77
+ def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
78
+ defaults = dict(filter=filter)
79
+ super().__init__(defaults)
80
+
81
+ if true is not None: self.set_child('true', true)
82
+ if false is not None: self.set_child('false', false)
83
+
84
+ def step(self, var):
85
+
86
+ params = var.params
87
+ filter = self.settings[params[0]]['filter']
88
+
89
+ true_idxs = []
90
+ false_idxs = []
91
+ for i,p in enumerate(params):
92
+ if filter(p): true_idxs.append(i)
93
+ else: false_idxs.append(i)
94
+
95
+ if 'true' in self.children:
96
+ true = self.children['true']
97
+ var = _split(true, idxs=true_idxs, params=params, var=var)
98
+
99
+ if 'false' in self.children:
100
+ false = self.children['false']
101
+ var = _split(false, idxs=false_idxs, params=params, var=var)
102
+
103
+ return var
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
7
7
 
8
8
 
9
9
  class Alternate(Module):
10
- """alternate between stepping with `modules`"""
10
+ """Alternates between stepping with :code:`modules`.
11
+
12
+ That is, first step is performed with 1st module, second step with second module, etc.
13
+
14
+ Args:
15
+ steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
16
+
17
+ Examples:
18
+ Alternate between Adam, SignSGD and RMSprop
19
+
20
+ .. code-block:: python
21
+
22
+ opt = tz.Modular(
23
+ model.parameters(),
24
+ tz.m.Alternate(
25
+ tz.m.Adam(),
26
+ [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
+ tz.m.RMSprop(),
28
+ ),
29
+ tz.m.LR(1e-3),
30
+ )
31
+ """
11
32
  LOOP = True
12
33
  def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
13
34
  if isinstance(steps, Iterable):
@@ -23,16 +44,16 @@ class Alternate(Module):
23
44
  self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
24
45
 
25
46
  @torch.no_grad
26
- def step(self, vars):
47
+ def step(self, var):
27
48
  # get current module
28
49
  current_module_idx = self.global_state.setdefault('current_module_idx', 0)
29
50
  module = self.children[f'module_{current_module_idx}']
30
51
 
31
52
  # step
32
- vars = module.step(vars.clone(clone_update=False))
53
+ var = module.step(var.clone(clone_update=False))
33
54
 
34
55
  # number of steps until next module
35
- steps = self.settings[vars.params[0]]['steps']
56
+ steps = self.settings[var.params[0]]['steps']
36
57
  if isinstance(steps, int): steps = [steps]*len(self.children)
37
58
 
38
59
  if 'steps_to_next' not in self.global_state:
@@ -51,17 +72,37 @@ class Alternate(Module):
51
72
 
52
73
  self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
53
74
 
54
- return vars
75
+ return var
55
76
 
56
77
  class Switch(Alternate):
57
- """switch to next module after some steps"""
78
+ """After :code:`steps` steps switches to the next module.
79
+
80
+ Args:
81
+ steps (int | Iterable[int]): Number of steps to perform with each module.
82
+
83
+ Examples:
84
+ Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
85
+
86
+ .. code-block:: python
87
+
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.Switch(
91
+ [tz.m.Adam(), tz.m.LR(1e-3)],
92
+ [tz.m.LBFGS(), tz.m.Backtracking()],
93
+ [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
94
+ steps = (1000, 2000)
95
+ )
96
+ )
97
+ """
98
+
58
99
  LOOP = False
59
100
  def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
60
101
 
61
102
  if isinstance(steps, Iterable):
62
103
  steps = list(steps)
63
104
  if len(steps) != len(modules) - 1:
64
- raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
105
+ raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
65
106
 
66
107
  steps.append(1)
67
108
 
@@ -11,4 +11,4 @@ from .experimental import CoordinateMomentum
11
11
  # from .matrix_momentum import MatrixMomentum
12
12
 
13
13
  from .momentum import NAG, HeavyBall
14
- from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
14
+ from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
@@ -1,3 +1,4 @@
1
+ """Modules that perform averaging over a history of past updates."""
1
2
  from collections import deque
2
3
  from collections.abc import Sequence
3
4
  from typing import Any, Literal, cast
@@ -9,14 +10,19 @@ from ...utils import tolist
9
10
 
10
11
 
11
12
  class Averaging(TensorwiseTransform):
13
+ """Average of past :code:`history_size` updates.
14
+
15
+ Args:
16
+ history_size (int): Number of past updates to average
17
+ target (Target, optional): target. Defaults to 'update'.
18
+ """
12
19
  def __init__(self, history_size: int, target: Target = 'update'):
13
20
  defaults = dict(history_size=history_size)
14
21
  super().__init__(uses_grad=False, defaults=defaults, target=target)
15
22
 
16
23
  @torch.no_grad
17
- def transform(self, tensor, param, grad, vars):
18
- history_size = self.settings[param]['history_size']
19
- state = self.state[param]
24
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
25
+ history_size = setting['history_size']
20
26
  if 'history' not in state:
21
27
  state['history'] = deque(maxlen=history_size)
22
28
  state['average'] = torch.zeros_like(tensor)
@@ -29,15 +35,19 @@ class Averaging(TensorwiseTransform):
29
35
  return average / len(history)
30
36
 
31
37
  class WeightedAveraging(TensorwiseTransform):
32
- """weights are oldest to newest"""
38
+ """Weighted average of past :code:`len(weights)` updates.
39
+
40
+ Args:
41
+ weights (Sequence[float]): a sequence of weights from oldest to newest.
42
+ target (Target, optional): target. Defaults to 'update'.
43
+ """
33
44
  def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
34
45
  defaults = dict(weights = tolist(weights))
35
46
  super().__init__(uses_grad=False, defaults=defaults, target=target)
36
47
 
37
48
  @torch.no_grad
38
- def transform(self, tensor, param, grad, vars):
39
- weights = self.settings[param]['weights']
40
- state = self.state[param]
49
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
50
+ weights = setting['weights']
41
51
 
42
52
  if 'history' not in state:
43
53
  state['history'] = deque(maxlen=len(weights))
@@ -59,14 +69,19 @@ class WeightedAveraging(TensorwiseTransform):
59
69
 
60
70
 
61
71
  class MedianAveraging(TensorwiseTransform):
72
+ """Median of past :code:`history_size` updates.
73
+
74
+ Args:
75
+ history_size (int): Number of past updates to average
76
+ target (Target, optional): target. Defaults to 'update'.
77
+ """
62
78
  def __init__(self, history_size: int, target: Target = 'update'):
63
79
  defaults = dict(history_size = history_size)
64
80
  super().__init__(uses_grad=False, defaults=defaults, target=target)
65
81
 
66
82
  @torch.no_grad
67
- def transform(self, tensor, param, grad, vars):
68
- history_size = self.settings[param]['history_size']
69
- state = self.state[param]
83
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
84
+ history_size = setting['history_size']
70
85
 
71
86
  if 'history' not in state:
72
87
  state['history'] = deque(maxlen=history_size)