torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_tensorlist.py CHANGED
@@ -1567,13 +1567,6 @@ def test_where(simple_tl: TensorList):
1567
1567
  assert_tl_allclose(result_module, expected_tl)
1568
1568
 
1569
1569
 
1570
- # Test inplace where_ (needs TensorList other)
1571
- tl_copy = simple_tl.clone()
1572
- result_inplace = tl_copy.where_(condition_tl, other_tl)
1573
- assert result_inplace is tl_copy
1574
- assert_tl_allclose(tl_copy, expected_tl)
1575
-
1576
-
1577
1570
  def test_masked_fill(simple_tl: TensorList):
1578
1571
  mask_tl = simple_tl.lt(0)
1579
1572
  fill_value_scalar = 99.0
@@ -1600,7 +1593,6 @@ def test_select_set_(simple_tl: TensorList):
1600
1593
  mask_tl = simple_tl.gt(0.5)
1601
1594
  value_scalar = -1.0
1602
1595
  value_list_scalar = [-1.0, -2.0, -3.0]
1603
- value_tl = simple_tl.clone().mul_(0.1)
1604
1596
 
1605
1597
  # Set with scalar value
1606
1598
  tl_copy_scalar = simple_tl.clone()
@@ -4,7 +4,6 @@ from functools import partial
4
4
  import pytest
5
5
  import torch
