torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,115 +0,0 @@
1
- from collections.abc import Callable, Iterable
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
-
7
- from ...core import OptimizerModule, _Chainable
8
-
9
-
10
- class Alpha(OptimizerModule):
11
- """Multiplies update by the learning rate, won't get picked up by learning rate schedulers."""
12
- def __init__(self, alpha = 1e-3):
13
- defaults = dict(alpha = alpha)
14
- super().__init__(defaults)
15
-
16
- @torch.no_grad
17
- def _update(self, vars, ascent):
18
- # multiply ascent direction by lr in-place
19
- lr = self.get_group_key('alpha')
20
- ascent *= lr
21
- return ascent
22
-
23
- class Clone(OptimizerModule):
24
- """Clones the update. Some modules update ascent in-place, so this may be
25
- useful if you need to preserve it."""
26
- def __init__(self):
27
- super().__init__({})
28
-
29
- @torch.no_grad
30
- def _update(self, vars, ascent): return ascent.clone()
31
-
32
- class Identity(OptimizerModule):
33
- """Does nothing."""
34
- def __init__(self, *args, **kwargs):
35
- super().__init__({})
36
-
37
- @torch.no_grad
38
- def _update(self, vars, ascent): return ascent
39
-
40
- class Lambda(OptimizerModule):
41
- """Applies a function to the ascent direction.
42
- The function must take a TensorList as the argument, and return the modified tensorlist.
43
-
44
- Args:
45
- f (Callable): function
46
- """
47
- def __init__(self, f: Callable[[TensorList], TensorList]):
48
- super().__init__({})
49
- self.f = f
50
-
51
- @torch.no_grad()
52
- def _update(self, vars, ascent): return self.f(ascent)
53
-
54
- class Grad(OptimizerModule):
55
- """Uses gradient as the update. This is useful for chains."""
56
- def __init__(self):
57
- super().__init__({})
58
-
59
- @torch.no_grad
60
- def _update(self, vars, ascent):
61
- ascent = vars.ascent = vars.maybe_compute_grad_(self.get_params())
62
- return ascent
63
-
64
- class Zeros(OptimizerModule):
65
- def __init__(self):
66
- super().__init__({})
67
-
68
- @torch.no_grad
69
- def _update(self, vars, ascent):
70
- return ascent.zeros_like()
71
-
72
- class Fill(OptimizerModule):
73
- def __init__(self, value):
74
- super().__init__({"value": value})
75
-
76
- @torch.no_grad
77
- def _update(self, vars, ascent):
78
- return ascent.fill(self.get_group_key('value'))
79
-
80
-
81
- class GradToUpdate(OptimizerModule):
82
- """sets gradient and .grad attributes to current update"""
83
- def __init__(self):
84
- super().__init__({})
85
-
86
- def _update(self, vars, ascent):
87
- vars.set_grad_(ascent, self.get_params())
88
- return ascent
89
-
90
- class MakeClosure(OptimizerModule):
91
- """Makes a closure that sets `.grad` attribute to the update generated by `modules`"""
92
- def __init__(self, modules: _Chainable):
93
- super().__init__({})
94
- self._set_child_('modules', modules)
95
-
96
- def step(self, vars):
97
- if vars.closure is None: raise ValueError("MakeClosure requires a closure")
98
-
99
- params = self.get_params()
100
- orig_closure = vars.closure
101
- orig_state = vars.copy(True)
102
-
103
- def new_closure(backward = True):
104
- if backward:
105
- cloned_state = orig_state.copy(True)
106
- g = self.children['modules'].return_ascent(cloned_state)
107
- params.set_grad_(g)
108
- return cloned_state.get_loss()
109
-
110
- else:
111
- return orig_closure(False)
112
-
113
- vars.closure = new_closure # type:ignore
114
- return self._update_params_or_step_with_next(vars)
115
-
@@ -1,96 +0,0 @@
1
- import random
2
- from collections.abc import Callable, Iterable
3
- from functools import partial
4
- from typing import TYPE_CHECKING, Any, overload
5
-
6
- import torch
7
-
8
- from ...tensorlist import TensorList
9
-
10
- from ...core import OptimizerModule
11
-
12
- if TYPE_CHECKING:
13
- from ...optim import Modular
14
-
15
- def _init_scheduler_hook(opt: "Modular", module: "LR", scheduler_cls, **kwargs):
16
- """post init hook that initializes the lr scheduler to the LR module and sets `_scheduler_step_fn`."""
17
- scheduler = scheduler_cls(module, **kwargs)
18
- module._scheduler_step_fn = scheduler.step
19
-
20
- def _set_momentum_hook(optimizer, state, momentum):
21
- for module in optimizer.unrolled_modules:
22
- if 'momentum' in module.defaults:
23
- for g in module.param_groups:
24
- g['momentum'] = momentum
25
- elif 'beta1' in module.defaults:
26
- for g in module.param_groups:
27
- g['beta1'] = momentum
28
-
29
- class LR(OptimizerModule):
30
- """Multiplies update by the learning rate. Optionally uses an lr scheduler.
31
-
32
- Args:
33
- lr (float, optional): learning rate. Defaults to 1e-3.
34
- scheduler (Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None, optional):
35
- A scheduler class, for example `torch.optim.lr_scheduler.OneCycleLR`. Defaults to None.
36
- cycle_momentum (bool, optional):
37
- enables schedulers that support it to affect momentum (like OneCycleLR).
38
- The momentum will be cycled on ALL modules that have `momentum` or `beta1` setting.
39
- This does not support external optimizers, wrapped with `Wrap`. Defaults to True.
40
- sheduler_step_every (int, optional):
41
- step with scheduler every n optimizer steps.
42
- Useful when the scheduler steps once per epoch. Defaults to 1.
43
- **kwargs:
44
- kwargs to pass to `scheduler`.
45
- """
46
- IS_LR_MODULE = True
47
- def __init__(
48
- self,
49
- lr: float = 1e-3,
50
- scheduler_cls: Callable[..., torch.optim.lr_scheduler.LRScheduler | Any] | None = None,
51
- cycle_momentum: bool = True,
52
- sheduler_step_every: int = 1,
53
- # *args,
54
- **kwargs,
55
- ):
56
-
57
- defaults = dict(lr = lr)
58
-
59
- if (scheduler_cls is not None) and cycle_momentum:
60
- defaults['momentum'] = 0
61
- super().__init__(defaults)
62
-
63
- self._scheduler_step_fn = None
64
- self.sheduler_step_every = sheduler_step_every
65
- self.cycle_momentum = cycle_momentum
66
- self.cur = 0
67
-
68
- if scheduler_cls is not None:
69
- self.post_init_hooks.append(lambda opt, module: _init_scheduler_hook(opt, module, scheduler_cls, **kwargs))
70
-
71
- self._skip = False
72
-
73
- @torch.no_grad
74
- def _update(self, vars, ascent):
75
- # step with scheduler
76
- if self._scheduler_step_fn is not None:
77
- if self.cur != 0 and self.cur % self.sheduler_step_every == 0:
78
- self._scheduler_step_fn()
79
-
80
- # add a hook to cycle momentum
81
- if self.cycle_momentum:
82
- vars.add_post_step_hook(_set_momentum_hook)
83
-
84
- # remove init hook to delete reference to scheduler
85
- if self.cur == 0 and len(self.post_init_hooks) == 1:
86
- del self.post_init_hooks[0]
87
-
88
- # skip if lr was applied by previous module (LR fusing)
89
- if not self._skip:
90
- # multiply ascent direction by lr in-place
91
- lr = self.get_group_key('lr')
92
- ascent *= lr
93
-
94
- self.cur += 1
95
- self._skip = False
96
- return ascent
@@ -1,51 +0,0 @@
1
- from collections.abc import Callable, Iterable
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
-
7
- from ...core import OptimizerModule, _Chainable
8
-
9
-
10
- class Multistep(OptimizerModule):
11
- """Performs multiple steps (per batch), passes total update to the next module.
12
-
13
- Args:
14
- modules (_Chainable): modules to perform multiple steps with.
15
- num_steps (int, optional): number of steps to perform. Defaults to 2.
16
- """
17
- def __init__(self, modules: _Chainable, num_steps: int = 2):
18
- super().__init__({})
19
- self.num_steps = num_steps
20
-
21
- self._set_child_('modules', modules)
22
-
23
- def step(self, vars):
24
- # no next module, just perform multiple steps
25
- if self.next_module is None:
26
- ret = None
27
- for step in range(self.num_steps):
28
- state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
29
- ret = self.children['modules'].step(state_copy)
30
-
31
- # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
32
- vars.grad = None; vars.fx0 = None
33
-
34
- return ret
35
-
36
- # accumulate steps and pass to next module
37
- p0 = self.get_params().clone()
38
- for step in range(self.num_steps):
39
- state_copy = vars.copy(clone_ascent=True) if step != self.num_steps - 1 else vars
40
- self.children['modules'].step(state_copy)
41
-
42
- # since parameters are updated after stepping, grad and fx0 must be erased as they are no longer correct
43
- vars.grad = None; vars.fx0 = None
44
-
45
- p1 = self.get_params()
46
- vars.ascent = p0 - p1
47
-
48
- # undo ascent
49
- p1.set_(p0)
50
-
51
- return self._update_params_or_step_with_next(vars, p1)
@@ -1,53 +0,0 @@
1
- import torch
2
-
3
- from ...core import OptimizerModule
4
-
5
-
6
- class NegateOnLossIncrease(OptimizerModule):
7
- """Performs an additional evaluation to check if update increases the loss. If it does,
8
- negates or backtracks the update.
9
-
10
- Args:
11
- backtrack (bool, optional):
12
- if True, sets update to minus update, otherwise sets it to zero. Defaults to True.
13
- """
14
- def __init__(self, backtrack = True):
15
- super().__init__({})
16
- self.backtrack = backtrack
17
-
18
- @torch.no_grad()
19
- def step(self, vars):
20
- if vars.closure is None: raise ValueError('NegateOnLossIncrease requires closure.')
21
- if vars.fx0 is None: vars.fx0 = vars.closure(False)
22
-
23
- # subtract ascent direction to params and see if loss decreases
24
- params = self.get_params()
25
- ascent_direction = vars.maybe_use_grad_(params)
26
- params -= ascent_direction
27
- vars.fx0_approx = vars.closure(False)
28
-
29
- # if this has no children, update params and return loss
30
- if self.next_module is None:
31
- if params is None: params = self.get_params()
32
-
33
- if vars.fx0_approx > vars.fx0:
34
- # loss increased, so we negate thea scent direction
35
- # we are currently at params - ascent direction
36
- # so we add twice the ascent direction
37
- params.add_(ascent_direction, alpha = 2 if self.backtrack else 1)
38
-
39
- # else: we are already at a lower loss point
40
- return vars.get_loss()
41
-
42
- # otherwise undo the ascent direction because it is passed to the child
43
- params += ascent_direction
44
-
45
- # if loss increases, negate ascent direction
46
- if vars.fx0_approx > vars.fx0:
47
- if self.backtrack: ascent_direction.neg_()
48
- else: ascent_direction.zero_()
49
-
50
- # otherwise undo the ascent direction and pass the updated ascent direction to the child
51
- return self.next_module.step(vars)
52
-
53
-
@@ -1,29 +0,0 @@
1
- from .multi import (
2
- Add,
3
- AddMagnitude,
4
- Div,
5
- Divide,
6
- Interpolate,
7
- Lerp,
8
- Mul,
9
- Pow,
10
- Power,
11
- RDiv,
12
- RPow,
13
- RSub,
14
- Sub,
15
- Subtract,
16
- )
17
- from .reduction import Mean, Product, Sum
18
- from .singular import (
19
- Abs,
20
- Cos,
21
- MagnitudePower,
22
- NanToNum,
23
- Negate,
24
- Operation,
25
- Reciprocal,
26
- Sign,
27
- Sin,
28
- sign_grad_,
29
- )
@@ -1,298 +0,0 @@
1
- from collections.abc import Iterable
2
- import torch
3
-
4
- from ...core import OptimizerModule
5
-
6
- _Value = int | float | OptimizerModule | Iterable[OptimizerModule]
7
-
8
- class Add(OptimizerModule):
9
- """add `value` to update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
10
- def __init__(self, value: _Value):
11
- super().__init__({})
12
-
13
- if not isinstance(value, (int, float)):
14
- self._set_child_('value', value)
15
-
16
- self.value = value
17
-
18
- @torch.no_grad()
19
- def _update(self, vars, ascent):
20
- if isinstance(self.value, (int, float)):
21
- return ascent.add_(self.value)
22
-
23
- state_copy = vars.copy(clone_ascent = True)
24
- v = self.children['value'].return_ascent(state_copy)
25
- return ascent.add_(v)
26
-
27
-
28
- class Sub(OptimizerModule):
29
- """subtracts `value` from update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
30
- def __init__(self, subtrahend: _Value):
31
- super().__init__({})
32
-
33
- if not isinstance(subtrahend, (int, float)):
34
- self._set_child_('subtrahend', subtrahend)
35
-
36
- self.subtrahend = subtrahend
37
-
38
- @torch.no_grad()
39
- def _update(self, vars, ascent):
40
- if isinstance(self.subtrahend, (int, float)):
41
- return ascent.sub_(self.subtrahend)
42
-
43
- state_copy = vars.copy(clone_ascent = True)
44
- subtrahend = self.children['subtrahend'].return_ascent(state_copy)
45
- return ascent.sub_(subtrahend)
46
-
47
- class RSub(OptimizerModule):
48
- """subtracts update from `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
49
- def __init__(self, minuend: _Value):
50
- super().__init__({})
51
-
52
- if not isinstance(minuend, (int, float)):
53
- self._set_child_('minuend', minuend)
54
-
55
- self.minuend = minuend
56
-
57
- @torch.no_grad()
58
- def _update(self, vars, ascent):
59
- if isinstance(self.minuend, (int, float)):
60
- return ascent.sub_(self.minuend).neg_()
61
-
62
- state_copy = vars.copy(clone_ascent = True)
63
- minuend = self.children['minuend'].return_ascent(state_copy)
64
- return ascent.sub_(minuend).neg_()
65
-
66
- class Subtract(OptimizerModule):
67
- """Calculates `minuend - subtrahend`"""
68
- def __init__(
69
- self,
70
- minuend: OptimizerModule | Iterable[OptimizerModule],
71
- subtrahend: OptimizerModule | Iterable[OptimizerModule],
72
- ):
73
- super().__init__({})
74
- self._set_child_('minuend', minuend)
75
- self._set_child_('subtrahend', subtrahend)
76
-
77
- @torch.no_grad
78
- def step(self, vars):
79
- state_copy = vars.copy(clone_ascent = True)
80
- minuend = self.children['minuend'].return_ascent(state_copy)
81
- vars.update_attrs_(state_copy)
82
- subtrahend = self.children['subtrahend'].return_ascent(vars)
83
-
84
- vars.ascent = minuend.sub_(subtrahend)
85
- return self._update_params_or_step_with_next(vars)
86
-
87
- class Mul(OptimizerModule):
88
- """multiplies update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
89
- def __init__(self, value: _Value):
90
- super().__init__({})
91
-
92
- if not isinstance(value, (int, float)):
93
- self._set_child_('value', value)
94
-
95
- self.value = value
96
-
97
- @torch.no_grad()
98
- def _update(self, vars, ascent):
99
- if isinstance(self.value, (int, float)):
100
- return ascent.mul_(self.value)
101
-
102
- state_copy = vars.copy(clone_ascent = True)
103
- v = self.children['value'].return_ascent(state_copy)
104
- return ascent.mul_(v)
105
-
106
-
107
- class Div(OptimizerModule):
108
- """divides update by `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
109
- def __init__(self, denominator: _Value):
110
- super().__init__({})
111
-
112
- if not isinstance(denominator, (int, float)):
113
- self._set_child_('denominator', denominator)
114
-
115
- self.denominator = denominator
116
-
117
- @torch.no_grad()
118
- def _update(self, vars, ascent):
119
- if isinstance(self.denominator, (int, float)):
120
- return ascent.div_(self.denominator)
121
-
122
- state_copy = vars.copy(clone_ascent = True)
123
- denominator = self.children['denominator'].return_ascent(state_copy)
124
- return ascent.div_(denominator)
125
-
126
- class RDiv(OptimizerModule):
127
- """`value` by update. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
128
- def __init__(self, numerator: _Value):
129
- super().__init__({})
130
-
131
- if not isinstance(numerator, (int, float)):
132
- self._set_child_('numerator', numerator)
133
-
134
- self.numerator = numerator
135
-
136
- @torch.no_grad()
137
- def _update(self, vars, ascent):
138
- if isinstance(self.numerator, (int, float)):
139
- return ascent.reciprocal_().mul_(self.numerator)
140
-
141
- state_copy = vars.copy(clone_ascent = True)
142
- numerator = self.children['numerator'].return_ascent(state_copy)
143
- return ascent.reciprocal_().mul_(numerator)
144
-
145
- class Divide(OptimizerModule):
146
- """calculates *numerator / denominator*"""
147
- def __init__(
148
- self,
149
- numerator: OptimizerModule | Iterable[OptimizerModule],
150
- denominator: OptimizerModule | Iterable[OptimizerModule],
151
- ):
152
- super().__init__({})
153
- self._set_child_('numerator', numerator)
154
- self._set_child_('denominator', denominator)
155
-
156
- @torch.no_grad
157
- def step(self, vars):
158
- state_copy = vars.copy(clone_ascent = True)
159
- numerator = self.children['numerator'].return_ascent(state_copy)
160
- vars.update_attrs_(state_copy)
161
- denominator = self.children['denominator'].return_ascent(vars)
162
-
163
- vars.ascent = numerator.div_(denominator)
164
- return self._update_params_or_step_with_next(vars)
165
-
166
-
167
- class Pow(OptimizerModule):
168
- """takes ascent to the power of `value`. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
169
- def __init__(self, power: _Value):
170
- super().__init__({})
171
-
172
- if not isinstance(power, (int, float)):
173
- self._set_child_('power', power)
174
-
175
- self.power = power
176
-
177
- @torch.no_grad()
178
- def _update(self, vars, ascent):
179
- if isinstance(self.power, (int, float)):
180
- return ascent.pow_(self.power)
181
-
182
- state_copy = vars.copy(clone_ascent = True)
183
- power = self.children['power'].return_ascent(state_copy)
184
- return ascent.pow_(power)
185
-
186
- class RPow(OptimizerModule):
187
- """takes `value` to the power of ascent. `value` can be a scalar, an OptimizerModule or sequence of OptimizerModules"""
188
- def __init__(self, base: _Value):
189
- super().__init__({})
190
-
191
- if not isinstance(base, (int, float)):
192
- self._set_child_('base', base)
193
-
194
- self.base = base
195
-
196
- @torch.no_grad()
197
- def _update(self, vars, ascent):
198
- if isinstance(self.base, (int, float)):
199
- return self.base ** ascent
200
-
201
- state_copy = vars.copy(clone_ascent = True)
202
- base = self.children['base'].return_ascent(state_copy)
203
- return base.pow_(ascent)
204
-
205
- class Power(OptimizerModule):
206
- """calculates *base ^ power*"""
207
- def __init__(
208
- self,
209
- base: OptimizerModule | Iterable[OptimizerModule],
210
- power: OptimizerModule | Iterable[OptimizerModule],
211
- ):
212
- super().__init__({})
213
- self._set_child_('base', base)
214
- self._set_child_('power', power)
215
-
216
- @torch.no_grad
217
- def step(self, vars):
218
- state_copy = vars.copy(clone_ascent = True)
219
- base = self.children['base'].return_ascent(state_copy)
220
- vars.update_attrs_(state_copy)
221
- power = self.children['power'].return_ascent(vars)
222
-
223
- vars.ascent = base.pow_(power)
224
- return self._update_params_or_step_with_next(vars)
225
-
226
-
227
- class Lerp(OptimizerModule):
228
- """Linear interpolation between update and `end` based on scalar `weight`.
229
-
230
- `out = update + weight * (end - update)`"""
231
- def __init__(self, end: OptimizerModule | Iterable[OptimizerModule], weight: float):
232
- super().__init__({})
233
-
234
- self._set_child_('end', end)
235
- self.weight = weight
236
-
237
- @torch.no_grad()
238
- def _update(self, vars, ascent):
239
-
240
- state_copy = vars.copy(clone_ascent = True)
241
- end = self.children['end'].return_ascent(state_copy)
242
- return ascent.lerp_(end, self.weight)
243
-
244
-
245
- class Interpolate(OptimizerModule):
246
- """Does a linear interpolation of two module's updates - `start` (given by input), and `end`, based on a scalar
247
- `weight`.
248
-
249
- `out = input + weight * (end - input)`"""
250
- def __init__(
251
- self,
252
- input: OptimizerModule | Iterable[OptimizerModule],
253
- end: OptimizerModule | Iterable[OptimizerModule],
254
- weight: float,
255
- ):
256
- super().__init__({})
257
- self._set_child_('input', input)
258
- self._set_child_('end', end)
259
- self.weight = weight
260
-
261
- @torch.no_grad
262
- def step(self, vars):
263
- state_copy = vars.copy(clone_ascent = True)
264
- input = self.children['input'].return_ascent(state_copy)
265
- vars.update_attrs_(state_copy)
266
- end = self.children['end'].return_ascent(vars)
267
-
268
- vars.ascent = input.lerp_(end, weight = self.weight)
269
-
270
- return self._update_params_or_step_with_next(vars)
271
-
272
- class AddMagnitude(OptimizerModule):
273
- """Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
274
-
275
- Args:
276
- value (Value): value to add to magnitude, either a float or an OptimizerModule.
277
- add_to_zero (bool, optional):
278
- if True, adds `value` to 0s. Otherwise, zeros remain zero.
279
- Only has effect if value is a float. Defaults to True.
280
- """
281
- def __init__(self, value: _Value, add_to_zero=True):
282
- super().__init__({})
283
-
284
- if not isinstance(value, (int, float)):
285
- self._set_child_('value', value)
286
-
287
- self.value = value
288
- self.add_to_zero = add_to_zero
289
-
290
- @torch.no_grad()
291
- def _update(self, vars, ascent):
292
- if isinstance(self.value, (int, float)):
293
- if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
294
- return ascent.add_(ascent.sign_().mul_(self.value))
295
-
296
- state_copy = vars.copy(clone_ascent = True)
297
- v = self.children['value'].return_ascent(state_copy)
298
- return ascent.add_(v.abs_().mul_(ascent.sign()))