torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,137 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections import ChainMap, defaultdict
3
+ from collections.abc import Mapping, Sequence
4
+ from typing import Any, overload, final
5
+
6
+ import torch
7
+
8
+ from .module import Module, Chainable, Vars
9
+ from .transform import apply, Transform, Target
10
+ from ..utils import TensorList, vec_to_tensors
11
+
12
+ class Preconditioner(Transform):
13
+ """Abstract class for a preconditioner."""
14
+ def __init__(
15
+ self,
16
+ defaults: dict | None,
17
+ uses_grad: bool,
18
+ concat_params: bool = False,
19
+ update_freq: int = 1,
20
+ scale_first: bool = False,
21
+ inner: Chainable | None = None,
22
+ target: Target = "update",
23
+ ):
24
+ if defaults is None: defaults = {}
25
+ defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
26
+ super().__init__(defaults, uses_grad=uses_grad, target=target)
27
+
28
+ if inner is not None:
29
+ self.set_child('inner', inner)
30
+
31
+ @abstractmethod
32
+ def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
33
+ """updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
34
+
35
+ @abstractmethod
36
+ def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
37
+ """applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
38
+
39
+
40
+ def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
+ step = self.global_state.get('step', 0)
42
+ states = [self.state[p] for p in params]
43
+ settings = [self.settings[p] for p in params]
44
+ global_settings = settings[0]
45
+ update_freq = global_settings['__update_freq']
46
+
47
+ scale_first = global_settings['__scale_first']
48
+ scale_factor = 0
49
+ if scale_first and step == 0:
50
+ # initial step size guess from pytorch LBFGS
51
+ scale_factor = TensorList(tensors).abs().sum()
52
+
53
+ # update preconditioner
54
+ if step % update_freq == 0:
55
+ self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
56
+
57
+ # step with inner
58
+ if 'inner' in self.children:
59
+ tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
60
+
61
+ # apply preconditioner
62
+ tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
63
+
64
+ # scale initial step, when preconditioner might not have been applied
65
+ if scale_first and step == 0:
66
+ torch._foreach_div_(tensors, scale_factor)
67
+
68
+ self.global_state['step'] = step + 1
69
+ return tensors
70
+
71
+ def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
72
+ step = self.global_state.get('step', 0)
73
+ tensors_vec = torch.cat([t.ravel() for t in tensors])
74
+ params_vec = torch.cat([p.ravel() for p in params])
75
+ grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
76
+
77
+ states = [self.state[params[0]]]
78
+ settings = [self.settings[params[0]]]
79
+ global_settings = settings[0]
80
+ update_freq = global_settings['__update_freq']
81
+
82
+ scale_first = global_settings['__scale_first']
83
+ scale_factor = 0
84
+ if scale_first and step == 0:
85
+ # initial step size guess from pytorch LBFGS
86
+ scale_factor = tensors_vec.abs().sum()
87
+
88
+ # update preconditioner
89
+ if step % update_freq == 0:
90
+ self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
91
+
92
+ # step with inner
93
+ if 'inner' in self.children:
94
+ tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
95
+ tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
96
+
97
+ # apply preconditioner
98
+ tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
99
+
100
+ # scale initial step, when preconditioner might not have been applied
101
+ if scale_first and step == 0:
102
+ if scale_factor >= torch.finfo(tensors_vec.dtype).eps:
103
+ tensors_vec /= scale_factor
104
+
105
+ tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
106
+ self.global_state['step'] = step + 1
107
+ return tensors
108
+
109
+ @torch.no_grad
110
+ def transform(self, tensors, params, grads, vars):
111
+ concat_params = self.settings[params[0]]['__concat_params']
112
+ if concat_params: return self._concat_transform(tensors, params, grads, vars)
113
+ return self._tensor_wise_transform(tensors, params, grads, vars)
114
+
115
+ class TensorwisePreconditioner(Preconditioner, ABC):
116
+ @abstractmethod
117
+ def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
118
+ """update preconditioner with `tensor`"""
119
+
120
+ @abstractmethod
121
+ def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
122
+ """apply preconditioner to `tensor`"""
123
+
124
+ @final
125
+ def update(self, tensors, params, grads, states, settings):
126
+ if grads is None: grads = [None]*len(tensors)
127
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
128
+ self.update_tensor(t, p, g, state, setting)
129
+
130
+ @final
131
+ def apply(self, tensors, params, grads, states, settings):
132
+ preconditioned = []
133
+ if grads is None: grads = [None]*len(tensors)
134
+ for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
135
+ preconditioned.append(self.apply_tensor(t, p, g, state, setting))
136
+ return preconditioned
137
+
@@ -0,0 +1,252 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable, Sequence
3
+ from typing import Any, Literal
4
+
5
+ import torch
6
+
7
+ from ..utils import set_storage_
8
+ from .module import Module, Vars, Chain, Chainable
9
+
10
+ Target = Literal['grad', 'update', 'closure', 'params_direct', 'params_difference', 'update_difference']
11
+
12
+ class Transform(Module, ABC):
13
+ """Base class for a transform.
14
+
15
+ This is an abstract class, to use it, subclass it and override `transform`.
16
+
17
+ Args:
18
+ defaults (dict[str,Any] | None): dict with default values.
19
+ uses_grad (bool):
20
+ Set this to True if `transform` method uses the `grad` argument. This will ensure
21
+ `grad` is always computed and can't be None. Otherwise set to False.
22
+ target (Target, optional):
23
+ what to set on vars. Defaults to 'update'.
24
+ """
25
+ def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
26
+ super().__init__(defaults)
27
+ self._target: Target = target
28
+ self._uses_grad = uses_grad
29
+
30
+ @abstractmethod
31
+ def transform(self, tensors: list[torch.Tensor], params: list[torch.Tensor], grads: list[torch.Tensor] | None, vars: Vars) -> Iterable[torch.Tensor]:
32
+ """applies the update rule to `target`."""
33
+
34
+ def step(self, vars: Vars) -> Vars:
35
+ # vars may change, therefore current params and grads have to be extracted and passed explicitly
36
+ if self._uses_grad: vars.get_grad()
37
+ params=vars.params; grad = vars.grad
38
+
39
+ # ---------------------------------- update ---------------------------------- #
40
+ if self._target == 'update':
41
+ vars.update = list(self.transform(vars.get_update(), params, grad, vars))
42
+ return vars
43
+
44
+ # ----------------------------------- grad ----------------------------------- #
45
+ if self._target == 'grad':
46
+ vars.grad = list(self.transform(vars.get_grad(), params, grad, vars))
47
+ return vars
48
+
49
+ # ------------------------------- params_direct ------------------------------ #
50
+ if self._target == 'params_direct':
51
+ new_params = self.transform(vars.params, params, grad, vars)
52
+ for p, new_p in zip(vars.params, new_params): set_storage_(p, new_p)
53
+ return vars
54
+
55
+ # ----------------------------- params_differnce ----------------------------- #
56
+ if self._target == 'params_difference':
57
+ new_params = tuple(self.transform([p.clone() for p in vars.params], params, grad, vars))
58
+ vars.update = list(torch._foreach_sub(vars.params, new_params))
59
+ return vars
60
+
61
+ # ----------------------------- update_difference ---------------------------- #
62
+ if self._target == 'update_difference':
63
+ update = vars.get_update()
64
+ new_update = tuple(self.transform([u.clone() for u in update], params, grad, vars))
65
+ vars.update = list(torch._foreach_sub(update, new_update))
66
+ return vars
67
+
68
+ # ---------------------------------- closure --------------------------------- #
69
+ if self._target == 'closure':
70
+ original_closure = vars.closure
71
+ if original_closure is None: raise ValueError('Target = "closure", but closure is None')
72
+
73
+ params = vars.params
74
+ def transformed_closure(backward=True):
75
+ if backward:
76
+ loss = original_closure()
77
+ current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
78
+ transformed_grad = list(self.transform(current_grad, params, grad, vars))
79
+ for p, g in zip(params, transformed_grad):
80
+ p.grad = g
81
+
82
+ else:
83
+ loss = original_closure(False)
84
+
85
+ return loss
86
+
87
+ vars.closure = transformed_closure
88
+ return vars
89
+
90
+ # ---------------------------------- invalid --------------------------------- #
91
+ raise ValueError(f'Invalid target: {self._target}')
92
+
93
+
94
+ class TensorwiseTransform(Module, ABC):
95
+ """Base class for a parameter-wise transform.
96
+
97
+ This is an abstract class, to use it, subclass it and override `transform`.
98
+
99
+ Args:
100
+ defaults (dict[str,Any] | None): dict with default values.
101
+ uses_grad (bool):
102
+ Set this to True if `transform` method uses the `grad` argument. This will ensure
103
+ `grad` is always computed and can't be None. Otherwise set to False.
104
+ target (Target, optional):
105
+ what to set on vars. Defaults to 'update'.
106
+ """
107
+ def __init__(self, defaults: dict[str,Any] | None, uses_grad: bool, target: Target = 'update'):
108
+ super().__init__(defaults)
109
+ self._target: Target = target
110
+ self._uses_grad: bool = uses_grad
111
+
112
+ @abstractmethod
113
+ def transform(
114
+ self,
115
+ tensor: torch.Tensor,
116
+ param: torch.Tensor,
117
+ grad: torch.Tensor | None,
118
+ vars: Vars,
119
+ ) -> torch.Tensor:
120
+ """applies the update rule to `target`"""
121
+
122
+ def step(self, vars: Vars) -> Vars:
123
+ params = vars.params
124
+ if self._uses_grad and vars.grad is None: vars.get_grad()
125
+
126
+ # ---------------------------------- update ---------------------------------- #
127
+ if self._target == 'update':
128
+ update = vars.get_update()
129
+ grad = vars.grad if vars.grad is not None else [None] * len(params)
130
+ transformed_update = []
131
+
132
+ for p, g, u in zip(params, grad, update):
133
+ # settings = self.settings[p] # couldn't make typing work with this
134
+ #, self.transform(target=u, param=p, grad=g, vars=vars, **{k:settings[k] for k in self.defaults})
135
+ transformed_update.append(self.transform(tensor=u, param=p, grad=g, vars=vars))
136
+
137
+ vars.update = transformed_update
138
+ return vars
139
+
140
+ # ----------------------------------- grad ----------------------------------- #
141
+ if self._target == 'grad':
142
+ grad = vars.get_grad()
143
+ transformed_grad = []
144
+
145
+ for p, g in zip(params, grad):
146
+ transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
147
+
148
+ vars.grad = transformed_grad
149
+ return vars
150
+
151
+ # ------------------------------- params_direct ------------------------------ #
152
+ if self._target == 'params_direct':
153
+ grad = vars.grad if vars.grad is not None else [None] * len(params)
154
+
155
+ for p, g in zip(params, grad):
156
+ set_storage_(p, self.transform(tensor=p, param=p, grad=g, vars=vars))
157
+
158
+ return vars
159
+
160
+ # ----------------------------- params_difference ---------------------------- #
161
+ if self._target == 'params_difference':
162
+ grad = vars.grad if vars.grad is not None else [None] * len(params)
163
+ transformed_params = []
164
+
165
+ for p, g in zip(params, grad):
166
+ transformed_params.append(
167
+ self.transform(tensor=p.clone(), param=p, grad=g, vars=vars)
168
+ )
169
+
170
+ vars.update = list(torch._foreach_sub(params, transformed_params))
171
+ return vars
172
+
173
+ # ----------------------------- update_difference ---------------------------- #
174
+ if self._target == 'update_difference':
175
+ update = vars.get_update()
176
+ grad = vars.grad if vars.grad is not None else [None] * len(params)
177
+ transformed_update = []
178
+
179
+ for p, g, u in zip(params, grad, update):
180
+ transformed_update.append(
181
+ self.transform(tensor=u.clone(), param=p, grad=g, vars=vars)
182
+ )
183
+
184
+ vars.update = list(torch._foreach_sub(update, transformed_update))
185
+ return vars
186
+
187
+ # ---------------------------------- closure --------------------------------- #
188
+ if self._target == 'closure':
189
+ original_closure = vars.closure
190
+ if original_closure is None: raise ValueError('Target = "closure", but closure is None')
191
+
192
+ params = vars.params
193
+ def transformed_closure(backward=True):
194
+ if backward:
195
+ loss = original_closure()
196
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
197
+ transformed_grad = []
198
+
199
+ for p, g in zip(params, grad):
200
+ transformed_grad.append(self.transform(tensor=g, param=p, grad=g, vars=vars))
201
+
202
+ for p, g in zip(params, transformed_grad):
203
+ p.grad = g
204
+
205
+ else:
206
+ loss = original_closure(False)
207
+
208
+ return loss
209
+
210
+ vars.closure = transformed_closure
211
+ return vars
212
+
213
+ # ---------------------------------- invalid --------------------------------- #
214
+ raise ValueError(f'Invalid target: {self._target}')
215
+
216
+
217
+
218
+ def apply(
219
+ tfm: Chainable,
220
+ tensors: list[torch.Tensor],
221
+ params: list[torch.Tensor],
222
+ grads: list[torch.Tensor] | None,
223
+ vars: Vars | None = None,
224
+ current_step: int = 0,
225
+ ):
226
+ if vars is None: vars = Vars(params=params, closure=None, model=None, current_step=current_step)
227
+ if isinstance(tfm, Transform):
228
+ if tfm._uses_grad and grads is None: grads = vars.get_grad()
229
+ return list(tfm.transform(tensors, params, grads, vars))
230
+
231
+ if isinstance(tfm, TensorwiseTransform):
232
+ grads_list = grads
233
+ if grads_list is None:
234
+ if tfm._uses_grad: grads_list = vars.get_grad()
235
+ else: grads_list = [None] * len(tensors)
236
+ return [tfm.transform(t, p, g, vars) for t,p,g in zip(tensors,params,grads_list)]
237
+
238
+ if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
239
+ if isinstance(tfm, Sequence):
240
+ for module in tfm:
241
+ tensors = apply(module, tensors=tensors, params=params, grads=grads, vars=vars)
242
+ return tensors
243
+
244
+ if isinstance(tfm, Module):
245
+ cvars = vars.clone(clone_update=False)
246
+ cvars.update = tensors
247
+ cvars = tfm.step(cvars)
248
+ vars.update_attrs_from_clone_(cvars)
249
+ assert cvars.update is not None
250
+ return cvars.update
251
+
252
+ raise TypeError(type(tfm))
@@ -1,21 +1,13 @@
1
- r"""
2
- This submodule contains composable optimizer "building blocks".
3
- """
4
-
5
- from ..core.module import OptimizerModule
6
- from . import experimental
7
- from .adaptive import *
8
- from .gradient_approximation import *
9
- from .line_search import *
10
- from .meta import *
11
- from .misc import *
12
- from .momentum import *
13
- from .operations import *
14
- from .optimizers import *
15
- from .orthogonalization import *
16
- from .quasi_newton import *
17
- from .regularization import *
18
- from .scheduling import *
19
- from .second_order import *
20
- from .smoothing import *
21
- from .weight_averaging import *
1
+ from .clipping import *
2
+ from .grad_approximation import *
3
+ from .line_search import *
4
+ from .lr import *
5
+ from .momentum import *
6
+ from .ops import *
7
+ from .optimizers import *
8
+ from .projections import *
9
+ from .quasi_newton import *
10
+ from .smoothing import *
11
+ from .weight_decay import *
12
+ from .wrappers import *
13
+ from .second_order import *
@@ -0,0 +1,3 @@
1
+ from .clipping import ClipValue, ClipNorm, Normalize, clip_grad_norm_, clip_grad_value_, normalize_grads_, Centralize
2
+ from .growth_clipping import ClipNormGrowth, ClipValueGrowth
3
+ from .ema_clipping import ClipNormByEMA, NormalizeByEMA, ClipValueByEMA