torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,49 +2,49 @@ from collections.abc import Iterable
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Module, Var
5
+ from ...core import Chainable, Module, Objective
6
6
  from ...utils import TensorList
7
7
 
8
- def _sequential_step(self: Module, var: Var, sequential: bool):
9
- params = var.params
8
+ def _sequential_step(self: Module, objective: Objective, sequential: bool):
9
+ params = objective.params
10
10
  steps = self.settings[params[0]]['steps']
11
11
 
12
- if sequential: modules = self.get_children_sequence() * steps
12
+ if sequential: modules: list[Module] = self.get_children_sequence() * steps
13
13
  else: modules = [self.children['module']] * steps
14
14
 
15
- if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
15
+ if objective.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
16
 
17
17
  # store original params unless this is last module and can update params directly
18
18
  params_before_steps = [p.clone() for p in params]
19
19
 
20
20
  # first step - pass var as usual
21
- var = modules[0].step(var)
22
- new_var = var
21
+ objective = modules[0].step(objective)
22
+ new_objective = objective
23
23
 
24
24
  # subsequent steps - update parameters and create new var
25
25
  if len(modules) > 1:
26
26
  for m in modules[1:]:
27
27
 
28
28
  # update params
29
- if (not new_var.skip_update):
29
+ if (not new_objective.skip_update):
30
30
  # if new_var.last_module_lrs is not None:
31
31
  # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
32
 
33
- torch._foreach_sub_(params, new_var.get_update())
33
+ torch._foreach_sub_(params, new_objective.get_updates())
34
34
 
35
35
  # create new var since we are at a new point, that means grad, update and loss will be None
36
- new_var = Var(params=new_var.params, closure=new_var.closure,
37
- model=new_var.model, current_step=new_var.current_step + 1)
36
+ new_objective = Objective(params=new_objective.params, closure=new_objective.closure,
37
+ model=new_objective.model, current_step=new_objective.current_step + 1)
38
38
 
39
39
  # step
40
- new_var = m.step(new_var)
40
+ new_objective = m.step(new_objective)
41
41
 
42
42
  # final parameter update
43
- if (not new_var.skip_update):
43
+ if (not new_objective.skip_update):
44
44
  # if new_var.last_module_lrs is not None:
45
45
  # torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
46
 
47
- torch._foreach_sub_(params, new_var.get_update())
47
+ torch._foreach_sub_(params, new_objective.get_updates())
48
48
 
49
49
  # if last module, update is applied so return new var
50
50
  # if params_before_steps is None:
@@ -53,13 +53,13 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
53
53
  # return new_var
54
54
 
55
55
  # otherwise use parameter difference as update
56
- var.update = list(torch._foreach_sub(params_before_steps, params))
56
+ objective.updates = list(torch._foreach_sub(params_before_steps, params))
57
57
  for p, bef in zip(params, params_before_steps):
58
58
  p.set_(bef) # pyright:ignore[reportArgumentType]
59
- return var
59
+ return objective
60
60
 
61
61
  class Multistep(Module):
62
- """Performs :code:`steps` inner steps with :code:`module` per each step.
62
+ """Performs ``steps`` inner steps with ``module`` per each step.
63
63
 
64
64
  The update is taken to be the parameter difference between parameters before and after the inner loop."""
65
65
  def __init__(self, module: Chainable, steps: int):
@@ -68,11 +68,11 @@ class Multistep(Module):
68
68
  self.set_child('module', module)
69
69
 
70
70
  @torch.no_grad
71
- def step(self, var):
72
- return _sequential_step(self, var, sequential=False)
71
+ def apply(self, objective):
72
+ return _sequential_step(self, objective, sequential=False)
73
73
 
74
74
  class Sequential(Module):
75
- """On each step, this sequentially steps with :code:`modules` :code:`steps` times.
75
+ """On each step, this sequentially steps with ``modules`` ``steps`` times.
76
76
 
77
77
  The update is taken to be the parameter difference between parameters before and after the inner loop."""
78
78
  def __init__(self, modules: Iterable[Chainable], steps: int=1):
@@ -81,28 +81,28 @@ class Sequential(Module):
81
81
  self.set_children_sequence(modules)
82
82
 
83
83
  @torch.no_grad
84
- def step(self, var):
85
- return _sequential_step(self, var, sequential=True)
84
+ def apply(self, objective):
85
+ return _sequential_step(self, objective, sequential=True)
86
86
 
87
87
 
88
88
  class NegateOnLossIncrease(Module):
