torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -2,49 +2,49 @@ from collections.abc import Iterable
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Module, Var
5
+ from ...core import Chainable, Module, Objective
6
6
  from ...utils import TensorList
7
7
 
8
- def _sequential_step(self: Module, var: Var, sequential: bool):
9
- params = var.params
8
+ def _sequential_step(self: Module, objective: Objective, sequential: bool):
9
+ params = objective.params
10
10
  steps = self.settings[params[0]]['steps']
11
11
 
12
- if sequential: modules = self.get_children_sequence() * steps
12
+ if sequential: modules: list[Module] = self.get_children_sequence() * steps
13
13
  else: modules = [self.children['module']] * steps
14
14
 
15
- if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
15
+ if objective.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
16
 
17
17
  # store original params unless this is last module and can update params directly
18
18
  params_before_steps = [p.clone() for p in params]
19
19
 
20
20
  # first step - pass var as usual
21
- var = modules[0].step(var)
22
- new_var = var
21
+ objective = modules[0].step(objective)
22
+ new_objective = objective
23
23
 
24
24
  # subsequent steps - update parameters and create new var
25
25
  if len(modules) > 1:
26
26
  for m in modules[1:]:
27
27
 
28
28
  # update params
29
- if (not new_var.skip_update):
29
+ if (not new_objective.skip_update):
30
30
  # if new_var.last_module_lrs is not None:
31
31
  # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
32
 
33
- torch._foreach_sub_(params, new_var.get_update())
33
+ torch._foreach_sub_(params, new_objective.get_updates())
34
34
 
35
35
  # create new var since we are at a new point, that means grad, update and loss will be None
36
- new_var = Var(params=new_var.params, closure=new_var.closure,
37
- model=new_var.model, current_step=new_var.current_step + 1)
36
+ new_objective = Objective(params=new_objective.params, closure=new_objective.closure,
37
+ model=new_objective.model, current_step=new_objective.current_step + 1)
38
38
 
39
39
  # step
40
- new_var = m.step(new_var)
40
+ new_objective = m.step(new_objective)
41
41
 
42
42
  # final parameter update
43
- if (not new_var.skip_update):
43
+ if (not new_objective.skip_update):
44
44
  # if new_var.last_module_lrs is not None:
45
45
  # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
46
 
47
- torch._foreach_sub_(params, new_var.get_update())
47
+ torch._foreach_sub_(params, new_objective.get_updates())
48
48
 
49
49
  # if last module, update is applied so return new var
50
50
  # if params_before_steps is None:
@@ -53,13 +53,13 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
53
53
  # return new_var
54
54
 
55
55
  # otherwise use parameter difference as update
56
- var.update = list(torch._foreach_sub(params_before_steps, params))
56
+ objective.updates = list(torch._foreach_sub(params_before_steps, params))
57
57
  for p, bef in zip(params, params_before_steps):
58
58
  p.set_(bef) # pyright:ignore[reportArgumentType]
59
- return var
59
+ return objective
60
60
 
61
61
  class Multistep(Module):
62
- """Performs :code:`steps` inner steps with :code:`module` per each step.
62
+ """Performs ``steps`` inner steps with ``module`` per each step.
63
63
 
64
64
  The update is taken to be the parameter difference between parameters before and after the inner loop."""
65
65
  def __init__(self, module: Chainable, steps: int):
@@ -68,11 +68,11 @@ class Multistep(Module):
68
68
  self.set_child('module', module)
69
69
 
70
70
  @torch.no_grad
71
- def step(self, var):
72
- return _sequential_step(self, var, sequential=False)
71
+ def apply(self, objective):
72
+ return _sequential_step(self, objective, sequential=False)
73
73
 
74
74
  class Sequential(Module):
75
- """On each step, this sequentially steps with :code:`modules` :code:`steps` times.
75
+ """On each step, this sequentially steps with ``modules`` ``steps`` times.
76
76
 
77
77
  The update is taken to be the parameter difference between parameters before and after the inner loop."""
