torchzero 0.3.13__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. tests/test_opts.py +4 -10
  2. torchzero/core/__init__.py +4 -1
  3. torchzero/core/chain.py +50 -0
  4. torchzero/core/functional.py +37 -0
  5. torchzero/core/modular.py +237 -0
  6. torchzero/core/module.py +12 -599
  7. torchzero/core/reformulation.py +3 -1
  8. torchzero/core/transform.py +7 -5
  9. torchzero/core/var.py +376 -0
  10. torchzero/modules/__init__.py +0 -1
  11. torchzero/modules/adaptive/adahessian.py +2 -2
  12. torchzero/modules/adaptive/esgd.py +2 -2
  13. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  14. torchzero/modules/adaptive/sophia_h.py +2 -2
  15. torchzero/modules/conjugate_gradient/cg.py +16 -16
  16. torchzero/modules/experimental/__init__.py +1 -0
  17. torchzero/modules/experimental/newtonnewton.py +5 -5
  18. torchzero/modules/experimental/spsa1.py +93 -0
  19. torchzero/modules/functional.py +7 -0
  20. torchzero/modules/grad_approximation/__init__.py +1 -1
  21. torchzero/modules/grad_approximation/forward_gradient.py +2 -5
  22. torchzero/modules/grad_approximation/rfdm.py +27 -110
  23. torchzero/modules/line_search/__init__.py +1 -1
  24. torchzero/modules/line_search/_polyinterp.py +3 -1
  25. torchzero/modules/line_search/adaptive.py +3 -3
  26. torchzero/modules/line_search/backtracking.py +1 -1
  27. torchzero/modules/line_search/interpolation.py +160 -0
  28. torchzero/modules/line_search/line_search.py +11 -20
  29. torchzero/modules/line_search/scipy.py +15 -3
  30. torchzero/modules/line_search/strong_wolfe.py +3 -5
  31. torchzero/modules/misc/misc.py +2 -2
  32. torchzero/modules/misc/multistep.py +13 -13
  33. torchzero/modules/quasi_newton/__init__.py +2 -0
  34. torchzero/modules/quasi_newton/quasi_newton.py +15 -6
  35. torchzero/modules/quasi_newton/sg2.py +292 -0
  36. torchzero/modules/restarts/restars.py +5 -4
  37. torchzero/modules/second_order/__init__.py +6 -3
  38. torchzero/modules/second_order/ifn.py +89 -0
  39. torchzero/modules/second_order/inm.py +105 -0
  40. torchzero/modules/second_order/newton.py +103 -193
  41. torchzero/modules/second_order/newton_cg.py +86 -110
  42. torchzero/modules/second_order/nystrom.py +1 -1
  43. torchzero/modules/second_order/rsn.py +227 -0
  44. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  45. torchzero/modules/trust_region/trust_cg.py +6 -4
  46. torchzero/modules/wrappers/optim_wrapper.py +49 -42
  47. torchzero/modules/zeroth_order/__init__.py +1 -1
  48. torchzero/modules/zeroth_order/cd.py +1 -238
  49. torchzero/utils/derivatives.py +19 -19
  50. torchzero/utils/linalg/linear_operator.py +50 -2
  51. torchzero/utils/optimizer.py +2 -2
  52. torchzero/utils/python_tools.py +1 -0
  53. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/METADATA +1 -1
  54. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/RECORD +57 -48
  55. torchzero/modules/higher_order/__init__.py +0 -1
  56. /torchzero/modules/{higher_order → experimental}/higher_order_newton.py +0 -0
  57. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/WHEEL +0 -0
  58. {torchzero-0.3.13.dist-info → torchzero-0.3.15.dist-info}/top_level.txt +0 -0
torchzero/core/module.py CHANGED
@@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
4
  from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any, final, overload, Literal, cast
6
+ from typing import Any, Literal, cast, final, overload
7
7
 
8
8
  import torch
9
9
 
@@ -13,259 +13,14 @@ from ..utils import (
13
13
  Params,
14
14
  _make_param_groups,
15
15
  get_state_vals,
16
+ vec_to_tensors,
16
17
  )
