torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,18 @@
1
1
  from abc import ABC, abstractmethod
2
- from collections.abc import Iterable, Sequence
3
- from typing import Any, Literal
2
+ from collections.abc import Iterable, Sequence, Mapping
3
+ from typing import Any, Literal, final
4
4
 
5
5
  import torch
6
6
 
7
- from ..utils import set_storage_
8
- from .module import Module, Vars, Chain, Chainable
7
+ from ..utils import set_storage_, TensorList, vec_to_tensors
8
+ from .module import Module, Var, Chain, Chainable
9
9
 
10
10
  Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
11
 
12
12
  class Transform(Module, ABC):
13
- """Base class for a transform.
13
+ """Base class for a transform. This is an abstract class, to use it, subclass it and override `update` and `apply` methods.
14
14
 
15
- This is an abstract class, to use it, subclass it and override `transform`.
15
+ A transform is a module that can also be applied manually to an arbitrary sequence of tensors.
16
16
 
17
17
  Args:
18
18
  defaults (dict[str,Any] | None): dict with default values.
@@ -20,62 +20,180 @@ class Transform(Module, ABC):
20
20
  Set this to True if `transform` method uses the `grad` argument. This will ensure
21
21
  `grad` is always computed and can't be None. Otherwise set to False.
22
22
  target (Target, optional):
23
- what to set on vars. Defaults to 'update'.
23
+ what to set on var. Defaults to 'update'.
24
24
  """
25
- def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
25
+ def __init__(
26
+ self,
27
+ defaults: dict[str,Any] | None,
28
+ uses_grad: bool,
29
+ concat_params: bool = False,
30
+ update_freq: int = 1,
31
+ scale_first: bool = False,
32
+ inner: Chainable | None = None,
33
+ target: Target = 'update',
34
+ ):
26
35
  super().__init__(defaults)
27
36
  self._target: Target = target
28
37
  self._uses_grad = uses_grad
38
+ self._concat_params = concat_params
39
+ self._update_freq = update_freq
40
+ self._scale_first = scale_first
41
+ self._inner = inner
42
+
43
+ def update(
44
+ self,
45
+ tensors: list[torch.Tensor],
46
+ params: list[torch.Tensor],
47
+ grads: list[torch.Tensor] | None,
48
+ loss: torch.Tensor | None,
49
+ states: list[dict[str, Any]],
50
+ settings: Sequence[Mapping[str, Any]],
51
+ ) -> None:
52
+ """Updates this transform. By default does nothing - if logic is in `apply` method."""
29
53
 
30
54
  @abstractmethod
31
- def transform(self, tensors: list[torch.Tensor], params: list[torch.Tensor], grads: list[torch.Tensor] | None, vars: Vars) -> Iterable[torch.Tensor]:
32
- """applies the update rule to `target`."""
55
+ def apply(
56
+ self,
57
+ tensors: list[torch.Tensor],
58
+ params: list[torch.Tensor],
59
+ grads: list[torch.Tensor] | None,
60
+ loss: torch.Tensor | None,
61
+ states: list[dict[str, Any]],
62
+ settings: Sequence[Mapping[str, Any]],
63
+ ) -> Sequence[torch.Tensor]:
64
+ """Applies the update rule to `tensors`."""
65
+
66
+ @final
67
+ @torch.no_grad
68
+ def transform(
69
+ self,
70
+ tensors: list[torch.Tensor],
71
+ params: list[torch.Tensor],
72
+ grads: list[torch.Tensor] | None,
73
+ loss: torch.Tensor | None,
74
+ states: list[dict[str, Any]],
75
+ settings: Sequence[Mapping[str, Any]] | None,
76
+ ) -> list[torch.Tensor]:
77
+ """Applies this transform to an arbitrary sequence of tensors."""
78
+ un_tensors = tensors
79
+ un_params = params
80
+ un_grads = grads
81
+ if self._concat_params:
82
+ tensors = [torch.cat([t.ravel() for t in tensors])]
83
+ params = [torch.cat([p.ravel() for p in params])]
84
+ grads = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
85
+
86
+ if settings is None:
87
+ settings = [self.defaults for _ in tensors]
88
+
89
+ step = self.global_state.get('__step', 0)
90
+ num = len(tensors)
91
+ states = states[:num]
92
+ settings = settings[:num]
93
+
94
+ update_freq = self._update_freq
95
+ scale_first = self._scale_first
96
+ scale_factor = 1
97
+
98
+ # scaling factor for 1st step
99
+ if scale_first and step == 0:
100
+ # initial step size guess from pytorch LBFGS
101
+ scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
102
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
103
+
104
+ # update transform
105
+ if step % update_freq == 0:
106
+ self.update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
107
+
108
+ # step with inner
109
+ if self._inner is not None:
110
+ tensors = apply_transform(self._inner, tensors=un_tensors, params=un_params, grads=un_grads)
111
+ if self._concat_params:
112
+ tensors = [torch.cat([t.ravel() for t in tensors])]
113
+
114
+ # apply transform
115
+ tensors = list(self.apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
116
+
117
+ # scale initial step, when preconditioner might not have been applied
118
+ if scale_first and step == 0:
119
+ torch._foreach_mul_(tensors, scale_factor)
120
+
121
+ self.global_state['__step'] = step + 1
122
+ if self._concat_params:
123
+ tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
124
+ return tensors
125
+
33
126
 
34
- def step(self, vars: Vars) -> Vars:
35
- # vars may change, therefore current params and grads have to be extracted and passed explicitly
36
- if self._uses_grad: vars.get_grad()
37
- params=vars.params; grad = vars.grad
127
+ @torch.no_grad
128
+ def keyed_transform(
129
+ self,
130
+ tensors: list[torch.Tensor],
131
+ params: list[torch.Tensor],
132
+ grads: list[torch.Tensor] | None,
133
+ loss: torch.Tensor | None,
134
+ ):
135
+ """Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
136
+ if self._concat_params:
137
+ p = params[0]
138
+ states = [self.state[p]]
139
+ settings = [self.settings[p]]
140
+
141
+ else:
142
+ states = []
143
+ settings = []
144
+ for p in params:
145
+ states.append(self.state[p])
146
+ settings.append(self.settings[p])
147
+
148
+ return self.transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
149
+
150
+ def step(self, var: Var) -> Var:
151
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
152
+ if self._uses_grad: var.get_grad()
153
+ params=var.params
38
154
 