78
78
  def __init__(self, modules: Iterable[Chainable], steps: int=1):
@@ -81,28 +81,28 @@ class Sequential(Module):
81
81
  self.set_children_sequence(modules)
82
82
 
83
83
  @torch.no_grad
84
- def step(self, var):
85
- return _sequential_step(self, var, sequential=True)
84
+ def apply(self, objective):
85
+ return _sequential_step(self, objective, sequential=True)
86
86
 
87
87
 
88
88
  class NegateOnLossIncrease(Module):
89
- """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
90
- if loss is larger than at :code:`parameters`,
91
- the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
89
+ """Uses an extra forward pass to evaluate loss at ``parameters+update``,
90
+ if loss is larger than at ``parameters``,
91
+ the update is set to 0 if ``backtrack=False`` and to ``-update`` otherwise"""
92
92
  def __init__(self, backtrack=False):
93
93
  defaults = dict(backtrack=backtrack)
94
94
  super().__init__(defaults=defaults)
95
95
 
96
96
  @torch.no_grad
97
- def step(self, var):
98
- closure = var.closure
97
+ def apply(self, objective):
98
+ closure = objective.closure
99
99
  if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
100
  backtrack = self.defaults['backtrack']
101
101
 
102
- update = var.get_update()
103
- f_0 = var.get_loss(backward=False)
102
+ update = objective.get_updates()
103
+ f_0 = objective.get_loss(backward=False)
104
104
 
105
- torch._foreach_sub_(var.params, update)
105
+ torch._foreach_sub_(objective.params, update)
106
106
  f_1 = closure(False)
107
107
 
108
108
  if f_1 <= f_0:
@@ -111,15 +111,15 @@ class NegateOnLossIncrease(Module):
111
111
  # var.skip_update = True
112
112
  # return var
113
113
 
114
- torch._foreach_add_(var.params, update)
115
- return var
114
+ torch._foreach_add_(objective.params, update)
115
+ return objective
116
116
 
117
- torch._foreach_add_(var.params, update)
117
+ torch._foreach_add_(objective.params, update)
118
118
  if backtrack:
119
- torch._foreach_neg_(var.update)
119
+ torch._foreach_neg_(objective.updates)
120
120
  else:
121
- torch._foreach_zero_(var.update)
122
- return var
121
+ torch._foreach_zero_(objective.updates)
122
+ return objective
123
123
 
124
124
 
125
125
  class Online(Module):
@@ -129,7 +129,7 @@ class Online(Module):
129
129
 
130
130
  Online L-BFGS with Backtracking line search
131
131
  ```python
132
- opt = tz.Modular(
132
+ opt = tz.Optimizer(
133
133
  model.parameters(),
134
134
  tz.m.Online(tz.m.LBFGS()),
135
135
  tz.m.Backtracking()
@@ -138,57 +138,56 @@ class Online(Module):
138
138
 
139
139
  Online L-BFGS trust region
140
140
  ```python
141
- opt = tz.Modular(
141
+ opt = tz.Optimizer(
142
142
  model.parameters(),
143
143
  tz.m.TrustCG(tz.m.Online(tz.m.LBFGS()))
144
144
  )
145
145
  ```
146
146
 
147
147
  """
148
- def __init__(self, *modules: Module,):
148
+ def __init__(self, module: Module,):
149
149
  super().__init__()
150
-
151
- self.set_child('module', modules)
150
+ self.set_child('module', module)
152
151
 
153
152
  @torch.no_grad
154
- def update(self, var):
155
- closure = var.closure
153
+ def update(self, objective):
154
+ closure = objective.closure
156
155
  if closure is None: raise ValueError("Closure must be passed for Online")
157
156
 
158
157
  step = self.global_state.get('step', 0) + 1
159
158
  self.global_state['step'] = step
160
159
 
161
- params = TensorList(var.params)
160
+ params = TensorList(objective.params)
162
161
  p_cur = params.clone()