17
- from ..utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
18
- from ..utils.python_tools import flatten
18
+ from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
19
19
  from ..utils.linalg.linear_operator import LinearOperator
20
+ from ..utils.python_tools import flatten
21
+ from .var import Var
20
22
 
21
23
 
22
- def _closure_backward(closure, params, retain_graph, create_graph):
23
- with torch.enable_grad():
24
- if not (retain_graph or create_graph):
25
- return closure()
26
-
27
- for p in params: p.grad = None
28
- loss = closure(False)
29
- grad = torch.autograd.grad(loss, params, retain_graph=retain_graph, create_graph=create_graph)
30
- for p,g in zip(params,grad): p.grad = g
31
- return loss
32
-
33
- # region Vars
34
- # ----------------------------------- var ----------------------------------- #
35
- class Var:
36
- """
37
- Holds parameters, gradient, update, objective function (closure) if supplied, loss, and some other info.
38
- Modules take in a ``Var`` object, modify and it is passed to the next module.
39
-
40
- """
41
- def __init__(
42
- self,
43
- params: list[torch.Tensor],
44
- closure: Callable | None,
45
- model: torch.nn.Module | None,
46
- current_step: int,
47
- parent: "Var | None" = None,
48
- modular: "Modular | None" = None,
49
- loss: torch.Tensor | None = None,
50
- storage: dict | None = None,
51
- ):
52
- self.params: list[torch.Tensor] = params
53
- """List of all parameters with requires_grad = True."""
54
-
55
- self.closure = closure
56
- """A closure that reevaluates the model and returns the loss, None if it wasn't specified"""
57
-
58
- self.model = model
59
- """torch.nn.Module object of the model, None if it wasn't specified."""
60
-
61
- self.current_step: int = current_step
62
- """global current step, starts at 0. This may not correspond to module current step,
63
- for example a module may step every 10 global steps."""
64
-
65
- self.parent: "Var | None" = parent
66
- """parent ``Var`` object. When ``self.get_grad()`` is called, it will also set ``parent.grad``.
67
- Same with ``self.get_loss()``. This is useful when ``self.params`` are different from ``parent.params``,
68
- e.g. when projecting."""
69
-
70
- self.modular: "Modular" = cast(Modular, modular)
71
- """Modular optimizer object that created this ``Var``."""
72
-
73
- self.update: list[torch.Tensor] | None = None
74
- """
75
- current update. Update is assumed to be a transformed gradient, therefore it is subtracted.
76
-
77
- If closure is None, this is initially set to cloned gradient. Otherwise this is set to None.
78
-
79
- At the end ``var.get_update()`` is subtracted from parameters. Therefore if ``var.update`` is ``None``,
80
- gradient will be used and calculated if needed.
81
- """
82
-
83
- self.grad: list[torch.Tensor] | None = None
84
- """gradient with current parameters. If closure is not ``None``, this is set to ``None`` and can be calculated if needed."""
85
-
86
- self.loss: torch.Tensor | Any | None = loss
87
- """loss with current parameters."""
88
-
89
- self.loss_approx: torch.Tensor | Any | None = None
90
- """loss at a point near current point. This can be useful as some modules only calculate loss at perturbed points,
91
- whereas some other modules require loss strictly at current point."""
92
-
93
- self.post_step_hooks: list[Callable[[Modular, Var]]] = []
94
- """list of functions to be called after optimizer step.
95
-
96
- This attribute should always be modified in-place (using ``append`` or ``extend``).
97
-
98
- The signature is:
99
-
100
- ```python
101
- def hook(optimizer: Modular, var: Vars): ...
102
- ```
103
- """
104
-
105
- self.is_last: bool = False
106
- """
107
- Indicates that current module is either last or next-to-last before a learning rate module.
108
- This is always False if current module has children or is a child.
109
- This is because otherwise the ``is_last`` would be passed to child modules, even though they aren't last.
110
- """
111
-
112
- self.nested_is_last: bool = False
113
- """
114
- Indicates that current module is either last or next-to-last before a learning rate module, for modules
115
- that have children. This will be passed to the children unless ``var.clone()`` is used, therefore
116
- a child of a last module may also receive ``var.nested_is_last=True``.
117
- """
118
-
119
- self.last_module_lrs: list[float] | None = None
120
- """
121
- List of per-parameter learning rates if current module is next-to-last before a
122
- learning rate module, otherwise this is set to None. Ignore this unless you are manually applying
123
- update to parameters.
124
- """
125
-
126
- self.stop: bool = False
127
- """if True, all following modules will be skipped.
128
- If this module is a child, it only affects modules at the same level (in the same Chain)."""
129
-
130
- self.skip_update: bool = False
131
- """if True, the parameters will not be updated."""
132
-
133
- # self.storage: dict = {}
134
- # """Storage for any other data, such as hessian estimates, etc."""
135
-
136
- self.attrs: dict = {}
137
- """attributes, Modular.attrs is updated with this after each step. This attribute should always be modified in-place"""
138
-
139
- if storage is None: storage = {}
140
- self.storage: dict = storage
141
- """additional kwargs passed to closure will end up in this dict. This attribute should always be modified in-place"""
142
-
143
- self.should_terminate: bool | None = None
144
- """termination criteria, Modular.should_terminate is set to this after each step if not None"""
145
-
146
- def get_loss(self, backward: bool, retain_graph = None, create_graph: bool = False) -> torch.Tensor | float:
147
- """Returns the loss at current parameters, computing it if it hasn't been computed already and assigning ``var.loss``.
148
- Do not call this at perturbed parameters. Backward always sets grads to None before recomputing."""
149
- if self.loss is None:
150
-
151
- if self.closure is None: raise RuntimeError("closure is None")
152
- if backward:
153
- with torch.enable_grad():
154
- self.loss = self.loss_approx = _closure_backward(
155
- closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
156
- )
157
-
158
- # initializing to zeros_like is equivalent to using zero_grad with set_to_none = False.
159
- # it is technically a more correct approach for when some parameters conditionally receive gradients
160
- # and in this case it shouldn't be slower.
161
-
162
- # next time closure() is called, it will set grad to None.
163
- # zero_grad(set_to_none=False) shouldn't be used (I should add a warning)
164
- self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
165
- else:
166
- self.loss = self.loss_approx = self.closure(False)
167
-
168
- # if self.loss was not None, above branch wasn't executed because loss has already been evaluated, but without backward since self.grad is None.
169
- # and now it is requested to be evaluated with backward.
170
- if backward and self.grad is None:
171
- warnings.warn('get_loss was called with backward=False, and then with backward=True so it had to be re-evaluated, so the closure was evaluated twice where it could have been evaluated once.')
172
- if self.closure is None: raise RuntimeError("closure is None")
173
-
174
- with torch.enable_grad():
175
- self.loss = self.loss_approx = _closure_backward(
176
- closure=self.closure, params=self.params, retain_graph=retain_graph, create_graph=create_graph
177
- )
178
- self.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
179
-
180
- # set parent grad
181
- if self.parent is not None:
182
- # the way projections/split work, they make a new closure which evaluates original
183
- # closure and projects the gradient, and set it as their var.closure.
184
- # then on `get_loss(backward=True)` it is called, so it also sets original parameters gradient.
185
- # and we set it to parent var here.
186
- if self.parent.loss is None: self.parent.loss = self.loss
187
- if self.parent.grad is None and backward:
188
- if all(p.grad is None for p in self.parent.params):
189
- warnings.warn("Parent grad is None after backward.")
190
- self.parent.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.parent.params]
191
-
192
- return self.loss # type:ignore
193
-
194
- def get_grad(self, retain_graph: bool | None = None, create_graph: bool = False) -> list[torch.Tensor]:
195
- """Returns the gradient at initial parameters, computing it if it hasn't been computed already and assigning
196
- ``var.grad`` and potentially ``var.loss``. Do not call this at perturbed parameters."""
197
- if self.grad is None:
198
- if self.closure is None: raise RuntimeError("closure is None")
199
- self.get_loss(backward=True, retain_graph=retain_graph, create_graph=create_graph) # evaluate and set self.loss and self.grad
200
-
201
- assert self.grad is not None
202
- return self.grad
203
-
204
- def get_update(self) -> list[torch.Tensor]:
205
- """Returns the update. If update is None, it is initialized by cloning the gradients and assigning to ``var.update``.
206
- Computing the gradients may assign ``var.grad`` and ``var.loss`` if they haven't been computed.
207
- Do not call this at perturbed parameters."""
208
- if self.update is None: self.update = [g.clone() for g in self.get_grad()]
209
- return self.update
210
-
211
- def clone(self, clone_update: bool, parent: "Var | None" = None):
212
- """Creates a shallow copy of the Vars object, update can optionally be deep-copied (via ``torch.clone``).
213
-
214
- Doesn't copy ``is_last``, ``nested_is_last`` and ``last_module_lrs``. They will always be ``False``/``None``.
215
-
216
- Setting ``parent`` is only if clone's parameters are something different,
217
- while clone's closure referes to the same objective but with a "view" on parameters.
218
- """
219
- copy = Var(params = self.params, closure=self.closure, model=self.model, current_step=self.current_step, parent=parent)
220
-
221
- if clone_update and self.update is not None:
222
- copy.update = [u.clone() for u in self.update]
223
- else:
224
- copy.update = self.update
225
-
226
- copy.grad = self.grad
227
- copy.loss = self.loss
228
- copy.loss_approx = self.loss_approx
229
- copy.closure = self.closure
230
- copy.post_step_hooks = self.post_step_hooks
231
- copy.stop = self.stop
232
- copy.skip_update = self.skip_update
233
-
234
- copy.modular = self.modular
235
- copy.attrs = self.attrs
236
- copy.storage = self.storage
237
- copy.should_terminate = self.should_terminate
238
-
239
- return copy
240
-
241
- def update_attrs_from_clone_(self, var: "Var"):
242
- """Updates attributes of this `Vars` instance from a cloned instance.
243
- Typically called after a child module has processed a cloned `Vars`
244
- object. This propagates any newly computed loss or gradient values
245
- from the child's context back to the parent `Vars` if the parent
246
- didn't have them computed already.
247
-
248
- Also, as long as ``post_step_hooks`` and ``attrs`` are modified in-place,
249
- if the child updates them, the update will affect the parent too.
250
- """
251
- if self.loss is None: self.loss = var.loss
252
- if self.loss_approx is None: self.loss_approx = var.loss_approx
253
- if self.grad is None: self.grad = var.grad
254
-
255
- if var.should_terminate is not None: self.should_terminate = var.should_terminate
256
-
257
- def zero_grad(self, set_to_none=True):
258
- if set_to_none:
259
- for p in self.params: p.grad = None
260
- else:
261
- grads = [p.grad for p in self.params if p.grad is not None]
262
- if len(grads) != 0: torch._foreach_zero_(grads)
263
-
264
- # endregion
265
-
266
-
267
- # region Module
268
- # ---------------------------------- module ---------------------------------- #
269
24
  class Module(ABC):
