torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,284 @@
1
+ from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
2
+ from typing import Any, Literal, TypeVar, overload
3
+
4
+ import torch
5
+
6
+ from .tensorlist import TensorList
7
+ from .numberlist import NumberList
8
+ from .torch_tools import tofloat, totensor
9
+
10
+ ListLike = TypeVar('ListLike', bound=MutableSequence)
11
+
12
+ ParamFilter = Literal["has_grad", "requires_grad", "all"] | Callable[[torch.Tensor], bool]
13
+ def _param_filter(param: torch.Tensor, mode: ParamFilter):
14
+ if callable(mode): return mode(param)
15
+ if mode == 'has_grad': return param.grad is not None
16
+ if mode == 'requires_grad': return param.requires_grad
17
+ if mode == 'all': return True
18
+ raise ValueError(f"Unknown mode {mode}")
19
+
20
+ def get_params(
21
+ param_groups: Iterable[Mapping[str, Any]],
22
+ mode: ParamFilter = 'requires_grad',
23
+ cls: type[ListLike] = TensorList,
24
+ ) -> ListLike:
25
+ return cls(p for g in param_groups for p in g['params'] if _param_filter(p, mode)) # type:ignore[reportCallIssue]
26
+
27
+
28
+ @overload
29
+ def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: str, *,
30
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> ListLike: ...
31
+ @overload
32
+ def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: list[str] | tuple[str,...], *,
33
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> list[ListLike]: ...
34
+ @overload
35
+ def get_group_vals(param_groups: Iterable[Mapping[str, Any]], key: str, key2: str, *keys: str,
36
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> list[ListLike]: ...
37
+
38
+ def get_group_vals(param_groups: Iterable[Mapping[str, Any]],
39
+ key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
40
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = list) -> ListLike | list[ListLike]:
41
+
42
+ # single key, return single cls
43
+ if isinstance(key, str) and key2 is None:
44
+ values = cls()
45
+ for group in param_groups:
46
+ num_params = len([p for p in group['params'] if _param_filter(p, mode)])
47
+ if num_params > 0:
48
+ group_value = group[key]
49
+ values.extend(group_value for _ in range(num_params))
50
+ return values
51
+
52
+ # multiple keys
53
+ k1 = (key,) if isinstance(key, str) else tuple(key)
54
+ k2 = () if key2 is None else (key2,)
55
+ keys = k1 + k2 + keys
56
+
57
+ values = [cls() for _ in keys]
58
+ for group in param_groups:
59
+ num_params = len([p for p in group['params'] if _param_filter(p, mode)])
60
+ if num_params > 0:
61
+ for i,key in enumerate(keys):
62
+ group_value = group[key]
63
+ values[i].extend(group_value for _ in range(num_params))
64
+ return values
65
+
66
+ _InitLiterals = Literal['param', 'grad']
67
+ Init = _InitLiterals | Any | list[_InitLiterals | Any] | tuple[_InitLiterals | Any]
68
+
69
+ def _make_initial_state_value(param: torch.Tensor, init: Init, i: int | None):
70
+ if callable(init): return init(param)
71
+ if isinstance(init, torch.Tensor): return init.detach().clone()
72
+
73
+ if isinstance(init, str):
74
+ if init in ('param','params'): return param.detach().clone()
75
+ if init in ('grad', 'grads'):
76
+ if param.grad is None: raise RuntimeError('init is set to "grad, but param.grad is None"')
77
+ return param.grad.detach().clone()
78
+
79
+ if isinstance(init, (list,tuple)):
80
+ if i is None: raise RuntimeError(f'init is per-parameter ({type(init)}) but parameter index i is None')
81
+ return _make_initial_state_value(param, init[i], None)
82
+
83
+ return init
84
+
85
+ @overload
86
+ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
87
+ key: str, *,
88
+ must_exist: bool = False, init: Init = torch.zeros_like,
89
+ cls: type[ListLike] = list) -> ListLike: ...
90
+ @overload
91
+ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
92
+ key: list[str] | tuple[str,...], *,
93
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
94
+ cls: type[ListLike] = list) -> list[ListLike]: ...
95
+ @overload
96
+ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
97
+ key: str, key2: str, *keys: str,
98
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
99
+ cls: type[ListLike] = list) -> list[ListLike]: ...
100
+
101
+ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], params: Sequence[torch.Tensor],
102
+ key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
103
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
104
+ cls: type[ListLike] = list) -> ListLike | list[ListLike]:
105
+
106
+ # single key, return single cls
107
+ if isinstance(key, str) and key2 is None:
108
+ values = cls()
109
+ for i, param in enumerate(params):
110
+ s = state[param]
111
+ if key not in s:
112
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
113
+ s[key] = _make_initial_state_value(param, init, i)
114
+ values.append(s[key])
115
+ return values
116
+
117
+ # multiple keys
118
+ k1 = (key,) if isinstance(key, str) else tuple(key)
119
+ k2 = () if key2 is None else (key2,)
120
+ keys = k1 + k2 + keys
121
+
122
+ values = [cls() for _ in keys]
123
+ for i, param in enumerate(params):
124
+ s = state[param]
125
+ for k_i, key in enumerate(keys):
126
+ if key not in s:
127
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
128
+ k_init = init[k_i] if isinstance(init, (list,tuple)) else init
129
+ s[key] = _make_initial_state_value(param, k_init, i)
130
+ values[k_i].append(s[key])
131
+
132
+ return values
133
+
134
+
135
+
136
+ def loss_at_params(closure, params: Iterable[torch.Tensor],
137
+ new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
138
+ params = TensorList(params)
139
+
140
+ old_params = params.clone() if restore else None
141
+
142
+ if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
143
+ # when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
144
+ # when restoring params will be set back to old_params so its fine
145
+ if restore: params.set_(new_params)
146
+ else: params.copy_(new_params) # type:ignore
147
+
148
+ else:
149
+ new_params = totensor(new_params)
150
+ params.from_vec_(new_params)
151
+
152
+ if backward: loss = closure()
153
+ else: loss = closure(False)
154
+
155
+ if restore:
156
+ assert old_params is not None
157
+ params.set_(old_params)
158
+
159
+ return tofloat(loss)
160
+
161
+ def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
162
+ params = TensorList(params)
163
+ old_params = params.clone() if restore else None
164
+ loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
165
+ grad = params.ensure_grad_().grad
166
+
167
+ if restore:
168
+ assert old_params is not None
169
+ params.set_(old_params)
170
+
171
+ return loss, grad
172
+
173
+ def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
174
+ return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
175
+
176
+ def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
177
+ params = TensorList(params)
178
+ old_params = params.clone() if restore else None
179
+ loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
180
+ grad = params.ensure_grad_().grad.to_vec()
181
+
182
+ if restore:
183
+ assert old_params is not None
184
+ params.set_(old_params)
185
+
186
+ return loss, grad
187
+
188
+ def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
189
+ return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
190
+
191
+
192
+
193
+ class Optimizer(torch.optim.Optimizer):
194
+ """subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
195
+
196
+ Args:
197
+ params (iterable): an iterable of :class:`torch.Tensor` s or
198
+ :class:`dict` s. Specifies what Tensors should be optimized.
199
+ defaults (dict | None): a dict containing default values of optimization
200
+ options (used when a parameter group doesn't specify them).
201
+ """
202
+ def __init__(self, params, defaults: dict[str, Any] | None = None, **_defaults):
203
+ if defaults is None: defaults = {}
204
+ defaults.update(_defaults)
205
+
206
+ super().__init__(params, defaults)
207
+ self.global_state = self.state[self.param_groups[0]['params'][0]]
208
+ """state of 1st parameter, can be used as global state which is how L-BFGS uses it in pytorch, and there is some kind of good reason to do it like that"""
209
+
210
+ def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
211
+ return get_params(self.param_groups, mode, cls)
212
+
213
+ @overload
214
+ def group_vals(self, key: str, *,
215
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike: ...
216
+ @overload
217
+ def group_vals(self, key: list[str] | tuple[str,...], *,
218
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
219
+ @overload
220
+ def group_vals(self, key: str, key2: str, *keys: str,
221
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
222
+
223
+ def group_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
224
+ mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike | list[ListLike]:
225
+ return get_group_vals(self.param_groups, key, key2, *keys, mode = mode, cls = cls) # pyright:ignore[reportArgumentType]
226
+
227
+
228
+ @overload
229
+ def state_vals(self, key: str, *,
230
+ init: Init = torch.zeros_like,
231
+ mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
232
+ cls: type[ListLike] = TensorList) -> ListLike: ...
233
+ @overload
234
+ def state_vals(self, key: list[str] | tuple[str,...], *,
235
+ init: Init | Sequence[Init] = torch.zeros_like,
236
+ mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
237
+ cls: type[ListLike] = TensorList) -> list[ListLike]: ...
238
+ @overload
239
+ def state_vals(self, key: str, key2: str, *keys: str,
240
+ init: Init | Sequence[Init] = torch.zeros_like,
241
+ mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
242
+ cls: type[ListLike] = TensorList) -> list[ListLike]: ...
243
+
244
+ def state_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
245
+ init: Init | Sequence[Init] = torch.zeros_like,
246
+ mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
247
+ cls: type[ListLike] = TensorList) -> ListLike | list[ListLike]:
248
+
249
+ if isinstance(mode, (list,tuple)): params = mode
250
+ else: params = self.get_params(mode)
251
+
252
+ return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
253
+
254
+ def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
255
+ return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
256
+
257
+ def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
258
+ return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
259
+
260
+ def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
261
+ return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
262
+
263
+ def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
264
+ return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
265
+
266
+ def grad_vec_at_params(self, closure, params: Any, restore=False):
267
+ return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
268
+
269
+
270
+ def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
271
+ if set_to_none:
272
+ for p in params:
273
+ p.grad = None
274
+
275
+ else:
276
+ grads = [p.grad for p in params if p.grad is not None]
277
+ for grad in grads:
278
+ # taken from torch.optim.Optimizer.zero_grad
279
+ if grad.grad_fn is not None:
280
+ grad.detach_()
281
+ else:
282
+ grad.requires_grad_(False)
283
+
284
+ torch._foreach_zero_(grads)
@@ -0,0 +1,40 @@
1
+ import optuna
2
+
3
+ from ..core import Chain, Module
4
+
5
+ from ..modules import (
6
+ EMA,
7
+ NAG,
8
+ Cautious,
9
+ ClipNorm,
10
+ ClipNormGrowth,
11
+ ClipValue,
12
+ ClipValueGrowth,
13
+ Debias,
14
+ Normalize,
15
+ )
16
+
17
+
18
+ def get_momentum(trial: optuna.Trial, prefix: str, conditional: bool=True) -> list[Module]:
19
+ cond = trial.suggest_categorical(f'{prefix}_use_momentum', [True,False]) if conditional else True
20
+ if cond:
21
+ beta = trial.suggest_float(f'{prefix}_beta', -1, 2)
22
+ dampening = trial.suggest_float(f'{prefix}_dampening', -1, 2)
23
+ lerp = trial.suggest_categorical(f'{prefix}_use_lerp', [True, False])
24
+ nag = trial.suggest_categorical(f'{prefix}_use_NAG', [True, False])
25
+ debiased = trial.suggest_categorical(f'{prefix}_debiased', [True, False])
26
+ if nag:
27
+ m = NAG(beta, dampening, lerp)
28
+ if debiased: m = Chain(m, Debias(beta1=beta))
29
+ else:
30
+ m = EMA(beta, dampening, debiased=debiased, lerp=lerp)
31
+ return [m]
32
+ return []
33
+
34
+ def get_clip_value(trial: optuna.Trial, prefix: str, conditional: bool=True) -> list[Module]:
35
+ cond = trial.suggest_categorical(f'{prefix}_use_clip_value', [True,False]) if conditional else True
36
+ if cond:
37
+ return [ClipValue(value = trial.suggest_float(f'{prefix}_clip_value', 0, 10))]
38
+ return []
39
+
40
+
@@ -0,0 +1,149 @@
1
+ from typing import Any
2
+ from collections.abc import Sequence, Iterable, Mapping
3
+ import warnings
4
+ import torch, numpy as np
5
+
6
+
7
+
8
+ Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Mapping[str, Any]]
9
+
10
+ def _validate_params_are_unique_(params: Sequence[torch.Tensor]):
11
+ # this is from pytorch add_param_group
12
+ if len(params) != len(set(params)):
13
+ warnings.warn(
14
+ "optimizer contains a parameter group with duplicate parameters; "
15
+ "in future, this will cause an error; "
16
+ "see github.com/pytorch/pytorch/issues/40967 for more information",
17
+ stacklevel=3,
18
+ )
19
+
20
+ def _validate_param_is_differentiable_(tensor: torch.Tensor | Any):
21
+ """Checks that param is torch.Tensor and isn't a leaf parameter unless differentiable is True, otherwise this raises, this is taken from torch.optim.Optimizer."""
22
+ if not (tensor.is_leaf or tensor.retains_grad):
23
+ raise ValueError("can't optimize a non-leaf Tensor")
24
+
25
+ def _validate_at_least_one_param_requires_grad_(params: Iterable[torch.Tensor]):
26
+ params = list(params)
27
+ if not any(p.requires_grad for p in params):
28
+ warnings.warn(
29
+ "Parameter group contains no parameters which require gradients. "
30
+ "Note for gradient-free optimizers, they still only optimize parameters with requires_grad=True, "
31
+ "so if needed, use `with torch.no_grad():` context instead.", stacklevel=3)
32
+
33
+
34
+
35
+ def _copy_param_groups(param_groups: list[dict[str, Any]]) -> list[dict[str, Any]]:
36
+ """copies param_groups, doesn't copy the tensors."""
37
+ new_param_group = []
38
+
39
+ for g in param_groups:
40
+ assert isinstance(g, dict)
41
+ g_copy = g.copy()
42
+
43
+ for k in ('params', 'updates', 'grads'):
44
+ if k in g_copy:
45
+ assert isinstance(g_copy[k], list)
46
+ g_copy[k] = g_copy[k].copy()
47
+
48
+ new_param_group.append(g_copy)
49
+
50
+ return new_param_group
51
+
52
+ def _process_param_group_(param_group: dict[str, Any]) -> dict[str, Any]:
53
+ """makes sure `param_group["params"]` is a list of tensors, and sets `param_group["param_names"]` if params are named."""
54
+ if 'params' not in param_group: raise KeyError("Param group doesn't have a `params` key.")
55
+
56
+ if isinstance(param_group['params'], torch.Tensor): param_group['params'] = [param_group['params']]
57
+
58
+ tensors: list[torch.Tensor] = []
59
+ names: list[str] | None = []
60
+
61
+ for p in param_group['params']:
62
+ if isinstance(p, torch.Tensor):
63
+ tensors.append(p)
64
+
65
+ elif isinstance(p, tuple):
66
+ if len(p) != 2:
67
+ raise ValueError(f'named_parameters must be a tuple of (name, tensor), got length {len(p)} tuple')
68
+ if (not isinstance(p[0], str)) or (not isinstance(p[1], torch.Tensor)):
69
+ raise ValueError(f'named_parameters must be a tuple of (name, tensor), got {[type(a) for a in p]}')
70
+ names.append(p[0])
71
+ tensors.append(p[1])
72
+
73
+ else:
74
+ raise ValueError(f'Parameters must be tensors or tuples (name, tensor), got parameter of type {type(p)}')
75
+
76
+ if len(tensors) == 0: warnings.warn('got an empty parameter group')
77
+
78
+ param_group['params'] = tensors
79
+
80
+ if len(names) != 0:
81
+ if len(names) != len(tensors):
82
+ raise ValueError(f"Number of parameters {len(tensors)} doesn't match number of names {len(names)}")
83
+ param_group['param_names'] = names
84
+
85
+ return param_group
86
+
87
+ def _make_param_groups(params: Params, differentiable: bool) -> list[dict[str, Any]]:
88
+ params = list(params)
89
+
90
+ param_groups: list[dict[str, Any]] = [dict(p) for p in params if isinstance(p, Mapping)]
91
+ tensors = [p for p in params if isinstance(p, torch.Tensor)]
92
+ named_tensors = [p for p in params if isinstance(p, tuple)]
93
+
94
+ if len(tensors) != 0: param_groups.append({"params": tensors})
95
+ if len(named_tensors) != 0: param_groups.append({"params": named_tensors})
96
+
97
+ # process param_groups
98
+ for g in param_groups:
99
+ _process_param_group_(g)
100
+
101
+ # validate
102
+ all_params = [p for g in param_groups for p in g['params']]
103
+ _validate_params_are_unique_(all_params)
104
+ _validate_at_least_one_param_requires_grad_(all_params)
105
+ if not differentiable:
106
+ for p in all_params: _validate_param_is_differentiable_(p)
107
+
108
+ return param_groups
109
+
110
+ def _add_defaults_to_param_groups_(param_groups: list[dict[str, Any]], defaults: dict[str, Any]) -> list[dict[str, Any]]:
111
+ for group in param_groups:
112
+ for k, v in defaults.items():
113
+ if k not in group:
114
+ group[k] = v
115
+ return param_groups
116
+
117
+ def _add_updates_grads_to_param_groups_(param_groups: list[dict[str, Any]]) -> list[dict[str, Any]]:
118
+ for group in param_groups:
119
+ if 'updates' in group: raise ValueError('updates in group')
120
+ group['updates'] = [None for _ in group['params']]
121
+
122
+ if 'grads' in group: raise ValueError('grads in group')
123
+ group['grads'] = [None for _ in group['grads']]
124
+
125
+ return param_groups
126
+
127
+
128
+ def _set_update_and_grad_(
129
+ param_groups: list[dict[str, Any]],
130
+ updates: list[torch.Tensor] | None,
131
+ grads: list[torch.Tensor] | None,
132
+ ) -> list[dict[str, Any]]:
133
+ if updates is None and grads is None: return param_groups
134
+
135
+ updates_iter = iter(updates) if updates is not None else None
136
+ grads_iter = iter(grads) if grads is not None else None
137
+
138
+ for group in param_groups:
139
+ group_params = group['params']
140
+ group_updates = group['updates']
141
+ group_grads = group['grads']
142
+
143
+ for i, param in enumerate(group_params):
144
+ if not param.requires_grad: continue
145
+ if updates_iter is not None: group_updates[i] = next(updates_iter)
146
+ if grads_iter is not None: group_grads[i] = next(grads_iter)
147
+
148
+ return param_groups
149
+
@@ -1,25 +1,40 @@
1
- import functools
2
- import operator
3
- from typing import Any, TypeVar
4
- from collections.abc import Iterable
5
-
6
- import torch
7
-
8
- def _flatten_no_check(iterable: Iterable) -> list[Any]:
9
- """Flatten an iterable of iterables, returns a flattened list. Note that if `iterable` is not Iterable, this will return `[iterable]`."""
10
- if isinstance(iterable, Iterable):
11
- return [a for i in iterable for a in _flatten_no_check(i)]
12
- return [iterable]
13
-
14
- def flatten(iterable: Iterable) -> list[Any]:
15
- """Flatten an iterable of iterables, returns a flattened list. If `iterable` is not iterable, raises a TypeError."""
16
- if isinstance(iterable, Iterable): return [a for i in iterable for a in _flatten_no_check(i)]
17
- raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
18
-
19
- X = TypeVar("X")
20
- # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
21
- def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
22
- """Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
23
- return functools.reduce(operator.iconcat, x, [])
24
-
25
- _ScalarLoss = int | float | bool | torch.Tensor
1
+ import functools
2
+ import operator
3
+ from typing import Any, TypeVar
4
+ from collections.abc import Iterable, Callable
5
+ from collections import UserDict
6
+
7
+
8
+ def _flatten_no_check(iterable: Iterable) -> list[Any]:
9
+ """Flatten an iterable of iterables, returns a flattened list. Note that if `iterable` is not Iterable, this will return `[iterable]`."""
10
+ if isinstance(iterable, Iterable) and not isinstance(iterable, str):
11
+ return [a for i in iterable for a in _flatten_no_check(i)]
12
+ return [iterable]
13
+
14
+ def flatten(iterable: Iterable) -> list[Any]:
15
+ """Flatten an iterable of iterables, returns a flattened list. If `iterable` is not iterable, raises a TypeError."""
16
+ if isinstance(iterable, Iterable): return [a for i in iterable for a in _flatten_no_check(i)]
17
+ raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
18
+
19
+ X = TypeVar("X")
20
+ # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
21
+ def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
22
+ """Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
23
+ return functools.reduce(operator.iconcat, x, [])
24
+
25
+ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
26
+ """generic equals function that supports scalars and lists of numbers"""
27
+ if isinstance(x, (int,float)):
28
+ if isinstance(y, (int,float)): return x==y
29
+ return all(i==x for i in y)
30
+ if isinstance(y, (int,float)):
31
+ return all(i==y for i in x)
32
+ return all(i==j for i,j in zip(x,y))
33
+
34
+ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
35
+ """If `other` is list/tuple, applies `fn` to self zipped with `other`.
36
+ Otherwise applies `fn` to this sequence and `other`.
37
+ Returns a new sequence with return values of the callable."""
38
+ if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
39
+ return self.__class__(fn(i, other, *args, **kwargs) for i in self)
40
+