torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
torchzero/core/modular.py CHANGED
@@ -1,27 +1,16 @@
1
1
 
2
2
  import warnings
3
- from abc import ABC, abstractmethod
4
- from collections import ChainMap, defaultdict
5
- from collections.abc import Callable, Iterable, MutableMapping, Sequence
6
- from operator import itemgetter
7
- from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
3
+ from collections import ChainMap
4
+ from collections.abc import MutableMapping
5
+ from typing import Any
8
6
 
9
7
  import torch
10
8
 
11
- from ..utils import (
12
- Init,
13
- ListLike,
14
- Params,
15
- _make_param_groups,
16
- get_state_vals,
17
- vec_to_tensors,
18
- )
19
- from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
20
- from ..utils.linalg.linear_operator import LinearOperator
21
- from ..utils.python_tools import flatten
22
- from .module import Chainable, Module
23
- from .var import Var
9
+ from ..utils.params import Params, _make_param_groups
24
10
  from .functional import step
11
+ from .module import Chainable, Module
12
+ from .objective import Objective
13
+
25
14
 
26
15
  class _EvalCounterClosure:
27
16
  """keeps track of how many times closure has been evaluated, and sets closure return"""
@@ -32,7 +21,7 @@ class _EvalCounterClosure:
32
21
 
33
22
  def __call__(self, *args, **kwargs):
34
23
  if self.closure is None:
35
- raise RuntimeError("One of the modules requires closure to be passed to the step method")
24
+ raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
36
25
 
37
26
  v = self.closure(*args, **kwargs)
38
27
 
@@ -44,17 +33,17 @@ class _EvalCounterClosure:
44
33
  return v
45
34
 
46
35
 
47
- def unroll_modules(*modules: Chainable) -> list[Module]:
48
- unrolled = []
36
+ def flatten_modules(*modules: Chainable) -> list[Module]:
37
+ flat = []
49
38
 
50
39
  for m in modules:
51
40
  if isinstance(m, Module):
52
- unrolled.append(m)
53
- unrolled.extend(unroll_modules(list(m.children.values())))
41
+ flat.append(m)
42
+ flat.extend(flatten_modules(list(m.children.values())))
54
43
  else:
55
- unrolled.extend(unroll_modules(*m))
44
+ flat.extend(flatten_modules(*m))
56
45
 
57
- return unrolled
46
+ return flat
58
47
 
59
48
 
60
49
  # have to inherit from Modular to support lr schedulers
@@ -83,7 +72,7 @@ class Modular(torch.optim.Optimizer):
83
72
  self.modules = modules
84
73
  """Top-level modules providedduring initialization."""
85
74
 
86
- self.unrolled_modules = unroll_modules(self.modules)
75
+ self.flat_modules = flatten_modules(self.modules)
87
76
  """A flattened list of all modules including all children."""
88
77
 
89
78
  param_groups = _make_param_groups(params, differentiable=False)
@@ -92,7 +81,7 @@ class Modular(torch.optim.Optimizer):
92
81
  Each element in the list is ChainDict's 2nd map of a module."""
93
82
 
94
83
  # make sure there is no more than a single learning rate module
95
- lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
84
+ lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
96
85
  if len(lr_modules) > 1:
97
86
  warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
98
87
 
@@ -100,13 +89,13 @@ class Modular(torch.optim.Optimizer):
100
89
  for group in param_groups:
101
90
  for k in group:
102
91
  if k in ('params', 'lr'): continue
103
- modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
92
+ modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
104
93
  if len(modules_with_k) > 1:
105
94
  warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
106
95
 
107
96
  # defaults for schedulers
108
97
  defaults = {}
109
- for m in self.unrolled_modules: defaults.update(m.defaults)
98
+ for m in self.flat_modules: defaults.update(m.defaults)
110
99
  super().__init__(param_groups, defaults=defaults)
111
100
 
112
101
  # note - this is what super().__init__(param_groups, defaults=defaults) does:
@@ -146,7 +135,7 @@ class Modular(torch.optim.Optimizer):
146
135
 
147
136
  for p in proc_param_group['params']:
148
137
  # updates global per-parameter setting overrides (medium priority)
149
- self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
138
+ self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
150
139
 
151
140
  def state_dict(self):
