torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,8 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, TensorwiseTransform, Target, Transform, Vars
9
- from ...utils import Distributions, NumberList, TensorList
8
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
10
 
11
11
 
12
12
  class Previous(TensorwiseTransform):
@@ -17,9 +17,8 @@ class Previous(TensorwiseTransform):
17
17
 
18
18
 
19
19
  @torch.no_grad
20
- def transform(self, tensor, param, grad, vars):
21
- n = self.settings[param]['n']
22
- state = self.state[param]
20
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
21
+ n = settings['n']
23
22
 
24
23
  if 'history' not in state:
25
24
  state['history'] = deque(maxlen=n+1)
@@ -35,10 +34,10 @@ class LastDifference(Transform):
35
34
  super().__init__({}, uses_grad=False, target=target)
36
35
 
37
36
  @torch.no_grad
38
- def transform(self, tensors, params, grads, vars):
39
- prev_target = self.get_state('prev_target', params=params) # initialized to 0
40
- difference = torch._foreach_sub(tensors, prev_target)
41
- for p, c in zip(prev_target, tensors): p.set_(c)
37
+ def apply(self, tensors, params, grads, loss, states, settings):
38
+ prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
39
+ difference = torch._foreach_sub(tensors, prev)
40
+ for p, c in zip(prev, tensors): p.set_(c)
42
41
  return difference
43
42
 
44
43
  class LastGradDifference(Module):
@@ -47,13 +46,13 @@ class LastGradDifference(Module):
47
46
  super().__init__({})
48
47
 
49
48
  @torch.no_grad
50
- def step(self, vars):
51
- grad = vars.get_grad()
52
- prev_grad = self.get_state('prev_grad', params=vars.params) # initialized to 0
49
+ def step(self, var):
50
+ grad = var.get_grad()
51
+ prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
53
52
  difference = torch._foreach_sub(grad, prev_grad)
54
53
  for p, c in zip(prev_grad, grad): p.set_(c)
55
- vars.update = list(difference)
56
- return vars
54
+ var.update = list(difference)
55
+ return var
57
56
 
58
57
 
59
58
  class LastProduct(Transform):
@@ -62,10 +61,10 @@ class LastProduct(Transform):
62
61
  super().__init__({}, uses_grad=False, target=target)
63
62
 
64
63
  @torch.no_grad
65
- def transform(self, tensors, params, grads, vars):
66
- prev_target = self.get_state('prev_target', params=params, init=torch.ones_like) # initialized to 1 for prod
67
- prod = torch._foreach_mul(tensors, prev_target)
68
- for p, c in zip(prev_target, tensors): p.set_(c)
64
+ def apply(self, tensors, params, grads, loss, states, settings):
65
+ prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
66
+ prod = torch._foreach_mul(tensors, prev)
67
+ for p, c in zip(prev, tensors): p.set_(c)
69
68
  return prod
70
69
 
71
70
  class LastRatio(Transform):
@@ -75,12 +74,12 @@ class LastRatio(Transform):
75
74
  super().__init__(defaults, uses_grad=False, target=target)
76
75
 
77
76
  @torch.no_grad
78
- def transform(self, tensors, params, grads, vars):
79
- prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to ones
80
- numerator = self.settings[params[0]]['numerator']
81
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
82
- else: ratio = torch._foreach_div(prev_target, tensors)
83
- for p, c in zip(prev_target, tensors): p.set_(c)
77
+ def apply(self, tensors, params, grads, loss, states, settings):
78
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
79
+ numerator = settings[0]['numerator']
80
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
81
+ else: ratio = torch._foreach_div(prev, tensors)
82
+ for p, c in zip(prev, tensors): p.set_(c)
84
83
  return ratio
85
84
 
86
85
  class LastAbsoluteRatio(Transform):
@@ -90,17 +89,17 @@ class LastAbsoluteRatio(Transform):
90
89
  super().__init__(defaults, uses_grad=False, target=target)