39
155
  # ---------------------------------- update ---------------------------------- #
40
156
  if self._target == 'update':
41
- vars.update = list(self.transform(vars.get_update(), params, grad, vars))
42
- return vars
157
+ update = var.get_update()
158
+ var.update = list(self.keyed_transform(update, params, var.grad, var.loss))
159
+ return var
43
160
 
44
161
  # ----------------------------------- grad ----------------------------------- #
45
162
  if self._target == 'grad':
46
- vars.grad = list(self.transform(vars.get_grad(), params, grad, vars))
47
- return vars
163
+ grad = var.get_grad()
164
+ var.grad = list(self.keyed_transform(grad, params, grad, var.loss))
165
+ return var
48
166
 
49
167
  # ------------------------------- params_direct ------------------------------ #
50
168
  if self._target == 'params_direct':
51
- new_params = self.transform(vars.params, params, grad, vars)
52
- for p, new_p in zip(vars.params, new_params): set_storage_(p, new_p)
53
- return vars
169
+ new_params = self.keyed_transform(var.params, params, var.grad, var.loss)
170
+ for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
171
+ return var
54
172
 
55
173
  # ----------------------------- params_differnce ----------------------------- #
56
174
  if self._target == 'params_difference':
57
- new_params = tuple(self.transform([p.clone() for p in vars.params], params, grad, vars))
58
- vars.update = list(torch._foreach_sub(vars.params, new_params))
59
- return vars
175
+ new_params = tuple(self.keyed_transform([p.clone() for p in var.params], params, var.grad, var.loss))
176
+ var.update = list(torch._foreach_sub(var.params, new_params))
177
+ return var
60
178
 