152
141
  all_params = [p for g in self.param_groups for p in g['params']]
@@ -163,7 +152,7 @@ class Modular(torch.optim.Optimizer):
163
152
  "params": all_params,
164
153
  "groups": groups,
165
154
  "defaults": self.defaults,
166
- "modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
155
+ "modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
167
156
  }
168
157
  return state_dict
169
158
 
@@ -183,7 +172,7 @@ class Modular(torch.optim.Optimizer):
183
172
  self.add_param_group(group)
184
173
 
185
174
  id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
186
- for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
175
+ for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
187
176
  m._load_state_dict(sd, id_to_tensor)
188
177
 
189
178
 
@@ -201,35 +190,42 @@ class Modular(torch.optim.Optimizer):
201
190
  if not p.requires_grad: continue
202
191
  for map in self._per_parameter_global_settings[p]: map.update(settings)
203
192
 
204
- # create var
193
+ # create Objective
205
194
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
206
- var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
207
195
 
208
- # if closure is None, assume backward has been called and gather grads
209
- if closure is None:
210
- var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
211
- self.num_evaluations += 1
196
+ counter_closure = None
197
+ if closure is not None:
198
+ counter_closure = _EvalCounterClosure(self, closure)
212
199
 
213
- if len(self.modules) == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
200
+ objective = Objective(
201
+ params=params, closure=counter_closure, model=self.model,
202
+ current_step=self.current_step, modular=self, loss=loss, storage=kwargs
203
+ )
214
204
 
215
- # step
216
- var = step(var, self.modules)
205
+ # step with all modules
206
+ objective = step(objective, self.modules)
217
207
 
218
- # apply update
219
- if not var.skip_update:
220
- with torch.no_grad():
221
- torch._foreach_sub_(params, var.get_update())
208
+ # apply update to parameters unless `objective.skip_update = True`
209
+ # this does:
210
+ # if not objective.skip_update:
211
+ # torch._foreach_sub_(objective.params, objective.get_updates())
212
+ objective.update_parameters()
222
213
 
223
214
  # update attributes
224
- self.attrs.update(var.attrs)
225
- if var.should_terminate is not None: self.should_terminate = var.should_terminate
226
-
227
- # hooks
228
- for hook in var.post_step_hooks:
229
- hook(self, var)
215
+ self.attrs.update(objective.attrs)
216
+ if objective.should_terminate is not None:
217
+ self.should_terminate = objective.should_terminate
230
218
 
231
219
  self.current_step += 1
232
- #return var.loss if var.loss is not None else var.loss_approx
220
+
221
+ # apply hooks
222
+ # this does:
223
+ # for hook in objective.post_step_hooks:
224
+ # hook(objective, modules)
225
+ objective.apply_post_step_hooks(self.modules)
226
+
227
+ # return the first closure evaluation return
228
+ # could return loss if it was passed but that's pointless
233
229
  return self._closure_return
234
230
 
235
231
  def __repr__(self):
torchzero/core/module.py CHANGED
@@ -1,24 +1,18 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
- from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
- from operator import itemgetter
6
- from typing import Any, Literal, cast, final, overload
4
+ from collections.abc import Callable, Iterable, Sequence
5
+ from typing import Any, overload, TYPE_CHECKING
7
6
 
8
7
  import torch
9
8
 
10
- from ..utils import (
11
- Init,
12
- ListLike,
13
- Params,
14
- _make_param_groups,
15
- get_state_vals,
16
- vec_to_tensors,
17
- )
18
- from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
19
- from ..utils.linalg.linear_operator import LinearOperator
20
- from ..utils.python_tools import flatten
21
- from .var import Var
9
+ from ..linalg.linear_operator import LinearOperator
10
+ from ..utils.optimizer import Init, ListLike, get_state_vals
11
+ from ..utils.params import Params, _make_param_groups
12
+ from .functional import step_tensors
13
+
14
+ if TYPE_CHECKING:
15
+ from .objective import Objective
22
16
 
23
17
 
24
18
  class Module(ABC):
