torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,158 @@
1
+ from typing import Any, Literal
2
+ from collections.abc import Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...tensorlist import TensorList
8
+ from ...core import _ClosureType, OptimizationState
9
+ from .base_ls import LineSearchBase
10
+
11
+ class GridLS(LineSearchBase):
12
+ """Test all `lrs` and pick best.
13
+
14
+ Args:
15
+ lrs (Sequence[float] | np.ndarray | torch.Tensor): sequence of lrs to test.
16
+ stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
17
+ stop_on_worsened (bool, optional):
18
+ stops if next lr loss is worse than previous one.
19
+ this assumes that lrs are in ascending order. Defaults to False.
20
+ log_lrs (bool, optional):
21
+ saves lrs and losses with them into optimizer._lrs (for debugging).
22
+ Defaults to False.
23
+ """
24
+ def __init__(
25
+ self,
26
+ lrs: Sequence[float] | np.ndarray | torch.Tensor,
27
+ stop_on_improvement=False,
28
+ stop_on_worsened=False,
29
+ log_lrs = False,
30
+ ):
31
+ super().__init__({}, maxiter=None, log_lrs=log_lrs)
32
+ self.lrs = lrs
33
+ self.stop_on_improvement = stop_on_improvement
34
+ self.stop_on_worsened = stop_on_worsened
35
+
36
+ @torch.no_grad
37
+ def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
38
+ if state.closure is None: raise ValueError("closure is not set")
39
+ if state.ascent is None: raise ValueError("ascent_direction is not set")
40
+
41
+ if self.stop_on_improvement:
42
+ if state.fx0 is None: state.fx0 = state.closure(False)
43
+ self._lowest_loss = state.fx0
44
+
45
+ for lr in self.lrs:
46
+ loss = self._evaluate_lr_(float(lr), state.closure, state.ascent, params)
47
+
48
+ # if worsened
49
+ if self.stop_on_worsened and loss != self._lowest_loss:
50
+ break
51
+
52
+ # if improved
53
+ if self.stop_on_improvement and loss == self._lowest_loss:
54
+ break
55
+
56
+ return float(self._best_lr)
57
+
58
+
59
+
60
+ class MultiplicativeLS(GridLS):
61
+ """Starts with `init` lr, then keeps multiplying it by `mul` until loss stops decreasing.
62
+
63
+ Args:
64
+ init (float, optional): initial lr. Defaults to 0.001.
65
+ mul (float, optional): lr multiplier. Defaults to 2.
66
+ num (int, optional): maximum number of multiplication steps. Defaults to 10.
67
+ stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
68
+ stop_on_worsened (bool, optional):
69
+ stops if next lr loss is worse than previous one.
70
+ this assumes that lrs are in ascending order. Defaults to False.
71
+ log_lrs (bool, optional):
72
+ saves lrs and losses with them into optimizer._lrs (for debugging).
73
+ Defaults to False.
74
+ """
75
+ def __init__(
76
+ self,
77
+ init: float = 0.001,
78
+ mul: float = 2,
79
+ num=10,
80
+ stop_on_improvement=False,
81
+ stop_on_worsened=True,
82
+ ):
83
+ super().__init__(
84
+ [init * mul**i for i in range(num)],
85
+ stop_on_improvement=stop_on_improvement,
86
+ stop_on_worsened=stop_on_worsened,
87
+ )
88
+
89
+ class BacktrackingLS(GridLS):
90
+ """tests `init` lr, and keeps multiplying it by `mul` until loss becomes better than initial loss.
91
+
92
+ note: this doesn't include Armijo–Goldstein condition.
93
+
94
+ Args:
95
+ init (float, optional): initial lr. Defaults to 1.
96
+ mul (float, optional): lr multiplier. Defaults to 0.5.
97
+ num (int, optional): maximum number of multiplication steps. Defaults to 10.
98
+ stop_on_improvement (bool, optional): stops if loss improves compared to current loss. Defaults to False.
99
+ stop_on_worsened (bool, optional):
100
+ stops if next lr loss is worse than previous one.
101
+ this assumes that lrs are in ascending order. Defaults to False.
102
+ log_lrs (bool, optional):
103
+ saves lrs and losses with them into optimizer._lrs (for debugging).
104
+ Defaults to False.
105
+
106
+ """
107
+ def __init__(
108
+ self,
109
+ init: float = 1,
110
+ mul: float = 0.5,
111
+ num=10,
112
+ stop_on_improvement=True,
113
+ stop_on_worsened=False,
114
+ log_lrs = False,
115
+ ):
116
+ super().__init__(
117
+ [init * mul**i for i in range(num)],
118
+ stop_on_improvement=stop_on_improvement,
119
+ stop_on_worsened=stop_on_worsened,
120
+ log_lrs = log_lrs,
121
+ )
122
+
123
+ class LinspaceLS(GridLS):
124
+ """Test all learning rates from a linspace and pick best."""
125
+ def __init__(
126
+ self,
127
+ start: float = 0.001,
128
+ end: float = 2,
129
+ steps=10,
130
+ stop_on_improvement=False,
131
+ stop_on_worsened=False,
132
+ log_lrs = False,
133
+ ):
134
+ super().__init__(
135
+ torch.linspace(start, end, steps),
136
+ stop_on_improvement=stop_on_improvement,
137
+ stop_on_worsened=stop_on_worsened,
138
+ log_lrs = log_lrs,
139
+ )
140
+
141
+ class ArangeLS(GridLS):
142
+ """Test all learning rates from a linspace and pick best."""
143
+ def __init__(
144
+ self,
145
+ start: float = 0.001,
146
+ end: float = 2,
147
+ step=0.1,
148
+ stop_on_improvement=False,
149
+ stop_on_worsened=False,
150
+ log_lrs = False,
151
+
152
+ ):
153
+ super().__init__(
154
+ torch.arange(start, end, step),
155
+ stop_on_improvement=stop_on_improvement,
156
+ stop_on_worsened=stop_on_worsened,
157
+ log_lrs = log_lrs,
158
+ )
@@ -0,0 +1,62 @@
1
+ import typing
2
+
3
+ import torch
4
+ try:
5
+ import scipy.optimize as scopt
6
+ except ModuleNotFoundError:
7
+ scopt = typing.cast(typing.Any, None)
8
+
9
+ from ...tensorlist import TensorList
10
+ from ...core import OptimizationState
11
+
12
+ from .base_ls import LineSearchBase, MaxIterReached
13
+
14
+ if typing.TYPE_CHECKING:
15
+ import scipy.optimize as scopt
16
+
17
+ class ScipyMinimizeScalarLS(LineSearchBase):
18
+ """Line search via `scipy.optimize.minimize_scalar`. All args except maxiter are the same as for it.
19
+
20
+ Args:
21
+ method (Optional[str], optional): 'brent', 'golden' or 'bounded'. Defaults to None.
22
+ maxiter (Optional[int], optional): hard limit on maximum number of function evaluations. Defaults to None.
23
+ bracket (optional): bracket. Defaults to None.
24
+ bounds (optional): bounds. Defaults to None.
25
+ tol (Optional[float], optional): some kind of tolerance. Defaults to None.
26
+ options (optional): options for method. Defaults to None.
27
+ log_lrs (bool, optional): logs lrs and values into `_lrs`. Defaults to False.
28
+ """
29
+ def __init__(
30
+ self,
31
+ method: str | None = None,
32
+ maxiter: int | None = None,
33
+ bracket = None,
34
+ bounds = None,
35
+ tol: float | None = None,
36
+ options = None,
37
+ log_lrs = False,
38
+ ):
39
+ if scopt is None: raise ModuleNotFoundError("scipy is not installed")
40
+ super().__init__({}, maxiter=maxiter, log_lrs=log_lrs)
41
+ self.method = method
42
+ self.tol = tol
43
+ self.bracket = bracket
44
+ self.bounds = bounds
45
+ self.options = options
46
+
47
+ @torch.no_grad
48
+ def _find_best_lr(self, state: OptimizationState, params: TensorList) -> float:
49
+ try:
50
+ res = scopt.minimize_scalar(
51
+ self._evaluate_lr_ensure_float,
52
+ args = (state.closure, state.ascent, params),
53
+ method = self.method,
54
+ tol = self.tol,
55
+ bracket = self.bracket,
56
+ bounds = self.bounds,
57
+ options = self.options,
58
+ ) # type:ignore
59
+ except MaxIterReached:
60
+ pass
61
+
62
+ return float(self._best_lr)
@@ -0,0 +1,12 @@
1
+ """Modules that use other modules."""
2
+ # from .chain import Chain, ChainReturn
3
+ import sys
4
+
5
+ from .alternate import Alternate
6
+ from .grafting import Graft, IntermoduleCautious, SignGrafting
7
+ from .return_overrides import ReturnAscent, ReturnClosure, SetGrad
8
+
9
+ # if sys.version_info[1] < 12:
10
+ from .optimizer_wrapper import Wrap, WrapClosure
11
+ # else:
12
+ # from .optimizer_wrapper import Wrap, WrapClosure
@@ -0,0 +1,65 @@
1
+ import random
2
+ from collections.abc import Iterable
3
+ from typing import Any, Literal
4
+
5
+ from ...core import OptimizerModule, _Chainable
6
+
7
+
8
+ class Alternate(OptimizerModule):
9
+ """Alternates stepping with multiple modules.
10
+
11
+ Args:
12
+ modules (Iterable[OptimizerModule | Iterable[OptimizerModule]]): modules to alternate between.
13
+ mode (int | list[int] | tuple[int] | "random"], optional):
14
+ can be integer - number of repeats for all modules;
15
+ list or tuple of integers per each module with number of repeats;
16
+ "random" to pick module randomly each time. Defaults to 1.
17
+ seed (int | None, optional): seed for "random" mode. Defaults to None.
18
+ """
19
+ def __init__(
20
+ self,
21
+ modules: Iterable[_Chainable],
22
+ mode: int | list[int] | tuple[int] | Literal["random"] = 1,
23
+ seed: int | None = None
24
+ ):
25
+ super().__init__({})
26
+ modules = list(modules)
27
+
28
+ for i,m in enumerate(modules):
29
+ self._set_child_(i, m)
30
+
31
+ self.random = random.Random(seed)
32
+
33
+ if isinstance(mode, int): mode = [mode for _ in modules]
34
+ self.mode: list[int] | tuple[int] | Literal['random'] = mode
35
+
36
+ self.cur = 0
37
+ if self.mode == 'random': self.remaining = 0
38
+ else:
39
+ self.remaining = self.mode[0]
40
+ if len(self.mode) != len(self.children):
41
+ raise ValueError(f"got {len(self.children)} modules but {len(mode)} repeats, they should be the same")
42
+
43
+ def step(self, state):
44
+ if self.mode == 'random':
45
+ module = self.random.choice(list(self.children.values()))
46
+
47
+ else:
48
+ if self.remaining == 0:
49
+ self.cur += 1
50
+
51
+ if self.cur >= len(self.mode):
52
+ self.cur = 0
53
+
54
+ if self.remaining == 0: self.remaining = self.mode[self.cur]
55
+
56
+ module = self.children[self.cur]
57
+
58
+ self.remaining -= 1
59
+
60
+ if self.next_module is None:
61
+ return module.step(state)
62
+
63
+ state.ascent = module.return_ascent(state)
64
+ return self._update_params_or_step_with_next(state)
65
+
@@ -0,0 +1,195 @@
1
+ from collections.abc import Iterable
2
+ from typing import Literal
3
+ import torch
4
+
5
+ from ...core import OptimizerModule
6
+ from ...tensorlist import TensorList
7
+
8
+
9
+ class Graft(OptimizerModule):
10
+ """
11
+ Optimizer grafting (magnitude#direction).
12
+ Takes update of one optimizer and makes its norm same as update of another optimizer.
13
+ Can be applied to all weights or layerwise.
14
+
15
+ Args:
16
+ magnitude (OptimizerModule | Iterable[OptimizerModule]):
17
+ module to use magnitude from.
18
+ If sequence of modules is provided, they will be chained.
19
+ direction (OptimizerModule | Iterable[OptimizerModule]):
20
+ module/modules to use direction from.
21
+ If sequence of modules is provided, they will be chained.
22
+ ord (int, optional): norm type. Defaults to 2.
23
+ eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
24
+ layerwise (bool, optional): whether to apply grafting layerwise. Defaults to False.
25
+
26
+ reference
27
+ *Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C.
28
+ Learning Rate Grafting: Transferability of Optimizer Tuning.*
29
+ """
30
+ def __init__(
31
+ self,
32
+ magnitude: OptimizerModule | Iterable[OptimizerModule],
33
+ direction: OptimizerModule | Iterable[OptimizerModule],
34
+ ord: float = 2,
35
+ eps: float = 1e-8,
36
+ layerwise: bool = False,
37
+ # TODO: channelwise
38
+ ):
39
+ super().__init__({})
40
+ self._set_child_('magnitude', magnitude)
41
+ self._set_child_('direction', direction)
42
+ self.ord = ord
43
+ self.eps = eps
44
+ self.layerwise = layerwise
45
+
46
+
47
+ @torch.no_grad
48
+ def step(self, state):
49
+ state_copy = state.copy(clone_ascent=True)
50
+ magnitude = self.children['magnitude'].return_ascent(state_copy)
51
+
52
+ if state_copy.grad is not None: state.grad = state_copy.grad
53
+ if state_copy.fx0 is not None: state.fx0 = state_copy.fx0
54
+ if state_copy.fx0_approx is not None: state.fx0_approx = state_copy.fx0_approx
55
+
56
+ direction = self.children['direction'].return_ascent(state)
57
+
58
+ if self.layerwise:
59
+ M = magnitude.norm(self.ord)
60
+ D = direction.norm(self.ord)
61
+ D.select_set_(D == 0, M)
62
+
63
+ else:
64
+ M = magnitude.total_vector_norm(self.ord)
65
+ D = direction.total_vector_norm(self.ord)
66
+ if D == 0: D = M
67
+
68
+ state.ascent = direction.mul_(M / (D + self.eps))
69
+ return self._update_params_or_step_with_next(state)
70
+
71
+
72
+
73
+ class SignGrafting(OptimizerModule):
74
+ """Weight-wise grafting-like operation where sign of the ascent is taken from first module
75
+ and magnitude from second module.
76
+
77
+ Args:
78
+ magnitude (OptimizerModule | Iterable[OptimizerModule]):
79
+ module to take magnitude from.
80
+ If sequence of modules is provided, they will be chained.
81
+ sign (OptimizerModule | Iterable[OptimizerModule]):
82
+ module to take sign from.
83
+ If sequence of modules is provided, they will be chained.
84
+ """
85
+ def __init__(
86
+ self,
87
+ magnitude: OptimizerModule | Iterable[OptimizerModule],
88
+ sign: OptimizerModule | Iterable[OptimizerModule],
89
+ ):
90
+ super().__init__({})
91
+
92
+ self._set_child_('magnitude', magnitude)
93
+ self._set_child_('sign', sign)
94
+
95
+
96
+ @torch.no_grad
97
+ def step(self, state):
98
+ state_copy = state.copy(clone_ascent=True)
99
+ magnitude = self.children['magnitude'].return_ascent(state_copy)
100
+
101
+ # make sure to store grad and fx0 if it was calculated
102
+ state.update_attrs_(state_copy)
103
+
104
+ sign = self.children['sign'].return_ascent(state)
105
+
106
+ state.ascent = magnitude.copysign_(sign)
107
+ return self._update_params_or_step_with_next(state)
108
+
109
+
110
+ class IntermoduleCautious(OptimizerModule):
111
+ """Negates update for parameters where updates of two modules or module chains have inconsistent sign.
112
+ Optionally normalizes the update by the number of parameters that are not masked.
113
+
114
+ Args:
115
+ main_module (OptimizerModule | Iterable[OptimizerModule]):
116
+ main module or sequence of modules to chain, which update will be used with a consistency mask applied.
117
+ compare_module (OptimizerModule | Iterable[OptimizerModule]):
118
+ module or sequence of modules to chain, which update will be used to compute a consistency mask.
119
+ Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare
120
+ to gradients.
121
+ normalize (bool, optional):
122
+ renormalize update after masking.
123
+ only has effect when mode is 'zero'. Defaults to False.
124
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
125
+ mode (str, optional):
126
+ what to do with updates with inconsistent signs.
127
+
128
+ "zero" - set them to zero (as in paper)
129
+
130
+ "grad" - set them to the gradient
131
+
132
+ "compare_module" - set them to `compare_module`'s update
133
+
134
+ "negate" - negate them (same as using update magnitude and gradient sign)
135
+ """
136
+ def __init__(
137
+ self,
138
+ main_module: OptimizerModule | Iterable[OptimizerModule],
139
+ compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'],
140
+ normalize=False,
141
+ eps=1e-6,
142
+ mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero",
143
+ ):
144
+ super().__init__({})
145
+
146
+ self._set_child_('main',main_module)
147
+ if isinstance(compare_module, str): self.compare_mode = compare_module
148
+ else:
149
+ self._set_child_('compare', compare_module)
150
+ self.compare_mode = 'module'
151
+ self.eps = eps
152
+ self.normalize = normalize
153
+ self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode
154
+
155
+ @torch.no_grad
156
+ def step(self, state):
157
+ params = None
158
+ state_copy = state.copy(clone_ascent=True)
159
+ ascent = self.children['main'].return_ascent(state_copy)
160
+ state.update_attrs_(state_copy)
161
+
162
+ if self.compare_mode == 'module': compare = self.children['compare'].return_ascent(state)
163
+ else:
164
+ params = self.get_params()
165
+ if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params)
166
+ elif self.compare_mode == 'grad': compare: TensorList = state.maybe_compute_grad_(params)
167
+ else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')
168
+
169
+ # mask will be > 0 for parameters where both signs are the same
170
+ mask = (ascent * compare) > 0
171
+
172
+ if self.mode == 'backtrack':
173
+ ascent -= ascent.mul(2).mul_(mask.logical_not_())
174
+
175
+ else:
176
+ # normalize if mode is `zero`
177
+ if self.normalize and self.mode == 'zero':
178
+ fmask = mask.to(ascent[0].dtype)
179
+ fmask /= fmask.total_mean() + self.eps
180
+ else:
181
+ fmask = mask
182
+
183
+ # apply the mask
184
+ ascent *= fmask
185
+
186
+ if self.mode == 'grad':
187
+ params = self.get_params()
188
+ ascent += state.maybe_compute_grad_(params) * mask.logical_not_()
189
+
190
+ elif self.mode == 'compare_module':
191
+ ascent += compare * mask.logical_not_()
192
+
193
+ state.ascent = ascent
194
+ return self._update_params_or_step_with_next(state, params)
195
+
@@ -0,0 +1,173 @@
1
+ from collections.abc import Callable, Sequence
2
+ from typing import Any, overload
3
+
4
+ import torch
5
+ from typing_extensions import Concatenate, ParamSpec
6
+
7
+ from ...core import OptimizerModule
8
+ from .return_overrides import SetGrad
9
+
10
+ K = ParamSpec('K')
11
+
12
+ class Wrap(OptimizerModule):
13
+ """
14
+ Wraps any torch.optim.Optimizer.
15
+
16
+ Sets .grad attribute to the current update and steps with the `optimizer`.
17
+
18
+ Additionally, if this is not the last module, this takes the update of `optimizer`,
19
+ undoes it and passes to the next module instead. That means you can chain multiple
20
+ optimizers together.
21
+
22
+ Args:
23
+ optimizer (torch.optim.Optimizer): optimizer to wrap,
24
+ or a callable (class) that constructs the optimizer.
25
+ kwargs:
26
+ if class is passed, kwargs are passed to the constructor.
27
+ parameters are passed separately and automatically
28
+ which is the point of passing a constructor
29
+ instead of an optimizer directly.
30
+
31
+ This can be constructed in two ways.
32
+ .. code-block:: python
33
+ wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
34
+ # or
35
+ wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
36
+ """
37
+
38
+ @overload
39
+ def __init__(self, optimizer: torch.optim.Optimizer): ...
40
+ @overload
41
+ # def __init__[**K](
42
+ def __init__(
43
+ self,
44
+ optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
45
+ *args: K.args,
46
+ **kwargs: K.kwargs,
47
+ # optimizer: abc.Callable[..., torch.optim.Optimizer],
48
+ # *args,
49
+ # **kwargs,
50
+ ): ...
51
+ def __init__(self, optimizer, *args, **kwargs):
52
+
53
+ super().__init__({})
54
+ self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
55
+ self._args = args
56
+ self._kwargs = kwargs
57
+
58
+ def _initialize_(self, params, set_passed_params):
59
+ """Initializes this optimizer and all children with the given parameters."""
60
+ super()._initialize_(params, set_passed_params=set_passed_params)
61
+ if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
62
+ self.optimizer = self._optimizer_cls
63
+ else:
64
+ self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
65
+
66
+ @torch.no_grad
67
+ def step(self, state):
68
+ # check attrs
69
+ # if self.pass_closure:
70
+ # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
71
+ # if state.ascent is not None:
72
+ # raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
73
+
74
+ params = self.get_params()
75
+
76
+ if self.next_module is None:
77
+ # set grad to ascent and make a step with the optimizer
78
+ g = state.maybe_use_grad_(params)
79
+ params.set_grad_(g)
80
+ state.fx0 = self.optimizer.step()
81
+ return state.get_loss()
82
+
83
+
84
+ params_before_step = params.clone()
85
+
86
+ g = state.maybe_use_grad_(params)
87
+ params.set_grad_(g)
88
+ state.fx0 = self.optimizer.step()
89
+
90
+ # calculate update as difference in params
91
+ state.ascent = params_before_step - params
92
+ params.set_(params_before_step)
93
+ return self.next_module.step(state)
94
+
95
+
96
+ class WrapClosure(OptimizerModule):
97
+ """
98
+ Wraps any torch.optim.Optimizer. This only works with modules with :code:`target = "Closure"` argument.
99
+ The modified closure will be passed to the optimizer.
100
+
101
+ Alternative any module can be turned into a closure module by using :any:`MakeClosure` module,
102
+ in that case this should be placed after MakeClosure.
103
+
104
+ Args:
105
+ optimizer (torch.optim.Optimizer): optimizer to wrap,
106
+ or a callable (class) that constructs the optimizer.
107
+ kwargs:
108
+ if class is passed, kwargs are passed to the constructor.
109
+ parameters are passed separately and automatically
110
+ which is the point of passing a constructor
111
+ instead of an optimizer directly.
112
+
113
+ This can be constructed in two ways.
114
+
115
+ .. code-block:: python
116
+
117
+ wrapper = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr = 0.1))
118
+ # or
119
+ wrapper = OptimizerWrapper(torch.optim.SGD, lr = 0.1)
120
+
121
+ """
122
+
123
+ @overload
124
+ def __init__(self, optimizer: torch.optim.Optimizer,): ...
125
+ @overload
126
+ def __init__(
127
+ self,
128
+ optimizer: Callable[Concatenate[Any, K], torch.optim.Optimizer],
129
+ *args: K.args,
130
+ **kwargs: K.kwargs,
131
+ # optimizer: abc.Callable[..., torch.optim.Optimizer],
132
+ # *args,
133
+ # **kwargs,
134
+ ): ...
135
+ def __init__(self, optimizer, *args, **kwargs):
136
+
137
+ super().__init__({})
138
+ self._optimizer_cls: torch.optim.Optimizer | Callable[..., torch.optim.Optimizer] = optimizer
139
+ self._args = args
140
+ self._kwargs = kwargs
141
+
142
+ def _initialize_(self, params, set_passed_params):
143
+ """Initializes this optimizer and all children with the given parameters."""
144
+ super()._initialize_(params, set_passed_params=set_passed_params)
145
+ if isinstance(self._optimizer_cls, torch.optim.Optimizer) or not callable(self._optimizer_cls):
146
+ self.optimizer = self._optimizer_cls
147
+ else:
148
+ self.optimizer = self._optimizer_cls(params, *self._args, **self._kwargs)
149
+
150
+ @torch.no_grad
151
+ def step(self, state):
152
+ # check attrs
153
+ # if self.pass_closure:
154
+ # if state.closure is None: raise ValueError('ClosureOptimizerWrapper requires closure.')
155
+ # if state.ascent is not None:
156
+ # raise ValueError('pass_closure = True, means ascent must be None (not sure though)')
157
+
158
+ params = self.get_params()
159
+
160
+ if self.next_module is None:
161
+ # set grad to ascent and make a step with the optimizer
162
+ state.fx0 = self.optimizer.step(state.closure) # type:ignore
163
+ return state.get_loss()
164
+
165
+
166
+ params_before_step = params.clone()
167
+ state.fx0 = self.optimizer.step(state.closure) # type:ignore
168
+
169
+ # calculate update as difference in params
170
+ state.ascent = params_before_step - params
171
+ params.set_(params_before_step)
172
+ return self.next_module.step(state)
173
+