270
25
  """Abstract base class for an optimizer modules.
271
26
 
@@ -317,9 +72,12 @@ class Module(ABC):
317
72
  return self
318
73
 
319
74
  def set_child(self, key: str, module: "Module | Sequence[Module]"):
75
+ from .chain import maybe_chain
320
76
  self.children[key] = maybe_chain(module)
321
77
 
322
78
  def set_children_sequence(self, modules: "Iterable[Module | Sequence[Module]]", prefix = 'module_'):
79
+ from .chain import maybe_chain
80
+
323
81
  modules = list(modules)
324
82
  for i, m in enumerate(modules):
325
83
  self.set_child(f'{prefix}{i}', maybe_chain(m))
@@ -531,7 +289,11 @@ class Module(ABC):
531
289
  def reset(self):
532
290
  """Resets the internal state of the module (e.g. momentum) and all children. By default clears state and global state."""
533
291
  self.state.clear()
292
+
293
+ generator = self.global_state.get("generator", None)
534
294
  self.global_state.clear()
295
+ if generator is not None: self.global_state["generator"] = generator
296
+
535
297
  for c in self.children.values(): c.reset()
536
298
 
537
299
  def reset_for_online(self):
@@ -554,82 +316,6 @@ class Module(ABC):
554
316
  """``_extra_pack`` return will be passed to this method when loading state_dict.