@@ -36,6 +30,7 @@ class Module(ABC):
36
30
  """
37
31
  def __init__(self, defaults: dict[str, Any] | None = None):
38
32
  if defaults is None: defaults = {}
33
+ if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
39
34
  self.defaults: dict[str, Any] = defaults
40
35
 
41
36
  # settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
@@ -55,7 +50,7 @@ class Module(ABC):
55
50
  """A dictionary of child modules."""
56
51
 
57
52
  self._overridden_keys = set()
58
- """tracks keys overridden with `set_param_groups`, only used to not give a warning"""
53
+ """tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
59
54
 
60
55
 
61
56
  def set_param_groups(self, param_groups: Params):
@@ -71,7 +66,12 @@ class Module(ABC):
71
66
  self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
72
67
  return self
73
68
 
74
- def set_child(self, key: str, module: "Module | Sequence[Module]"):
69
+ def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
70
+ if key in self.children:
71
+ warnings.warn(f"set_child overwriting child `{key}`")
72
+
73
+ if module is None: return
74
+
75
75
  from .chain import maybe_chain
76
76
  self.children[key] = maybe_chain(module)
77
77
 
@@ -85,6 +85,62 @@ class Module(ABC):
85
85
  def get_children_sequence(self, prefix = 'module_'):
86
86
  return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
87
87
 
88
+ def inner_step(
89
+ self,
90
+ key: str,
91
+ objective: "Objective",
92
+ must_exist: bool = True,
93
+ ) -> "Objective":
94
+ """Passes ``objective`` to child and returns it."""
95
+ child = self.children.get(key, None)
96
+
97
+ if child is None:
98
+ if must_exist: raise KeyError(f"child `{key}` doesn't exist")
99
+ return objective
100
+
101
+ return child.step(objective)
102
+
103
+
104
+ def inner_step_tensors(
105
+ self,
106
+ key: str,
107
+ tensors: list[torch.Tensor],
108
+ clone: bool,
109
+ params: Iterable[torch.Tensor] | None = None,
110
+ grads: Sequence[torch.Tensor] | None = None,
111
+ loss: torch.Tensor | None = None,
112
+ closure: Callable | None = None,
113
+ objective: "Objective | None" = None,
114
+ must_exist: bool = True
115
+ ) -> list[torch.Tensor]:
116
+ """Steps with child module. Can be used to apply transforms to any internal buffers.
117
+
118
+ If ``objective`` is specified, other attributes shouldn't to be specified.
119
+
120
+ Args:
121
+ key (str): Child module key.
122
+ tensors (Sequence[torch.Tensor]): tensors to pass to child module.
123
+ clone (bool):
124
+ If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
125
+ If ``key`` doesn't exist, ``tensors`` are always returned without cloning
126
+ params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
127
+ grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
128
+ loss (torch.Tensor | None, optional): loss. Defaults to None.
129
+ closure (Callable | None, optional): closure. Defaults to None.
130
+ must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
131
+ """
132
+
133
+ child = self.children.get(key, None)
134
+
135
+ if child is None:
136
+ if must_exist: raise KeyError(f"child `{key}` doesn't exist")
137
+ return tensors
138
+
139
+ if clone: tensors = [t.clone() for t in tensors]
140
+ return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
141
+ loss=loss, closure=closure, objective=objective)
142
+
143
+
88
144
  def __repr__(self):
89
145
  s = self.__class__.__name__
90
146
  if self.children:
@@ -106,7 +162,6 @@ class Module(ABC):
106
162
 
107
163
  def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
108
164
  *keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
109
- # if isinstance(params, Vars): params = params.params
110
165
  return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
111
166
 
112
167
 
@@ -176,13 +231,8 @@ class Module(ABC):
176
231
  - if state_keys has multiple keys and keys has a single key, return cls.
177
232
  - if state_keys has multiple keys and keys has multiple keys, return list of cls.