91
90
 
92
91
  @torch.no_grad
93
- def transform(self, tensors, params, grads, vars):
94
- prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to 0
95
- numerator = self.settings[params[0]]['numerator']
96
- eps = self.get_settings('eps', params=params, cls = NumberList)
92
+ def apply(self, tensors, params, grads, loss, states, settings):
93
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
94
+ numerator = settings[0]['numerator']
95
+ eps = NumberList(s['eps'] for s in settings)
97
96
 
98
97
  torch._foreach_abs_(tensors)
99
- torch._foreach_clamp_min_(prev_target, eps)
98
+ torch._foreach_clamp_min_(prev, eps)
100
99
 
101
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
102
- else: ratio = torch._foreach_div(prev_target, tensors)
103
- for p, c in zip(prev_target, tensors): p.set_(c)
100
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
101
+ else: ratio = torch._foreach_div(prev, tensors)
102
+ for p, c in zip(prev, tensors): p.set_(c)
104
103
  return ratio
105
104
 
106
105
  class GradSign(Transform):
@@ -109,7 +108,7 @@ class GradSign(Transform):
109
108
  super().__init__({}, uses_grad=True, target=target)
110
109
 
111
110
  @torch.no_grad
112
- def transform(self, tensors, params, grads, vars):
111
+ def apply(self, tensors, params, grads, loss, states, settings):
113
112
  assert grads is not None
114
113
  return [t.copysign_(g) for t,g in zip(tensors, grads)]
115
114
 
@@ -119,7 +118,7 @@ class UpdateSign(Transform):
119
118
  super().__init__({}, uses_grad=True, target=target)
120
119
 
121
120
  @torch.no_grad
122
- def transform(self, tensors, params, grads, vars):
121
+ def apply(self, tensors, params, grads, loss, states, settings):
123
122
  assert grads is not None
124
123
  return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
125
124
 
@@ -130,9 +129,9 @@ class GraftToGrad(Transform):
130
129
  super().__init__(defaults, uses_grad=True, target=target)
131
130
 
132
131
  @torch.no_grad
133
- def transform(self, tensors, params, grads, vars):
132
+ def apply(self, tensors, params, grads, loss, states, settings):
134
133
  assert grads is not None
135
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
134
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
136
135
  return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
137
136
 
138
137
  class GraftGradToUpdate(Transform):
@@ -142,9 +141,9 @@ class GraftGradToUpdate(Transform):
142
141
  super().__init__(defaults, uses_grad=True, target=target)
143
142
 
144
143
  @torch.no_grad
145
- def transform(self, tensors, params, grads, vars):
144
+ def apply(self, tensors, params, grads, loss, states, settings):
146
145
  assert grads is not None
147
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
146
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
148
147
  return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
149
148
 
150
149
 
@@ -155,8 +154,8 @@ class GraftToParams(Transform):
155
154
  super().__init__(defaults, uses_grad=False, target=target)
156
155
 
157
156
  @torch.no_grad
158
- def transform(self, tensors, params, grads, vars):
159
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
157
+ def apply(self, tensors, params, grads, loss, states, settings):
158
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
160
159
  return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
161
160
 
162
161
  class Relative(Transform):
@@ -166,8 +165,8 @@ class Relative(Transform):
166
165
  super().__init__(defaults, uses_grad=False, target=target)
167
166
 
168
167
  @torch.no_grad
169
- def transform(self, tensors, params, grads, vars):
170
- mul = TensorList(params).abs().clamp_(self.get_settings('min_value', params=params))
168
+ def apply(self, tensors, params, grads, loss, states, settings):
169
+ mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
171
170
  torch._foreach_mul_(tensors, mul)
172
171
  return tensors
173
172
 
@@ -178,94 +177,94 @@ class FillLoss(Module):
178
177
  super().__init__(defaults)
179
178
 
180
179
  @torch.no_grad