89
- """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
90
- if loss is larger than at :code:`parameters`,
91
- the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
89
+ """Uses an extra forward pass to evaluate loss at ``parameters+update``,
90
+ if loss is larger than at ``parameters``,
91
+ the update is set to 0 if ``backtrack=False`` and to ``-update`` otherwise"""
92
92
  def __init__(self, backtrack=False):
93
93
  defaults = dict(backtrack=backtrack)
94
94
  super().__init__(defaults=defaults)
95
95
 
96
96
  @torch.no_grad
97
- def step(self, var):
98
- closure = var.closure
97
+ def apply(self, objective):
98
+ closure = objective.closure
99
99
  if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
100
  backtrack = self.defaults['backtrack']
101
101
 
102
- update = var.get_update()
103
- f_0 = var.get_loss(backward=False)
102
+ update = objective.get_updates()
103
+ f_0 = objective.get_loss(backward=False)
104
104
 
105
- torch._foreach_sub_(var.params, update)
105
+ torch._foreach_sub_(objective.params, update)
106
106
  f_1 = closure(False)
107
107
 
108
108
  if f_1 <= f_0:
@@ -111,15 +111,15 @@ class NegateOnLossIncrease(Module):
111
111
  # var.skip_update = True
112
112
  # return var
113
113
 
114
- torch._foreach_add_(var.params, update)
115
- return var
114
+ torch._foreach_add_(objective.params, update)
115
+ return objective
116
116
 
117
- torch._foreach_add_(var.params, update)
117
+ torch._foreach_add_(objective.params, update)
118
118
  if backtrack:
119
- torch._foreach_neg_(var.update)
119
+ torch._foreach_neg_(objective.updates)
120
120
  else:
121
- torch._foreach_zero_(var.update)
122
- return var
121
+ torch._foreach_zero_(objective.updates)
122
+ return objective
123
123
 
124
124
 
125
125
  class Online(Module):
@@ -147,48 +147,50 @@ class Online(Module):
147
147
  """
148
148
  def __init__(self, *modules: Module,):
149
149
  super().__init__()
150
+ if len(modules) == 0:
151
+ raise RuntimeError("Online got empty list of modules. To make a module online, wrap it in tz.m.Online, e.g. `tz.m.Online(tz.m.LBFGS())`")
150
152
 
151
153
  self.set_child('module', modules)
152
154
 
153
155
  @torch.no_grad
154
- def update(self, var):
155
- closure = var.closure
156
+ def update(self, objective):
157
+ closure = objective.closure
156
158
  if closure is None: raise ValueError("Closure must be passed for Online")
157
159
 
158
160
  step = self.global_state.get('step', 0) + 1
159
161
  self.global_state['step'] = step
160
162
 
161
- params = TensorList(var.params)
163
+ params = TensorList(objective.params)
162
164
  p_cur = params.clone()
163
165
  p_prev = self.get_state(params, 'p_prev', cls=TensorList)
164
166
 
165
167
  module = self.children['module']
166
- var_c = var.clone(clone_update=False)
168
+ var_c = objective.clone(clone_updates=False)
167
169
 
168
170
  # on 1st step just step and store previous params
169
171
  if step == 1:
170
172
  p_prev.copy_(params)
171
173
 
172
174
  module.update(var_c)
173
- var.update_attrs_from_clone_(var_c)
175
+ objective.update_attrs_from_clone_(var_c)
174
176
  return
175
177
 
176
178
  # restore previous params and update
177
- var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
179
+ prev_objective = Objective(params=params, closure=closure, model=objective.model, current_step=objective.current_step)
178
180
  params.set_(p_prev)
179
181
  module.reset_for_online()
180
- module.update(var_prev)
182
+ module.update(prev_objective)
181
183
 
182
184
  # restore current params and update
183
185
  params.set_(p_cur)
184
186
  p_prev.copy_(params)
185
187
  module.update(var_c)
186
- var.update_attrs_from_clone_(var_c)
188
+ objective.update_attrs_from_clone_(var_c)
187
189
 
188
190
  @torch.no_grad
189
- def apply(self, var):
191
+ def apply(self, objective):
190
192
  module = self.children['module']
191
- return module.apply(var.clone(clone_update=False))
193
+ return module.apply(objective.clone(clone_updates=False))
192
194
 
193
- def get_H(self, var):
194
- return self.children['module'].get_H(var)
195
+ def get_H(self, objective):
196
+ return self.children['module'].get_H(objective)
@@ -1,14 +1,14 @@
1
1
  import torch
2
2
 
