torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,233 @@
1
+
2
+ import warnings
3
+ from collections import ChainMap
4
+ from collections.abc import MutableMapping
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from ..utils.params import Params, _make_param_groups
10
+ from .functional import step
11
+ from .module import Chainable, Module
12
+ from .objective import Objective
13
+
14
+
15
+ class _EvalCounterClosure:
16
+ """keeps track of how many times closure has been evaluated, and sets closure return"""
17
+ __slots__ = ("modular", "closure")
18
+ def __init__(self, modular: "Modular", closure):
19
+ self.modular = modular
20
+ self.closure = closure
21
+
22
+ def __call__(self, *args, **kwargs):
23
+ if self.closure is None:
24
+ raise RuntimeError("closure is None in _EvalCounterClosure, and this can't happen")
25
+
26
+ v = self.closure(*args, **kwargs)
27
+
28
+ # set closure return on 1st evaluation
29
+ if self.modular._closure_return is None:
30
+ self.modular._closure_return = v
31
+
32
+ self.modular.num_evaluations += 1
33
+ return v
34
+
35
+
36
+ def flatten_modules(*modules: Chainable) -> list[Module]:
37
+ flat = []
38
+
39
+ for m in modules:
40
+ if isinstance(m, Module):
41
+ flat.append(m)
42
+ flat.extend(flatten_modules(list(m.children.values())))
43
+ else:
44
+ flat.extend(flatten_modules(*m))
45
+
46
+ return flat
47
+
48
+
49
+ # have to inherit from Modular to support lr schedulers
50
+ # although Accelerate doesn't work due to converting param_groups to a dict
51
+ class Modular(torch.optim.Optimizer):
52
+ """Chains multiple modules into an optimizer.
53
+
54
+ Args:
55
+ params (Params | torch.nn.Module): An iterable of parameters to optimize
56
+ (typically `model.parameters()`), an iterable of parameter group dicts,
57
+ or a `torch.nn.Module` instance.
58
+ *modules (Module): A sequence of `Module` instances that define the
59
+ optimization algorithm steps.
60
+ """
61
+ # this is specifically for lr schedulers
62
+ param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
63
+
64
+ def __init__(self, params: Params | torch.nn.Module, *modules: Module):
65
+ if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
66
+ self.model: torch.nn.Module | None = None
67
+ """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
68
+ if isinstance(params, torch.nn.Module):
69
+ self.model = params
70
+ params = params.parameters()
71
+
72
+ self.modules = modules
73
+ """Top-level modules providedduring initialization."""
74
+
75
+ self.flat_modules = flatten_modules(self.modules)
76
+ """A flattened list of all modules including all children."""
77
+
78
+ param_groups = _make_param_groups(params, differentiable=False)
79
+ self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
80
+ """Maps each parameter tensor to a list of per-module global settings.
81
+ Each element in the list is ChainDict's 2nd map of a module."""
82
+
83
+ # make sure there is no more than a single learning rate module
84
+ lr_modules = [m for m in self.flat_modules if 'lr' in m.defaults]
85
+ if len(lr_modules) > 1:
86
+ warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
87
+
88
+ # iterate over all per-parameter settings overrides and check if they are applied at most once
89
+ for group in param_groups:
90
+ for k in group:
91
+ if k in ('params', 'lr'): continue
92
+ modules_with_k = [m for m in self.flat_modules if k in m.defaults and k not in m._overridden_keys]
93
+ if len(modules_with_k) > 1:
94
+ warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
95
+
96
+ # defaults for schedulers
97
+ defaults = {}
98
+ for m in self.flat_modules: defaults.update(m.defaults)
99
+ super().__init__(param_groups, defaults=defaults)
100
+
101
+ # note - this is what super().__init__(param_groups, defaults=defaults) does:
102
+
103
+ # self.defaults = defaults
104
+ # for param_group in param_groups:
105
+ # self.add_param_group(param_group)
106
+
107
+ # add_param_group adds a ChainMap where defaults are lowest priority,
108
+ # and entries specifed in param_groups or scheduler are higher priority.
109
+ # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
110
+ # in each module, settings passed to that module by calling set_param_groups are highest priority
111
+
112
+ self.current_step = 0
113
+ """global step counter for the optimizer."""
114
+
115
+ self.num_evaluations = 0
116
+ """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
117
+
118
+ # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
119
+ # we want to return original loss so this attribute is used
120
+ self._closure_return = None
121
+ """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
122
+
123
+ self.attrs = {}
124
+ """custom attributes that can be set by modules, for example EMA of weights or best so far"""
125
+
126
+ self.should_terminate = False
127
+ """is set to True by termination criteria modules."""
128
+
129
+ def add_param_group(self, param_group: dict[str, Any]):
130
+ proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
131
+ self.param_groups.append(ChainMap(proc_param_group, self.defaults))
132
+ # setting param_group[key] = value sets it to first map (the `proc_param_group`).
133
+ # therefore lr schedulers override defaults, but not settings passed to individual modules
134
+ # by `set_param_groups` .
135
+
136
+ for p in proc_param_group['params']:
137
+ # updates global per-parameter setting overrides (medium priority)
138
+ self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.flat_modules]
139
+
140
+ def state_dict(self):
141
+ all_params = [p for g in self.param_groups for p in g['params']]
142
+ id_to_idx = {id(p): i for i,p in enumerate(all_params)}
143
+
144
+ groups = []
145
+ for g in self.param_groups:
146
+ g = g.copy()
147
+ g['params'] = [id_to_idx[id(p)] for p in g['params']]
148
+ groups.append(g)
149
+
150
+ state_dict = {
151
+ "idx_to_id": {v:k for k,v in id_to_idx.items()},
152
+ "params": all_params,
153
+ "groups": groups,
154
+ "defaults": self.defaults,
155
+ "modules": {i: m.state_dict() for i, m in enumerate(self.flat_modules)}
156
+ }
157
+ return state_dict
158
+
159
+ def load_state_dict(self, state_dict: dict):
160
+ self.defaults.clear()
161
+ self.defaults.update(state_dict['defaults'])
162
+
163
+ idx_to_param = dict(enumerate(state_dict['params']))
164
+ groups = []
165
+ for g in state_dict['groups']:
166
+ g = g.copy()
167
+ g['params'] = [idx_to_param[p] for p in g['params']]
168
+ groups.append(g)
169
+
170
+ self.param_groups.clear()
171
+ for group in groups:
172
+ self.add_param_group(group)
173
+
174
+ id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
175
+ for m, sd in zip(self.flat_modules, state_dict['modules'].values()):
176
+ m._load_state_dict(sd, id_to_tensor)
177
+
178
+
179
+ def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
180
+ # clear closure return from previous step
181
+ self._closure_return = None
182
+
183
+ # propagate global per-parameter setting overrides
184
+ for g in self.param_groups:
185
+ settings = dict(g.maps[0]) # ignore defaults
186
+ params = settings.pop('params')
187
+ if not settings: continue
188
+
189
+ for p in params:
190
+ if not p.requires_grad: continue
191
+ for map in self._per_parameter_global_settings[p]: map.update(settings)
192
+
193
+ # create Objective
194
+ params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
195
+
196
+ counter_closure = None
197
+ if closure is not None:
198
+ counter_closure = _EvalCounterClosure(self, closure)
199
+
200
+ objective = Objective(
201
+ params=params, closure=counter_closure, model=self.model,
202
+ current_step=self.current_step, modular=self, loss=loss, storage=kwargs
203
+ )
204
+
205
+ # step with all modules
206
+ objective = step(objective, self.modules)
207
+
208
+ # apply update to parameters unless `objective.skip_update = True`
209
+ # this does:
210
+ # if not objective.skip_update:
211
+ # torch._foreach_sub_(objective.params, objective.get_updates())
212
+ objective.update_parameters()
213
+
214
+ # update attributes
215
+ self.attrs.update(objective.attrs)
216
+ if objective.should_terminate is not None:
217
+ self.should_terminate = objective.should_terminate
218
+
219
+ self.current_step += 1
220
+
221
+ # apply hooks
222
+ # this does:
223
+ # for hook in objective.post_step_hooks:
224
+ # hook(objective, modules)
225
+ objective.apply_post_step_hooks(self.modules)
226
+
227
+ # return the first closure evaluation return
228
+ # could return loss if it was passed but that's pointless
229
+ return self._closure_return
230
+
231
+ def __repr__(self):
232
+ return f'Modular({", ".join(str(m) for m in self.modules)})'
233
+