6
6
  from torchzero.utils.optimizer import (
7
- Optimizer,
8
7
  get_group_vals,
9
8
  get_params,
10
9
  get_state_vals,
torchzero/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from . import core, optim, utils
2
- from .core import Modular
3
- from .utils import set_compilation
2
+ from .core import Optimizer
3
+ from .utils.compile import enable_compilation
4
4
  from . import modules as m
@@ -1,5 +1,8 @@
1
- from .chain import Chain, maybe_chain
2
- from .modular import Modular
1
+ from .transform import TensorTransform, Transform
3
2
  from .module import Chainable, Module
4
- from .transform import Target, TensorwiseTransform, Transform, apply_transform
5
- from .var import Var
3
+ from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
4
+
5
+ # order is important to avoid circular imports
6
+ from .modular import Optimizer
7
+ from .functional import apply, step, step_tensors, update
8
+ from .chain import Chain, maybe_chain
torchzero/core/chain.py CHANGED
@@ -2,36 +2,33 @@ from collections.abc import Iterable
2
2
 
3
3
  from ..utils.python_tools import flatten
4
4
  from .module import Module, Chainable
5
-
5
+ from .functional import _chain_step
6
6
 
7
7
  class Chain(Module):
8
- """Chain of modules, mostly used internally"""
8
+ """Chain modules, mostly used internally"""
9
9
  def __init__(self, *modules: Module | Iterable[Module]):
10
10
  super().__init__()
11
11
  flat_modules: list[Module] = flatten(modules)
12
12
  for i, module in enumerate(flat_modules):
13
13
  self.set_child(f'module_{i}', module)
14
14
 
15
- def update(self, var):
16
- # note here that `update` and `apply` shouldn't be used directly
17
- # as it will update all modules, and then apply all modules
18
- # it is used in specific cases like Chain as trust region hessian module
19
- for i in range(len(self.children)):
20
- self.children[f'module_{i}'].update(var)
21
- if var.stop: break
22
- return var
23
-
24
- def apply(self, var):
25
- for i in range(len(self.children)):
26
- var = self.children[f'module_{i}'].apply(var)
27
- if var.stop: break
28
- return var
29
-
30
- def step(self, var):
31
- for i in range(len(self.children)):
32
- var = self.children[f'module_{i}'].step(var)
33
- if var.stop: break
34
- return var
15
+ def update(self, objective):
16
+ if len(self.children) > 1:
17
+ raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
18
+
19
+ if len(self.children) == 0: return
20
+ return self.children['module_0'].update(objective)
21
+
22
+ def apply(self, objective):
23
+ if len(self.children) > 1:
24
+ raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
25
+
26
+ if len(self.children) == 0: return objective
27
+ return self.children['module_0'].apply(objective)
28
+
29
+ def step(self, objective):
30
+ children = [self.children[f'module_{i}'] for i in range(len(self.children))]
31
+ return _chain_step(objective, children)
35
32
 
36
33
  def __repr__(self):
37
34
  s = self.__class__.__name__
@@ -41,7 +38,7 @@ class Chain(Module):
41
38
  return s
42
39
 
43
40
  def maybe_chain(*modules: Chainable) -> Module:
44
- """Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
41
+ """Returns a single module directly if only one is provided, otherwise wraps them in a ``Chain``."""
45
42
  flat_modules: list[Module] = flatten(modules)
46
43
  if len(flat_modules) == 1:
47
44
  return flat_modules[0]
@@ -1,37 +1,103 @@
1
- from collections.abc import Sequence
2
- from typing import TYPE_CHECKING
1
+ from collections.abc import Mapping, Sequence, Iterable, Callable
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import torch
5
+
6
+ from .objective import Objective
3
7
 
4
8
  if TYPE_CHECKING:
5
9
  from .module import Module
6
- from .var import Var
10
+ from .transform import Transform
11
+
12
+
7
13
 
14
+ def update(
15
+ objective: "Objective",
16
+ module: "Transform",
17
+ states: list[dict[str, Any]] | None = None,
18
+ settings: Sequence[Mapping[str, Any]] | None = None,
19
+ ) -> None:
20
+ if states is None:
21
+ assert settings is None
22
+ module.update(objective)
8
23
 
9
- def step(var: "Var", modules: "Sequence[Module]",) -> "Var":
10
- """steps with ``modules`` and returns modified ``var``, doesn't update parameters.
24
+ else:
25
+ assert settings is not None
26
+ module.update_states(objective, states, settings)
11
27
 
12
- Args:
13
- var (Var): Var object.
14
- modules (Sequence[Module]): sequence of modules to step with.
28
+ def apply(
29
+ objective: "Objective",
30
+ module: "Transform",
31
+ states: list[dict[str, Any]] | None = None,
32
+ settings: Sequence[Mapping[str, Any]] | None = None,
33
+ ) -> "Objective":
34
+ if states is None:
35
+ assert settings is None
36
+ return module.apply(objective)
15
37
 
16
- Returns:
17
- Var: modified Var
18
- """
19
- # n_modules = len(modules)
20
- # if n_modules == 0: return var.clone(clone_update=False)
21
- # last_module = modules[-1]
22
- # last_lr = last_module.defaults.get('lr', None)
38
+ else:
39
+ assert settings is not None
40
+ return module.apply_states(objective, states, settings)
23
41
 
42
+ def _chain_step(objective: "Objective", modules: "Sequence[Module]"):
43
+ """steps with ``modules`` and returns updated objective, this is used within ``step`` and within ``Chain.step``"""
24
44
  # step
25
45
  for i, module in enumerate(modules):
26
- if i!=0: var = var.clone(clone_update=False)
46
+ if i!=0: objective = objective.clone(clone_updates=False)
47
+
48
+ objective = module.step(objective)
49
+ if objective.stop: break
50
+
51
+ return objective
52
+
53
+ def step(objective: "Objective", modules: "Module | Sequence[Module]"):
54
+ """doesn't apply hooks!"""
55
+ if not isinstance(modules, Sequence):
56
+ modules = (modules, )
57
+
58
+ if len(modules) == 0:
59
+ raise RuntimeError("`modules` is an empty sequence")
60
+
61
+ # if closure is None, assume backward has been called and gather grads
62
+ if objective.closure is None:
63
+ objective.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in objective.params]
64
+
65
+ # step and return
66
+ return _chain_step(objective, modules)
67
+
68
+
69
+ def step_tensors(
70
+ modules: "Module | Sequence[Module]",
71
+ tensors: Sequence[torch.Tensor],
72
+ params: Iterable[torch.Tensor] | None = None,
73
+ grads: Sequence[torch.Tensor] | None = None,
74
+ loss: torch.Tensor | None = None,
75
+ closure: Callable | None = None,
76
+ objective: "Objective | None" = None
77
+ ) -> list[torch.Tensor]:
78
+ if objective is not None:
79
+ if any(i is not None for i in (params, grads, loss, closure)):
80
+ raise RuntimeError("Specify either `objective` or `(params, grads, loss, closure)`")
81
+
82
+ if not isinstance(modules, Sequence):
83
+ modules = (modules, )
84
+
85
+ # make fake params if they are only used for shapes
86
+ if params is None:
87
+ params = [t.view_as(t).requires_grad_() for t in tensors]
88
+
89
+ # create objective
90
+ if objective is None:
91
+ objective = Objective(params=params, loss=loss, closure=closure)
92
+
93
+ if grads is not None:
94
+ objective.grads = list(grads)
27
95
 