3
- from ...core import Chainable, Module, Target, Transform
3
+ from ...core import Chainable, Module, Transform
4
4
  from ...core.reformulation import Reformulation
5
- from ...utils import Distributions, NumberList, TensorList
5
+ from ...utils import Distributions, Metrics, NumberList, TensorList, evaluate_metric
6
6
 
7
7
 
8
8
  class Dropout(Transform):
9
9
  """Applies dropout to the update.
10
10
 
11
- For each weight the update to that weight has :code:`p` probability to be set to 0.
11
+ For each weight the update to that weight has ``p`` probability to be set to 0.
12
12
  This can be used to implement gradient dropout or update dropout depending on placement.
13
13
 
14
14
  Args:
@@ -18,36 +18,37 @@ class Dropout(Transform):
18
18
  target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
19
19
 
20
20
 
21
- Examples:
22
- Gradient dropout.
21
+ ### Examples:
23
22
 
24
- .. code-block:: python
23
+ Gradient dropout.
25
24
 
26
- opt = tz.Modular(
27
- model.parameters(),
28
- tz.m.Dropout(0.5),
29
- tz.m.Adam(),
30
- tz.m.LR(1e-3)
31
- )
25
+ ```python
26
+ opt = tz.Modular(
27
+ model.parameters(),
28
+ tz.m.Dropout(0.5),
29
+ tz.m.Adam(),
30
+ tz.m.LR(1e-3)
31
+ )
32
+ ```
32
33
 
33
- Update dropout.
34
+ Update dropout.
34
35
 
35
- .. code-block:: python
36
-
37
- opt = tz.Modular(
38
- model.parameters(),
39
- tz.m.Adam(),
40
- tz.m.Dropout(0.5),
41
- tz.m.LR(1e-3)
42
- )
36
+ ``python
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ tz.m.Adam(),
40
+ tz.m.Dropout(0.5),
41
+ tz.m.LR(1e-3)
42
+ )
43
+ ```
43
44
 
44
45
  """
45
- def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
46
+ def __init__(self, p: float = 0.5, graft: bool=False):
46
47
  defaults = dict(p=p, graft=graft)
47
- super().__init__(defaults, uses_grad=False, target=target)
48
+ super().__init__(defaults)
48
49
 
49
50
  @torch.no_grad
50
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
51
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
51
52
  tensors = TensorList(tensors)
52
53
  p = NumberList(s['p'] for s in settings)
53
54
  graft = settings[0]['graft']
@@ -67,32 +68,31 @@ class WeightDropout(Module):
67
68
  """
68
69
  Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
69
70
 
70
- Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
71
+ Dropout can be disabled for a parameter by setting ``use_dropout=False`` in corresponding parameter group.
71
72
 
72
73
  Args:
73
74
  p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
74
- graft (bool, optional):
75
- if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
76
75
  """
77
- def __init__(self, p: float = 0.5, graft: bool = True):
78
- defaults = dict(p=p, graft=graft, use_dropout=True)
76
+ def __init__(self, p: float = 0.5):
77
+ defaults = dict(p=p, use_dropout=True)
79
78
  super().__init__(defaults)
80
79
 
81
80
  @torch.no_grad
82
- def step(self, var):
83
- closure = var.closure
81
+ def update(self, objective):
82
+ closure = objective.closure
84
83
  if closure is None: raise RuntimeError('WeightDropout requires closure')
85
- params = TensorList(var.params)
84
+ params = TensorList(objective.params)
86
85
  p = NumberList(self.settings[p]['p'] for p in params)
87
86
 
88
87
  # create masks
89
88
  mask = []
90
- for p, m in zip(params, mask):
89
+ for p in params:
91
90
  prob = self.settings[p]['p']
92
91
  use_dropout = self.settings[p]['use_dropout']
93
92
  if use_dropout: mask.append(_bernoulli_like(p, prob))
94
93
  else: mask.append(torch.ones_like(p))
95
94
 
95
+ # create a closure that evaluates masked parameters
96
96
  @torch.no_grad
97
97
  def dropout_closure(backward=True):
98
98
  orig_params = params.clone()
@@ -104,15 +104,14 @@ class WeightDropout(Module):
104
104
  params.copy_(orig_params)
105
105
  return loss
106
106
 
107
- var.closure = dropout_closure
108
- return var
107
+ objective.closure = dropout_closure
109
108
 
110
109
 
111
110
  class PerturbWeights(Module):