61
179
  # ----------------------------- update_difference ---------------------------- #
62
180
  if self._target == 'update_difference':
63
- update = vars.get_update()
64
- new_update = tuple(self.transform([u.clone() for u in update], params, grad, vars))
65
- vars.update = list(torch._foreach_sub(update, new_update))
66
- return vars
181
+ update = var.get_update()
182
+ new_update = tuple(self.keyed_transform([u.clone() for u in update], params, var.grad, var.loss))
183
+ var.update = list(torch._foreach_sub(update, new_update))
184
+ return var
67
185
 
68
186
  # ---------------------------------- closure --------------------------------- #
69
187
  if self._target == 'closure':
70
- original_closure = vars.closure
188
+ original_closure = var.closure
71
189
  if original_closure is None: raise ValueError('Target = "closure", but closure is None')
72
190
 
73
- params = vars.params
191
+ params = var.params
74
192
  def transformed_closure(backward=True):
75
193
  if backward:
76
194
  loss = original_closure()
77
195
  current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
78
- transformed_grad = list(self.transform(current_grad, params, grad, vars))
196
+ transformed_grad = list(self.keyed_transform(current_grad, params, var.grad, var.loss))
79
197
  for p, g in zip(params, transformed_grad):
80
198
  p.grad = g
81
199
 
@@ -84,14 +202,14 @@ class Transform(Module, ABC):
84
202
 
85
203
  return loss
86
204
 
87
- vars.closure = transformed_closure
88
- return vars
205
+ var.closure = transformed_closure
206
+ return var
89
207
 
90
208
  # ---------------------------------- invalid --------------------------------- #
91
209
  raise ValueError(f'Invalid target: {self._target}')
92
210
 
93
211
 
94
- class TensorwiseTransform(Module, ABC):
212
+ class TensorwiseTransform(Transform, ABC):
95
213
  """Base class for a parameter-wise transform.
96
214
 
97
215
  This is an abstract class, to use it, subclass it and override `transform`.
@@ -102,151 +220,94 @@ class TensorwiseTransform(Module, ABC):
102
220
  Set this to True if `transform` method uses the `grad` argument. This will ensure
103
221
  `grad` is always computed and can't be None. Otherwise set to False.
104
222
  target (Target, optional):
105
- what to set on vars. Defaults to 'update'.
223
+ what to set on var. Defaults to 'update'.
106
224
  """
107
- def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
108
- super().__init__(defaults)
109
- self._target: Target = target
110
- self._uses_grad: bool = uses_grad
225
+ def __init__(
226
+ self,
227
+ defaults: dict[str,Any] | None,
228
+ uses_grad: bool,
229
+ concat_params: bool = False,
230
+ update_freq: int = 1,
231
+ scale_first: bool = False,
232
+ inner: Chainable | None = None,
233
+ target: Target = 'update',
234
+ ):
235
+ super().__init__(
236
+ defaults=defaults,
237
+ uses_grad=uses_grad,
238
+ concat_params=concat_params,
239
+ update_freq=update_freq,
240
+ scale_first=scale_first,
241
+ inner=inner,
242
+ target=target,
243
+ )
244
+
245
+ def update_tensor(
246
+ self,
247
+ tensor: torch.Tensor,
248
+ param: torch.Tensor,
249
+ grad: torch.Tensor | None,
250
+ loss: torch.Tensor | None,
251
+ state: dict[str, Any],
252
+ settings: Mapping[str, Any],
253
+ ) -> None:
254
+ """Updates this transform. By default does nothing - if logic is in `apply` method."""
111
255
 
112
256
  @abstractmethod