178
233
  """
179
- # if isinstance(params, Vars): params = params.params
180
234
  return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
181
235
 
182
- # def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
183
- # # if isinstance(params, Vars): params = params.params
184
- # return itemgetter(*keys)(self.settings[params[0]])
185
-
186
236
  def clear_state_keys(self, *keys:str):
187
237
  for s in self.state.values():
188
238
  for k in keys:
@@ -248,36 +298,73 @@ class Module(ABC):
248
298
  # extra info
249
299
  self._extra_unpack(state_dict['extra'])
250
300
 
251
- # ---------------------------- OVERRIDABLE METHODS --------------------------- #
252
- def step(self, var: Var) -> Var:
253
- """performs a step, returns new ``var`` but may update it in-place."""
254
- self.update(var)
255
- return self.apply(var)
301
+ def get_generator(self, device: torch.types.Device, seed: int | None):
302
+ """If ``seed=None``, returns ``None``.
303
+
304
+ Otherwise, if generator on this device and with this seed hasn't been created,
305
+ creates it and stores in global state.
306
+
307
+ Returns ``torch.Generator``."""
308
+ if seed is None: return None
256
309
 
257
- def update(self, var:Var) -> Any:
258
- """Updates the internal state of this module. This should not modify ``var.update``.
310
+ if device is None: device_obj = torch.get_default_device()
311
+ else: device_obj = torch.device(device)
312
+ key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"
313
+
314
+ if key not in self.global_state:
315
+ self.global_state[key] = torch.Generator(device).manual_seed(seed)
316
+
317
+ return self.global_state[key]
318
+
319
+ def increment_counter(self, key: str, start: int):
320
+ """first value is ``start``"""
321
+ value = self.global_state.get(key, start - 1) + 1
322
+ self.global_state[key] = value
323
+ return value
324
+
325
+ # ---------------------------- OVERRIDABLE METHODS --------------------------- #
326
+ def update(self, objective:"Objective") -> None:
327
+ """Updates internal state of this module. This should not modify ``objective.update``.
259
328
 
260
329
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
261
- such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
330
+ such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
331
+
332
+ ``update`` is guaranteed to be called at least once before ``apply``.
333
+
334
+ Args:
335
+ objective (Objective): ``Objective`` object
262
336
  """
263
337
 
264
- def apply(self, var: Var) -> Var:
265
- """Applies this module to ``var.get_update()``.
266
- This should not modify the internal state of this module if possible.
338
+ @abstractmethod
339
+ def apply(self, objective: "Objective") -> "Objective":
340
+ """Updates ``objective`` using the internal state of this module.
341
+
342
+ If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.
267
343
 
268
344
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
269
- such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
345
+ such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
346
+
347
+ ``update`` is guaranteed to be called at least once before ``apply``.
348
+
349
+ Args:
350
+ objective (Objective): ``Objective`` object
270
351
  """
271
- return self.step(var)
352
+ # if apply is empty, it should be defined explicitly.
353
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")
354
+
355
+ def step(self, objective: "Objective") -> "Objective":
356
+ """Perform a step with this module. Calls ``update``, then ``apply``."""
357
+ self.update(objective)
358
+ return self.apply(objective)
272
359
 
273
- def get_H(self, var: Var) -> LinearOperator | None:
360
+ def get_H(self, objective: "Objective") -> LinearOperator | None:
274
361
  """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
275
362
  The hessian approximation is assumed to be for all parameters concatenated to a vector."""
276
363
  # if this method is not defined it searches in children
277
364
  # this should be overwritten to return None if child params are different from this modules params
278
365
  H = None
279
366
  for k,v in self.children.items():
280
- H_v = v.get_H(var)
367
+ H_v = v.get_H(objective)
281
368
 
282
369
  if (H is not None) and (H_v is not None):
283
370
  raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
@@ -307,21 +394,14 @@ class Module(ABC):
307
394
  """
308
395
  for c in self.children.values(): c.reset_for_online()
309
396
 
310
- def _extra_pack(self):
311
- """extra information to store in state_dict of this optimizer.
312
- Will be passed to ``_extra_unpack`` when loading the state_dict."""
397
+ def _extra_pack(self) -> dict:
398
+ """extra information to store in ``state_dict`` of this optimizer.
399
+ Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
313
400
  return {}
314
401
 
315
- def _extra_unpack(self, x):
316
- """``_extra_pack`` return will be passed to this method when loading state_dict.
402
+ def _extra_unpack(self, d: dict):
403
+ """``_extra_pack`` return will be passed to this method when loading ``state_dict``.
317
404
  This method is called after loading the rest of the state dict"""
318
405
 
319
- def get_generator(self, device: torch.types.Device, seed: int | None):
320
- if seed is None: return None
321
-
322
- if 'generator' not in self.global_state:
323
- self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
324
-
325
- return self.global_state['generator']
326
406
 
327
407
  Chainable = Module | Sequence[Module]