112
111
  """
113
112
  Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
114
113
 
115
- Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
114
+ Can be disabled for a parameter by setting ``perturb=False`` in corresponding parameter group.
116
115
 
117
116
  Args:
118
117
  alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
@@ -120,15 +119,22 @@ class PerturbWeights(Module):
120
119
  distribution (bool, optional):
121
120
  distribution of the random perturbation. Defaults to False.
122
121
  """
123
- def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
124
- defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
122
+
123
+ def __init__(
124
+ self,
125
+ alpha: float = 0.1,
126
+ relative: bool = True,
127
+ distribution: Distributions = "normal",
128
+ metric: Metrics = "mad",
129
+ ):
130
+ defaults = dict(alpha=alpha, relative=relative, distribution=distribution, metric=metric, perturb=True)
125
131
  super().__init__(defaults)
126
132
 
127
133
  @torch.no_grad
128
- def step(self, var):
129
- closure = var.closure
134
+ def update(self, objective):
135
+ closure = objective.closure
130
136
  if closure is None: raise RuntimeError('WeightDropout requires closure')
131
- params = TensorList(var.params)
137
+ params = TensorList(objective.params)
132
138
 
133
139
  # create perturbations
134
140
  perts = []
@@ -140,7 +146,7 @@ class PerturbWeights(Module):
140
146
 
141
147
  alpha = settings['alpha']
142
148
  if settings['relative']:
143
- alpha *= p.abs().mean()
149
+ alpha *= evaluate_metric(p, settings["metric"])
144
150
 
145
151
  distribution = self.settings[p]['distribution'].lower()
146
152
  if distribution in ('normal', 'gaussian'):
@@ -163,5 +169,4 @@ class PerturbWeights(Module):
163
169
  params.sub_(perts)
164
170
  return loss
165
171
 
166
- var.closure = perturbed_closure
167
- return var
172
+ objective.closure = perturbed_closure
@@ -1,54 +1,53 @@
1
- import warnings
2
1
  from collections.abc import Callable, Sequence, Iterable
3
2
  from typing import cast
4
3
 
5
4
  import torch
6
5
 
7
- from ...core import Chainable, Module, Var
6
+ from ...core import Chainable, Module, Objective
8
7
 
9
8
 
10
9
  def _split(
11
10
  module: Module,
12
11
  idxs,
13
12
  params,
14
- var: Var,
13
+ objective: Objective,
15
14
  ):
16
15
  split_params = [p for i,p in enumerate(params) if i in idxs]
17
16
 
18
17
  split_grad = None
19
- if var.grad is not None:
20
- split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
18
+ if objective.grads is not None:
19
+ split_grad = [g for i,g in enumerate(objective.grads) if i in idxs]
21
20
 
22
21
  split_update = None
23
- if var.update is not None:
24
- split_update = [u for i,u in enumerate(var.update) if i in idxs]
22
+ if objective.updates is not None:
23
+ split_update = [u for i,u in enumerate(objective.updates) if i in idxs]
25
24
 
26
- split_var = var.clone(clone_update=False, parent=var)
27
- split_var.params = split_params
28
- split_var.grad = split_grad
29
- split_var.update = split_update
25
+ split_obj = objective.clone(clone_updates=False, parent=objective)
26
+ split_obj.params = split_params
27
+ split_obj.grads = split_grad
28
+ split_obj.updates = split_update
30
29
 
31
- split_var = module.step(split_var)
30
+ split_obj = module.step(split_obj)
32
31
 
33
32
  # those should be set due to var being parent
34
- if split_var.grad is not None:
35
- assert var.grad is not None
33
+ if split_obj.grads is not None:
34
+ assert objective.grads is not None
36
35
 
37
- if split_var.loss is not None:
38
- assert var.loss is not None
36
+ if split_obj.loss is not None:
37
+ assert objective.loss is not None
39
38
 
40
- if split_var.update is not None:
39
+ if split_obj.updates is not None:
41
40
 
42
41
  # make sure update is set, it will be filled with ``true`` and ``false`` tensors
43
- if var.update is None:
44
- if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
45
- else: var.update = [g.clone() for g in var.grad]
42
+ if objective.updates is None:
43
+ if objective.grads is None: objective.updates = [cast(torch.Tensor, None) for _ in objective.params]
44
+ else: objective.updates = [g.clone() for g in objective.grads]
46
45
 
47
46
  # set all tensors from this split
48
- for idx, u in zip(idxs, split_var.update):
49
- var.update[idx] = u
47
+ for idx, u in zip(idxs, split_obj.updates):
48
+ objective.updates[idx] = u
50
49
 