181
- def step(self, vars):
182
- alpha = self.get_settings('alpha', params=vars.params)
183
- loss = vars.get_loss(backward=self.settings[vars.params[0]]['backward'])
184
- vars.update = [torch.full_like(p, loss*a) for p,a in zip(vars.params, alpha)]
185
- return vars
180
+ def step(self, var):
181
+ alpha = self.get_settings(var.params, 'alpha')
182
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
183
+ var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
184
+ return var
186
185
 
187
- class MulByLoss(Transform):
186
+ class MulByLoss(Module):
188
187
  """multiplies update by loss times alpha"""
189
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
188
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
190
189
  defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
191
- super().__init__(defaults, uses_grad=False, target=target)
190
+ super().__init__(defaults)
192
191
 
193
192
  @torch.no_grad
194
- def transform(self, tensors, params, grads, vars): #vars used for loss
195
- alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
196
- loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
193
+ def step(self, var):
194
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
195
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
197
196
  mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
198
- torch._foreach_mul_(tensors, mul)
199
- return tensors
197
+ torch._foreach_mul_(var.update, mul)
198
+ return var
200
199
 
201
- class DivByLoss(Transform):
200
+ class DivByLoss(Module):
202
201
  """divides update by loss times alpha"""
203
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
202
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
204
203
  defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
205
- super().__init__(defaults, uses_grad=False, target=target)
204
+ super().__init__(defaults)
206
205
 
207
206
  @torch.no_grad
208
- def transform(self, tensors, params, grads, vars): #vars used for loss
209
- alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
210
- loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
207
+ def step(self, var):
208
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
209
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
211
210
  mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
- torch._foreach_div_(tensors, mul)
213
- return tensors
211
+ torch._foreach_div_(var.update, mul)
212
+ return var
214
213
 
215
214
 
216
215
 
217
- def _sequential_step(self: Module, vars: Vars, sequential: bool):
218
- params = vars.params
216
+ def _sequential_step(self: Module, var: Var, sequential: bool):
217
+ params = var.params
219
218
  steps = self.settings[params[0]]['steps']
220
219
 
221
220
  if sequential: modules = self.get_children_sequence()
222
221
  else: modules = [self.children['module']] * steps
223
222
 
224
- if vars.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
223
+ if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
225
224
 
226
225
  # store original params unless this is last module and can update params directly
227
- params_before_steps = None if (vars.is_last and vars.last_module_lrs is None) else [p.clone() for p in params]
226
+ params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
228
227
 
229
- # first step - pass vars as usual
230
- vars = modules[0].step(vars)
231
- new_vars = vars
228
+ # first step - pass var as usual
229
+ var = modules[0].step(var)
230
+ new_var = var
232
231
 
233
- # subsequent steps - update parameters and create new vars
232
+ # subsequent steps - update parameters and create new var
234
233
  if len(modules) > 1:
235
234
  for m in modules[1:]:
236
235
 
237
236
  # update params
238
- if (not new_vars.skip_update):
239
- if new_vars.last_module_lrs is not None:
240
- torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
237
+ if (not new_var.skip_update):
238
+ if new_var.last_module_lrs is not None:
239
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
241
240
 
242
- torch._foreach_sub_(params, new_vars.get_update())
241
+ torch._foreach_sub_(params, new_var.get_update())
243
242
 
244
- # create new vars since we are at a new point, that means grad, update and loss will be None
245
- new_vars = Vars(params=new_vars.params, closure=new_vars.closure,
246
- model=new_vars.model, current_step=new_vars.current_step + 1)
243
+ # create new var since we are at a new point, that means grad, update and loss will be None
244
+ new_var = Var(params=new_var.params, closure=new_var.closure,
245
+ model=new_var.model, current_step=new_var.current_step + 1)
247
246
 
248
247
  # step
249
- new_vars = m.step(new_vars)
248
+ new_var = m.step(new_var)
250
249
 
251
250
  # final parameter update
