torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,218 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable
4
+ from typing import Any, Literal
5
+ import warnings
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Vars
9
+ from ...utils import vec_to_tensors
10
+
11
+
12
+ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
13
+ params: list[torch.Tensor], projected_params: list[torch.Tensor]):
14
+
15
+ def projected_closure(backward=True):
16
+ unprojected_params = projection.unproject(projected_params, vars, current='params')
17
+
18
+ with torch.no_grad():
19
+ for p, new_p in zip(params, unprojected_params):
20
+ p.set_(new_p) # pyright: ignore[reportArgumentType]
21
+
22
+ if backward:
23
+ loss = closure()
24
+ grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
25
+ projected_grads = projection.project(grads, vars, current='grads')
26
+ for p, g in zip(projected_params, projected_grads):
27
+ p.grad = g
28
+
29
+ else:
30
+ loss = closure(False)
31
+
32
+ return loss
33
+
34
+ return projected_closure
35
+
36
+
37
+ class Projection(Module, ABC):
38
+ """
39
+ Base class for projections.
40
+ This is an abstract class, to use it, subclass it and override `project` and `unproject`.
41
+
42
+ Args:
43
+ modules (Chainable): modules that will be applied in the projected domain.
44
+ project_update (bool, optional): whether to project the update. Defaults to True.
45
+ project_params (bool, optional):
46
+ whether to project the params. This is necessary for modules that use closure. Defaults to False.
47
+ project_grad (bool, optional): whether to project the gradients (separately from update). Defaults to False.
48
+ defaults (dict[str, Any] | None, optional): dictionary with defaults. Defaults to None.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ modules: Chainable,
54
+ project_update=True,
55
+ project_params=False,
56
+ project_grad=False,
57
+ defaults: dict[str, Any] | None = None,
58
+ ):
59
+ super().__init__(defaults)
60
+ self.set_child('modules', modules)
61
+ self.global_state['current_step'] = 0
62
+ self._project_update = project_update
63
+ self._project_params = project_params
64
+ self._project_grad = project_grad
65
+ self._projected_params = None
66
+
67
+ @abstractmethod
68
+ def project(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
69
+ """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
70
+
71
+ @abstractmethod
72
+ def unproject(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
73
+ """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
74
+
75
+ @torch.no_grad
76
+ def step(self, vars: Vars):
77
+ projected_vars = vars.clone(clone_update=False)
78
+ update_is_grad = False
79
+
80
+ # closure will calculate projected update and grad if needed
81
+ if self._project_params and vars.closure is not None:
82
+ if self._project_update and vars.update is not None: projected_vars.update = list(self.project(vars.update, vars=vars, current='update'))
83
+ else:
84
+ update_is_grad = True
85
+ if self._project_grad and vars.grad is not None: projected_vars.grad = list(self.project(vars.grad, vars=vars, current='grads'))
86
+
87
+ # project update and grad, unprojected attributes are deleted
88
+ else:
89
+ if self._project_update:
90
+ if vars.update is None:
91
+ # update is None, meaning it will be set to `grad`.
92
+ # we can project grad and use it for update
93
+ grad = vars.get_grad()
94
+ projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
95
+ if self._project_grad: projected_vars.update = [g.clone() for g in projected_vars.grad]
96
+ else: projected_vars.update = projected_vars.grad.copy() # don't clone because grad shouldn't be used
97
+ del vars.update
98
+ update_is_grad = True
99
+
100
+ else:
101
+ update = vars.get_update()
102
+ projected_vars.update = list(self.project(update, vars=vars, current='update'))
103
+ del update, vars.update
104
+
105
+ if self._project_grad and projected_vars.grad is None:
106
+ grad = vars.get_grad()
107
+ projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
108
+
109
+ original_params = None
110
+ if self._project_params:
111
+ original_params = [p.clone() for p in vars.params]
112
+ projected_params = self.project(vars.params, vars=vars, current='params')
113
+
114
+ else:
115
+ # make fake params for correct shapes and state storage
116
+ # they reuse update or grad storage for memory efficiency
117
+ projected_params = projected_vars.update if projected_vars.update is not None else projected_vars.grad
118
+ assert projected_params is not None
119
+
120
+ if self._projected_params is None:
121
+ # 1st step - create objects for projected_params. They have to remain the same python objects
122
+ # to support per-parameter states which are stored by ids.
123
+ self._projected_params = [p.view_as(p).requires_grad_() for p in projected_params]
124
+ else:
125
+ # set storage to new fake params while ID remains the same
126
+ for empty_p, new_p in zip(self._projected_params, projected_params):
127
+ empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
128
+
129
+ # project closure
130
+ if self._project_params:
131
+ closure = vars.closure; params = vars.params
132
+ projected_vars.closure = _make_projected_closure(closure, vars=vars, projection=self, params=params,
133
+ projected_params=self._projected_params)
134
+
135
+ else:
136
+ projected_vars.closure = None
137
+
138
+ # step
139
+ projected_vars.params = self._projected_params
140
+ projected_vars = self.children['modules'].step(projected_vars)
141
+
142
+ # empty fake params storage
143
+ # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
144
+ if not self._project_params:
145
+ for p in self._projected_params:
146
+ p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
147
+
148
+ # unproject
149
+ unprojected_vars = projected_vars.clone(clone_update=False)
150
+ unprojected_vars.closure = vars.closure
151
+ unprojected_vars.params = vars.params
152
+ if unprojected_vars.grad is None: unprojected_vars.grad = vars.grad
153
+
154
+ if self._project_update:
155
+ assert projected_vars.update is not None
156
+ unprojected_vars.update = list(self.unproject(projected_vars.update, vars=vars, current='grads' if update_is_grad else 'update'))
157
+ del projected_vars.update
158
+
159
+ # unprojecting grad doesn't make sense?
160
+ # if self._project_grad:
161
+ # assert projected_vars.grad is not None
162
+ # unprojected_vars.grad = list(self.unproject(projected_vars.grad, vars=vars))
163
+
164
+ del projected_vars
165
+
166
+ if original_params is not None:
167
+ for p, o in zip(unprojected_vars.params, original_params):
168
+ p.set_(o) # pyright: ignore[reportArgumentType]
169
+
170
+ return unprojected_vars
171
+
172
+
173
+
174
+ class FlipConcatProjection(Projection):
175
+ """
176
+ for testing
177
+ """
178
+
179
+ def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
180
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
181
+
182
+ @torch.no_grad
183
+ def project(self, tensors, vars, current):
184
+ return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
185
+
186
+ @torch.no_grad
187
+ def unproject(self, tensors, vars, current):
188
+ return vec_to_tensors(vec=tensors[0].flip(0), reference=vars.params)
189
+
190
+
191
+ class NoopProjection(Projection):
192
+ """an example projection which doesn't do anything for testing"""
193
+
194
+ def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
195
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
196
+
197
+ @torch.no_grad
198
+ def project(self, tensors, vars, current):
199
+ return tensors
200
+
201
+ @torch.no_grad
202
+ def unproject(self, tensors, vars, current):
203
+ return tensors
204
+
205
+ class MultipyProjection(Projection):
206
+ """an example projection which multiplies everything by 2"""
207
+
208
+ def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
209
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
210
+
211
+ @torch.no_grad
212
+ def project(self, tensors, vars, current):
213
+ return torch._foreach_mul(tensors, 2)
214
+
215
+ @torch.no_grad
216
+ def unproject(self, tensors, vars, current):
217
+ return torch._foreach_div(tensors, 2)
218
+
@@ -0,0 +1,151 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from ...core import Chainable
7
+ from ...utils import vec_to_tensors, TensorList
8
+ from ..optimizers.shampoo import _merge_small_dims
9
+ from .projection import Projection
10
+
11
+
12
+ class VectorProjection(Projection):
13
+ """
14
+ flattens and concatenates all parameters into a vector
15
+ """
16
+ def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
17
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
18
+
19
+ @torch.no_grad
20
+ def project(self, tensors, vars, current):
21
+ return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
22
+
23
+ @torch.no_grad
24
+ def unproject(self, tensors, vars, current):
25
+ return vec_to_tensors(vec=tensors[0], reference=vars.params)
26
+
27
+
28
+
29
+ class TensorizeProjection(Projection):
30
+ """flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
31
+ def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
32
+ defaults = dict(max_side=max_side)
33
+ super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
34
+
35
+ @torch.no_grad
36
+ def project(self, tensors, vars, current):
37
+ params = vars.params
38
+ max_side = self.settings[params[0]]['max_side']
39
+ num_elems = sum(t.numel() for t in tensors)
40
+
41
+ if num_elems < max_side:
42
+ self.global_state['remainder'] = 0
43
+ # return 1d
44
+ return [torch.cat([t.view(-1) for t in tensors])]
45
+
46
+
47
+ # determine appropriate shape to reshape into
48
+ ndims = math.ceil(math.log(num_elems, max_side)) # determine number of dims
49
+ dim_size = math.ceil(num_elems ** (1/ndims)) # average size of a dim with ndims
50
+ dims = [dim_size for _ in range(ndims)]
51
+ required_elems = math.prod(dims)
52
+
53
+ # add few extra zeros to vec to match a reshapable size
54
+ remainder = required_elems-num_elems
55
+ if remainder > 0: tensors = tensors + [torch.zeros(remainder, dtype=tensors[0].dtype, device=tensors[0].device)]
56
+ self.global_state['remainder'] = remainder
57
+
58
+ # flatten and reshape
59
+ vec = torch.cat([t.view(-1) for t in tensors])
60
+ return [vec.view(dims)]
61
+
62
+ @torch.no_grad
63
+ def unproject(self, tensors, vars, current):
64
+ remainder = self.global_state['remainder']
65
+ # warnings.warn(f'{tensors[0].shape = }')
66
+ vec = tensors[0].view(-1)
67
+ if remainder > 0: vec = vec[:-remainder]
68
+ return vec_to_tensors(vec, vars.params)
69
+
70
+ class BlockPartition(Projection):
71
+ """splits parameters into blocks (for now flatttens them and chunks)"""
72
+ def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
73
+ defaults = dict(max_size=max_size, batched=batched)
74
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
75
+
76
+ @torch.no_grad
77
+ def project(self, tensors, vars, current):
78
+ partitioned = []
79
+ for p,t in zip(vars.params, tensors):
80
+ settings = self.settings[p]
81
+ max_size = settings['max_size']
82
+ n = t.numel()
83
+ if n <= max_size:
84
+ partitioned.append(t)
85
+ continue
86
+
87
+ t_flat = t.view(-1)
88
+
89
+ batched = settings['batched']
90
+ num_chunks = math.ceil(n / max_size)
91
+
92
+ if batched:
93
+ chunks_size = num_chunks * max_size
94
+ if num_chunks * max_size > n:
95
+ t_flat = torch.cat([t_flat, torch.zeros(n-chunks_size, dtype=t_flat.dtype, device=t_flat.device)])
96
+ partitioned.append(t_flat.view(num_chunks, -1))
97
+
98
+ else:
99
+ partitioned.extend(t_flat.chunk(num_chunks))
100
+
101
+ return partitioned
102
+
103
+ @torch.no_grad
104
+ def unproject(self, tensors, vars, current):
105
+ ti = iter(tensors)
106
+ unprojected = []
107
+ for p in vars.params:
108
+ settings = self.settings[p]
109
+ n = p.numel()
110
+
111
+ if settings['batched']:
112
+ unprojected.append(next(ti).view(-1)[:n].view_as(p))
113
+
114
+ else:
115
+ chunks = []
116
+ t_n = 0
117
+ while t_n < n:
118
+ t = next(ti)
119
+ chunks.append(t)
120
+ t_n += t.numel()
121
+
122
+ assert t_n == n
123
+ unprojected.append(torch.cat(chunks).view_as(p))
124
+
125
+ return unprojected
126
+
127
+
128
+ class TensorNormsProjection(Projection):
129
+ def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
130
+ super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
131
+
132
+ @torch.no_grad
133
+ def project(self, tensors, vars, current):
134
+ orig = self.get_state(f'{current}_orig', params=vars.params)
135
+ torch._foreach_copy_(orig, tensors)
136
+
137
+ norms = torch._foreach_norm(tensors)
138
+ self.get_state(f'{current}_orig_norms', params=vars.params, init=norms, cls=TensorList).set_(norms)
139
+
140
+ return [torch.stack(norms)]
141
+
142
+ @torch.no_grad
143
+ def unproject(self, tensors, vars, current):
144
+ orig = self.get_state(f'{current}_orig', params=vars.params)
145
+ orig_norms = torch.stack(self.get_state(f'{current}_orig_norms', params=vars.params))
146
+ target_norms = tensors[0]
147
+
148
+ orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
149
+
150
+ torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
151
+ return orig
@@ -1,4 +1,7 @@
1
- r"""
2
- This includes modules that compute a step direction via quasi-newton methods.
3
- """
4
- # from .hv_inv_fdm import HvInvFDM
1
+ from .cg import PolakRibiere, FletcherReeves, HestenesStiefel, DaiYuan, LiuStorey, ConjugateDescent, HagerZhang, HybridHS_DY
2
+ from .lbfgs import LBFGS
3
+ from .olbfgs import OnlineLBFGS
4
+ # from .experimental import ModularLBFGS
5
+
6
+ from .quasi_newton import BFGS, SR1, DFP, BroydenGood, BroydenBad, Greenstadt1, Greenstadt2, ColumnUpdatingMethod, ThomasOptimalMethod, PSB, Pearson2, SSVM
7
+ from .lsr1 import LSR1
@@ -0,0 +1,218 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Transform, apply
6
+ from ...utils import TensorList, as_tensorlist
7
+
8
+
9
+ class ConguateGradientBase(Transform, ABC):
10
+ """all CGs are the same except beta calculation"""
11
+ def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
12
+ if defaults is None: defaults = {}
13
+ defaults['reset_interval'] = reset_interval
14
+ defaults['clip_beta'] = clip_beta
15
+ super().__init__(defaults, uses_grad=False)
16
+
17
+ if inner is not None:
18
+ self.set_child('inner', inner)
19
+
20
+ def initialize(self, p: TensorList, g: TensorList):
21
+ """runs on first step when prev_grads and prev_dir are not available"""
22
+
23
+ @abstractmethod
24
+ def get_beta(self, p: TensorList, g: TensorList, prev_g: TensorList, prev_d: TensorList) -> float | torch.Tensor:
25
+ """returns beta"""
26
+
27
+ @torch.no_grad
28
+ def transform(self, tensors, params, grads, vars):
29
+ tensors = as_tensorlist(tensors)
30
+ params = as_tensorlist(params)
31
+
32
+ step = self.global_state.get('step', 0)
33
+ prev_dir, prev_grads = self.get_state('prev_dir', 'prev_grad', params=params, cls=TensorList)
34
+
35
+ # initialize on first step
36
+ if step == 0:
37
+ self.initialize(params, tensors)
38
+ prev_dir.copy_(tensors)
39
+ prev_grads.copy_(tensors)
40
+ self.global_state['step'] = step + 1
41
+ return tensors
42
+
43
+ # get beta
44
+ beta = self.get_beta(params, tensors, prev_grads, prev_dir)
45
+ if self.settings[params[0]]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
+ prev_grads.copy_(tensors)
47
+
48
+ # inner step
49
+ if 'inner' in self.children:
50
+ tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
51
+
52
+ # calculate new direction with beta
53
+ dir = tensors.add_(prev_dir.mul_(beta))
54
+ prev_dir.copy_(dir)
55
+
56
+ # resetting
57
+ self.global_state['step'] = step + 1
58
+ reset_interval = self.settings[params[0]]['reset_interval']
59
+ if reset_interval is not None and (step+1) % reset_interval == 0:
60
+ self.reset()
61
+
62
+ return dir
63
+
64
+ # ------------------------------- Polak-Ribière ------------------------------ #
65
+ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
66
+ denom = prev_g.dot(prev_g)
67
+ if denom == 0: return 0
68
+ return g.dot(g - prev_g) / denom
69
+
70
+ class PolakRibiere(ConguateGradientBase):
71
+ """Polak-Ribière-Polyak nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this."""
72
+ def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
73
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
74
+
75
+ def get_beta(self, p, g, prev_g, prev_d):
76
+ return polak_ribiere_beta(g, prev_g)
77
+
78
+ # ------------------------------ Fletcher–Reeves ----------------------------- #
79
+ def fletcher_reeves_beta(gg, prev_gg):
80
+ if prev_gg == 0: return 0
81
+ return gg / prev_gg
82
+
83
+ class FletcherReeves(ConguateGradientBase):
84
+ """Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
85
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
86
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
87
+
88
+ def initialize(self, p, g):
89
+ self.global_state['prev_gg'] = g.dot(g)
90
+
91
+ def get_beta(self, p, g, prev_g, prev_d):
92
+ gg = g.dot(g)
93
+ beta = fletcher_reeves_beta(gg, self.global_state['prev_gg'])
94
+ self.global_state['prev_gg'] = gg
95
+ return beta
96
+
97
+ # ----------------------------- Hestenes–Stiefel ----------------------------- #
98
+ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
99
+ grad_diff = g - prev_g
100
+ denom = prev_d.dot(grad_diff)
101
+ if denom == 0: return 0
102
+ return (g.dot(grad_diff) / denom).neg()
103
+
104
+
105
+ class HestenesStiefel(ConguateGradientBase):
106
+ """Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
107
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
108
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
109
+
110
+ def get_beta(self, p, g, prev_g, prev_d):
111
+ return hestenes_stiefel_beta(g, prev_d, prev_g)
112
+
113
+
114
+ # --------------------------------- Dai–Yuan --------------------------------- #
115
+ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
116
+ denom = prev_d.dot(g - prev_g)
117
+ if denom == 0: return 0
118
+ return (g.dot(g) / denom).neg()
119
+
120
+ class DaiYuan(ConguateGradientBase):
121
+ """Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
122
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
123
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
124
+
125
+ def get_beta(self, p, g, prev_g, prev_d):
126
+ return dai_yuan_beta(g, prev_d, prev_g)
127
+
128
+
129
+ # -------------------------------- Liu-Storey -------------------------------- #
130
+ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
131
+ denom = prev_g.dot(prev_d)
132
+ if denom == 0: return 0
133
+ return g.dot(g - prev_g) / denom
134
+
135
+ class LiuStorey(ConguateGradientBase):
136
+ """Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
137
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
138
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
139
+
140
+ def get_beta(self, p, g, prev_g, prev_d):
141
+ return liu_storey_beta(g, prev_d, prev_g)
142
+
143
+ # ----------------------------- Conjugate Descent ---------------------------- #
144
+ class ConjugateDescent(Transform):
145
+ """Conjugate Descent (CD). This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
146
+ def __init__(self, inner: Chainable | None = None):
147
+ super().__init__(defaults={}, uses_grad=False)
148
+
149
+ if inner is not None:
150
+ self.set_child('inner', inner)
151
+
152
+
153
+ @torch.no_grad
154
+ def transform(self, tensors, params, grads, vars):
155
+ g = as_tensorlist(tensors)
156
+
157
+ prev_d = self.get_state('prev_dir', params=params, cls=TensorList, init = torch.zeros_like)
158
+ if 'denom' not in self.global_state:
159
+ self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
+
161
+ prev_gd = self.global_state.get('prev_gd', 0)
162
+ if prev_gd == 0: beta = 0
163
+ else: beta = g.dot(g) / prev_gd
164
+
165
+ # inner step
166
+ if 'inner' in self.children:
167
+ g = as_tensorlist(apply(self.children['inner'], g, params, grads, vars))
168
+
169
+ dir = g.add_(prev_d.mul_(beta))
170
+ prev_d.copy_(dir)
171
+ self.global_state['prev_gd'] = g.dot(dir)
172
+ return dir
173
+
174
+
175
+ # -------------------------------- Hager-Zhang ------------------------------- #
176
+ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
177
+ g_diff = g - prev_g
178
+ denom = prev_d.dot(g_diff)
179
+ if denom == 0: return 0
180
+
181
+ term1 = 1/denom
182
+ # term2
183
+ term2 = (g_diff - (2 * prev_d * (g_diff.pow(2).global_sum()/denom))).dot(g)
184
+ return (term1 * term2).neg()
185
+
186
+
187
+ class HagerZhang(ConguateGradientBase):
188
+ """Hager-Zhang nonlinear conjugate gradient method,
189
+ This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
190
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
191
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
192
+
193
+ def get_beta(self, p, g, prev_g, prev_d):
194
+ return hager_zhang_beta(g, prev_d, prev_g)
195
+
196
+
197
+ # ----------------------------------- HS-DY ---------------------------------- #
198
+ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
199
+ grad_diff = g - prev_g
200
+ denom = prev_d.dot(grad_diff)
201
+ if denom == 0: return 0
202
+
203
+ # Dai-Yuan
204
+ dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
205
+
206
+ # Hestenes–Stiefel
207
+ hs_beta = (g.dot(grad_diff) / denom).neg().clamp(min=0)
208
+
209
+ return max(0, min(dy_beta, hs_beta)) # type:ignore
210
+
211
+ class HybridHS_DY(ConguateGradientBase):
212
+ """HS-DY hybrid conjugate gradient method.
213
+ This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
214
+ def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
215
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
216
+
217
+ def get_beta(self, p, g, prev_g, prev_d):
218
+ return hs_dy_beta(g, prev_d, prev_g)
@@ -0,0 +1 @@
1
+ from .modular_lbfgs import ModularLBFGS