163
162
  p_prev = self.get_state(params, 'p_prev', cls=TensorList)
164
163
 
165
164
  module = self.children['module']
166
- var_c = var.clone(clone_update=False)
165
+ var_c = objective.clone(clone_updates=False)
167
166
 
168
167
  # on 1st step just step and store previous params
169
168
  if step == 1:
170
169
  p_prev.copy_(params)
171
170
 
172
171
  module.update(var_c)
173
- var.update_attrs_from_clone_(var_c)
172
+ objective.update_attrs_from_clone_(var_c)
174
173
  return
175
174
 
176
175
  # restore previous params and update
177
- var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
176
+ prev_objective = Objective(params=params, closure=closure, model=objective.model, current_step=objective.current_step)
178
177
  params.set_(p_prev)
179
178
  module.reset_for_online()
180
- module.update(var_prev)
179
+ module.update(prev_objective)
181
180
 
182
181
  # restore current params and update
183
182
  params.set_(p_cur)
184
183
  p_prev.copy_(params)
185
184
  module.update(var_c)
186
- var.update_attrs_from_clone_(var_c)
185
+ objective.update_attrs_from_clone_(var_c)
187
186
 
188
187
  @torch.no_grad
189
- def apply(self, var):
188
+ def apply(self, objective):
190
189
  module = self.children['module']
191
- return module.apply(var.clone(clone_update=False))
190
+ return module.apply(objective.clone(clone_updates=False))
192
191
 
193
- def get_H(self, var):
194
- return self.children['module'].get_H(var)
192
+ def get_H(self, objective):
193
+ return self.children['module'].get_H(objective)
@@ -1,14 +1,14 @@
1
1
  import torch
2
2
 
3
- from ...core import Chainable, Module, Target, Transform
3
+ from ...core import Chainable, Module, Transform
4
4
  from ...core.reformulation import Reformulation
5
- from ...utils import Distributions, NumberList, TensorList
5
+ from ...utils import Distributions, Metrics, NumberList, TensorList, evaluate_metric
6
6
 
7
7
 
8
8
  class Dropout(Transform):
9
9
  """Applies dropout to the update.
10
10
 
11
- For each weight the update to that weight has :code:`p` probability to be set to 0.
11
+ For each weight the update to that weight has ``p`` probability to be set to 0.
12
12
  This can be used to implement gradient dropout or update dropout depending on placement.
13
13
 
14
14
  Args:
@@ -18,36 +18,37 @@ class Dropout(Transform):
18
18
  target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
19
19
 
20
20
 
21
- Examples:
22
- Gradient dropout.
21
+ ### Examples:
23
22
 
24
- .. code-block:: python
23
+ Gradient dropout.
25
24
 
26
- opt = tz.Modular(
27
- model.parameters(),
28
- tz.m.Dropout(0.5),
29
- tz.m.Adam(),
30
- tz.m.LR(1e-3)
31
- )
25
+ ```python
26
+ opt = tz.Optimizer(
27
+ model.parameters(),
28
+ tz.m.Dropout(0.5),
29
+ tz.m.Adam(),
30
+ tz.m.LR(1e-3)
31
+ )
32
+ ```
32
33
 
33
- Update dropout.
34
+ Update dropout.
34
35
 
35
- .. code-block:: python
36
-
37
- opt = tz.Modular(
38
- model.parameters(),
39
- tz.m.Adam(),
40
- tz.m.Dropout(0.5),
41
- tz.m.LR(1e-3)
42
- )
36
+ ``python
37
+ opt = tz.Optimizer(
38
+ model.parameters(),
39
+ tz.m.Adam(),
40
+ tz.m.Dropout(0.5),
41
+ tz.m.LR(1e-3)
42
+ )
43
+ ```
43
44
 
44
45
  """
45
- def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
46
+ def __init__(self, p: float = 0.5, graft: bool=False):
46
47
  defaults = dict(p=p, graft=graft)
47
- super().__init__(defaults, uses_grad=False, target=target)
48
+ super().__init__(defaults)
48
49
 
49
50
  @torch.no_grad
50
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
51
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
51
52
  tensors = TensorList(tensors)
52
53
  p = NumberList(s['p'] for s in settings)
53
54
  graft = settings[0]['graft']
@@ -67,32 +68,31 @@ class WeightDropout(Module):
67
68
  """
68
69
  Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
69
70
 
70
- Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
71
+ Dropout can be disabled for a parameter by setting ``use_dropout=False`` in corresponding parameter group.
71
72
 
72
73
  Args:
73
74
  p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
74
- graft (bool, optional):
75
- if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
76
75
  """
77
- def __init__(self, p: float = 0.5, graft: bool = True):
78
- defaults = dict(p=p, graft=graft, use_dropout=True)
76
+ def __init__(self, p: float = 0.5):
77
+ defaults = dict(p=p, use_dropout=True)
79
78
  super().__init__(defaults)
80
79
 
81
80
  @torch.no_grad
82
- def step(self, var):
83
- closure = var.closure
81
+ def update(self, objective):
82
+ closure = objective.closure
84
83
  if closure is None: raise RuntimeError('WeightDropout requires closure')
85
- params = TensorList(var.params)
84
+ params = TensorList(objective.params)
86
85
  p = NumberList(self.settings[p]['p'] for p in params)
87
86
 
88
87
  # create masks
89
88
  mask = []
90
- for p, m in zip(params, mask):
89
+ for p in params:
91
90
  prob = self.settings[p]['p']
92
91
  use_dropout = self.settings[p]['use_dropout']
93
92
  if use_dropout: mask.append(_bernoulli_like(p, prob))
94
93
  else: mask.append(torch.ones_like(p))
95
94
 
95
+ # create a closure that evaluates masked parameters
96
96
  @torch.no_grad
97
97
  def dropout_closure(backward=True):
98
98
  orig_params = params.clone()
@@ -104,15 +104,14 @@ class WeightDropout(Module):
104
104
  params.copy_(orig_params)
105
105
  return loss
106
106
 
107
- var.closure = dropout_closure
108
- return var
107
+ objective.closure = dropout_closure
109
108
 
110
109
 
111
110
  class PerturbWeights(Module):
112
111
  """
113
112
  Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
114
113
 
115
- Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
114
+ Can be disabled for a parameter by setting ``perturb=False`` in corresponding parameter group.
116
115
 
117
116
  Args:
118
117
  alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
@@ -120,15 +119,22 @@ class PerturbWeights(Module):
120
119
  distribution (bool, optional):
121
120
  distribution of the random perturbation. Defaults to False.
122
121
  """
123
- def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
124
- defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
122
+
123
+ def __init__(
124
+ self,
125
+ alpha: float = 0.1,
126
+ relative: bool = True,
127
+ distribution: Distributions = "normal",
128
+ metric: Metrics = "mad",
129
+ ):
130
+ defaults = dict(alpha=alpha, relative=relative, distribution=distribution, metric=metric, perturb=True)
125
131
  super().__init__(defaults)
126
132
 
127
133
  @torch.no_grad
128
- def step(self, var):
129
- closure = var.closure
134
+ def update(self, objective):
135
+ closure = objective.closure
130
136
  if closure is None: raise RuntimeError('WeightDropout requires closure')
131
- params = TensorList(var.params)
137
+ params = TensorList(objective.params)
132
138
 
133
139
  # create perturbations
134
140
  perts = []
@@ -140,7 +146,7 @@ class PerturbWeights(Module):
140
146
 
141
147
  alpha = settings['alpha']
142
148
  if settings['relative']:
143
- alpha *= p.abs().mean()
149
+ alpha *= evaluate_metric(p, settings["metric"])
144
150
 
145
151
  distribution = self.settings[p]['distribution'].lower()