252
- if (not new_vars.skip_update):
253
- if new_vars.last_module_lrs is not None:
254
- torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
251
+ if (not new_var.skip_update):
252
+ if new_var.last_module_lrs is not None:
253
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
255
254
 
256
- torch._foreach_sub_(params, new_vars.get_update())
255
+ torch._foreach_sub_(params, new_var.get_update())
257
256
 
258
- # if last module, update is applied so return new vars
257
+ # if last module, update is applied so return new var
259
258
  if params_before_steps is None:
260
- new_vars.stop = True
261
- new_vars.skip_update = True
262
- return new_vars
259
+ new_var.stop = True
260
+ new_var.skip_update = True
261
+ return new_var
263
262
 
264
263
  # otherwise use parameter difference as update
265
- vars.update = list(torch._foreach_sub(params_before_steps, params))
264
+ var.update = list(torch._foreach_sub(params_before_steps, params))
266
265
  for p, bef in zip(params, params_before_steps):
267
266
  p.set_(bef) # pyright:ignore[reportArgumentType]
268
- return vars
267
+ return var
269
268
 
270
269
  class Multistep(Module):
271
270
  def __init__(self, module: Chainable, steps: int):
@@ -274,8 +273,8 @@ class Multistep(Module):
274
273
  self.set_child('module', module)
275
274
 
276
275
  @torch.no_grad
277
- def step(self, vars):
278
- return _sequential_step(self, vars, sequential=False)
276
+ def step(self, var):
277
+ return _sequential_step(self, var, sequential=False)
279
278
 
280
279
  class Sequential(Module):
281
280
  def __init__(self, modules: Iterable[Chainable], steps: int):
@@ -284,8 +283,8 @@ class Sequential(Module):
284
283
  self.set_children_sequence(modules)
285
284
 
286
285
  @torch.no_grad
287
- def step(self, vars):
288
- return _sequential_step(self, vars, sequential=True)
286
+ def step(self, var):
287
+ return _sequential_step(self, var, sequential=True)
289
288
 
290
289
 
291
290
  class GradientAccumulation(Module):
@@ -297,22 +296,22 @@ class GradientAccumulation(Module):
297
296
 
298
297
 
299
298
  @torch.no_grad
300
- def step(self, vars):
301
- accumulator = self.get_state('accumulator', params=vars.params)
302
- settings = self.settings[vars.params[0]]
299
+ def step(self, var):
300
+ accumulator = self.get_state(var.params, 'accumulator')
301
+ settings = self.settings[var.params[0]]
303
302
  n = settings['n']; mean = settings['mean']; stop = settings['stop']
304
303
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
305
304
 
306
305
  # add update to accumulator
307
- torch._foreach_add_(accumulator, vars.get_update())
306
+ torch._foreach_add_(accumulator, var.get_update())
308
307
 
309
308
  # step with accumulated updates
310
309
  if step % n == 0:
311
310
  if mean:
312
311
  torch._foreach_div_(accumulator, n)
313
312
 
314
- vars.update = [a.clone() for a in accumulator]
315
- vars = self.children['modules'].step(vars)
313
+ var.update = [a.clone() for a in accumulator]
314
+ var = self.children['modules'].step(var)
316
315
 
317
316
  # zero accumulator
318
317
  torch._foreach_zero_(accumulator)
@@ -320,10 +319,10 @@ class GradientAccumulation(Module):
320
319
  else:
321
320
  # prevent update
322
321
  if stop:
323
- vars.stop=True
324
- vars.skip_update=True
322
+ var.stop=True
323
+ var.skip_update=True
325
324
 
326
- return vars
325
+ return var
327
326
 
328
327
 
329
328
  class Dropout(Transform):
@@ -332,10 +331,10 @@ class Dropout(Transform):
332
331
  super().__init__(defaults, uses_grad=False, target=target)
333
332
 
334
333
  @torch.no_grad
335
- def transform(self, tensors, params, grads, vars):
334
+ def apply(self, tensors, params, grads, loss, states, settings):
336
335
  tensors = TensorList(tensors)