28
- # last module, or next to last module before lr
29
- # if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
30
- # if len(module.children) != 0 or is_nested: var.nested_is_last = True
31
- # else: var.is_last = True
32
- # if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
96
+ objective.updates = list(tensors)
33
97
 
34
- var = module.step(var)
35
- if var.stop: break
98
+ # step with modules
99
+ # this won't update parameters in-place because objective.Optimizer is None
100
+ objective = _chain_step(objective, modules)
36
101
 
37
- return var
102
+ # return updates
103
+ return objective.get_updates()
torchzero/core/modular.py CHANGED
@@ -1,38 +1,27 @@
1
1
 
2
2
  import warnings
3
- from abc import ABC, abstractmethod
4
- from collections import ChainMap, defaultdict
5
- from collections.abc import Callable, Iterable, MutableMapping, Sequence
6
- from operator import itemgetter
7
- from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
3
+ from collections import ChainMap
4
+ from collections.abc import MutableMapping
5
+ from typing import Any
8
6
 
9
7
  import torch
10
8
 
11
- from ..utils import (
12
- Init,
13
- ListLike,
14
- Params,
15
- _make_param_groups,
16
- get_state_vals,
17
- vec_to_tensors,
18
- )
19
- from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
20
- from ..utils.linalg.linear_operator import LinearOperator
21
- from ..utils.python_tools import flatten
22
- from .module import Chainable, Module
23
- from .var import Var
9
+ from ..utils.params import Params, _make_param_groups
24
10
  from .functional import step
11
+ from .module import Chainable, Module
12
+ from .objective import Objective
13
+
25
14
 
26
15
  class _EvalCounterClosure:
27
16
  """keeps track of how many times closure has been evaluated, and sets closure return"""
28
17
  __slots__ = ("modular", "closure")
29
- def __init__(self, modular: "Modular", closure):
18
+ def __init__(self, modular: "Optimizer", closure):
30
19
  self.modular = modular
31
20
  self.closure = closure
32
21
 
33
22
  def __call__(self, *args, **kwargs):
34
23
  if self.closure is None:
35
- raise RuntimeError("One of the modules requires closure to be passed to the step method")
24
+ raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
36
25
 
37
26
  v = self.closure(*args, **kwargs)
38
27
 
@@ -44,22 +33,22 @@ class _EvalCounterClosure:
44
33
  return v
45
34
 
46
35
 
47
- def unroll_modules(*modules: Chainable) -> list[Module]:
48
- unrolled = []
36
+ def flatten_modules(*modules: Chainable) -> list[Module]:
37
+ flat = []
49
38
 
50
39
  for m in modules:
51
40
  if isinstance(m, Module):
52
- unrolled.append(m)
53
- unrolled.extend(unroll_modules(list(m.children.values())))
41
+ flat.append(m)
42
+ flat.extend(flatten_modules(list(m.children.values())))
54
43
  else:
55
- unrolled.extend(unroll_modules(*m))
44
+ flat.extend(flatten_modules(*m))
56
45
 
57
- return unrolled
46
+ return flat
58
47
 
59
48
 
60
- # have to inherit from Modular to support lr schedulers
49
+ # have to inherit from Optimizer to support lr schedulers
61
50
  # although Accelerate doesn't work due to converting param_groups to a dict
62
- class Modular(torch.optim.Optimizer):
51
+ class Optimizer(torch.optim.Optimizer):
63
52
  """Chains multiple modules into an optimizer.
64
53
 
65
54
  Args:
@@ -73,7 +62,7 @@ class Modular(torch.optim.Optimizer):
73
62
  param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
74
63
 
75
64
  def __init__(self, params: Params | torch.nn.Module, *modules: Module):
76
- if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
65
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Optimizer`")
77
66
  self.model: torch.nn.Module | None = None
78
67
  """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
79
68
  if isinstance(params, torch.nn.Module):
@@ -83,7 +72,7 @@ class Modular(torch.optim.Optimizer):
83
72
  self.modules = modules
84
73
  """Top-level modules providedduring initialization."""
85
74
 
86
- self.unrolled_modules = unroll_modules(self.modules)
75
+ self.flat_modules = flatten_modules(self.modules)
87
76
  """A flattened list of all modules including all children."""
88
77
 
89
78
  param_groups = _make_param_groups(params, differentiable=False)