555
317
  This method is called after loading the rest of the state dict"""
556
318
 
557
-
558
-
559
- # ------------------------------ HELPER METHODS ------------------------------ #
560
- @torch.no_grad
561
- def Hvp(
562
- self,
563
- v: Sequence[torch.Tensor],
564
- at_x0: bool,
565
- var: Var,
566
- rgrad: Sequence[torch.Tensor] | None,
567
- hvp_method: Literal['autograd', 'forward', 'central'],
568
- h: float,
569
- normalize: bool,
570
- retain_grad: bool,
571
- ) -> tuple[Sequence[torch.Tensor], Sequence[torch.Tensor] | None]:
572
- """
573
- Returns ``(Hvp, rgrad)``, where ``rgrad`` is gradient at current parameters,
574
- possibly with ``create_graph=True``, or it may be None with ``hvp_method="central"``.
575
- Gradient is set to vars automatically if ``at_x0``, you can always access it with ``vars.get_grad()``
576
-
577
- Single sample example:
578
-
579
- ```python
580
- Hvp, _ = self.hvp(v, at_x0=True, rgrad=None, ..., retain_graph=False)
581
- ```
582
-
583
- Multiple samples example:
584
-
585
- ```python
586
- D = None
587
- rgrad = None
588
- for i in range(n_samples):
589
- v = [torch.randn_like(p) for p in params]
590
- Hvp, rgrad = self.hvp(v, at_x0=True, rgrad=rgrad, ..., retain_graph=i < n_samples-1)
591
-
592
- if D is None: D = Hvp
593
- else: torch._foreach_add_(D, Hvp)
594
-
595
- if n_samples > 1: torch._foreach_div_(D, n_samples)
596
- ```
597
-
598
- Args:
599
- v (Sequence[torch.Tensor]): vector in hessian-vector product
600
- at_x0 (bool): whether this is being called at original or perturbed parameters.
601
- var (Var): Var
602
- rgrad (Sequence[torch.Tensor] | None): pass None initially, then pass what this returns.
603
- hvp_method (str): hvp method.
604
- h (float): finite difference step size
605
- normalize (bool): whether to normalize v for finite difference
606
- retain_grad (bool): retain grad
607
- """
608
- # get grad
609
- if rgrad is None and hvp_method in ('autograd', 'forward'):
610
- if at_x0: rgrad = var.get_grad(create_graph = hvp_method=='autograd')
611
- else:
612
- if var.closure is None: raise RuntimeError("Closure is required to calculate HVp")
613
- with torch.enable_grad():
614
- loss = var.closure()
615
- rgrad = torch.autograd.grad(loss, var.params, create_graph = hvp_method=='autograd')
616
-
617
- if hvp_method == 'autograd':
618
- assert rgrad is not None
619
- Hvp = hvp(var.params, rgrad, v, retain_graph=retain_grad)
620
-
621
- elif hvp_method == 'forward':
622
- assert rgrad is not None
623
- loss, Hvp = hvp_fd_forward(var.closure, var.params, v, h=h, g_0=rgrad, normalize=normalize)
624
-
625
- elif hvp_method == 'central':
626
- loss, Hvp = hvp_fd_central(var.closure, var.params, v, h=h, normalize=normalize)
627
-
628
- else:
629
- raise ValueError(hvp_method)
630
-
631
- return Hvp, rgrad
632
-
633
319
  def get_generator(self, device: torch.types.Device, seed: int | None):
634
320
  if seed is None: return None
635
321
 
@@ -638,277 +324,4 @@ class Module(ABC):
638
324
 
639
325
  return self.global_state['generator']
640
326
 
641
- # endregion
642
-
643
327
  Chainable = Module | Sequence[Module]
644
-
645
-
646
- def unroll_modules(*modules: Chainable) -> list[Module]:
647
- unrolled = []
648
-
649
- for m in modules:
650
- if isinstance(m, Module):
651
- unrolled.append(m)
652
- unrolled.extend(unroll_modules(list(m.children.values())))
653
- else:
654
- unrolled.extend(unroll_modules(*m))
655
-
656
- return unrolled
657
-
658
-
659
- # region Modular
660
- # ---------------------------------- Modular --------------------------------- #
661
-
662
- class _EvalCounterClosure:
663
- """keeps track of how many times closure has been evaluated, and sets closure return"""
664
- __slots__ = ("modular", "closure")
665
- def __init__(self, modular: "Modular", closure):
666
- self.modular = modular
667
- self.closure = closure
668
-
669
- def __call__(self, *args, **kwargs):
670
- if self.closure is None:
671
- raise RuntimeError("One of the modules requires closure to be passed to the step method")
672
-
673
- v = self.closure(*args, **kwargs)
674
-
675
- # set closure return on 1st evaluation
676
- if self.modular._closure_return is None:
677
- self.modular._closure_return = v
678
-
679
- self.modular.num_evaluations += 1
680
- return v
681
-
682
- # have to inherit from Modular to support lr schedulers
683
- # although Accelerate doesn't work due to converting param_groups to a dict
684
- class Modular(torch.optim.Optimizer):
685
- """Chains multiple modules into an optimizer.
686
-
687
- Args:
688
- params (Params | torch.nn.Module): An iterable of parameters to optimize
689
- (typically `model.parameters()`), an iterable of parameter group dicts,
690
- or a `torch.nn.Module` instance.
691
- *modules (Module): A sequence of `Module` instances that define the
692
- optimization algorithm steps.
693
- """
694
- # this is specifically for lr schedulers
695
- param_groups: list[ChainMap[str, Any]] # pyright:ignore[reportIncompatibleVariableOverride]
696
-
697
- def __init__(self, params: Params | torch.nn.Module, *modules: Module):
698
- if len(modules) == 0: raise RuntimeError("Empty list of modules passed to `Modular`")
699
- self.model: torch.nn.Module | None = None
700
- """The model whose parameters are being optimized, if a model instance was passed to `__init__`."""
701
- if isinstance(params, torch.nn.Module):
702
- self.model = params
703
- params = params.parameters()
704
-
705
- self.modules = modules
706
- """Top-level modules providedduring initialization."""
707
-
708
- self.unrolled_modules = unroll_modules(self.modules)
709
- """A flattened list of all modules including all children."""
710
-
711
- param_groups = _make_param_groups(params, differentiable=False)
712
- self._per_parameter_global_settings: dict[torch.Tensor, list[MutableMapping[str, Any]]] = {}
713
-
714
- # make sure there is no more than a single learning rate module
715
- lr_modules = [m for m in self.unrolled_modules if 'lr' in m.defaults]
716
- if len(lr_modules) > 1:
717
- warnings.warn(f'multiple learning rate modules detected: {lr_modules}. This may lead to componding of learning rate multiplication with per-parameter learning rates and schedulers.')
718
-
719
- # iterate over all per-parameter settings overrides and check if they are applied at most once
720
- for group in param_groups:
721
- for k in group:
722
- if k in ('params', 'lr'): continue
723
- modules_with_k = [m for m in self.unrolled_modules if k in m.defaults and k not in m._overridden_keys]
724
- if len(modules_with_k) > 1:
725
- warnings.warn(f'`params` has a `{k}` key, and multiple modules have that key: {modules_with_k}. If you intended to only set `{k}` to one of them, use `module.set_param_groups(params)`')
726
-
727
- # defaults for schedulers
728
- defaults = {}
729
- for m in self.unrolled_modules: defaults.update(m.defaults)
730
- super().__init__(param_groups, defaults=defaults)
731
-
732
- # note - this is what super().__init__(param_groups, defaults=defaults) does:
733
-
734
- # self.defaults = defaults
735
- # for param_group in param_groups:
736
- # self.add_param_group(param_group)
737
-
738
- # add_param_group adds a ChainMap where defaults are lowest priority,
739
- # and entries specifed in param_groups or scheduler are higher priority.
740
- # pytorch schedulers do group["lr"] = new_lr, which sets higher priority key.
741
- # in each module, settings passed to that module by calling set_param_groups are highest priority
742
-
743
- self.current_step = 0
744
- """global step counter for the optimizer."""
745
-
746
- self.num_evaluations = 0
747
- """number of times the objective has been evaluated (number of closure calls or number of steps if closure is None)."""
748
-
749
- # reformulations will change the closure to return a different loss (e.g. a sqrt homotopy, gaussian homotopy)
750
- # we want to return original loss so this attribute is used
751
- self._closure_return = None
752
- """on each step, first time a closure is evaluated, this attribute is set to the returned value. `step` method returns this."""
753
-
754
- self.attrs = {}
755
- """custom attributes that can be set by modules, for example EMA of weights or best so far"""
756
-
757
- self.should_terminate = False
758
- """is set to True by termination criteria modules."""
759
-
760
- def add_param_group(self, param_group: dict[str, Any]):
761
- proc_param_group = _make_param_groups([param_group], differentiable=False)[0]
762
- self.param_groups.append(ChainMap(proc_param_group, self.defaults))
763
-
764
- for p in proc_param_group['params']:
765
- # updates global per-parameter setting overrides (medium priority)
766
- self._per_parameter_global_settings[p] = [m.settings[p].maps[1] for m in self.unrolled_modules]
767
-
768
- def state_dict(self):
769
- all_params = [p for g in self.param_groups for p in g['params']]
770
- id_to_idx = {id(p): i for i,p in enumerate(all_params)}
771
-
772
- groups = []
773
- for g in self.param_groups:
774
- g = g.copy()
775
- g['params'] = [id_to_idx[id(p)] for p in g['params']]
776
- groups.append(g)
777
-
778
- state_dict = {
779
- "idx_to_id": {v:k for k,v in id_to_idx.items()},
780
- "params": all_params,
781
- "groups": groups,
782
- "defaults": self.defaults,
783
- "modules": {i: m.state_dict() for i, m in enumerate(self.unrolled_modules)}
784
- }
785
- return state_dict
786
-
787
- def load_state_dict(self, state_dict: dict):
788
- self.defaults.clear()
789
- self.defaults.update(state_dict['defaults'])
790
-
791
- idx_to_param = dict(enumerate(state_dict['params']))
792
- groups = []
793
- for g in state_dict['groups']:
794
- g = g.copy()
795
- g['params'] = [idx_to_param[p] for p in g['params']]
796
- groups.append(g)
797
-
798
- self.param_groups.clear()
799
- for group in groups:
800
- self.add_param_group(group)
801
-
802
- id_to_tensor = {state_dict['idx_to_id'][i]: p for i,p in enumerate(state_dict['params'])}
803
- for m, sd in zip(self.unrolled_modules, state_dict['modules'].values()):
804
- m._load_state_dict(sd, id_to_tensor)
805
-
806
-
807
- def step(self, closure=None, loss=None, **kwargs): # pyright: ignore[reportIncompatibleMethodOverride]
808
- # clear closure return from previous step
809
- self._closure_return = None
810
-
811
- # propagate global per-parameter setting overrides
812
- for g in self.param_groups:
813
- settings = dict(g.maps[0]) # ignore defaults
814
- params = settings.pop('params')
815
- if not settings: continue
816
-
817
- for p in params:
818
- if not p.requires_grad: continue
819
- for map in self._per_parameter_global_settings[p]: map.update(settings)
820
-
821
- # create var
822
- params = [p for g in self.param_groups for p in g['params'] if p.requires_grad]
823
- var = Var(params=params, closure=_EvalCounterClosure(self, closure), model=self.model, current_step=self.current_step, modular=self, loss=loss, storage=kwargs)
824
-
825
- # if closure is None, assume backward has been called and gather grads
826
- if closure is None:
827
- var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
828
- self.num_evaluations += 1
829
-
830
- n_modules = len(self.modules)
831
- if n_modules == 0: raise RuntimeError("There are no modules in this `Modular` optimizer")
832
- last_module = self.modules[-1]
833
- last_lr = last_module.defaults.get('lr', None)
834
-
835
- # step
836
- for i, module in enumerate(self.modules):
837
- if i!=0: var = var.clone(clone_update=False)
838
-
839
- # last module, or next to last module before lr
840
- if (i == n_modules - 1) or ((i == n_modules - 2) and (last_lr is not None)):
841
- if module.children: var.nested_is_last = True
842
- else: var.is_last = True
843
- if last_lr is not None: var.last_module_lrs = [last_module.settings[p]['lr'] for p in var.params]
844
-
845
- var = module.step(var)
846
- if var.stop: break
847
-
848
- # apply update
849
- if not var.skip_update:
850
- with torch.no_grad():
851
- torch._foreach_sub_(params, var.get_update())
852
-
853
- # update attributes
854
- self.attrs.update(var.attrs)
855
- if var.should_terminate is not None: self.should_terminate = var.should_terminate
856
-
857
- # hooks
858
- for hook in var.post_step_hooks:
859
- hook(self, var)
860
-
861
- self.current_step += 1
862
- #return var.loss if var.loss is not None else var.loss_approx
863
- return self._closure_return
864
-
865
- def __repr__(self):
866
- return f'Modular({", ".join(str(m) for m in self.modules)})'
867
- # endregion
868
-
869
- # region Chain
870
- # ----------------------------------- Chain ---------------------------------- #
871
- class Chain(Module):
872
- """Chain of modules, mostly used internally"""
873
- def __init__(self, *modules: Module | Iterable[Module]):
874
- super().__init__()
875
- flat_modules: list[Module] = flatten(modules)
876
- for i, module in enumerate(flat_modules):
877
- self.set_child(f'module_{i}', module)
878
-
879
- def update(self, var):
880
- # note here that `update` and `apply` shouldn't be used directly
881
- # as it will update all modules, and then apply all modules
882
- # it is used in specific cases like Chain as trust region hessian module
883
- for i in range(len(self.children)):
884
- self.children[f'module_{i}'].update(var)
885
- if var.stop: break
886
- return var
887
-
888
- def apply(self, var):
889
- for i in range(len(self.children)):
890
- var = self.children[f'module_{i}'].apply(var)
891
- if var.stop: break
892
- return var
893
-
894
- def step(self, var):
895
- for i in range(len(self.children)):
896
- var = self.children[f'module_{i}'].step(var)
897
- if var.stop: break
898
- return var
899
-
900
- def __repr__(self):
901
- s = self.__class__.__name__
902
- if self.children:
903
- if s == 'Chain': s = 'C' # to shorten it
904
- s = f'{s}({", ".join(str(m) for m in self.children.values())})'
905
- return s
906
-
907
- def maybe_chain(*modules: Chainable) -> Module:
908
- """Returns a single module directly if only one is provided, otherwise wraps them in a :code:`Chain`."""
909
- flat_modules: list[Module] = flatten(modules)
910
- if len(flat_modules) == 1:
911
- return flat_modules[0]
912
- return Chain(*flat_modules)
913
- # endregion
914
-
@@ -3,7 +3,9 @@ from collections.abc import Callable, Sequence
3
3
 
4
4
  import torch
5
5
 
6
- from .module import Chainable, Modular, Module, Var
6
+ from .chain import Chain
7
+ from .module import Chainable, Module
8
+ from .var import Var
7
9
 
8
10
 
9
11
  class Reformulation(Module, ABC):