146
152
  if distribution in ('normal', 'gaussian'):
@@ -163,5 +169,4 @@ class PerturbWeights(Module):
163
169
  params.sub_(perts)
164
170
  return loss
165
171
 
166
- var.closure = perturbed_closure
167
- return var
172
+ objective.closure = perturbed_closure
@@ -1,54 +1,53 @@
1
- import warnings
2
1
  from collections.abc import Callable, Sequence, Iterable
3
2
  from typing import cast
4
3
 
5
4
  import torch
6
5
 
7
- from ...core import Chainable, Module, Var
6
+ from ...core import Chainable, Module, Objective
8
7
 
9
8
 
10
9
  def _split(
11
10
  module: Module,
12
11
  idxs,
13
12
  params,
14
- var: Var,
13
+ objective: Objective,
15
14
  ):
16
15
  split_params = [p for i,p in enumerate(params) if i in idxs]
17
16
 
18
17
  split_grad = None
19
- if var.grad is not None:
20
- split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
18
+ if objective.grads is not None:
19
+ split_grad = [g for i,g in enumerate(objective.grads) if i in idxs]
21
20
 
22
21
  split_update = None
23
- if var.update is not None:
24
- split_update = [u for i,u in enumerate(var.update) if i in idxs]
22
+ if objective.updates is not None:
23
+ split_update = [u for i,u in enumerate(objective.updates) if i in idxs]
25
24
 
26
- split_var = var.clone(clone_update=False, parent=var)
27
- split_var.params = split_params
28
- split_var.grad = split_grad
29
- split_var.update = split_update
25
+ split_obj = objective.clone(clone_updates=False, parent=objective)
26
+ split_obj.params = split_params
27
+ split_obj.grads = split_grad
28
+ split_obj.updates = split_update
30
29
 
31
- split_var = module.step(split_var)
30
+ split_obj = module.step(split_obj)
32
31
 
33
32
  # those should be set due to var being parent
34
- if split_var.grad is not None:
35
- assert var.grad is not None
33
+ if split_obj.grads is not None:
34
+ assert objective.grads is not None
36
35
 
37
- if split_var.loss is not None:
38
- assert var.loss is not None
36
+ if split_obj.loss is not None:
37
+ assert objective.loss is not None
39
38
 
40
- if split_var.update is not None:
39
+ if split_obj.updates is not None:
41
40
 
42
41
  # make sure update is set, it will be filled with ``true`` and ``false`` tensors
43
- if var.update is None:
44
- if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
45
- else: var.update = [g.clone() for g in var.grad]
42
+ if objective.updates is None:
43
+ if objective.grads is None: objective.updates = [cast(torch.Tensor, None) for _ in objective.params]
44
+ else: objective.updates = [g.clone() for g in objective.grads]
46
45
 
47
46
  # set all tensors from this split
48
- for idx, u in zip(idxs, split_var.update):
49
- var.update[idx] = u
47
+ for idx, u in zip(idxs, split_obj.updates):
48
+ objective.updates[idx] = u
50
49
 
51
- return var
50
+ return objective
52
51
 
53
52
  _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
54
53
  Filter = _SingleFilter | Iterable[_SingleFilter]
@@ -82,7 +81,7 @@ class Split(Module):
82
81
  Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
83
82
 