113
- def transform(
257
+ def apply_tensor(
114
258
  self,
115
259
  tensor: torch.Tensor,
116
260
  param: torch.Tensor,
117
261
  grad: torch.Tensor | None,
118
- vars: Vars,
262
+ loss: torch.Tensor | None,
263
+ state: dict[str, Any],
264
+ settings: Mapping[str, Any],
119
265
  ) -> torch.Tensor:
120
- """applies the update rule to `target`"""
121
-
122
- def step(self, vars: Vars) -> Vars:
123
- params = vars.params
124
- if self._uses_grad and vars.grad is None: vars.get_grad()
125
-
126
- # ---------------------------------- update ---------------------------------- #
127
- if self._target == 'update':
128
- update = vars.get_update()
129
- grad = vars.grad if vars.grad is not None else [None] * len(params)
130
- transformed_update = []
131
-
132
- for p, g, u in zip(params, grad, update):
133
- # settings = self.settings[p] # couldn't make typing work with this
134
- #, self.transform(target=u, param=p, grad=g, vars=vars, **{k:settings[k] for k in self.defaults})
135
- transformed_update.append(self.transform(tensor=u, param=p, grad=g, vars=vars))
136
-
137
- vars.update = transformed_update
138
- return vars
139
-
140
- # ----------------------------------- grad ----------------------------------- #
141
- if self._target == 'grad':
142
- grad = vars.get_grad()
143
- transformed_grad = []
144
-
145
- for p, g in zip(params, grad):
146
- transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
147
-
148
- vars.grad = transformed_grad
149
- return vars
150
-
151
- # ------------------------------- params_direct ------------------------------ #
152
- if self._target == 'params_direct':
153
- grad = vars.grad if vars.grad is not None else [None] * len(params)
154
-
155
- for p, g in zip(params, grad):
156
- set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
157
-
158
- return vars
159
-
160
- # ----------------------------- params_difference ---------------------------- #
161
- if self._target == 'params_difference':
162
- grad = vars.grad if vars.grad is not None else [None] * len(params)
163
- transformed_params = []
164
-
165
- for p, g in zip(params, grad):
166
- transformed_params.append(
167
- self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
168
- )
169
-
170
- vars.update = list(torch._foreach_sub(params, transformed_params))
171
- return vars
172
-
173
- # ----------------------------- update_difference ---------------------------- #
174
- if self._target == 'update_difference':
175
- update = vars.get_update()
176
- grad = vars.grad if vars.grad is not None else [None] * len(params)
177
- transformed_update = []
178
-
179
- for p, g, u in zip(params, grad, update):
180
- transformed_update.append(
181
- self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
182
- )
183
-
184
- vars.update = list(torch._foreach_sub(update, transformed_update))
185
- return vars
186
-
187
- # ---------------------------------- closure --------------------------------- #
188
- if self._target == 'closure':
189
- original_closure = vars.closure
190
- if original_closure is None: raise ValueError('Target = "closure", but closure is None')
191
-
192
- params = vars.params
193
- def transformed_closure(backward=True):
194
- if backward:
195
- loss = original_closure()
196
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
197
- transformed_grad = []
198
-
199
- for p, g in zip(params, grad):
200
- transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
201
-
202
- for p, g in zip(params, transformed_grad):
203
- p.grad = g
204
-
205
- else:
206
- loss = original_closure(False)
207
-
208
- return loss
209
-
210
- vars.closure = transformed_closure
211
- return vars
212
-
213
- # ---------------------------------- invalid --------------------------------- #
214
- raise ValueError(f'Invalid target: {self._target}')
215
-
216
-
217
-
218
- def apply(
266
+ """Applies the update rule to `tensor`."""
267
+
268
+ @final
269
+ def update(self, tensors, params, grads, loss, states, settings):
270
+ if grads is None: grads = [None]*len(tensors)
271
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
272
+ self.update_tensor(t, p, g, loss, state, setting)
273
+
274
+ @final
275
+ def apply(self, tensors, params, grads, loss, states, settings):
276
+ applied = []
277
+ if grads is None: grads = [None]*len(tensors)
278
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
279
+ applied.append(self.apply_tensor(t, p, g, loss, state, setting))
280
+ return applied
281
+
282
+ def apply_transform(
219
283
  tfm: Chainable,
220
284
  tensors: list[torch.Tensor],
221
285
  params: list[torch.Tensor],
222
286
  grads: list[torch.Tensor] | None,
223
- vars: Vars | None = None,
287
+ loss: torch.Tensor | None = None,
288
+ var: Var | None = None,
224
289
  current_step: int = 0,
225
290
  ):
226
- if vars is None: vars = Vars(params=params, closure=None, model=None, current_step=current_step)
227
- if isinstance(tfm, Transform):
228
- if tfm._uses_grad and grads is None: grads = vars.get_grad()
229
- return list(tfm.transform(tensors, params, grads, vars))
291
+ if var is None:
292
+ var = Var(params=params, closure=None, model=None, current_step=current_step)
293
+ var.loss = loss
230
294
 
231
- if isinstance(tfm, TensorwiseTransform):
232
- grads_list = grads
233
- if grads_list is None:
234
- if tfm._uses_grad: grads_list = vars.get_grad()
235
- else: grads_list = [None] * len(tensors)
236
- return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
295
+ if isinstance(tfm, Transform):
296
+ if tfm._uses_grad and grads is None: grads = var.get_grad()
297
+ return list(tfm.keyed_transform(tensors, params, grads, loss))
237
298
 
238
299
  if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
239
300
  if isinstance(tfm, Sequence):
240
301
  for module in tfm:
241
- tensors = apply(module, tensors=tensors, params=params, grads=grads, vars=vars)
302
+ tensors = apply_transform(module, tensors=tensors, params=params, grads=grads, var=var)
242
303
  return tensors
243
304
 
244
305
  if isinstance(tfm, Module):
245
- cvars = vars.clone(clone_update=False)
246
- cvars.update = tensors
247
- cvars = tfm.step(cvars)
248
- vars.update_attrs_from_clone_(cvars)
249
- assert cvars.update is not None
250
- return cvars.update
306
+ cvar = var.clone(clone_update=False)
307
+ cvar.update = tensors
308
+ cvar = tfm.step(cvar)
309
+ var.update_attrs_from_clone_(cvar)
310
+ assert cvar.update is not None
311
+ return cvar.update
251
312
 
252
313
  raise TypeError(type(tfm))
@@ -11,3 +11,4 @@ from .smoothing import *
11
11
  from .weight_decay import *
12
12
  from .wrappers import *
13
13
  from .second_order import *
14
+ from .higher_order import *
@@ -151,8 +151,8 @@ class ClipValue(Transform):
151
151
  super().__init__(defaults, uses_grad=False, target=target)
152
152
 
153
153
  @torch.no_grad
154
- def transform(self, tensors, params, grads, vars):
155
- value = self.get_settings('value', params=params)
154
+ def apply(self, tensors, params, grads, loss, states, settings):
155
+ value = [s['value'] for s in settings]
156
156
  return TensorList(tensors).clip_([-v for v in value], value)
157
157
 
158
158
  class ClipNorm(Transform):
@@ -186,9 +186,9 @@ class ClipNorm(Transform):
186
186
  super().__init__(defaults, uses_grad=False, target=target)
187
187
 
188
188
  @torch.no_grad
189
- def transform(self, tensors, params, grads, vars):
190
- max_norm = self.get_settings('max_norm', params=params, cls=NumberList)
191
- ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
189
+ def apply(self, tensors, params, grads, loss, states, settings):
190
+ max_norm = NumberList(s['max_norm'] for s in settings)
191
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
192
192
  _clip_norm_(
193
193
  tensors_ = TensorList(tensors),
194
194
  min = 0,
@@ -232,9 +232,9 @@ class Normalize(Transform):
232
232
  super().__init__(defaults, uses_grad=False, target=target)
233
233
 
234
234
  @torch.no_grad
235
- def transform(self, tensors, params, grads, vars):
236
- norm_value = self.get_settings('norm_value', params=params, cls=NumberList)
237
- ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
235
+ def apply(self, tensors, params, grads, loss, states, settings):
236
+ norm_value = NumberList(s['norm_value'] for s in settings)
237
+ ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
238
238
 
239
239
  _clip_norm_(
240
240
  tensors_ = TensorList(tensors),
@@ -311,8 +311,8 @@ class Centralize(Transform):
311
311
  super().__init__(defaults, uses_grad=False, target=target)
312
312
 
313
313
  @torch.no_grad
314
- def transform(self, tensors, params, grads, vars):
315
- dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(self.settings[params[0]])
314
+ def apply(self, tensors, params, grads, loss, states, settings):
315
+ dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
316
316
 
317
317
  _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
318
318
 
@@ -4,8 +4,8 @@ from collections.abc import Iterable, Sequence
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Module, Target, Transform, apply, Chainable
8
- from ...utils import NumberList, TensorList, generic_eq
7
+ from ...core import Module, Target, Transform, apply_transform, Chainable
8
+ from ...utils import NumberList, TensorList, generic_eq, unpack_dicts, unpack_states
9
9
 
10
10
  class ClipNormByEMA(Transform):
11
11
  """Clips norm to be no larger than the norm of an exponential moving average of past updates.
@@ -34,13 +34,14 @@ class ClipNormByEMA(Transform):
34
34
  super().__init__(defaults, uses_grad=False)
35
35
 
36
36
  @torch.no_grad
37
- def transform(self, tensors, params, grads, vars):
38
- ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(self.settings[params[0]])
39
-
40
- beta, eps = self.get_settings('beta', 'eps', params=params, cls=NumberList)
37
+ def apply(self, tensors, params, grads, loss, states, settings):
41
38
  tensors = TensorList(tensors)
39
+ ord, tensorwise, ema_init, max_ema_growth = itemgetter('ord', 'tensorwise', 'ema_init', 'max_ema_growth')(settings[0])
40
+
41
+ beta, eps = unpack_dicts(settings, 'beta', 'eps', cls=NumberList)
42
+
43
+ ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
42
44
 
43
- ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else tensors), cls=TensorList)
44
45
  ema.lerp_(tensors, 1-beta)
45
46
 
46
47
  if tensorwise:
@@ -48,7 +49,7 @@ class ClipNormByEMA(Transform):
48
49
 
49
50
  # clip ema norm growth
50
51
  if max_ema_growth is not None:
51
- prev_ema_norm = self.get_state('prev_ema_norm', params=params, init=ema_norm, cls=TensorList)
52
+ prev_ema_norm = unpack_states(states, tensors, 'prev_ema_norm', init=ema_norm, cls=TensorList)
52
53
  allowed_norm = (prev_ema_norm * max_ema_growth).clip(min=1e-6)
53
54
  ema_denom = (ema_norm / allowed_norm).clip(min=1)
54
55
  ema.div_(ema_denom)
@@ -119,17 +120,17 @@ class ClipValueByEMA(Transform):
119
120
  self.set_child('ema_tfm', ema_tfm)
120
121
 
121
122
  @torch.no_grad
122
- def transform(self, tensors, params, grads, vars):
123
- ema_init = itemgetter('ema_init')(self.settings[params[0]])
123
+ def apply(self, tensors, params, grads, loss, states, settings):
124
+ ema_init = itemgetter('ema_init')(settings[0])
124
125
 
125
- beta = self.get_settings('beta', params=params, cls=NumberList)
126
+ beta = unpack_dicts(settings, 'beta', cls=NumberList)
126
127
  tensors = TensorList(tensors)
127
128
 
128
- ema = self.get_state('ema', params=params, init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
129
+ ema = unpack_states(states, tensors, 'ema', init = (torch.zeros_like if ema_init=='zeros' else lambda t: t.abs()), cls=TensorList)
129
130
  ema.lerp_(tensors.abs(), 1-beta)
130
131
 
131
132
  if 'ema_tfm' in self.children:
132
- ema = TensorList(apply(self.children['ema_tfm'], ema, params, vars.grad, vars))
133
+ ema = TensorList(apply_transform(self.children['ema_tfm'], ema, params, grads, loss))
133
134
 
134
135
  tensors.clip_(-ema, ema)
135
136
  return tensors
@@ -19,7 +19,7 @@ class ClipValueGrowth(TensorwiseTransform):
19
19
  bounds the tracked multiplicative clipping decay to prevent collapse to 0.
20
20
  Next update is at most :code:`max(previous update * mul, max_decay)`.
21
21
  Defaults to 2.
22
- target (Target, optional): what to set on vars.. Defaults to "update".
22
+ target (Target, optional): what to set on var.. Defaults to "update".
23
23
  """
24
24
  def __init__(
25
25
  self,
@@ -33,12 +33,10 @@ class ClipValueGrowth(TensorwiseTransform):
33
33
  super().__init__(defaults, uses_grad=False, target=target)
34
34
 
35
35
 
36
- def transform(self, tensor, param, grad, vars):
37
- add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(self.settings[param])
36
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
37
+ add, mul, min_value, max_decay = itemgetter('add','mul','min_value','max_decay')(settings)
38
38
  add: float | None
39
39
 
40
- state = self.state[param]
41
-
42
40
  if add is None and mul is None:
43
41
  return tensor
44
42
 
@@ -133,7 +131,7 @@ class ClipNormGrowth(Transform):
133
131
  ord (float, optional): norm order. Defaults to 2.
134
132
  parameterwise (bool, optional):
135
133
  if True, norms are calculated parameter-wise, otherwise treats all parameters as single vector. Defaults to True.
136
- target (Target, optional): what to set on vars. Defaults to "update".
134
+ target (Target, optional): what to set on var. Defaults to "update".
137
135
  """
138
136
  def __init__(
139
137
  self,
@@ -150,35 +148,35 @@ class ClipNormGrowth(Transform):
150
148
 
151
149
 
152
150
 
153
- def transform(self, tensors, params, grads, vars):
154
- parameterwise = self.settings[params[0]]['parameterwise']
151
+ def apply(self, tensors, params, grads, loss, states, settings):
152
+ parameterwise = settings[0]['parameterwise']
155
153
  tensors = TensorList(tensors)
156
154
 
157
155
  if parameterwise:
158
156
  ts = tensors
159
- stts = [self.state[p] for p in params]
160
- stns = [self.settings[p] for p in params]
157
+ stts = states
158
+ stns = settings
161
159
 
162
160
  else:
163
161
  ts = [tensors.to_vec()]
164
162
  stts = [self.global_state]
165
- stns = [self.settings[params[0]]]
163
+ stns = [settings[0]]
166
164
 
167
165
 
168
- for t,state, settings in zip(ts, stts, stns):
166
+ for t, state, setting in zip(ts, stts, stns):
169
167
  if 'prev_norm' not in state:
170
- state['prev_norm'] = torch.linalg.vector_norm(t, ord=settings['ord']) # pylint:disable=not-callable
168
+ state['prev_norm'] = torch.linalg.vector_norm(t, ord=setting['ord']) # pylint:disable=not-callable
171
169
  state['prev_denom'] = 1
172
170
  continue
173
171
 
174
172
  _, state['prev_norm'], state['prev_denom'] = norm_growth_clip_(
175
173
  tensor_ = t,
176
174
  prev_norm = state['prev_norm'],
177
- add = settings['add'],
178
- mul = settings['mul'],
179
- min_value = settings['min_value'],
180
- max_decay = settings['max_decay'],
181
- ord = settings['ord'],
175
+ add = setting['add'],
176
+ mul = setting['mul'],
177
+ min_value = setting['min_value'],
178
+ max_decay = setting['max_decay'],
179
+ ord = setting['ord'],
182
180
  )
183
181
 
184
182
  if not parameterwise: