torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from typing import Literal
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Module, apply
5
+ from ...core import Module, apply_transform, Chainable
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
@@ -13,105 +13,147 @@ class MatrixMomentum(Module):
13
13
 
14
14
  `mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
15
15
 
16
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
16
+ Args:
17
+ mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
18
+ beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
19
+ hvp_method (str, optional):
20
+ How to calculate hessian-vector products.
21
+ Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
22
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
23
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
24
+
25
+ Reference:
26
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
17
27
  """
18
- def __init__(self, mu=0.1, beta:float=1, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
19
- defaults = dict(mu=mu, beta=beta, hvp_mode=hvp_mode, h=h)
28
+
29
+ def __init__(
30
+ self,
31
+ mu=0.1,
32
+ beta: float = 1,
33
+ hvp_method: Literal["autograd", "forward", "central"] = "forward",
34
+ h: float = 1e-3,
35
+ hvp_tfm: Chainable | None = None,
36
+ ):
37
+ defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
20
38
  super().__init__(defaults)
21
39
 
22
40
  if hvp_tfm is not None:
23
41
  self.set_child('hvp_tfm', hvp_tfm)
24
42
 
25
43
  @torch.no_grad
26
- def step(self, vars):
27
- assert vars.closure is not None
28
- prev_update = self.get_state('prev_update', params=vars.params, cls=TensorList)
29
- hvp_mode = self.settings[vars.params[0]]['hvp_mode']
30
- h = self.settings[vars.params[0]]['h']
44
+ def step(self, var):
45
+ assert var.closure is not None
46
+ prev_update = self.get_state(var.params, 'prev_update', cls=TensorList)
47
+ hvp_method = self.settings[var.params[0]]['hvp_method']
48
+ h = self.settings[var.params[0]]['h']
31
49
 
32
- mu,beta = self.get_settings('mu','beta', params=vars.params, cls=NumberList)
50
+ mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
33
51
 
34
- if hvp_mode == 'autograd':
52
+ if hvp_method == 'autograd':
35
53
  with torch.enable_grad():
36
- grad = vars.get_grad(create_graph=True)
37
- hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
54
+ grad = var.get_grad(create_graph=True)
55
+ hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
38
56
 
39
- elif hvp_mode == 'forward':
40
- vars.get_grad()
41
- l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
42
- if vars.loss_approx is None: vars.loss_approx = l
57
+ elif hvp_method == 'forward':
58
+ var.get_grad()
59
+ l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
60
+ if var.loss_approx is None: var.loss_approx = l
43
61
 
44
- elif hvp_mode == 'central':
45
- l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
46
- if vars.loss_approx is None: vars.loss_approx = l
62
+ elif hvp_method == 'central':
63
+ l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
64
+ if var.loss_approx is None: var.loss_approx = l
47
65
 
48
66
  else:
49
- raise ValueError(hvp_mode)
67
+ raise ValueError(hvp_method)
50
68
 
51
69
  if 'hvp_tfm' in self.children:
52
- hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
70
+ hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
53
71
 
54
- update = TensorList(vars.get_update())
72
+ update = TensorList(var.get_update())
55
73
 
56
74
  hvp_ = as_tensorlist(hvp_)
57
75
  update.add_(prev_update - hvp_*mu)
58
76
  prev_update.set_(update * beta)
59
- vars.update = update
60
- return vars
77
+ var.update = update
78
+ return var
61
79
 
62
80
 
63
81
  class AdaptiveMatrixMomentum(Module):
64
82
  """
65
- Mu here is estimated as ||s_k||/||y_k||.
83
+ May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
84
+ Evaluates hessian vector product on each step (via finite difference or autograd).
85
+
86
+ This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
87
+
88
+ Args:
89
+ mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
90
+ beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
91
+ hvp_method (str, optional):
92
+ How to calculate hessian-vector products.
93
+ Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
94
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
95
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
96
+
97
+ Reference:
98
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
66
99
  """
67
- def __init__(self, mu_mul:float=1, beta:float=1, eps=1e-4, hvp_mode: Literal['autograd', 'forward', 'central'] = 'forward', h=1e-3, hvp_tfm=None):
68
- defaults = dict(mu_mul=mu_mul, beta=beta, hvp_mode=hvp_mode, h=h, eps=eps)
100
+
101
+ def __init__(
102
+ self,
103
+ mu_mul: float = 1,
104
+ beta: float = 1,
105
+ eps=1e-4,
106
+ hvp_method: Literal["autograd", "forward", "central"] = "forward",
107
+ h: float = 1e-3,
108
+ hvp_tfm: Chainable | None = None,
109
+ ):
110
+ defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
69
111
  super().__init__(defaults)
70
112
 
71
113
  if hvp_tfm is not None:
72
114
  self.set_child('hvp_tfm', hvp_tfm)
73
115
 
74
116
  @torch.no_grad
75
- def step(self, vars):
76
- assert vars.closure is not None
77
- prev_update, prev_params, prev_grad = self.get_state('prev_update', 'prev_params', 'prev_grad', params=vars.params, cls=TensorList)
117
+ def step(self, var):
118
+ assert var.closure is not None
119
+ prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
78
120
 
79
- settings = self.settings[vars.params[0]]
80
- hvp_mode = settings['hvp_mode']
121
+ settings = self.settings[var.params[0]]
122
+ hvp_method = settings['hvp_method']
81
123
  h = settings['h']
82
124
  eps = settings['eps']
83
125
 
84
- mu_mul, beta = self.get_settings('mu_mul','beta', params=vars.params, cls=NumberList)
126
+ mu_mul, beta = self.get_settings(var.params, 'mu_mul','beta', cls=NumberList)
85
127
 
86
- if hvp_mode == 'autograd':
128
+ if hvp_method == 'autograd':
87
129
  with torch.enable_grad():
88
- grad = vars.get_grad(create_graph=True)
89
- hvp_ = TensorList(hvp(vars.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
130
+ grad = var.get_grad(create_graph=True)
131
+ hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
90
132
 
91
- elif hvp_mode == 'forward':
92
- vars.get_grad()
93
- l, hvp_ = hvp_fd_forward(vars.closure, vars.params, vec=prev_update, g_0=vars.grad, h=h, normalize=True)
94
- if vars.loss_approx is None: vars.loss_approx = l
133
+ elif hvp_method == 'forward':
134
+ var.get_grad()
135
+ l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
136
+ if var.loss_approx is None: var.loss_approx = l
95
137
 
96
- elif hvp_mode == 'central':
97
- l, hvp_ = hvp_fd_central(vars.closure, vars.params, vec=prev_update, h=h, normalize=True)
98
- if vars.loss_approx is None: vars.loss_approx = l
138
+ elif hvp_method == 'central':
139
+ l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
140
+ if var.loss_approx is None: var.loss_approx = l
99
141
 
100
142
  else:
101
- raise ValueError(hvp_mode)
143
+ raise ValueError(hvp_method)
102
144
 
103
145
  if 'hvp_tfm' in self.children:
104
- hvp_ = TensorList(apply(self.children['hvp_tfm'], hvp_, params=vars.params, grads=vars.grad, vars=vars))
146
+ hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
105
147
 
106
148
  # adaptive part
107
- update = TensorList(vars.get_update())
149
+ update = TensorList(var.get_update())
108
150
 
109
- s_k = vars.params - prev_params
110
- prev_params.copy_(vars.params)
151
+ s_k = var.params - prev_params
152
+ prev_params.copy_(var.params)
111
153
 
112
- assert vars.grad is not None
113
- y_k = vars.grad - prev_grad
114
- prev_grad.copy_(vars.grad)
154
+ assert var.grad is not None
155
+ y_k = var.grad - prev_grad
156
+ prev_grad.copy_(var.grad)
115
157
 
116
158
  ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
117
159
 
@@ -119,6 +161,6 @@ class AdaptiveMatrixMomentum(Module):
119
161
  hvp_ = as_tensorlist(hvp_)
120
162
  update.add_(prev_update - hvp_*ada_mu)
121
163
  prev_update.set_(update * beta)
122
- vars.update = update
123
- return vars
164
+ var.update = update
165
+ return var
124
166
 
@@ -3,11 +3,22 @@ from typing import Literal
3
3
  import torch
4
4
 
5
5
  from ...core import Target, Transform
6
- from ...utils import NumberList, TensorList
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
7
  from .ema import EMA
8
8
 
9
9
 
10
10
  class HeavyBall(EMA):
11
+ """Polyak's momentum (heavy-ball method).
12
+
13
+ Args:
14
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
15
+ dampening (float, optional): momentum dampening. Defaults to 0.
16
+ debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
17
+ lerp (bool, optional):
18
+ whether to use linear interpolation, if True, this becomes exponential moving average. Defaults to False.
19
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
20
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
21
+ """
11
22
  def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=False, ema_init: Literal['zeros', 'update'] = 'update', target: Target = 'update'):
12
23
  super().__init__(momentum=momentum, dampening=dampening, debiased=debiased, lerp=lerp, ema_init=ema_init, target=target)
13
24
 
@@ -30,14 +41,23 @@ def nag_(
30
41
 
31
42
 
32
43
  class NAG(Transform):
44
+ """Nesterov accelerated gradient method (nesterov momentum).
45
+
46
+ Args:
47
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
48
+ dampening (float, optional): momentum dampening. Defaults to 0.
49
+ lerp (bool, optional):
50
+ whether to use linear interpolation, if True, this becomes similar to exponential moving average. Defaults to False.
51
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
52
+ """
33
53
  def __init__(self, momentum:float=0.9, dampening:float=0, lerp=False, target: Target = 'update'):
34
54
  defaults = dict(momentum=momentum,dampening=dampening, lerp=lerp)
35
55
  super().__init__(defaults, uses_grad=False, target=target)
36
56
 
37
57
  @torch.no_grad
38
- def transform(self, tensors, params, grads, vars):
39
- velocity = self.get_state('velocity', params=params, cls=TensorList)
58
+ def apply(self, tensors, params, grads, loss, states, settings):
59
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
40
60
  lerp = self.settings[params[0]]['lerp']
41
61
 
42
- momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
62
+ momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
43
63
  return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
@@ -5,61 +5,91 @@ from typing import Literal
5
5
  import torch
6
6
 
7
7
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList
8
+ from ...utils import TensorList, NumberList, unpack_states, unpack_dicts
9
9
 
10
10
  class AccumulateSum(Transform):
11
+ """Accumulates sum of all past updates.
12
+
13
+ Args:
14
+ decay (float, optional): decays the accumulator. Defaults to 0.
15
+ target (Target, optional): target. Defaults to 'update'.
16
+ """
11
17
  def __init__(self, decay: float = 0, target: Target = 'update',):
12
18
  defaults = dict(decay=decay)
13
19
  super().__init__(defaults, uses_grad=False, target=target)
14
20
 
15
21
  @torch.no_grad
16
- def transform(self, tensors, params, grads, vars):
17
- sum = self.get_state('sum', params=params, cls=TensorList)
18
- decay = self.get_settings('decay', params=params, cls=NumberList)
19
- return sum.add_(tensors).lazy_mul(1-decay, clone=True)
22
+ def apply(self, tensors, params, grads, loss, states, settings):
23
+ sum = unpack_states(states, tensors, 'sum', cls=TensorList)
24
+ decay = [1-s['decay'] for s in settings]
25
+ return sum.add_(tensors).lazy_mul(decay, clone=True)
20
26
 
21
27
  class AccumulateMean(Transform):
28
+ """Accumulates mean of all past updates.
29
+
30
+ Args:
31
+ decay (float, optional): decays the accumulator. Defaults to 0.
32
+ target (Target, optional): target. Defaults to 'update'.
33
+ """
22
34
  def __init__(self, decay: float = 0, target: Target = 'update',):
23
35
  defaults = dict(decay=decay)
24
36
  super().__init__(defaults, uses_grad=False, target=target)
25
37
 
26
38
  @torch.no_grad
27
- def transform(self, tensors, params, grads, vars):
39
+ def apply(self, tensors, params, grads, loss, states, settings):
28
40
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
- mean = self.get_state('mean', params=params, cls=TensorList)
30
- decay = self.get_settings('decay', params=params, cls=NumberList)
31
- return mean.add_(tensors).lazy_mul(1-decay, clone=True).div_(step)
41
+ mean = unpack_states(states, tensors, 'mean', cls=TensorList)
42
+ decay = [1-s['decay'] for s in settings]
43
+ return mean.add_(tensors).lazy_mul(decay, clone=True).div_(step)
32
44
 
33
45
  class AccumulateProduct(Transform):
46
+ """Accumulates product of all past updates.
47
+
48
+ Args:
49
+ decay (float, optional): decays the accumulator. Defaults to 0.
50
+ target (Target, optional): target. Defaults to 'update'.
51
+ """
34
52
  def __init__(self, decay: float = 0, target: Target = 'update',):
35
53
  defaults = dict(decay=decay)
36
54
  super().__init__(defaults, uses_grad=False, target=target)
37
55
 
38
56
  @torch.no_grad
39
- def transform(self, tensors, params, grads, vars):
40
- prod = self.get_state('prod', params=params, cls=TensorList)
41
- decay = self.get_settings('decay', params=params, cls=NumberList)
42
- return prod.mul_(tensors).lazy_mul(1-decay, clone=True)
57
+ def apply(self, tensors, params, grads, loss, states, settings):
58
+ prod = unpack_states(states, tensors, 'prod', cls=TensorList)
59
+ decay = [1-s['decay'] for s in settings]
60
+ return prod.mul_(tensors).lazy_mul(decay, clone=True)
43
61
 
44
62
  class AccumulateMaximum(Transform):
63
+ """Accumulates maximum of all past updates.
64
+
65
+ Args:
66
+ decay (float, optional): decays the accumulator. Defaults to 0.
67
+ target (Target, optional): target. Defaults to 'update'.
68
+ """
45
69
  def __init__(self, decay: float = 0, target: Target = 'update',):
46
70
  defaults = dict(decay=decay)
47
71
  super().__init__(defaults, uses_grad=False, target=target)
48
72
 
49
73
  @torch.no_grad
50
- def transform(self, tensors, params, grads, vars):
51
- maximum = self.get_state('maximum', params=params, cls=TensorList)
52
- decay = self.get_settings('decay', params=params, cls=NumberList)
53
- return maximum.maximum_(tensors).lazy_mul(1-decay, clone=True)
74
+ def apply(self, tensors, params, grads, loss, states, settings):
75
+ maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
76
+ decay = [1-s['decay'] for s in settings]
77
+ return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
54
78
 
55
79
  class AccumulateMinimum(Transform):
80
+ """Accumulates minimum of all past updates.
81
+
82
+ Args:
83
+ decay (float, optional): decays the accumulator. Defaults to 0.
84
+ target (Target, optional): target. Defaults to 'update'.
85
+ """
56
86
  def __init__(self, decay: float = 0, target: Target = 'update',):
57
87
  defaults = dict(decay=decay)
58
88
  super().__init__(defaults, uses_grad=False, target=target)
59
89
 
60
90
  @torch.no_grad
61
- def transform(self, tensors, params, grads, vars):
62
- minimum = self.get_state('minimum', params=params, cls=TensorList)
63
- decay = self.get_settings('decay', params=params, cls=NumberList)
64
- return minimum.minimum_(tensors).lazy_mul(1-decay, clone=True)
91
+ def apply(self, tensors, params, grads, loss, states, settings):
92
+ minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
93
+ decay = [1-s['decay'] for s in settings]
94
+ return minimum.minimum_(tensors).lazy_mul(decay, clone=True)
65
95
 
@@ -7,7 +7,7 @@ from typing import Any
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Target, Vars, maybe_chain
10
+ from ...core import Chainable, Module, Target, Var, maybe_chain
11
11
  from ...utils import TensorList, tensorlist
12
12
 
13
13
 
@@ -26,25 +26,25 @@ class BinaryOperation(Module, ABC):
26
26
  self.operands[k] = v
27
27
 
28
28
  @abstractmethod
29
- def transform(self, vars: Vars, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
29
+ def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
30
30
  """applies the operation to operands"""
31
31
  raise NotImplementedError
32
32
 
33
33
  @torch.no_grad
34
- def step(self, vars: Vars) -> Vars:
34
+ def step(self, var: Var) -> Var:
35
35
  # pass cloned update to all module operands
36
36
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
37
37
 
38
38
  for k,v in self.operands.items():
39
39
  if k in self.children:
40
40
  v: Module
41
- updated_vars = v.step(vars.clone(clone_update=True))
42
- processed_operands[k] = updated_vars.get_update()
43
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
41
+ updated_var = v.step(var.clone(clone_update=True))
42
+ processed_operands[k] = updated_var.get_update()
43
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
44
44
 
45
- transformed = self.transform(vars, update=vars.get_update(), **processed_operands)
46
- vars.update = list(transformed)
47
- return vars
45
+ transformed = self.transform(var, update=var.get_update(), **processed_operands)
46
+ var.update = list(transformed)
47
+ return var
48
48
 
49
49
 
50
50
  class Add(BinaryOperation):
@@ -53,9 +53,9 @@ class Add(BinaryOperation):
53
53
  super().__init__(defaults, other=other)
54
54
 
55
55
  @torch.no_grad
56
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
57
- if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[vars.params[0]]['alpha'])
58
- else: torch._foreach_add_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
56
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
57
+ if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
58
+ else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
59
59
  return update
60
60
 
61
61
  class Sub(BinaryOperation):
@@ -64,9 +64,9 @@ class Sub(BinaryOperation):
64
64
  super().__init__(defaults, other=other)
65
65
 
66
66
  @torch.no_grad
67
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
68
- if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[vars.params[0]]['alpha'])
69
- else: torch._foreach_sub_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
67
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
68
+ if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
69
+ else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
70
70
  return update
71
71
 
72
72
  class RSub(BinaryOperation):
@@ -74,7 +74,7 @@ class RSub(BinaryOperation):
74
74
  super().__init__({}, other=other)
75
75
 
76
76
  @torch.no_grad
77
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
77
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
78
78
  return other - TensorList(update)
79
79
 
80
80
  class Mul(BinaryOperation):
@@ -82,7 +82,7 @@ class Mul(BinaryOperation):
82
82
  super().__init__({}, other=other)
83
83
 
84
84
  @torch.no_grad
85
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
85
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
86
86
  torch._foreach_mul_(update, other)
87
87
  return update
88
88
 
@@ -91,7 +91,7 @@ class Div(BinaryOperation):
91
91
  super().__init__({}, other=other)
92
92
 
93
93
  @torch.no_grad
94
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
94
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
95
95
  torch._foreach_div_(update, other)
96
96
  return update
97
97
 
@@ -100,7 +100,7 @@ class RDiv(BinaryOperation):
100
100
  super().__init__({}, other=other)
101
101
 
102
102
  @torch.no_grad
103
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
103
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
104
104
  return other / TensorList(update)
105
105
 
106
106
  class Pow(BinaryOperation):
@@ -108,7 +108,7 @@ class Pow(BinaryOperation):
108
108
  super().__init__({}, exponent=exponent)
109
109
 
110
110
  @torch.no_grad
111
- def transform(self, vars, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
111
+ def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
112
112
  torch._foreach_pow_(update, exponent)
113
113
  return update
114
114
 
@@ -117,7 +117,7 @@ class RPow(BinaryOperation):
117
117
  super().__init__({}, other=other)
118
118
 
119
119
  @torch.no_grad
120
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
120
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
121
121
  if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
122
122
  torch._foreach_pow_(other, update)
123
123
  return other
@@ -128,8 +128,8 @@ class Lerp(BinaryOperation):
128
128
  super().__init__(defaults, end=end)
129
129
 
130
130
  @torch.no_grad
131
- def transform(self, vars, update: list[torch.Tensor], end: list[torch.Tensor]):
132
- torch._foreach_lerp_(update, end, weight=self.get_settings('weight',params=vars))
131
+ def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
132
+ torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
133
133
  return update
134
134
 
135
135
  class CopySign(BinaryOperation):
@@ -137,7 +137,7 @@ class CopySign(BinaryOperation):
137
137
  super().__init__({}, other=other)
138
138
 
139
139
  @torch.no_grad
140
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
140
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141
141
  return [u.copysign_(o) for u, o in zip(update, other)]
142
142
 
143
143
  class RCopySign(BinaryOperation):
@@ -145,7 +145,7 @@ class RCopySign(BinaryOperation):
145
145
  super().__init__({}, other=other)
146
146
 
147
147
  @torch.no_grad
148
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
148
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
149
149
  return [o.copysign_(u) for u, o in zip(update, other)]
150
150
  CopyMagnitude = RCopySign
151
151
 
@@ -154,7 +154,7 @@ class Clip(BinaryOperation):
154
154
  super().__init__({}, min=min, max=max)
155
155
 
156
156
  @torch.no_grad
157
- def transform(self, vars, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
157
+ def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
158
158
  return TensorList(update).clamp_(min=min, max=max)
159
159
 
160
160
  class MirroredClip(BinaryOperation):
@@ -163,7 +163,7 @@ class MirroredClip(BinaryOperation):
163
163
  super().__init__({}, value=value)
164
164
 
165
165
  @torch.no_grad
166
- def transform(self, vars, update: list[torch.Tensor], value: float | list[torch.Tensor]):
166
+ def transform(self, var, update: list[torch.Tensor], value: float | list[torch.Tensor]):
167
167
  min = -value if isinstance(value, (int,float)) else [-v for v in value]
168
168
  return TensorList(update).clamp_(min=min, max=value)
169
169
 
@@ -174,8 +174,8 @@ class Graft(BinaryOperation):
174
174
  super().__init__(defaults, magnitude=magnitude)
175
175
 
176
176
  @torch.no_grad
177
- def transform(self, vars, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
178
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
177
+ def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
178
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
179
179
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180
180
 
181
181
  class RGraft(BinaryOperation):
@@ -186,8 +186,8 @@ class RGraft(BinaryOperation):
186
186
  super().__init__(defaults, direction=direction)
187
187
 
188
188
  @torch.no_grad
189
- def transform(self, vars, update: list[torch.Tensor], direction: list[torch.Tensor]):
190
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
189
+ def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
190
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
191
191
  return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
192
192
 
193
193
  GraftToUpdate = RGraft
@@ -197,7 +197,7 @@ class Maximum(BinaryOperation):
197
197
  super().__init__({}, other=other)
198
198
 
199
199
  @torch.no_grad
200
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
200
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
201
201
  torch._foreach_maximum_(update, other)
202
202
  return update
203
203
 
@@ -206,7 +206,7 @@ class Minimum(BinaryOperation):
206
206
  super().__init__({}, other=other)
207
207
 
208
208
  @torch.no_grad
209
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
209
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
210
210
  torch._foreach_minimum_(update, other)
211
211
  return update
212
212
 
@@ -217,7 +217,7 @@ class GramSchimdt(BinaryOperation):
217
217
  super().__init__({}, other=other)
218
218
 
219
219
  @torch.no_grad
220
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
220
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
221
221
  update = TensorList(update); other = TensorList(other)
222
222
  return update - (other*update) / ((other*other) + 1e-8)
223
223
 
@@ -229,8 +229,8 @@ class Threshold(BinaryOperation):
229
229
  super().__init__(defaults, threshold=threshold, value=value)
230
230
 
231
231
  @torch.no_grad
232
- def transform(self, vars, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
233
- update_above = self.settings[vars.params[0]]['update_above']
232
+ def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
233
+ update_above = self.settings[var.params[0]]['update_above']
234
234
  update = TensorList(update)
235
235
  if update_above:
236
236
  if isinstance(value, list): return update.where_(update>threshold, value)
@@ -10,16 +10,16 @@ class PrintUpdate(Module):
10
10
  defaults = dict(text=text, print_fn=print_fn)
11
11
  super().__init__(defaults)
12
12
 
13
- def step(self, vars):
14
- self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{vars.update}')
15
- return vars
13
+ def step(self, var):
14
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
15
+ return var
16
16
 
17
17
  class PrintShape(Module):
18
18
  def __init__(self, text = 'shapes = ', print_fn = print):
19
19
  defaults = dict(text=text, print_fn=print_fn)
20
20
  super().__init__(defaults)
21
21
 
22
- def step(self, vars):
23
- shapes = [u.shape for u in vars.update] if vars.update is not None else None
24
- self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{shapes}')
25
- return vars
22
+ def step(self, var):
23
+ shapes = [u.shape for u in var.update] if var.update is not None else None
24
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
25
+ return var