torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
|
|
2
|
+
import warnings
|
|
3
|
+
from collections import ChainMap
|
|
4
|
+
from collections.abc import MutableMapping
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..utils.params import Params, _make_param_groups
|
|
10
|
+
from .functional import step
|
|
11
|
+
from .module import Chainable, Module
|
|
12
|
+
from .objective import Objective
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class _EvalCounterClosure:
|
|
16
|
+
"""keeps track of how many times closure has been evaluated, and sets closure return"""
|
|
17
|
+
__slots__ = ("modular", "closure")
|
|
18
|
+
def __init__(self, modular: "Modular", closure):
|
|
19
|
+
self.modular = modular
|
|
20
|
+
self.closure = closure
|
|
21
|
+
|
|
22
|
+
def __call__(self, *args, **kwargs):
|
|
23
|
+
if self.closure is None:
|
|
24
|
+
raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
|
|
25
|
+
|
|
26
|
+
v = self.closure(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
# set closure return on 1st evaluation
|
|
29
|
+
if self.modular._closure_return is None:
|
|
30
|
+
self.modular._closure_return = v
|
|
31
|
+
|
|
32
|
+
self.modular.num_evaluations += 1
|
|
33
|
+
return v
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def flatten_modules(*modules: Chainable) -> list[Module]:
|
|
37
|
+
flat = []
|
|
38
|
+
|
|
39
|
+
for m in modules:
|
|
40
|
+
if isinstance(m, Module):
|
|
41
|
+
flat.append(m)
|
|
42
|
+
flat.extend(flatten_modules(list(m.children.values())))
|
|
43
|
+
else:
|
|
44
|
+
flat.extend(flatten_modules(*m))
|
|
45
|
+
|
|
46
|
+
return flat
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# have to inherit from Modular to support lr schedulers
|
|
50
|
+
# although Accelerate doesn't work due to converting param_groups to a dict
|
|
51
|
+
class Modular(torch.optim.Optimizer):
|
|
52
|
+
"""Chains multiple modules into an optimizer.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
params (Params | torch.nn.Module): An iterable of parameters to optimize
|
|
56
|
+
(typically `model.parameters()`), an iterable of parameter group dicts,
|
|
57
|
+
or a `torch.nn.Module` instance.
|
|
58
|
+
*modules (Module): A sequence of `Module` instances that define the
|
|
59
|
+
optimization algorithm steps.
|
|
60
|
+
"""
|
|
61
|
+
# this is specifically for lr schedulers
|
|
62
|
+
param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
|
|
63
|
+
|
|
64
|
+
def __init__(self, params: Params | torch.nn.Module, *modules: Module):
|
|
65
|
+
if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
|
|
66
|
+
self.model: torch.nn.Module | None = None
|
|
67
|
+
"""The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
|
|
68
|
+
if isinstance(params, torch.nn.Module):
|
|
69
|
+
self.model = params
|
|
70
|
+
params = params.parameters()
|
|
71
|
+
|
|
72
|
+
self.modules = modules
|
|
73
|
+
"""Top-level modules providedduring initialization."""
|
|
74
|
+
|
|
75
|
+
self.flat_modules = flatten_modules(self.modules)
|
|
76
|
+
"""A flattened list of all modules including all children."""
|
|
77
|
+
|
|
78
|
+
param_groups = _make_param_groups(params, differentiable=False)
|
|
79
|
+
self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
|
|
80
|
+
"""Maps each parameter tensor to a list of per-module global settings.
|
|
81
|
+
Each element in the list is ChainDict's 2nd map of a module."""
|
|
82
|
+
|
|
83
|
+
# make sure there is no more than a single learning rate module
|
|
84
|
+
lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
|
|
85
|
+
if len(lr_modules) > 1:
|
|
86
|
+
warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
|
|
87
|
+
|
|
88
|
+
# iterate over all per-parameter settings overrides and check if they are applied at most once
|
|
89
|
+
for group in param_groups:
|
|
90
|
+
for k in group:
|
|
91
|
+
if k in ('params', 'lr'): continue
|
|
92
|
+
modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
|
|
93
|
+
if len(modules_with_k) > 1:
|
|
94
|
+
warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
|
|
95
|
+
|
|
96
|
+
# defaults for schedulers
|
|
97
|
+
defaults = {}
|
|
98
|
+
for m in self.flat_modules: defaults.update(m.defaults)
|
|
99
|
+
super().__init__(param_groups, defaults=defaults)
|
|
100
|
+
|
|
101
|
+
# note - this is what super().__init__(param_groups, defaults=defaults) does:
|
|
102
|
+
|
|
103
|
+
# self.defaults = defaults
|
|
104
|
+
# for param_group in param_groups:
|
|
105
|
+
# self.add_param_group(param_group)
|
|
106
|
+
|
|
107
|
+
# add_param_group adds a ChainMap where defaults are lowest priority,
|
|
108
|
+
# and entries specifed in param_groups or scheduler are higher priority.
|
|
109
|
+
# pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
|
|
110
|
+
# in each module, settings passed to that module by calling set_param_groups are highest priority
|
|
111
|
+
|
|
112
|
+
self.current_step = 0
|
|
113
|
+
"""global step counter for the optimizer."""
|
|
114
|
+
|
|
115
|
+
self.num_evaluations = 0
|
|
116
|
+
"""number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
|
|
117
|
+
|
|
118
|
+
# reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
|
|
119
|
+
# we want to return original loss so this attribute is used
|
|
120
|
+
self._closure_return = None
|
|
121
|
+
"""on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
|
|
122
|
+
|
|
123
|
+
self.attrs = {}
|
|
124
|
+
"""custom attributes that can be set by modules, for example EMA of weights or best so far"""
|
|
125
|
+
|
|
126
|
+
self.should_terminate = False
|
|
127
|
+
"""is set to True by termination criteria modules."""
|
|
128
|
+
|
|
129
|
+
def add_param_group(self, param_group: dict[str, Any]):
|
|
130
|
+
proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
|
|
131
|
+
self.param_groups.append(ChainMap(proc_param_group, self.defaults))
|
|
132
|
+
# setting param_group[key] = value sets it to first map (the `proc_param_group`).
|
|
133
|
+
# therefore lr schedulers override defaults, but not settings passed to individual modules
|
|
134
|
+
# by `set_param_groups` .
|
|
135
|
+
|
|
136
|
+
for p in proc_param_group['params']:
|
|
137
|
+
# updates global per-parameter setting overrides (medium priority)
|
|
138
|
+
self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
|
|
139
|
+
|
|
140
|
+
def state_dict(self):
|
|
141
|
+
all_params = [p for g in self.param_groups for p in g['params']]
|
|
142
|
+
id_to_idx = {id(p): i for i,p in enumerate(all_params)}
|
|
143
|
+
|
|
144
|
+
groups = []
|
|
145
|
+
for g in self.param_groups:
|
|
146
|
+
g = g.copy()
|
|
147
|
+
g['params'] = [id_to_idx[id(p)] for p in g['params']]
|
|
148
|
+
groups.append(g)
|
|
149
|
+
|
|
150
|
+
state_dict = {
|
|
151
|
+
"idx_to_id": {v:k for k,v in id_to_idx.items()},
|
|
152
|
+
"params": all_params,
|
|
153
|
+
"groups": groups,
|
|
154
|
+
"defaults": self.defaults,
|
|
155
|
+
"modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
|
|
156
|
+
}
|
|
157
|
+
return state_dict
|
|
158
|
+
|
|
159
|
+
def load_state_dict(self, state_dict: dict):
|
|
160
|
+
self.defaults.clear()
|
|
161
|
+
self.defaults.update(state_dict['defaults'])
|
|
162
|
+
|
|
163
|
+
idx_to_param = dict(enumerate(state_dict['params']))
|
|
164
|
+
groups = []
|
|
165
|
+
for g in state_dict['groups']:
|
|
166
|
+
g = g.copy()
|
|
167
|
+
g['params'] = [idx_to_param[p] for p in g['params']]
|
|
168
|
+
groups.append(g)
|
|
169
|
+
|
|
170
|
+
self.param_groups.clear()
|
|
171
|
+
for group in groups:
|
|
172
|
+
self.add_param_group(group)
|
|
173
|
+
|
|
174
|
+
id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
|
|
175
|
+
for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
|
|
176
|
+
m._load_state_dict(sd, id_to_tensor)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
|
|
180
|
+
# clear closure return from previous step
|
|
181
|
+
self._closure_return = None
|
|
182
|
+
|
|
183
|
+
# propagate global per-parameter setting overrides
|
|
184
|
+
for g in self.param_groups:
|
|
185
|
+
settings = dict(g.maps[0]) # ignore defaults
|
|
186
|
+
params = settings.pop('params')
|
|
187
|
+
if not settings: continue
|
|
188
|
+
|
|
189
|
+
for p in params:
|
|
190
|
+
if not p.requires_grad: continue
|
|
191
|
+
for map in self._per_parameter_global_settings[p]: map.update(settings)
|
|
192
|
+
|
|
193
|
+
# create Objective
|
|
194
|
+
params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
|
|
195
|
+
|
|
196
|
+
counter_closure = None
|
|
197
|
+
if closure is not None:
|
|
198
|
+
counter_closure = _EvalCounterClosure(self, closure)
|
|
199
|
+
|
|
200
|
+
objective = Objective(
|
|
201
|
+
params=params, closure=counter_closure, model=self.model,
|
|
202
|
+
current_step=self.current_step, modular=self, loss=loss, storage=kwargs
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# step with all modules
|
|
206
|
+
objective = step(objective, self.modules)
|
|
207
|
+
|
|
208
|
+
# apply update to parameters unless `objective.skip_update = True`
|
|
209
|
+
# this does:
|
|
210
|
+
# if not objective.skip_update:
|
|
211
|
+
# torch._foreach_sub_(objective.params, objective.get_updates())
|
|
212
|
+
objective.update_parameters()
|
|
213
|
+
|
|
214
|
+
# update attributes
|
|
215
|
+
self.attrs.update(objective.attrs)
|
|
216
|
+
if objective.should_terminate is not None:
|
|
217
|
+
self.should_terminate = objective.should_terminate
|
|
218
|
+
|
|
219
|
+
self.current_step += 1
|
|
220
|
+
|
|
221
|
+
# apply hooks
|
|
222
|
+
# this does:
|
|
223
|
+
# for hook in objective.post_step_hooks:
|
|
224
|
+
# hook(objective, modules)
|
|
225
|
+
objective.apply_post_step_hooks(self.modules)
|
|
226
|
+
|
|
227
|
+
# return the first closure evaluation return
|
|
228
|
+
# could return loss if it was passed but that's pointless
|
|
229
|
+
return self._closure_return
|
|
230
|
+
|
|
231
|
+
def __repr__(self):
|
|
232
|
+
return f'Modular({", ".join(str(m) for m in self.modules)})'
|
|
233
|
+
|