51
- return var
50
+ return objective
52
51
 
53
52
  _SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
54
53
  Filter = _SingleFilter | Iterable[_SingleFilter]
@@ -101,9 +100,12 @@ class Split(Module):
101
100
  if true is not None: self.set_child('true', true)
102
101
  if false is not None: self.set_child('false', false)
103
102
 
104
- def step(self, var):
103
+ def update(self, objective): raise RuntimeError
104
+ def apply(self, objective): raise RuntimeError
105
105
 
106
- params = var.params
106
+ def step(self, objective):
107
+
108
+ params = objective.params
107
109
  filter = _make_filter(self.settings[params[0]]['filter'])
108
110
 
109
111
  true_idxs = []
@@ -114,10 +116,10 @@ class Split(Module):
114
116
 
115
117
  if 'true' in self.children and len(true_idxs) > 0:
116
118
  true = self.children['true']
117
- var = _split(true, idxs=true_idxs, params=params, var=var)
119
+ objective = _split(true, idxs=true_idxs, params=params, objective=objective)
118
120
 
119
121
  if 'false' in self.children and len(false_idxs) > 0:
120
122
  false = self.children['false']
121
- var = _split(false, idxs=false_idxs, params=params, var=var)
123
+ objective = _split(false, idxs=false_idxs, params=params, objective=objective)
122
124
 
123
- return var
125
+ return objective
@@ -14,20 +14,21 @@ class Alternate(Module):
14
14
  Args:
15
15
  steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
16
16
 
17
- Examples:
18
- Alternate between Adam, SignSGD and RMSprop
19
-
20
- .. code-block:: python
21
-
22
- opt = tz.Modular(
23
- model.parameters(),
24
- tz.m.Alternate(
25
- tz.m.Adam(),
26
- [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
- tz.m.RMSprop(),
28
- ),
29
- tz.m.LR(1e-3),
30
- )
17
+ ### Examples:
18
+ Alternate between Adam, SignSGD and RMSprop
19
+
20
+ ```python
21
+
22
+ opt = tz.Modular(
23
+ model.parameters(),
24
+ tz.m.Alternate(
25
+ tz.m.Adam(),
26
+ [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
+ tz.m.RMSprop(),
28
+ ),
29
+ tz.m.LR(1e-3),
30
+ )
31
+ ```
31
32
  """
32
33
  LOOP = True
33
34
  def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
@@ -43,14 +44,17 @@ class Alternate(Module):
43
44
  self.global_state['current_module_idx'] = 0
44
45
  self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
45
46
 
47
+ def update(self, objective): raise RuntimeError
48
+ def apply(self, objective): raise RuntimeError
49
+
46
50
  @torch.no_grad
47
- def step(self, var):
51
+ def step(self, objective):
48
52
  # get current module
49
53
  current_module_idx = self.global_state.setdefault('current_module_idx', 0)
50
54
  module = self.children[f'module_{current_module_idx}']
51
55
 
52
56
  # step
53
- var = module.step(var.clone(clone_update=False))
57
+ objective = module.step(objective.clone(clone_updates=False))
54
58
 
55
59
  # number of steps until next module
56
60
  steps = self.defaults['steps']
@@ -72,28 +76,29 @@ class Alternate(Module):
72
76
 
73
77
  self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
74
78
 
75
- return var
79
+ return objective
76
80
 
77
81
  class Switch(Alternate):
78
- """After :code:`steps` steps switches to the next module.
82
+ """After ``steps`` steps switches to the next module.
79
83
 
80
84
  Args:
81
85
  steps (int | Iterable[int]): Number of steps to perform with each module.
82
86
 
83
- Examples:
84
- Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
85
-
86
- .. code-block:: python
87
-
88
- opt = tz.Modular(
89
- model.parameters(),
90
- tz.m.Switch(
91
- [tz.m.Adam(), tz.m.LR(1e-3)],
92
- [tz.m.LBFGS(), tz.m.Backtracking()],
93
- [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
94
- steps = (1000, 2000)
95
- )
96
- )
87
+ ### Examples:
88
+
89
+ Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
90
+
91
+ ```python
92
+ opt = tz.Modular(
93
+ model.parameters(),
94
+ tz.m.Switch(
95
+ [tz.m.Adam(), tz.m.LR(1e-3)],
96
+ [tz.m.LBFGS(), tz.m.Backtracking()],
97
+ [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
98
+ steps = (1000, 2000)
99
+ )
100
+ )
101
+ ```
97
102
  """
98
103
 
99
104
  LOOP = False