torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
tests/test_tensorlist.py
CHANGED
|
@@ -1567,13 +1567,6 @@ def test_where(simple_tl: TensorList):
|
|
|
1567
1567
|
assert_tl_allclose(result_module, expected_tl)
|
|
1568
1568
|
|
|
1569
1569
|
|
|
1570
|
-
# Test inplace where_ (needs TensorList other)
|
|
1571
|
-
tl_copy = simple_tl.clone()
|
|
1572
|
-
result_inplace = tl_copy.where_(condition_tl, other_tl)
|
|
1573
|
-
assert result_inplace is tl_copy
|
|
1574
|
-
assert_tl_allclose(tl_copy, expected_tl)
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
1570
|
def test_masked_fill(simple_tl: TensorList):
|
|
1578
1571
|
mask_tl = simple_tl.lt(0)
|
|
1579
1572
|
fill_value_scalar = 99.0
|
|
@@ -1600,7 +1593,6 @@ def test_select_set_(simple_tl: TensorList):
|
|
|
1600
1593
|
mask_tl = simple_tl.gt(0.5)
|
|
1601
1594
|
value_scalar = -1.0
|
|
1602
1595
|
value_list_scalar = [-1.0, -2.0, -3.0]
|
|
1603
|
-
value_tl = simple_tl.clone().mul_(0.1)
|
|
1604
1596
|
|
|
1605
1597
|
# Set with scalar value
|
|
1606
1598
|
tl_copy_scalar = simple_tl.clone()
|
tests/test_utils_optimizer.py
CHANGED
torchzero/__init__.py
CHANGED
torchzero/core/__init__.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .modular import Modular
|
|
1
|
+
from .transform import TensorTransform, Transform
|
|
3
2
|
from .module import Chainable, Module
|
|
4
|
-
from .
|
|
5
|
-
|
|
3
|
+
from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
|
|
4
|
+
|
|
5
|
+
# order is important to avoid circular imports
|
|
6
|
+
from .modular import Optimizer
|
|
7
|
+
from .functional import apply, step, step_tensors, update
|
|
8
|
+
from .chain import Chain, maybe_chain
|
torchzero/core/chain.py
CHANGED
|
@@ -2,36 +2,33 @@ from collections.abc import Iterable
|
|
|
2
2
|
|
|
3
3
|
from ..utils.python_tools import flatten
|
|
4
4
|
from .module import Module, Chainable
|
|
5
|
-
|
|
5
|
+
from .functional import _chain_step
|
|
6
6
|
|
|
7
7
|
class Chain(Module):
|
|
8
|
-
"""Chain
|
|
8
|
+
"""Chain modules, mostly used internally"""
|
|
9
9
|
def __init__(self, *modules: Module | Iterable[Module]):
|
|
10
10
|
super().__init__()
|
|
11
11
|
flat_modules: list[Module] = flatten(modules)
|
|
12
12
|
for i, module in enumerate(flat_modules):
|
|
13
13
|
self.set_child(f'module_{i}', module)
|
|
14
14
|
|
|
15
|
-
def update(self,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
var = self.children[f'module_{i}'].step(var)
|
|
33
|
-
if var.stop: break
|
|
34
|
-
return var
|
|
15
|
+
def update(self, objective):
|
|
16
|
+
if len(self.children) > 1:
|
|
17
|
+
raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
|
|
18
|
+
|
|
19
|
+
if len(self.children) == 0: return
|
|
20
|
+
return self.children['module_0'].update(objective)
|
|
21
|
+
|
|
22
|
+
def apply(self, objective):
|
|
23
|
+
if len(self.children) > 1:
|
|
24
|
+
raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
|
|
25
|
+
|
|
26
|
+
if len(self.children) == 0: return objective
|
|
27
|
+
return self.children['module_0'].apply(objective)
|
|
28
|
+
|
|
29
|
+
def step(self, objective):
|
|
30
|
+
children = [self.children[f'module_{i}'] for i in range(len(self.children))]
|
|
31
|
+
return _chain_step(objective, children)
|
|
35
32
|
|
|
36
33
|
def __repr__(self):
|
|
37
34
|
s = self.__class__.__name__
|
|
@@ -41,7 +38,7 @@ class Chain(Module):
|
|
|
41
38
|
return s
|
|
42
39
|
|
|
43
40
|
def maybe_chain(*modules: Chainable) -> Module:
|
|
44
|
-
"""Returns a single module directly if only one is provided, otherwise wraps them in a
|
|
41
|
+
"""Returns a single module directly if only one is provided, otherwise wraps them in a ``Chain``."""
|
|
45
42
|
flat_modules: list[Module] = flatten(modules)
|
|
46
43
|
if len(flat_modules) == 1:
|
|
47
44
|
return flat_modules[0]
|
torchzero/core/functional.py
CHANGED
|
@@ -1,37 +1,103 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
1
|
+
from collections.abc import Mapping, Sequence, Iterable, Callable
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .objective import Objective
|
|
3
7
|
|
|
4
8
|
if TYPE_CHECKING:
|
|
5
9
|
from .module import Module
|
|
6
|
-
from .
|
|
10
|
+
from .transform import Transform
|
|
11
|
+
|
|
12
|
+
|
|
7
13
|
|
|
14
|
+
def update(
|
|
15
|
+
objective: "Objective",
|
|
16
|
+
module: "Transform",
|
|
17
|
+
states: list[dict[str, Any]] | None = None,
|
|
18
|
+
settings: Sequence[Mapping[str, Any]] | None = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
if states is None:
|
|
21
|
+
assert settings is None
|
|
22
|
+
module.update(objective)
|
|
8
23
|
|
|
9
|
-
|
|
10
|
-
|
|
24
|
+
else:
|
|
25
|
+
assert settings is not None
|
|
26
|
+
module.update_states(objective, states, settings)
|
|
11
27
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
28
|
+
def apply(
|
|
29
|
+
objective: "Objective",
|
|
30
|
+
module: "Transform",
|
|
31
|
+
states: list[dict[str, Any]] | None = None,
|
|
32
|
+
settings: Sequence[Mapping[str, Any]] | None = None,
|
|
33
|
+
) -> "Objective":
|
|
34
|
+
if states is None:
|
|
35
|
+
assert settings is None
|
|
36
|
+
return module.apply(objective)
|
|
15
37
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
# n_modules = len(modules)
|
|
20
|
-
# if n_modules == 0: return var.clone(clone_update=False)
|
|
21
|
-
# last_module = modules[-1]
|
|
22
|
-
# last_lr = last_module.defaults.get('lr', None)
|
|
38
|
+
else:
|
|
39
|
+
assert settings is not None
|
|
40
|
+
return module.apply_states(objective, states, settings)
|
|
23
41
|
|
|
42
|
+
def _chain_step(objective: "Objective", modules: "Sequence[Module]"):
|
|
43
|
+
"""steps with ``modules`` and returns updated objective, this is used within ``step`` and within ``Chain.step``"""
|
|
24
44
|
# step
|
|
25
45
|
for i, module in enumerate(modules):
|
|
26
|
-
if i!=0:
|
|
46
|
+
if i!=0: objective = objective.clone(clone_updates=False)
|
|
47
|
+
|
|
48
|
+
objective = module.step(objective)
|
|
49
|
+
if objective.stop: break
|
|
50
|
+
|
|
51
|
+
return objective
|
|
52
|
+
|
|
53
|
+
def step(objective: "Objective", modules: "Module | Sequence[Module]"):
|
|
54
|
+
"""doesn't apply hooks!"""
|
|
55
|
+
if not isinstance(modules, Sequence):
|
|
56
|
+
modules = (modules, )
|
|
57
|
+
|
|
58
|
+
if len(modules) == 0:
|
|
59
|
+
raise RuntimeError("`modules` is an empty sequence")
|
|
60
|
+
|
|
61
|
+
# if closure is None, assume backward has been called and gather grads
|
|
62
|
+
if objective.closure is None:
|
|
63
|
+
objective.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in objective.params]
|
|
64
|
+
|
|
65
|
+
# step and return
|
|
66
|
+
return _chain_step(objective, modules)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def step_tensors(
|
|
70
|
+
modules: "Module | Sequence[Module]",
|
|
71
|
+
tensors: Sequence[torch.Tensor],
|
|
72
|
+
params: Iterable[torch.Tensor] | None = None,
|
|
73
|
+
grads: Sequence[torch.Tensor] | None = None,
|
|
74
|
+
loss: torch.Tensor | None = None,
|
|
75
|
+
closure: Callable | None = None,
|
|
76
|
+
objective: "Objective | None" = None
|
|
77
|
+
) -> list[torch.Tensor]:
|
|
78
|
+
if objective is not None:
|
|
79
|
+
if any(i is not None for i in (params, grads, loss, closure)):
|
|
80
|
+
raise RuntimeError("Specify either `objective` or `(params, grads, loss, closure)`")
|
|
81
|
+
|
|
82
|
+
if not isinstance(modules, Sequence):
|
|
83
|
+
modules = (modules, )
|
|
84
|
+
|
|
85
|
+
# make fake params if they are only used for shapes
|
|
86
|
+
if params is None:
|
|
87
|
+
params = [t.view_as(t).requires_grad_() for t in tensors]
|
|
88
|
+
|
|
89
|
+
# create objective
|
|
90
|
+
if objective is None:
|
|
91
|
+
objective = Objective(params=params, loss=loss, closure=closure)
|
|
92
|
+
|
|
93
|
+
if grads is not None:
|
|
94
|
+
objective.grads = list(grads)
|
|
27
95
|
|
|
28
|
-
|
|
29
|
-
# if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
|
|
30
|
-
# if len(module.children) != 0 or is_nested: var.nested_is_last = True
|
|
31
|
-
# else: var.is_last = True
|
|
32
|
-
# if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
|
|
96
|
+
objective.updates = list(tensors)
|
|
33
97
|
|
|
34
|
-
|
|
35
|
-
|
|
98
|
+
# step with modules
|
|
99
|
+
# this won't update parameters in-place because objective.Optimizer is None
|
|
100
|
+
objective = _chain_step(objective, modules)
|
|
36
101
|
|
|
37
|
-
return
|
|
102
|
+
# return updates
|
|
103
|
+
return objective.get_updates()
|
torchzero/core/modular.py
CHANGED
|
@@ -1,38 +1,27 @@
|
|
|
1
1
|
|
|
2
2
|
import warnings
|
|
3
|
-
from
|
|
4
|
-
from collections import
|
|
5
|
-
from
|
|
6
|
-
from operator import itemgetter
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Literal, cast, final, overload
|
|
3
|
+
from collections import ChainMap
|
|
4
|
+
from collections.abc import MutableMapping
|
|
5
|
+
from typing import Any
|
|
8
6
|
|
|
9
7
|
import torch
|
|
10
8
|
|
|
11
|
-
from ..utils import
|
|
12
|
-
Init,
|
|
13
|
-
ListLike,
|
|
14
|
-
Params,
|
|
15
|
-
_make_param_groups,
|
|
16
|
-
get_state_vals,
|
|
17
|
-
vec_to_tensors,
|
|
18
|
-
)
|
|
19
|
-
from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
|
|
20
|
-
from ..utils.linalg.linear_operator import LinearOperator
|
|
21
|
-
from ..utils.python_tools import flatten
|
|
22
|
-
from .module import Chainable, Module
|
|
23
|
-
from .var import Var
|
|
9
|
+
from ..utils.params import Params, _make_param_groups
|
|
24
10
|
from .functional import step
|
|
11
|
+
from .module import Chainable, Module
|
|
12
|
+
from .objective import Objective
|
|
13
|
+
|
|
25
14
|
|
|
26
15
|
class _EvalCounterClosure:
|
|
27
16
|
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
28
17
|
__slots__ = ("modular", "closure")
|
|
29
|
-
def __init__(self, modular: "
|
|
18
|
+
def __init__(self, modular: "Optimizer", closure):
|
|
30
19
|
self.modular = modular
|
|
31
20
|
self.closure = closure
|
|
32
21
|
|
|
33
22
|
def __call__(self, *args, **kwargs):
|
|
34
23
|
if self.closure is None:
|
|
35
|
-
raise RuntimeError("
|
|
24
|
+
raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
|
|
36
25
|
|
|
37
26
|
v = self.closure(*args, **kwargs)
|
|
38
27
|
|
|
@@ -44,22 +33,22 @@ class _EvalCounterClosure:
|
|
|
44
33
|
return v
|
|
45
34
|
|
|
46
35
|
|
|
47
|
-
def
|
|
48
|
-
|
|
36
|
+
def flatten_modules(*modules: Chainable) -> list[Module]:
|
|
37
|
+
flat = []
|
|
49
38
|
|
|
50
39
|
for m in modules:
|
|
51
40
|
if isinstance(m, Module):
|
|
52
|
-
|
|
53
|
-
|
|
41
|
+
flat.append(m)
|
|
42
|
+
flat.extend(flatten_modules(list(m.children.values())))
|
|
54
43
|
else:
|
|
55
|
-
|
|
44
|
+
flat.extend(flatten_modules(*m))
|
|
56
45
|
|
|
57
|
-
return
|
|
46
|
+
return flat
|
|
58
47
|
|
|
59
48
|
|
|
60
|
-
# have to inherit from
|
|
49
|
+
# have to inherit from Optimizer to support lr schedulers
|
|
61
50
|
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
62
|
-
class
|
|
51
|
+
class Optimizer(torch.optim.Optimizer):
|
|
63
52
|
"""Chains multiple modules into an optimizer.
|
|
64
53
|
|
|
65
54
|
Args:
|
|
@@ -73,7 +62,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
73
62
|
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
74
63
|
|
|
75
64
|
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
76
|
-
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `
|
|
65
|
+
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Optimizer`")
|
|
77
66
|
self.model: torch.nn.Module | None = None
|
|
78
67
|
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
79
68
|
if isinstance(params, torch.nn.Module):
|
|
@@ -83,7 +72,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
83
72
|
self.modules = modules
|
|
84
73
|
"""Top-level modules providedduring initialization."""
|
|
85
74
|
|
|
86
|
-
self.
|
|
75
|
+
self.flat_modules = flatten_modules(self.modules)
|
|
87
76
|
"""A flattened list of all modules including all children."""
|
|
88
77
|
|
|
89
78
|
param_groups = _make_param_groups(params, differentiable=False)
|
|
@@ -92,7 +81,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
92
81
|
Each element in the list is ChainDict's 2nd map of a module."""
|
|
93
82
|
|
|
94
83
|
# make sure there is no more than a single learning rate module
|
|
95
|
-
lr_modules = [m for m in self.
|
|
84
|
+
lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
|
|
96
85
|
if len(lr_modules) > 1:
|
|
97
86
|
warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
|
|
98
87
|
|
|
@@ -100,13 +89,13 @@ class Modular(torch.optim.Optimizer):
|
|
|
100
89
|
for group in param_groups:
|
|
101
90
|
for k in group:
|
|
102
91
|
if k in ('params', 'lr'): continue
|
|
103
|
-
modules_with_k = [m for m in self.
|
|
92
|
+
modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
|
|
104
93
|
if len(modules_with_k) > 1:
|
|
105
94
|
warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
|
|
106
95
|
|
|
107
96
|
# defaults for schedulers
|
|
108
97
|
defaults = {}
|
|
109
|
-
for m in self.
|
|
98
|
+
for m in self.flat_modules: defaults.update(m.defaults)
|
|
110
99
|
super().__init__(param_groups, defaults=defaults)
|
|
111
100
|
|
|
112
101
|
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
@@ -146,7 +135,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
146
135
|
|
|
147
136
|
for p in proc_param_group['params']:
|
|
148
137
|
# updates global per-parameter setting overrides (medium priority)
|
|
149
|
-
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.
|
|
138
|
+
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
|
|
150
139
|
|
|
151
140
|
def state_dict(self):
|
|
152
141
|
all_params = [p for g in self.param_groups for p in g['params']]
|
|
@@ -163,7 +152,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
163
152
|
"params": all_params,
|
|
164
153
|
"groups": groups,
|
|
165
154
|
"defaults": self.defaults,
|
|
166
|
-
"modules": {i: m.state_dict() for i, m in enumerate(self.
|
|
155
|
+
"modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
|
|
167
156
|
}
|
|
168
157
|
return state_dict
|
|
169
158
|
|
|
@@ -183,7 +172,7 @@ class Modular(torch.optim.Optimizer):
|
|
|
183
172
|
self.add_param_group(group)
|
|
184
173
|
|
|
185
174
|
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
186
|
-
for m, sd in zip(self.
|
|
175
|
+
for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
|
|
187
176
|
m._load_state_dict(sd, id_to_tensor)
|
|
188
177
|
|
|
189
178
|
|
|
@@ -201,37 +190,44 @@ class Modular(torch.optim.Optimizer):
|
|
|
201
190
|
if not p.requires_grad: continue
|
|
202
191
|
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
203
192
|
|
|
204
|
-
# create
|
|
193
|
+
# create Objective
|
|
205
194
|
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
206
|
-
var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
|
|
207
195
|
|
|
208
|
-
|
|
209
|
-
if closure is None:
|
|
210
|
-
|
|
211
|
-
self.num_evaluations += 1
|
|
196
|
+
counter_closure = None
|
|
197
|
+
if closure is not None:
|
|
198
|
+
counter_closure = _EvalCounterClosure(self, closure)
|
|
212
199
|
|
|
213
|
-
|
|
200
|
+
objective = Objective(
|
|
201
|
+
params=params, closure=counter_closure, model=self.model,
|
|
202
|
+
current_step=self.current_step, modular=self, loss=loss, storage=kwargs
|
|
203
|
+
)
|
|
214
204
|
|
|
215
|
-
# step
|
|
216
|
-
|
|
205
|
+
# step with all modules
|
|
206
|
+
objective = step(objective, self.modules)
|
|
217
207
|
|
|
218
|
-
# apply update
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
208
|
+
# apply update to parameters unless `objective.skip_update = True`
|
|
209
|
+
# this does:
|
|
210
|
+
# if not objective.skip_update:
|
|
211
|
+
# torch._foreach_sub_(objective.params, objective.get_updates())
|
|
212
|
+
objective.update_parameters()
|
|
222
213
|
|
|
223
214
|
# update attributes
|
|
224
|
-
self.attrs.update(
|
|
225
|
-
if
|
|
226
|
-
|
|
227
|
-
# hooks
|
|
228
|
-
for hook in var.post_step_hooks:
|
|
229
|
-
hook(self, var)
|
|
215
|
+
self.attrs.update(objective.attrs)
|
|
216
|
+
if objective.should_terminate is not None:
|
|
217
|
+
self.should_terminate = objective.should_terminate
|
|
230
218
|
|
|
231
219
|
self.current_step += 1
|
|
232
|
-
|
|
220
|
+
|
|
221
|
+
# apply hooks
|
|
222
|
+
# this does:
|
|
223
|
+
# for hook in objective.post_step_hooks:
|
|
224
|
+
# hook(objective, modules)
|
|
225
|
+
objective.apply_post_step_hooks(self.modules)
|
|
226
|
+
|
|
227
|
+
# return the first closure evaluation return
|
|
228
|
+
# could return loss if it was passed but that's pointless
|
|
233
229
|
return self._closure_return
|
|
234
230
|
|
|
235
231
|
def __repr__(self):
|
|
236
|
-
return f'
|
|
232
|
+
return f'Optimizer({", ".join(str(m) for m in self.modules)})'
|
|
237
233
|
|