337
- p = self.get_settings('p', params=params, cls=NumberList)
338
- graft = self.settings[params[0]]['graft']
336
+ p = NumberList(s['p'] for s in settings)
337
+ graft = settings[0]['graft']
339
338
 
340
339
  if graft:
341
340
  target_norm = tensors.global_vector_norm()
@@ -351,11 +350,11 @@ class WeightDropout(Module):
351
350
  super().__init__(defaults)
352
351
 
353
352
  @torch.no_grad
354
- def step(self, vars):
355
- closure = vars.closure
353
+ def step(self, var):
354
+ closure = var.closure
356
355
  if closure is None: raise RuntimeError('WeightDropout requires closure')
357
- params = TensorList(vars.params)
358
- p = self.get_settings('p', params=params)
356
+ params = TensorList(var.params)
357
+ p = NumberList(self.settings[p]['p'] for p in params)
359
358
  mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
360
359
 
361
360
  @torch.no_grad
@@ -369,8 +368,8 @@ class WeightDropout(Module):
369
368
  params.copy_(orig_params)
370
369
  return loss
371
370
 
372
- vars.closure = dropout_closure
373
- return vars
371
+ var.closure = dropout_closure
372
+ return var
374
373
 
375
374
  class NoiseSign(Transform):
376
375
  """uses random vector with update sign"""
@@ -379,8 +378,8 @@ class NoiseSign(Transform):
379
378
  super().__init__(defaults, uses_grad=False)
380
379
 
381
380
  @torch.no_grad
382
- def transform(self, tensors, params, grads, vars):
383
- alpha = self.get_settings('alpha', params=params)
381
+ def apply(self, tensors, params, grads, loss, states, settings):
382
+ alpha = [s['alpha'] for s in settings]
384
383
  distribution = self.settings[params[0]]['distribution']
385
384
  return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
386
385
 
@@ -391,29 +390,29 @@ class NegateOnLossIncrease(Module):
391
390
  super().__init__(defaults=defaults)
392
391
 
393
392
  @torch.no_grad
394
- def step(self, vars):
395
- closure = vars.closure
393
+ def step(self, var):
394
+ closure = var.closure
396
395
  if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
397
- backtrack = self.settings[vars.params[0]]['backtrack']
396
+ backtrack = self.settings[var.params[0]]['backtrack']
398
397
 
399
- update = vars.get_update()
400
- f_0 = vars.get_loss(backward=False)
398
+ update = var.get_update()
399
+ f_0 = var.get_loss(backward=False)
401
400
 
402
- torch._foreach_sub_(vars.params, update)
401
+ torch._foreach_sub_(var.params, update)
403
402
  f_1 = closure(False)
404
403
 
405
404
  if f_1 <= f_0:
406
- if vars.is_last and vars.last_module_lrs is None:
407
- vars.stop = True
408
- vars.skip_update = True
409
- return vars
405
+ if var.is_last and var.last_module_lrs is None:
406
+ var.stop = True
407
+ var.skip_update = True
408
+ return var
410
409
 
411
- torch._foreach_add_(vars.params, update)
412
- return vars
410
+ torch._foreach_add_(var.params, update)
411
+ return var
413
412
 
414
- torch._foreach_add_(vars.params, update)
413
+ torch._foreach_add_(var.params, update)
415
414
  if backtrack:
416
- torch._foreach_neg_(vars.update)
415
+ torch._foreach_neg_(var.update)
417
416
  else:
418
- torch._foreach_zero_(vars.update)
419
- return vars
417
+ torch._foreach_zero_(var.update)
418
+ return var
@@ -7,7 +7,7 @@ from typing import Any
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Target, Vars, maybe_chain
10
+ from ...core import Chainable, Module, Target, Var, maybe_chain
11
11
  from ...utils import TensorList, tensorlist
12
12
 
13
13
 