@@ -92,7 +81,7 @@ class Modular(torch.optim.Optimizer):
92
81
  Each element in the list is ChainDict's 2nd map of a module."""
93
82
 
94
83
  # make sure there is no more than a single learning rate module
95
- lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
84
+ lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
96
85
  if len(lr_modules) > 1:
97
86
  warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
98
87
 
@@ -100,13 +89,13 @@ class Modular(torch.optim.Optimizer):
100
89
  for group in param_groups:
101
90
  for k in group:
102
91
  if k in ('params', 'lr'): continue
103
- modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
92
+ modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
104
93
  if len(modules_with_k) > 1:
105
94
  warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
106
95
 
107
96
  # defaults for schedulers
108
97
  defaults = {}
109
- for m in self.unrolled_modules: defaults.update(m.defaults)
98
+ for m in self.flat_modules: defaults.update(m.defaults)
110
99
  super().__init__(param_groups, defaults=defaults)
111
100
 
112
101
  # note - this is what super().__init__(param_groups, defaults=defaults) does:
@@ -146,7 +135,7 @@ class Modular(torch.optim.Optimizer):
146
135
 
147
136
  for p in proc_param_group['params']:
148
137
  # updates global per-parameter setting overrides (medium priority)
149
- self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
138
+ self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
150
139
 
151
140
  def state_dict(self):
152
141
  all_params = [p for g in self.param_groups for p in g['params']]
@@ -163,7 +152,7 @@ class Modular(torch.optim.Optimizer):
163
152
  "params": all_params,
164
153
  "groups": groups,
165
154
  "defaults": self.defaults,
166
- "modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
155
+ "modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
167
156
  }
168
157
  return state_dict
169
158
 
@@ -183,7 +172,7 @@ class Modular(torch.optim.Optimizer):
183
172
  self.add_param_group(group)
184
173
 
185
174
  id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
186
- for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
175
+ for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
187
176
  m._load_state_dict(sd, id_to_tensor)
188
177
 
189
178
 
@@ -201,37 +190,44 @@ class Modular(torch.optim.Optimizer):
201
190
  if not p.requires_grad: continue
202
191
  for map in self._per_parameter_global_settings[p]: map.update(settings)
203
192
 
204
- # create var
193
+ # create Objective
205
194
  params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
206
- var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
207
195
 
208
- # if closure is None, assume backward has been called and gather grads
209
- if closure is None:
210
- var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
211
- self.num_evaluations += 1
196
+ counter_closure = None
197
+ if closure is not None:
198
+ counter_closure = _EvalCounterClosure(self, closure)
212
199
 
213
- if len(self.modules) == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
200
+ objective = Objective(
201
+ params=params, closure=counter_closure, model=self.model,
202
+ current_step=self.current_step, modular=self, loss=loss, storage=kwargs
203
+ )
214
204
 
215
- # step
216
- var = step(var, self.modules)
205
+ # step with all modules
206
+ objective = step(objective, self.modules)
217
207
 
218
- # apply update
219
- if not var.skip_update:
220
- with torch.no_grad():
221
- torch._foreach_sub_(params, var.get_update())
208
+ # apply update to parameters unless `objective.skip_update = True`
209
+ # this does:
210
+ # if not objective.skip_update:
211
+ # torch._foreach_sub_(objective.params, objective.get_updates())
212
+ objective.update_parameters()
222
213
 
223
214
  # update attributes
224
- self.attrs.update(var.attrs)
225
- if var.should_terminate is not None: self.should_terminate = var.should_terminate
226
-
227
- # hooks
228
- for hook in var.post_step_hooks:
229
- hook(self, var)
215
+ self.attrs.update(objective.attrs)
216
+ if objective.should_terminate is not None:
217
+ self.should_terminate = objective.should_terminate
230
218
 
231
219
  self.current_step += 1
232
- #return var.loss if var.loss is not None else var.loss_approx
220
+
221
+ # apply hooks
222
+ # this does:
223
+ # for hook in objective.post_step_hooks:
224
+ # hook(objective, modules)
225
+ objective.apply_post_step_hooks(self.modules)
226
+
227
+ # return the first closure evaluation return
228
+ # could return loss if it was passed but that's pointless
233
229
  return self._closure_return
234
230
 
235
231
  def __repr__(self):
236
- return f'Modular({", ".join(str(m) for m in self.modules)})'
232
+ return f'Optimizer({", ".join(str(m) for m in self.modules)})'
237
233