torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,8 @@ class Transform(Module, ABC):
25
25
  def __init__(
26
26
  self,
27
27
  defaults: dict[str,Any] | None,
28
- uses_grad: bool,
28
+ uses_grad: bool = False,
29
+ uses_loss: bool = False,
29
30
  concat_params: bool = False,
30
31
  update_freq: int = 1,
31
32
  scale_first: bool = False,
@@ -35,49 +36,48 @@ class Transform(Module, ABC):
35
36
  super().__init__(defaults)
36
37
  self._target: Target = target
37
38
  self._uses_grad = uses_grad
39
+ self._uses_loss = uses_loss
38
40
  self._concat_params = concat_params
39
41
  self._update_freq = update_freq
40
42
  self._scale_first = scale_first
41
43
  self._inner = inner
42
44
 
43
- def update(
45
+ def update_tensors(
44
46
  self,
45
47
  tensors: list[torch.Tensor],
46
48
  params: list[torch.Tensor],
47
49
  grads: list[torch.Tensor] | None,
48
- loss: torch.Tensor | None,
50
+ loss: torch.Tensor | float | None,
49
51
  states: list[dict[str, Any]],
50
52
  settings: Sequence[Mapping[str, Any]],
51
53
  ) -> None:
52
- """Updates this transform. By default does nothing - if logic is in `apply` method."""
54
+ """update function, this shouldn't be called directly. Updates this module."""
53
55
 
54
56
  @abstractmethod
55
- def apply(
57
+ def apply_tensors(
56
58
  self,
57
59
  tensors: list[torch.Tensor],
58
60
  params: list[torch.Tensor],
59
61
  grads: list[torch.Tensor] | None,
60
- loss: torch.Tensor | None,
62
+ loss: torch.Tensor | float | None,
61
63
  states: list[dict[str, Any]],
62
64
  settings: Sequence[Mapping[str, Any]],
63
65
  ) -> Sequence[torch.Tensor]:
64
- """Applies the update rule to `tensors`."""
66
+ """apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
67
+ If possible, this shouldn't modify the internal state of this transform."""
65
68
 
66
69
  @final
67
70
  @torch.no_grad
68
- def transform(
71
+ def transform_update(
69
72
  self,
70
73
  tensors: list[torch.Tensor],
71
74
  params: list[torch.Tensor],
72
75
  grads: list[torch.Tensor] | None,
73
- loss: torch.Tensor | None,
76
+ loss: torch.Tensor | float | None,
74
77
  states: list[dict[str, Any]],
75
78
  settings: Sequence[Mapping[str, Any]] | None,
76
- ) -> list[torch.Tensor]:
77
- """Applies this transform to an arbitrary sequence of tensors."""
78
- un_tensors = tensors
79
- un_params = params
80
- un_grads = grads
79
+ ) -> None:
80
+ """Updates this transform from an arbitrary sequence of tensors."""
81
81
  if self._concat_params:
82
82
  tensors = [torch.cat([t.ravel() for t in tensors])]
83
83
  params = [torch.cat([p.ravel() for p in params])]
@@ -86,24 +86,61 @@ class Transform(Module, ABC):
86
86
  if settings is None:
87
87
  settings = [self.defaults for _ in tensors]
88
88
 
89
- step = self.global_state.get('__step', 0)
89
+ step = self.global_state.get('__step', 0) # that way it gets reset correctly
90
+ self.global_state['__step'] = step + 1
91
+
90
92
  num = len(tensors)
91
93
  states = states[:num]
92
94
  settings = settings[:num]
93
95
 
94
- update_freq = self._update_freq
95
- scale_first = self._scale_first
96
96
  scale_factor = 1
97
97
 
98
98
  # scaling factor for 1st step
99
- if scale_first and step == 0:
99
+ if self._scale_first and step == 0:
100
100
  # initial step size guess from pytorch LBFGS
101
101
  scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
102
102
  scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
103
103
 
104
104
  # update transform
105
- if step % update_freq == 0:
106
- self.update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
105
+ if step % self._update_freq == 0:
106
+ self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
107
+
108
+ # store for transform_apply
109
+ self.global_state["__tensors"] = tensors
110
+ self.global_state["__params"] = params
111
+ self.global_state["__grads"] = grads
112
+ self.global_state["__scale_factor"] = scale_factor
113
+
114
+
115
+ @final
116
+ @torch.no_grad
117
+ def transform_apply(
118
+ self,
119
+ tensors: list[torch.Tensor],
120
+ params: list[torch.Tensor],
121
+ grads: list[torch.Tensor] | None,
122
+ loss: torch.Tensor | float | None,
123
+ states: list[dict[str, Any]],
124
+ settings: Sequence[Mapping[str, Any]] | None,
125
+ ) -> list[torch.Tensor]:
126
+ """Applies this transform to an arbitrary sequence of tensors.
127
+ This can be used after ``transform_update`` has been used at least once."""
128
+
129
+ if settings is None:
130
+ settings = [self.defaults for _ in tensors]
131
+
132
+ num = len(tensors)
133
+ states = states[:num]
134
+ settings = settings[:num]
135
+
136
+ un_tensors = tensors
137
+ un_params = params
138
+ un_grads = grads
139
+
140
+ tensors = self.global_state.pop("__tensors")
141
+ params = self.global_state.pop("__params")
142
+ grads = self.global_state.pop("__grads")
143
+ scale_factor = self.global_state.pop("__scale_factor")
107
144
 
108
145
  # step with inner
109
146
  if self._inner is not None:
@@ -112,27 +149,17 @@ class Transform(Module, ABC):
112
149
  tensors = [torch.cat([t.ravel() for t in tensors])]
113
150
 
114
151
  # apply transform
115
- tensors = list(self.apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
152
+ tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
116
153
 
117
154
  # scale initial step, when preconditioner might not have been applied
118
- if scale_first and step == 0:
155
+ if self._scale_first and self.global_state['__step'] == 1:
119
156
  torch._foreach_mul_(tensors, scale_factor)
120
157
 
121
- self.global_state['__step'] = step + 1
122
158
  if self._concat_params:
123
159
  tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
124
160
  return tensors
125
161
 
126
-
127
- @torch.no_grad
128
- def keyed_transform(
129
- self,
130
- tensors: list[torch.Tensor],
131
- params: list[torch.Tensor],
132
- grads: list[torch.Tensor] | None,
133
- loss: torch.Tensor | None,
134
- ):
135
- """Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
162
+ def _get_keyed_states_settings(self, params: list[torch.Tensor]):
136
163
  if self._concat_params:
137
164
  p = params[0]
138
165
  states = [self.state[p]]
@@ -145,41 +172,116 @@ class Transform(Module, ABC):
145
172
  states.append(self.state[p])
146
173
  settings.append(self.settings[p])
147
174
 
148
- return self.transform(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
175
+ return states, settings
176
+
177
+ @final
178
+ @torch.no_grad
179
+ def keyed_transform_update(
180
+ self,
181
+ tensors: list[torch.Tensor],
182
+ params: list[torch.Tensor],
183
+ grads: list[torch.Tensor] | None,
184
+ loss: torch.Tensor | float | None,
185
+ ):
186
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
187
+ states, settings = self._get_keyed_states_settings(params)
188
+ self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
189
+
190
+
191
+ @final
192
+ @torch.no_grad
193
+ def keyed_transform_apply(
194
+ self,
195
+ tensors: list[torch.Tensor],
196
+ params: list[torch.Tensor],
197
+ grads: list[torch.Tensor] | None,
198
+ loss: torch.Tensor | float | None,
199
+ ):
200
+ """`params` will be used as keys and need to always point to same tensor objects.`"""
201
+ states, settings = self._get_keyed_states_settings(params)
202
+ return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
203
+
204
+
205
+ def pre_step(self, var: Var) -> None:
206
+ """Logic to run pre-transform, this way transform has access to Var."""
207
+ def post_step(self, var: Var) -> None:
208
+ """Logic to run post-transform, this way transform has access to Var."""
209
+
210
+ def update(self, var: Var):
211
+ if self._target != 'update':
212
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
213
+ f"With {self._target = } only `step` method can be used.")
214
+
215
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
216
+ update = var.get_update() # this sets loss
217
+ if self._uses_grad: var.get_grad()
218
+ if self._uses_loss: var.get_loss(False)
219
+ params=var.params
220
+ self.pre_step(var)
221
+
222
+ # update
223
+ self.keyed_transform_update(update, params, var.grad, var.loss)
224
+
225
+ def apply(self, var: Var):
226
+ if self._target != 'update':
227
+ raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
228
+ f"With {self._target = } only `step` method can be used.")
149
229
 
150
- def step(self, var: Var) -> Var:
151
230
  # var may change, therefore current params and grads have to be extracted and passed explicitly
231
+ update = var.get_update() # this sets loss
152
232
  if self._uses_grad: var.get_grad()
233
+ if self._uses_loss: var.get_loss(False)
234
+ params=var.params
235
+
236
+ # apply
237
+ var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
238
+ self.post_step(var)
239
+ return var
240
+
241
+ def step(self, var: Var) -> Var:
242
+
243
+ # var may change, therefore current params and grads have to be extracted and passed explicitly
244
+ if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
245
+ if self._uses_grad or self._target == 'grad': var.get_grad()
246
+ if self._uses_loss: var.get_loss(False)
153
247
  params=var.params
248
+ self.pre_step(var)
154
249
 
155
250
  # ---------------------------------- update ---------------------------------- #
156
251
  if self._target == 'update':
157
252
  update = var.get_update()
158
- var.update = list(self.keyed_transform(update, params, var.grad, var.loss))
253
+ self.keyed_transform_update(update, params, var.grad, var.loss)
254
+ var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
159
255
  return var
160
256
 
161
257
  # ----------------------------------- grad ----------------------------------- #
162
258
  if self._target == 'grad':
163
259
  grad = var.get_grad()
164
- var.grad = list(self.keyed_transform(grad, params, grad, var.loss))
260
+ self.keyed_transform_update(grad, params, grad, var.loss)
261
+ var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
165
262
  return var
166
263
 
167
264
  # ------------------------------- params_direct ------------------------------ #
168
265
  if self._target == 'params_direct':
169
- new_params = self.keyed_transform(var.params, params, var.grad, var.loss)
266
+ self.keyed_transform_update(var.params, params, var.grad, var.loss)
267
+ new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
170
268
  for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
171
269
  return var
172
270
 
173
271
  # ----------------------------- params_differnce ----------------------------- #
174
272
  if self._target == 'params_difference':
175
- new_params = tuple(self.keyed_transform([p.clone() for p in var.params], params, var.grad, var.loss))
273
+ p_clone = [p.clone() for p in var.params]
274
+ self.keyed_transform_update(p_clone, params, var.grad, var.loss)
275
+ new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
176
276
  var.update = list(torch._foreach_sub(var.params, new_params))
177
277
  return var
178
278
 
179
279
  # ----------------------------- update_difference ---------------------------- #
180
280
  if self._target == 'update_difference':
181
281
  update = var.get_update()
182
- new_update = tuple(self.keyed_transform([u.clone() for u in update], params, var.grad, var.loss))
282
+ u_clone = [u.clone() for u in update]
283
+ self.keyed_transform_update(u_clone, params, var.grad, var.loss)
284
+ new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
183
285
  var.update = list(torch._foreach_sub(update, new_update))
184
286
  return var
185
287
 
@@ -193,7 +295,8 @@ class Transform(Module, ABC):
193
295
  if backward:
194
296
  loss = original_closure()
195
297
  current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
196
- transformed_grad = list(self.keyed_transform(current_grad, params, var.grad, var.loss))
298
+ self.keyed_transform_update(current_grad, params, var.grad, var.loss)
299
+ transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
197
300
  for p, g in zip(params, transformed_grad):
198
301
  p.grad = g
199
302
 
@@ -203,6 +306,7 @@ class Transform(Module, ABC):
203
306
  return loss
204
307
 
205
308
  var.closure = transformed_closure
309
+ self.post_step(var)
206
310
  return var
207
311
 
208
312
  # ---------------------------------- invalid --------------------------------- #
@@ -225,7 +329,8 @@ class TensorwiseTransform(Transform, ABC):
225
329
  def __init__(
226
330
  self,
227
331
  defaults: dict[str,Any] | None,
228
- uses_grad: bool,
332
+ uses_grad: bool = False,
333
+ uses_loss: bool = False,
229
334
  concat_params: bool = False,
230
335
  update_freq: int = 1,
231
336
  scale_first: bool = False,
@@ -238,6 +343,7 @@ class TensorwiseTransform(Transform, ABC):
238
343
  concat_params=concat_params,
239
344
  update_freq=update_freq,
240
345
  scale_first=scale_first,
346
+ uses_loss=uses_loss,
241
347
  inner=inner,
242
348
  target=target,
243
349
  )
@@ -247,9 +353,9 @@ class TensorwiseTransform(Transform, ABC):
247
353
  tensor: torch.Tensor,
248
354
  param: torch.Tensor,
249
355
  grad: torch.Tensor | None,
250
- loss: torch.Tensor | None,
356
+ loss: torch.Tensor | float | None,
251
357
  state: dict[str, Any],
252
- settings: Mapping[str, Any],
358
+ setting: Mapping[str, Any],
253
359
  ) -> None:
254
360
  """Updates this transform. By default does nothing - if logic is in `apply` method."""
255
361
 
@@ -259,20 +365,20 @@ class TensorwiseTransform(Transform, ABC):
259
365
  tensor: torch.Tensor,
260
366
  param: torch.Tensor,
261
367
  grad: torch.Tensor | None,
262
- loss: torch.Tensor | None,
368
+ loss: torch.Tensor | float | None,
263
369
  state: dict[str, Any],
264
- settings: Mapping[str, Any],
370
+ setting: Mapping[str, Any],
265
371
  ) -> torch.Tensor:
266
372
  """Applies the update rule to `tensor`."""
267
373
 
268
374
  @final
269
- def update(self, tensors, params, grads, loss, states, settings):
375
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
270
376
  if grads is None: grads = [None]*len(tensors)
271
377
  for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
272
378
  self.update_tensor(t, p, g, loss, state, setting)
273
379
 
274
380
  @final
275
- def apply(self, tensors, params, grads, loss, states, settings):
381
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
276
382
  applied = []
277
383
  if grads is None: grads = [None]*len(tensors)
278
384
  for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
@@ -284,7 +390,7 @@ def apply_transform(
284
390
  tensors: list[torch.Tensor],
285
391
  params: list[torch.Tensor],
286
392
  grads: list[torch.Tensor] | None,
287
- loss: torch.Tensor | None = None,
393
+ loss: torch.Tensor | float | None = None,
288
394
  var: Var | None = None,
289
395
  current_step: int = 0,
290
396
  ):
@@ -292,9 +398,10 @@ def apply_transform(
292
398
  var = Var(params=params, closure=None, model=None, current_step=current_step)
293
399
  var.loss = loss
294
400
 
295
- if isinstance(tfm, Transform):
401
+ if isinstance(tfm, Transform) and tfm._target == 'update':
296
402
  if tfm._uses_grad and grads is None: grads = var.get_grad()
297
- return list(tfm.keyed_transform(tensors, params, grads, loss))
403
+ tfm.keyed_transform_update(tensors, params, grads, loss)
404
+ return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
298
405
 
299
406
  if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
300
407
  if isinstance(tfm, Sequence):
@@ -1,7 +1,7 @@
1
1
  from .clipping import *
2
2
  from .grad_approximation import *
3
3
  from .line_search import *
4
- from .lr import *
4
+ from .step_size import *
5
5
  from .momentum import *
6
6
  from .ops import *
7
7
  from .optimizers import *
@@ -11,4 +11,5 @@ from .smoothing import *
11
11
  from .weight_decay import *
12
12
  from .wrappers import *
13
13
  from .second_order import *
14
- from .higher_order import *
14
+ from .higher_order import *
15
+ from .misc import *
@@ -5,7 +5,7 @@ import math
5
5
  import torch
6
6
 
7
7
  from ...core import Module, Target, Transform
8
- from ...utils import NumberList, TensorList, generic_eq
8
+ from ...utils import NumberList, TensorList
9
9
 
10
10
 
11
11
  def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
@@ -24,7 +24,7 @@ def _clip_norm_(
24
24
  min: float | NumberList | None,
25
25
  max: float | NumberList | None,
26
26
  norm_value: float | NumberList | None,
27
- ord: float,
27
+ ord: float | Literal['mean_abs'],
28
28
  dim: int | Sequence[int] | Literal["global"] | None,
29
29
  inverse_dims: bool,
30
30
  min_size: int,
@@ -54,9 +54,13 @@ def _clip_norm_(
54
54
  size = math.prod(tensor.size(d) for d in real_dim)
55
55
  if size < min_size: continue
56
56
 
57
- norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
57
+ if ord == 'mean_abs':
58
+ norm = tensor.abs().mean(dim=real_dim, keepdim=True)
59
+ else:
60
+ norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
61
+
58
62
  if norm.numel() == 1 and norm == 0: continue
59
- norm = torch.where(norm == 0, 1, norm)
63
+ norm = torch.where(norm <= 1e-12, 1, norm)
60
64
 
61
65
  # normalize = True, perform normalization
62
66
  norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
@@ -90,7 +94,7 @@ def _clip_norm_(
90
94
  def clip_grad_norm_(
91
95
  params: Iterable[torch.Tensor],
92
96
  max_norm: float | None,
93
- ord: float = 2,
97
+ ord: float | Literal['mean_abs'] = 2,
94
98
  dim: int | Sequence[int] | Literal["global"] | None = None,
95
99
  inverse_dims: bool = False,
96
100
  min_size: int = 2,
@@ -118,7 +122,7 @@ def clip_grad_norm_(
118
122
  def normalize_grads_(
119
123
  params: Iterable[torch.Tensor],
120
124
  norm_value: float,
121
- ord: float = 2,
125
+ ord: float | Literal['mean_abs'] = 2,
122
126
  dim: int | Sequence[int] | Literal["global"] | None = None,
123
127
  inverse_dims: bool = False,
124
128
  min_size: int = 1,
@@ -145,13 +149,43 @@ def normalize_grads_(
145
149
 
146
150
 
147
151
  class ClipValue(Transform):
148
- """Clips update magnitude to be within `(-value, value)` range."""
152
+ """Clips update magnitude to be within `(-value, value)` range.
153
+
154
+ Args:
155
+ value (float): value to clip to.
156
+ target (str): refer to :ref:`target argument` in documentation.
157
+
158
+ Examples:
159
+
160
+ Gradient clipping:
161
+
162
+ .. code-block:: python
163
+
164
+ opt = tz.Modular(
165
+ model.parameters(),
166
+ tz.m.ClipValue(1),
167
+ tz.m.Adam(),
168
+ tz.m.LR(1e-2),
169
+ )
170
+
171
+ Update clipping:
172
+
173
+ .. code-block:: python
174
+
175
+ opt = tz.Modular(
176
+ model.parameters(),
177
+ tz.m.Adam(),
178
+ tz.m.ClipValue(1),
179
+ tz.m.LR(1e-2),
180
+ )
181
+
182
+ """
149
183
  def __init__(self, value: float, target: Target = 'update'):
150
184
  defaults = dict(value=value)
151
- super().__init__(defaults, uses_grad=False, target=target)
185
+ super().__init__(defaults, target=target)
152
186
 
153
187
  @torch.no_grad
154
- def apply(self, tensors, params, grads, loss, states, settings):
188
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
155
189
  value = [s['value'] for s in settings]
156
190
  return TensorList(tensors).clip_([-v for v in value], value)
157
191
 
@@ -172,21 +206,45 @@ class ClipNorm(Transform):
172
206
  minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
173
207
  target (str, optional):
174
208
  what this affects.
209
+
210
+ Examples:
211
+
212
+ Gradient norm clipping:
213
+
214
+ .. code-block:: python
215
+
216
+ opt = tz.Modular(
217
+ model.parameters(),
218
+ tz.m.ClipNorm(1),
219
+ tz.m.Adam(),
220
+ tz.m.LR(1e-2),
221
+ )
222
+
223
+ Update norm clipping:
224
+
225
+ .. code-block:: python
226
+
227
+ opt = tz.Modular(
228
+ model.parameters(),
229
+ tz.m.Adam(),
230
+ tz.m.ClipNorm(1),
231
+ tz.m.LR(1e-2),
232
+ )
175
233
  """
176
234
  def __init__(
177
235
  self,
178
236
  max_norm: float,
179
- ord: float = 2,
237
+ ord: float | Literal['mean_abs'] = 2,
180
238
  dim: int | Sequence[int] | Literal["global"] | None = None,
181
239
  inverse_dims: bool = False,
182
240
  min_size: int = 1,
183
241
  target: Target = "update",
184
242
  ):
185
243
  defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
186
- super().__init__(defaults, uses_grad=False, target=target)
244
+ super().__init__(defaults, target=target)
187
245
 
188
246
  @torch.no_grad
189
- def apply(self, tensors, params, grads, loss, states, settings):
247
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
190
248
  max_norm = NumberList(s['max_norm'] for s in settings)
191
249
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
192
250
  _clip_norm_(
@@ -218,21 +276,45 @@ class Normalize(Transform):
218
276
  minimal size of a dimension to normalize along it. Defaults to 1.
219
277
  target (str, optional):
220
278
  what this affects.
279
+
280
+ Examples:
281
+
282
+ Gradient normalization:
283
+
284
+ .. code-block:: python
285
+
286
+ opt = tz.Modular(
287
+ model.parameters(),
288
+ tz.m.Normalize(1),
289
+ tz.m.Adam(),
290
+ tz.m.LR(1e-2),
291
+ )
292
+
293
+ Update normalization:
294
+
295
+ .. code-block:: python
296
+
297
+ opt = tz.Modular(
298
+ model.parameters(),
299
+ tz.m.Adam(),
300
+ tz.m.Normalize(1),
301
+ tz.m.LR(1e-2),
302
+ )
221
303
  """
222
304
  def __init__(
223
305
  self,
224
306
  norm_value: float = 1,
225
- ord: float = 2,
307
+ ord: float | Literal['mean_abs'] = 2,
226
308
  dim: int | Sequence[int] | Literal["global"] | None = None,
227
309
  inverse_dims: bool = False,
228
310
  min_size: int = 1,
229
311
  target: Target = "update",
230
312
  ):
231
313
  defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
232
- super().__init__(defaults, uses_grad=False, target=target)
314
+ super().__init__(defaults, target=target)
233
315
 
234
316
  @torch.no_grad
235
- def apply(self, tensors, params, grads, loss, states, settings):
317
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
236
318
  norm_value = NumberList(s['norm_value'] for s in settings)
237
319
  ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
238
320
 
@@ -299,6 +381,21 @@ class Centralize(Transform):
299
381
  if True, the `dims` argument is inverted, and all other dimensions are centralized.
300
382
  min_size (int, optional):
301
383
  minimal size of a dimension to normalize along it. Defaults to 1.
384
+
385
+ Examples:
386
+
387
+ Standard gradient centralization:
388
+
389
+ .. code-block:: python
390
+
391
+ opt = tz.Modular(
392
+ model.parameters(),
393
+ tz.m.Centralize(dim=0),
394
+ tz.m.LR(1e-2),
395
+ )
396
+
397
+ References:
398
+ - Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
302
399
  """
303
400
  def __init__(
304
401
  self,
@@ -308,10 +405,10 @@ class Centralize(Transform):
308
405
  target: Target = "update",
309
406
  ):
310
407
  defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
311
- super().__init__(defaults, uses_grad=False, target=target)
408
+ super().__init__(defaults, target=target)
312
409
 
313
410
  @torch.no_grad
314
- def apply(self, tensors, params, grads, loss, states, settings):
411
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
315
412
  dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
316
413
 
317
414
  _centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)