@@ -29,25 +29,25 @@ class MultiOperation(Module, ABC):
29
29
  raise ValueError('At least one operand must be a module')
30
30
 
31
31
  @abstractmethod
32
- def transform(self, vars: Vars, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
32
+ def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
33
33
  """applies the operation to operands"""
34
34
  raise NotImplementedError
35
35
 
36
36
  @torch.no_grad
37
- def step(self, vars: Vars) -> Vars:
37
+ def step(self, var: Var) -> Var:
38
38
  # pass cloned update to all module operands
39
39
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
40
40
 
41
41
  for k,v in self.operands.items():
42
42
  if k in self.children:
43
43
  v: Module
44
- updated_vars = v.step(vars.clone(clone_update=True))
45
- processed_operands[k] = updated_vars.get_update()
46
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
44
+ updated_var = v.step(var.clone(clone_update=True))
45
+ processed_operands[k] = updated_var.get_update()
46
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
47
47
 
48
- transformed = self.transform(vars, **processed_operands)
49
- vars.update = transformed
50
- return vars
48
+ transformed = self.transform(var, **processed_operands)
49
+ var.update = transformed
50
+ return var
51
51
 
52
52
 
53
53
 
@@ -57,8 +57,8 @@ class SubModules(MultiOperation):
57
57
  super().__init__(defaults, input=input, other=other)
58
58
 
59
59
  @torch.no_grad
60
- def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
61
- alpha = self.settings[vars.params[0]]['alpha']
60
+ def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
61
+ alpha = self.settings[var.params[0]]['alpha']
62
62
 
63
63
  if isinstance(input, (int,float)):
64
64
  assert isinstance(other, list)
@@ -74,7 +74,7 @@ class DivModules(MultiOperation):
74
74
  super().__init__(defaults, input=input, other=other)
75
75
 
76
76
  @torch.no_grad
77
- def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
77
+ def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
78
78
  if isinstance(input, (int,float)):
79
79
  assert isinstance(other, list)
80
80
  return input / TensorList(other)
@@ -88,7 +88,7 @@ class PowModules(MultiOperation):
88
88
  super().__init__(defaults, input=input, exponent=exponent)
89
89
 
90
90
  @torch.no_grad
91
- def transform(self, vars: Vars, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
91
+ def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
92
92
  if isinstance(input, (int,float)):
93
93
  assert isinstance(exponent, list)
94
94
  return input ** TensorList(exponent)
@@ -102,8 +102,8 @@ class LerpModules(MultiOperation):
102
102
  super().__init__(defaults, input=input, end=end)
103
103
 
104
104
  @torch.no_grad
105
- def transform(self, vars: Vars, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
106
- torch._foreach_lerp_(input, end, weight=self.settings[vars.params[0]]['weight'])
105
+ def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
106
+ torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
107
107
  return input
108
108
 
109
109
  class ClipModules(MultiOperation):
@@ -112,7 +112,7 @@ class ClipModules(MultiOperation):
112
112
  super().__init__(defaults, input=input, min=min, max=max)
113
113
 
114
114
  @torch.no_grad
115
- def transform(self, vars: Vars, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
115
+ def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
116
116
  return TensorList(input).clamp_(min=min, max=max)
117
117
 
118
118
 
@@ -122,8 +122,8 @@ class GraftModules(MultiOperation):
122
122
  super().__init__(defaults, direction=direction, magnitude=magnitude)
123
123
 
124
124
  @torch.no_grad
125
- def transform(self, vars, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
126
- tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[vars.params[0]])
125
+ def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
126
+ tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
127
127
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
128
128
 
129
129
 
@@ -132,6 +132,6 @@ class Where(MultiOperation):
132
132
  super().__init__({}, condition=condition, input=input, other=other)
133
133
 
134
134
  @torch.no_grad
135
- def transform(self, vars, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
135
+ def transform(self, var, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
136
136
  return tensorlist.where(TensorList(condition).as_bool(), input, other)
137
137