84
83
  ```python
85
- opt = tz.Modular(
84
+ opt = tz.Optimizer(
86
85
  model.parameters(),
87
86
  tz.m.NAG(0.95),
88
87
  tz.m.Split(
@@ -101,9 +100,12 @@ class Split(Module):
101
100
  if true is not None: self.set_child('true', true)
102
101
  if false is not None: self.set_child('false', false)
103
102
 
104
- def step(self, var):
103
+ def update(self, objective): raise RuntimeError
104
+ def apply(self, objective): raise RuntimeError
105
105
 
106
- params = var.params
106
+ def step(self, objective):
107
+
108
+ params = objective.params
107
109
  filter = _make_filter(self.settings[params[0]]['filter'])
108
110
 
109
111
  true_idxs = []
@@ -114,10 +116,10 @@ class Split(Module):
114
116
 
115
117
  if 'true' in self.children and len(true_idxs) > 0:
116
118
  true = self.children['true']
117
- var = _split(true, idxs=true_idxs, params=params, var=var)
119
+ objective = _split(true, idxs=true_idxs, params=params, objective=objective)
118
120
 
119
121
  if 'false' in self.children and len(false_idxs) > 0:
120
122
  false = self.children['false']
121
- var = _split(false, idxs=false_idxs, params=params, var=var)
123
+ objective = _split(false, idxs=false_idxs, params=params, objective=objective)
122
124
 
123
- return var
125
+ return objective
@@ -14,20 +14,21 @@ class Alternate(Module):
14
14
  Args:
15
15
  steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
16
16
 
17
- Examples:
18
- Alternate between Adam, SignSGD and RMSprop
19
-
20
- .. code-block:: python
21
-
22
- opt = tz.Modular(
23
- model.parameters(),
24
- tz.m.Alternate(
25
- tz.m.Adam(),
26
- [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
- tz.m.RMSprop(),
28
- ),
29
- tz.m.LR(1e-3),
30
- )
17
+ ### Examples:
18
+ Alternate between Adam, SignSGD and RMSprop
19
+
20
+ ```python
21
+
22
+ opt = tz.Optimizer(
23
+ model.parameters(),
24
+ tz.m.Alternate(
25
+ tz.m.Adam(),
26
+ [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
+ tz.m.RMSprop(),
28
+ ),
29
+ tz.m.LR(1e-3),
30
+ )
31
+ ```
31
32
  """
32
33
  LOOP = True
33
34
  def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
@@ -43,14 +44,17 @@ class Alternate(Module):
43
44
  self.global_state['current_module_idx'] = 0
44
45
  self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
45
46
 
47
+ def update(self, objective): raise RuntimeError
48
+ def apply(self, objective): raise RuntimeError
49
+
46
50
  @torch.no_grad
47
- def step(self, var):
51
+ def step(self, objective):
48
52
  # get current module
49
53
  current_module_idx = self.global_state.setdefault('current_module_idx', 0)
50
54
  module = self.children[f'module_{current_module_idx}']
51
55
 
52
56
  # step
53
- var = module.step(var.clone(clone_update=False))
57
+ objective = module.step(objective.clone(clone_updates=False))
54
58
 
55
59
  # number of steps until next module
56
60
  steps = self.defaults['steps']
@@ -72,28 +76,29 @@ class Alternate(Module):
72
76
 
73
77
  self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
74
78
 
75
- return var
79
+ return objective
76
80
 
77
81
  class Switch(Alternate):
78
- """After :code:`steps` steps switches to the next module.
82
+ """After ``steps`` steps switches to the next module.
79
83
 
80
84
  Args:
81
85
  steps (int | Iterable[int]): Number of steps to perform with each module.
82
86
 
83
- Examples:
84
- Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
85
-
86
- .. code-block:: python
87
-
88
- opt = tz.Modular(
89
- model.parameters(),
90
- tz.m.Switch(
91
- [tz.m.Adam(), tz.m.LR(1e-3)],
92
- [tz.m.LBFGS(), tz.m.Backtracking()],
93
- [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
94
- steps = (1000, 2000)
95
- )
96
- )
87
+ ### Examples:
88
+
89
+ Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
90
+
91
+ ```python
92
+ opt = tz.Optimizer(
93
+ model.parameters(),
94
+ tz.m.Switch(
95
+ [tz.m.Adam(), tz.m.LR(1e-3)],
96
+ [tz.m.LBFGS(), tz.m.Backtracking()],
97
+ [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
98
+ steps = (1000, 2000)
99
+ )
100
+ )
101
+ ```
97
102
  """
98
103
 
99
104
  LOOP = False