torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,124 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...core import Module, apply
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+
9
+ class MatrixMomentum(Module):
10
+ """
11
+ May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
12
+ Evaluates hessian vector product on each step (via finite difference or autograd).
13
+
14
+ `mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
15
+
16
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
17
+ """
18
+ def __init__(self, mu=0.1, beta:float=1, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
19
+ defaults = dict(mu=mu, beta=beta, hvp_mode=hvp_mode, h=h)
20
+ super().__init__(defaults)
21
+
22
+ if hvp_tfm is not None:
23
+ self.set_child('hvp_tfm', hvp_tfm)
24
+
25
+ @torch.no_grad
26
+ def step(self, vars):
27
+ assert vars.closure is not None
28
+ prev_update = self.get_state('prev_update', params=vars.params, cls=TensorList)
29
+ hvp_mode = self.settings[vars.params[0]]['hvp_mode']
30
+ h = self.settings[vars.params[0]]['h']
31
+
32
+ mu,beta = self.get_settings('mu','beta', params=vars.params, cls=NumberList)
33
+
34
+ if hvp_mode == 'autograd':
35
+ with torch.enable_grad():
36
+ grad = vars.get_grad(create_graph=True)
37
+ hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
38
+
39
+ elif hvp_mode == 'forward':
40
+ vars.get_grad()
41
+ l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
42
+ if vars.loss_approx is None: vars.loss_approx = l
43
+
44
+ elif hvp_mode == 'central':
45
+ l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
46
+ if vars.loss_approx is None: vars.loss_approx = l
47
+
48
+ else:
49
+ raise ValueError(hvp_mode)
50
+
51
+ if 'hvp_tfm' in self.children:
52
+ hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
53
+
54
+ update = TensorList(vars.get_update())
55
+
56
+ hvp_ = as_tensorlist(hvp_)
57
+ update.add_(prev_update - hvp_*mu)
58
+ prev_update.set_(update * beta)
59
+ vars.update = update
60
+ return vars
61
+
62
+
63
+ class AdaptiveMatrixMomentum(Module):
64
+ """
65
+ Mu here is estimated as ||s_k||/||y_k||.
66
+ """
67
+ def __init__(self, mu_mul:float=1, beta:float=1, eps=1e-4, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
68
+ defaults = dict(mu_mul=mu_mul, beta=beta, hvp_mode=hvp_mode, h=h, eps=eps)
69
+ super().__init__(defaults)
70
+
71
+ if hvp_tfm is not None:
72
+ self.set_child('hvp_tfm', hvp_tfm)
73
+
74
+ @torch.no_grad
75
+ def step(self, vars):
76
+ assert vars.closure is not None
77
+ prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad', params=vars.params, cls=TensorList)
78
+
79
+ settings = self.settings[vars.params[0]]
80
+ hvp_mode = settings['hvp_mode']
81
+ h = settings['h']
82
+ eps = settings['eps']
83
+
84
+ mu_mul, beta = self.get_settings('mu_mul','beta', params=vars.params, cls=NumberList)
85
+
86
+ if hvp_mode == 'autograd':
87
+ with torch.enable_grad():
88
+ grad = vars.get_grad(create_graph=True)
89
+ hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
90
+
91
+ elif hvp_mode == 'forward':
92
+ vars.get_grad()
93
+ l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
94
+ if vars.loss_approx is None: vars.loss_approx = l
95
+
96
+ elif hvp_mode == 'central':
97
+ l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
98
+ if vars.loss_approx is None: vars.loss_approx = l
99
+
100
+ else:
101
+ raise ValueError(hvp_mode)
102
+
103
+ if 'hvp_tfm' in self.children:
104
+ hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
105
+
106
+ # adaptive part
107
+ update = TensorList(vars.get_update())
108
+
109
+ s_k = vars.params - prev_params
110
+ prev_params.copy_(vars.params)
111
+
112
+ assert vars.grad is not None
113
+ y_k = vars.grad - prev_grad
114
+ prev_grad.copy_(vars.grad)
115
+
116
+ ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
117
+
118
+ # matrix momentum uppdate
119
+ hvp_ = as_tensorlist(hvp_)
120
+ update.add_(prev_update - hvp_*ada_mu)
121
+ prev_update.set_(update * beta)
122
+ vars.update = update
123
+ return vars
124
+
@@ -1,106 +1,43 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
- from ...core import OptimizerModule
7
-
8
- def _heavyball_step(ascent, velocity: TensorList, momentum, dampening: TensorList):
9
- velocity.mul_(momentum).add_(ascent * (1 - dampening))
10
- return velocity.clone()
11
-
12
- class HeavyBall(OptimizerModule):
13
- """Polyak's (heavy ball) momentum. Exactly matches pytorch SGD `momentum` option.
14
-
15
- Args:
16
- decay (float, optional): momentum decay. Defaults to 0.9.
17
- dampening (float, optional): momentum dampening. Defaults to 0.
18
- """
19
- def __init__(self, momentum: float = 0.9, dampening: float = 0, ):
20
- defaults = dict(momentum = momentum, dampening = dampening)
21
- super().__init__(defaults)
22
-
23
- @torch.no_grad
24
- def _update(self, vars, ascent):
25
- velocity = self.get_state_key('velocity', init = ascent)
26
- settings = self.get_all_group_keys()
27
- updated_direction = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
28
- return updated_direction
29
-
30
-
31
- def _nesterov_step_(ascent, velocity: TensorList, momentum, dampening,):
32
- # update velocity with the ascent direction
33
- velocity += ascent
34
-
35
- # decay velocity (this can be moved before previous line for slightly different results)
36
- velocity *= momentum
37
-
38
- # update ascent direction with velocity
39
- ascent += velocity * (1 - dampening)
40
-
41
-
42
- class NesterovMomentum(OptimizerModule):
43
- """Nesterov momentum. Exactly matches pytorch SGD with `nesterov=True`,
44
- except this also supports dampening.
45
-
46
- Args:
47
- decay (float, optional): momentum decay. Defaults to 0.9.
48
- dampening (float, optional): momentum dampening. Defaults to 0.
49
- """
50
- def __init__(self, decay: float = 0.9, dampening: float = 0, ):
51
- defaults = dict(momentum = decay, dampening = dampening)
52
- super().__init__(defaults)
53
-
54
- @torch.no_grad
55
- def _update(self, vars, ascent):
56
- velocity = self.get_state_key('velocity')
57
- settings = self.get_all_group_keys()
58
- _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
59
- return ascent
60
-
61
- class GradientAveraging(OptimizerModule):
62
- """Averages last 2 gradients (TODO)"""
63
- def __init__(self, dampening: float = 0, ):
64
- defaults = dict(dampening = dampening)
65
- super().__init__(defaults)
66
-
67
- @torch.no_grad
68
- def _update(self, vars, ascent):
69
- velocity = self.get_state_key('velocity')
70
- dampening = self.get_group_key('dampening')
71
-
72
- new_direction = ascent + velocity * (1-dampening)
73
- velocity.copy_(ascent)
74
-
75
- return new_direction
76
-
77
-
78
- class RandomCoordinateMomentum(OptimizerModule):
79
- """Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
80
- This works but I don't know if it is any good.
81
-
82
- Args:
83
- p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
84
- nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
85
- """
86
- def __init__(self, p: float = 0.1, nesterov=True):
87
- defaults = dict(p=p)
88
- super().__init__(defaults)
89
- self.nesterov = nesterov
90
-
91
- @torch.no_grad
92
- def _update(self, vars, ascent):
93
- velocity = self.get_state_key('velocity', init = ascent)
94
- settings = self.get_all_group_keys()
95
-
96
- # pick p veclocity indexes to update with the new ascent direction
97
- indexes = ascent.bernoulli_like(settings['p']).as_bool()
98
-
99
- if self.nesterov:
100
- # update the velocity at those indexes
101
- velocity.masked_set_(mask = indexes, value = ascent)
102
- return velocity.clone()
103
-
104
- new_ascent = velocity.clone()
105
- velocity.masked_set_(mask = indexes, value = ascent)
106
- return new_ascent
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...core import Target, Transform
6
+ from ...utils import NumberList, TensorList
7
+ from .ema import EMA
8
+
9
+
10
+ class HeavyBall(EMA):
11
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
12
+ super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
13
+
14
+ def nag_(
15
+ tensors_: TensorList,
16
+ velocity_: TensorList,
17
+ momentum: float | NumberList,
18
+ dampening: float | NumberList,
19
+ lerp: bool = False,
20
+ ):
21
+ """Nesterov momentum.
22
+
23
+ Returns `tensors_`"""
24
+ if lerp: velocity_.lerp_(tensors_, 1 - momentum)
25
+ else: velocity_.add_(tensors_).mul_(momentum)
26
+
27
+ tensors_ += velocity_.lazy_mul(1 - dampening)
28
+
29
+ return tensors_
30
+
31
+
32
+ class NAG(Transform):
33
+ def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
34
+ defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
35
+ super().__init__(defaults, uses_grad=False, target=target)
36
+
37
+ @torch.no_grad
38
+ def transform(self, tensors, params, grads, vars):
39
+ velocity = self.get_state('velocity', params=params, cls=TensorList)
40
+ lerp = self.settings[params[0]]['lerp']
41
+
42
+ momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
43
+ return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
@@ -0,0 +1,103 @@
1
+ from .accumulate import (
2
+ AccumulateMaximum,
3
+ AccumulateMean,
4
+ AccumulateMinimum,
5
+ AccumulateProduct,
6
+ AccumulateSum,
7
+ )
8
+ from .binary import (
9
+ Add,
10
+ BinaryOperation,
11
+ Clip,
12
+ CopyMagnitude,
13
+ CopySign,
14
+ Div,
15
+ Graft,
16
+ GraftToUpdate,
17
+ GramSchimdt,
18
+ Maximum,
19
+ Minimum,
20
+ Mul,
21
+ Pow,
22
+ RCopySign,
23
+ RDiv,
24
+ RGraft,
25
+ RPow,
26
+ RSub,
27
+ Sub,
28
+ Threshold,
29
+ )
30
+ from .debug import PrintShape, PrintUpdate
31
+ from .misc import (
32
+ DivByLoss,
33
+ Dropout,
34
+ FillLoss,
35
+ GradientAccumulation,
36
+ GradSign,
37
+ GraftGradToUpdate,
38
+ GraftToGrad,
39
+ GraftToParams,
40
+ LastAbsoluteRatio,
41
+ LastDifference,
42
+ LastGradDifference,
43
+ LastProduct,
44
+ LastRatio,
45
+ MulByLoss,
46
+ Multistep,
47
+ NegateOnLossIncrease,
48
+ NoiseSign,
49
+ Previous,
50
+ Relative,
51
+ Sequential,
52
+ UpdateSign,
53
+ WeightDropout,
54
+ )
55
+ from .multi import (
56
+ ClipModules,
57
+ DivModules,
58
+ GraftModules,
59
+ LerpModules,
60
+ MultiOperation,
61
+ PowModules,
62
+ SubModules,
63
+ )
64
+ from .reduce import (
65
+ MaximumModules,
66
+ Mean,
67
+ MinimumModules,
68
+ Prod,
69
+ ReduceOperation,
70
+ Sum,
71
+ WeightedMean,
72
+ WeightedSum,
73
+ )
74
+ from .split import Split
75
+ from .switch import Alternate, Switch
76
+ from .unary import (
77
+ Abs,
78
+ CustomUnaryOperation,
79
+ Exp,
80
+ NanToNum,
81
+ Negate,
82
+ Reciprocal,
83
+ Sign,
84
+ Sqrt,
85
+ UnaryLambda,
86
+ UnaryParameterwiseLambda,
87
+ )
88
+ from .utility import (
89
+ Clone,
90
+ Fill,
91
+ Grad,
92
+ GradToNone,
93
+ Identity,
94
+ NoOp,
95
+ Ones,
96
+ Params,
97
+ Randn,
98
+ RandomSample,
99
+ Uniform,
100
+ Update,
101
+ UpdateToNone,
102
+ Zeros,
103
+ )
@@ -0,0 +1,65 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Target, Transform
8
+ from ...utils import TensorList, NumberList
9
+
10
+ class AccumulateSum(Transform):
11
+ def __init__(self, decay: float = 0, target: Target = 'update',):
12
+ defaults = dict(decay=decay)
13
+ super().__init__(defaults, uses_grad=False, target=target)
14
+
15
+ @torch.no_grad
16
+ def transform(self, tensors, params, grads, vars):
17
+ sum = self.get_state('sum', params=params, cls=TensorList)
18
+ decay = self.get_settings('decay', params=params, cls=NumberList)
19
+ return sum.add_(tensors).lazy_mul(1-decay, clone=True)
20
+
21
+ class AccumulateMean(Transform):
22
+ def __init__(self, decay: float = 0, target: Target = 'update',):
23
+ defaults = dict(decay=decay)
24
+ super().__init__(defaults, uses_grad=False, target=target)
25
+
26
+ @torch.no_grad
27
+ def transform(self, tensors, params, grads, vars):
28
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
+ mean = self.get_state('mean', params=params, cls=TensorList)
30
+ decay = self.get_settings('decay', params=params, cls=NumberList)
31
+ return mean.add_(tensors).lazy_mul(1-decay, clone=True).div_(step)
32
+
33
+ class AccumulateProduct(Transform):
34
+ def __init__(self, decay: float = 0, target: Target = 'update',):
35
+ defaults = dict(decay=decay)
36
+ super().__init__(defaults, uses_grad=False, target=target)
37
+
38
+ @torch.no_grad
39
+ def transform(self, tensors, params, grads, vars):
40
+ prod = self.get_state('prod', params=params, cls=TensorList)
41
+ decay = self.get_settings('decay', params=params, cls=NumberList)
42
+ return prod.mul_(tensors).lazy_mul(1-decay, clone=True)
43
+
44
+ class AccumulateMaximum(Transform):
45
+ def __init__(self, decay: float = 0, target: Target = 'update',):
46
+ defaults = dict(decay=decay)
47
+ super().__init__(defaults, uses_grad=False, target=target)
48
+
49
+ @torch.no_grad
50
+ def transform(self, tensors, params, grads, vars):
51
+ maximum = self.get_state('maximum', params=params, cls=TensorList)
52
+ decay = self.get_settings('decay', params=params, cls=NumberList)
53
+ return maximum.maximum_(tensors).lazy_mul(1-decay, clone=True)
54
+
55
+ class AccumulateMinimum(Transform):
56
+ def __init__(self, decay: float = 0, target: Target = 'update',):
57
+ defaults = dict(decay=decay)
58
+ super().__init__(defaults, uses_grad=False, target=target)
59
+
60
+ @torch.no_grad
61
+ def transform(self, tensors, params, grads, vars):
62
+ minimum = self.get_state('minimum', params=params, cls=TensorList)
63
+ decay = self.get_settings('decay', params=params, cls=NumberList)
64
+ return minimum.minimum_(tensors).lazy_mul(1-decay, clone=True)
65
+
@@ -0,0 +1,240 @@
1
+ #pyright: reportIncompatibleMethodOverride=false
2
+ """"""
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Iterable, Sequence
5
+ from operator import itemgetter
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, Target, Vars, maybe_chain
11
+ from ...utils import TensorList, tensorlist
12
+
13
+
14
+ class BinaryOperation(Module, ABC):
15
+ """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
16
+ def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
+ super().__init__(defaults=defaults)
18
+
19
+ self.operands = {}
20
+ for k,v in operands.items():
21
+
22
+ if isinstance(v, (Module, Sequence)):
23
+ self.set_child(k, v)
24
+ self.operands[k] = self.children[k]
25
+ else:
26
+ self.operands[k] = v
27
+
28
+ @abstractmethod
29
+ def transform(self, vars: Vars, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
30
+ """applies the operation to operands"""
31
+ raise NotImplementedError
32
+
33
+ @torch.no_grad
34
+ def step(self, vars: Vars) -> Vars:
35
+ # pass cloned update to all module operands
36
+ processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
37
+
38
+ for k,v in self.operands.items():
39
+ if k in self.children:
40
+ v: Module
41
+ updated_vars = v.step(vars.clone(clone_update=True))
42
+ processed_operands[k] = updated_vars.get_update()
43
+ vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
44
+
45
+ transformed = self.transform(vars, update=vars.get_update(), **processed_operands)
46
+ vars.update = list(transformed)
47
+ return vars
48
+
49
+
50
+ class Add(BinaryOperation):
51
+ def __init__(self, other: Chainable | float, alpha: float = 1):
52
+ defaults = dict(alpha=alpha)
53
+ super().__init__(defaults, other=other)
54
+
55
+ @torch.no_grad
56
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
57
+ if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[vars.params[0]]['alpha'])
58
+ else: torch._foreach_add_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
59
+ return update
60
+
61
+ class Sub(BinaryOperation):
62
+ def __init__(self, other: Chainable | float, alpha: float = 1):
63
+ defaults = dict(alpha=alpha)
64
+ super().__init__(defaults, other=other)
65
+
66
+ @torch.no_grad
67
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
68
+ if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[vars.params[0]]['alpha'])
69
+ else: torch._foreach_sub_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
70
+ return update
71
+
72
+ class RSub(BinaryOperation):
73
+ def __init__(self, other: Chainable | float):
74
+ super().__init__({}, other=other)
75
+
76
+ @torch.no_grad
77
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
78
+ return other - TensorList(update)
79
+
80
+ class Mul(BinaryOperation):
81
+ def __init__(self, other: Chainable | float):
82
+ super().__init__({}, other=other)
83
+
84
+ @torch.no_grad
85
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
86
+ torch._foreach_mul_(update, other)
87
+ return update
88
+
89
+ class Div(BinaryOperation):
90
+ def __init__(self, other: Chainable | float):
91
+ super().__init__({}, other=other)
92
+
93
+ @torch.no_grad
94
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
95
+ torch._foreach_div_(update, other)
96
+ return update
97
+
98
+ class RDiv(BinaryOperation):
99
+ def __init__(self, other: Chainable | float):
100
+ super().__init__({}, other=other)
101
+
102
+ @torch.no_grad
103
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
104
+ return other / TensorList(update)
105
+
106
+ class Pow(BinaryOperation):
107
+ def __init__(self, exponent: Chainable | float):
108
+ super().__init__({}, exponent=exponent)
109
+
110
+ @torch.no_grad
111
+ def transform(self, vars, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
112
+ torch._foreach_pow_(update, exponent)
113
+ return update
114
+
115
+ class RPow(BinaryOperation):
116
+ def __init__(self, other: Chainable | float):
117
+ super().__init__({}, other=other)
118
+
119
+ @torch.no_grad
120
+ def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
121
+ if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
122
+ torch._foreach_pow_(other, update)
123
+ return other
124
+
125
+ class Lerp(BinaryOperation):
126
+ def __init__(self, end: Chainable, weight: float):
127
+ defaults = dict(weight=weight)
128
+ super().__init__(defaults, end=end)
129
+
130
+ @torch.no_grad
131
+ def transform(self, vars, update: list[torch.Tensor], end: list[torch.Tensor]):
132
+ torch._foreach_lerp_(update, end, weight=self.get_settings('weight',params=vars))
133
+ return update
134
+
135
+ class CopySign(BinaryOperation):
136
+ def __init__(self, other: Chainable):
137
+ super().__init__({}, other=other)
138
+
139
+ @torch.no_grad
140
+ def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
141
+ return [u.copysign_(o) for u, o in zip(update, other)]
142
+
143
+ class RCopySign(BinaryOperation):
144
+ def __init__(self, other: Chainable):
145
+ super().__init__({}, other=other)
146
+
147
+ @torch.no_grad
148
+ def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
149
+ return [o.copysign_(u) for u, o in zip(update, other)]
150
+ CopyMagnitude = RCopySign
151
+
152
+ class Clip(BinaryOperation):
153
+ def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
154
+ super().__init__({}, min=min, max=max)
155
+
156
+ @torch.no_grad
157
+ def transform(self, vars, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
158
+ return TensorList(update).clamp_(min=min, max=max)
159
+
160
+ class MirroredClip(BinaryOperation):
161
+ """clip by -value, value"""
162
+ def __init__(self, value: float | Chainable):
163
+ super().__init__({}, value=value)
164
+
165
+ @torch.no_grad
166
+ def transform(self, vars, update: list[torch.Tensor], value: float | list[torch.Tensor]):
167
+ min = -value if isinstance(value, (int,float)) else [-v for v in value]
168
+ return TensorList(update).clamp_(min=min, max=value)
169
+
170
+ class Graft(BinaryOperation):
171
+ """use direction from update and magnitude from `magnitude` module"""
172
+ def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
173
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
174
+ super().__init__(defaults, magnitude=magnitude)
175
+
176
+ @torch.no_grad
177
+ def transform(self, vars, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
178
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
179
+ return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180
+
181
+ class RGraft(BinaryOperation):
182
+ """use direction from `direction` module and magnitude from update"""
183
+
184
+ def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
185
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
186
+ super().__init__(defaults, direction=direction)
187
+
188
+ @torch.no_grad
189
+ def transform(self, vars, update: list[torch.Tensor], direction: list[torch.Tensor]):
190
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
191
+ return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
192
+
193
+ GraftToUpdate = RGraft
194
+
195
+ class Maximum(BinaryOperation):
196
+ def __init__(self, other: Chainable):
197
+ super().__init__({}, other=other)
198
+
199
+ @torch.no_grad
200
+ def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
201
+ torch._foreach_maximum_(update, other)
202
+ return update
203
+
204
+ class Minimum(BinaryOperation):
205
+ def __init__(self, other: Chainable):
206
+ super().__init__({}, other=other)
207
+
208
+ @torch.no_grad
209
+ def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
210
+ torch._foreach_minimum_(update, other)
211
+ return update
212
+
213
+
214
+ class GramSchimdt(BinaryOperation):
215
+ """makes update orthonormal to `other`"""
216
+ def __init__(self, other: Chainable):
217
+ super().__init__({}, other=other)
218
+
219
+ @torch.no_grad
220
+ def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
221
+ update = TensorList(update); other = TensorList(other)
222
+ return update - (other*update) / ((other*other) + 1e-8)
223
+
224
+
225
+ class Threshold(BinaryOperation):
226
+ """update above/below threshold, value at and below"""
227
+ def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
228
+ defaults = dict(update_above=update_above)
229
+ super().__init__(defaults, threshold=threshold, value=value)
230
+
231
+ @torch.no_grad
232
+ def transform(self, vars, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
233
+ update_above = self.settings[vars.params[0]]['update_above']
234
+ update = TensorList(update)
235
+ if update_above:
236
+ if isinstance(value, list): return update.where_(update>threshold, value)
237
+ return update.masked_fill_(update<=threshold, value)
238
+
239
+ if isinstance(value, list): return update.where_(update<threshold, value)
240
+ return update.masked_fill_(update>=threshold, value)