torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,29 +1,35 @@
1
1
  import math
2
- from functools import partial
2
+ import warnings
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Iterable
4
+ from collections import defaultdict, ChainMap
5
+ from collections.abc import Iterable, Mapping, Sequence
6
+ from functools import partial
5
7
  from typing import Any, Literal
6
- import warnings
8
+
7
9
  import torch
8
10
 
9
- from ...core import Chainable, Module, Vars
10
- from ...utils import vec_to_tensors
11
+ from ...core import Chainable, Module, Var
12
+ from ...utils import vec_to_tensors, set_storage_
11
13
 
12
14
 
13
- def _make_projected_closure(closure, vars: Vars, projection: "Projection",
15
+ def _make_projected_closure(closure, project_fn, unproject_fn,
14
16
  params: list[torch.Tensor], projected_params: list[torch.Tensor]):
15
-
16
17
  def projected_closure(backward=True):
17
- unprojected_params = projection.unproject(projected_params, vars, current='params')
18
+ # unproject projected params
19
+ unprojected_params = unproject_fn(projected_tensors=projected_params, current='params')
18
20
 
21
+ # set actual model parameters to suggested parameters
19
22
  with torch.no_grad():
20
23
  for p, new_p in zip(params, unprojected_params):
21
24
  p.set_(new_p) # pyright: ignore[reportArgumentType]
22
25
 
26
+ # evaluate closure with suggested parameters
23
27
  if backward:
24
28
  loss = closure()
25
29
  grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
26
- projected_grads = projection.project(grads, vars, current='grads')
30
+
31
+ # project gradients on backward and set to projected parameter .grad attributes
32
+ projected_grads = project_fn(grads, current='grads')
27
33
  for p, g in zip(projected_params, projected_grads):
28
34
  p.grad = g
29
35
 
@@ -34,27 +40,44 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
34
40
 
35
41
  return projected_closure
36
42
 
37
- def _projected_get_grad_override(
38
- retain_graph: bool | None = None,
39
- create_graph: bool = False,
40
- projection: Any = ...,
41
- unprojected_vars: Any = ...,
42
- self: Any = ...,
43
- ):
44
- assert isinstance(projection, Projection)
45
- assert isinstance(unprojected_vars, Vars)
46
- assert isinstance(self, Vars)
47
-
48
- if self.grad is not None: return self.grad
49
- grads = unprojected_vars.get_grad(retain_graph, create_graph)
50
- projected_grads = list(projection.project(grads, self, current='grads'))
51
- self.grad = projected_grads
52
- for p, g in zip(self.params, projected_grads):
53
- p.grad = g
54
- return self.grad
55
-
56
-
57
- class Projection(Module, ABC):
43
+ class _FakeProjectedClosure:
44
+ """This is used when project_params is False. Then the closure is meant to only be used to evaluate the initial gradient.
45
+ It should just evaluate original closure, project the gradients, and set them to fake params.
46
+
47
+ I made it into a class so that it can know and raise when it evaluates closure more than once.
48
+ """
49
+ __slots__ = ('closure', 'project_fn', 'params', 'fake_params', 'evaluated')
50
+ def __init__(self, closure, project_fn, params: list[torch.Tensor], fake_params: list[torch.Tensor]):
51
+ self.closure = closure
52
+ self.project_fn = project_fn
53
+ self.params = params
54
+ self.fake_params = fake_params
55
+ self.evaluated = False
56
+
57
+ def __call__(self, backward: bool = True):
58
+ if self.evaluated:
59
+ raise RuntimeError("set project_params to True if projected modules require closure.")
60
+ self.evaluated = True
61
+
62
+ # evaluate closure with suggested parameters
63
+ if backward:
64
+
65
+ loss = self.closure()
66
+ grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
67
+
68
+ # project gradients on backward and set to projected parameter .grad attributes
69
+ projected_grads = self.project_fn(grads, current='grads')
70
+ for p, g in zip(self.fake_params, projected_grads):
71
+ p.grad = g
72
+
73
+ else:
74
+ loss = self.closure(False)
75
+
76
+ return loss
77
+
78
+
79
+
80
+ class ProjectionBase(Module, ABC):
58
81
  """
59
82
  Base class for projections.
60
83
  This is an abstract class, to use it, subclass it and override `project` and `unproject`.
@@ -84,57 +107,125 @@ class Projection(Module, ABC):
84
107
  self._project_grad = project_grad
85
108
  self._projected_params = None
86
109
 
110
+ self._states: dict[str, list[dict[str, Any]]] = {}
111
+ """per-parameter states for each projection target"""
112
+
87
113
  @abstractmethod
88
- def project(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
114
+ def project(
115
+ self,
116
+ tensors: list[torch.Tensor],
117
+ params: list[torch.Tensor],
118
+ grads: list[torch.Tensor] | None,
119
+ loss: torch.Tensor | None,
120
+ states: list[dict[str, Any]],
121
+ settings: list[ChainMap[str, Any]],
122
+ current: str,
123
+ ) -> Iterable[torch.Tensor]:
89
124
  """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
90
125
 
91
126
  @abstractmethod
92
- def unproject(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
93
- """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
127
+ def unproject(
128
+ self,
129
+ projected_tensors: list[torch.Tensor],
130
+ params: list[torch.Tensor],
131
+ grads: list[torch.Tensor] | None,
132
+ loss: torch.Tensor | None,
133
+ states: list[dict[str, Any]],
134
+ settings: list[ChainMap[str, Any]],
135
+ current: str,
136
+ ) -> Iterable[torch.Tensor]:
137
+ """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.
138
+
139
+ Args:
140
+ projected_tensors (list[torch.Tensor]): projected tensors to unproject.
141
+ params (list[torch.Tensor]): original, unprojected parameters.
142
+ grads (list[torch.Tensor] | None): original, unprojected gradients
143
+ loss (torch.Tensor | None): loss at initial point.
144
+ states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
145
+ settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
146
+ current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".
147
+
148
+ Returns:
149
+ Iterable[torch.Tensor]: unprojected tensors of the same shape as params
150
+ """
94
151
 
95
152
  @torch.no_grad
96
- def step(self, vars: Vars):
97
- projected_vars = vars.clone(clone_update=False)
153
+ def step(self, var: Var):
154
+ params = var.params
155
+ settings = [self.settings[p] for p in params]
156
+
157
+ def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
158
+ states = self._states.setdefault(current, [{} for _ in params])
159
+ return list(self.project(
160
+ tensors=tensors,
161
+ params=params,
162
+ grads=var.grad,
163
+ loss=var.loss,
164
+ states=states,
165
+ settings=settings,
166
+ current=current,
167
+ ))
168
+
169
+ projected_var = var.clone(clone_update=False)
170
+
171
+ closure = var.closure
172
+
173
+ # if this is True, update and grad were projected simultaneously under current="grads"
174
+ # so update will have to be unprojected with current="grads"
98
175
  update_is_grad = False
99
176
 
100
- # closure will calculate projected update and grad if needed
101
- if self._project_params and vars.closure is not None:
102
- if self._project_update and vars.update is not None: projected_vars.update = list(self.project(vars.update, vars=vars, current='update'))
177
+ # if closure is provided and project_params=True, make new closure that evaluates projected params
178
+ # that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
179
+ # but if it has already been computed, it should be projected
180
+ if self._project_params and closure is not None:
181
+
182
+ if self._project_update and var.update is not None:
183
+ # project update only if it already exists
184
+ projected_var.update = _project(var.update, current='update')
185
+
103
186
  else:
187
+ # update will be set to gradients on var.get_grad()
188
+ # therefore projection will happen with current="grads"
104
189
  update_is_grad = True
105
- if self._project_grad and vars.grad is not None: projected_vars.grad = list(self.project(vars.grad, vars=vars, current='grads'))
106
190
 
107
- # project update and grad, unprojected attributes are deleted
191
+ # project grad only if it already exists
192
+ if self._project_grad and var.grad is not None:
193
+ projected_var.grad = _project(var.grad, current='grads')
194
+
195
+ # otherwise update/grad needs to be calculated and projected here
108
196
  else:
109
197
  if self._project_update:
110
- if vars.update is None:
198
+ if var.update is None:
111
199
  # update is None, meaning it will be set to `grad`.
112
200
  # we can project grad and use it for update
113
- grad = vars.get_grad()
114
- projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
115
- if self._project_grad: projected_vars.update = [g.clone() for g in projected_vars.grad]
116
- else: projected_vars.update = projected_vars.grad.copy() # don't clone because grad shouldn't be used
117
- del vars.update
201
+ grad = var.get_grad()
202
+ projected_var.grad = _project(grad, current='grads')
203
+ projected_var.update = [g.clone() for g in projected_var.grad]
204
+ del var.update
118
205
  update_is_grad = True
119
206
 
120
207
  else:
121
- update = vars.get_update()
122
- projected_vars.update = list(self.project(update, vars=vars, current='update'))
123
- del update, vars.update
208
+ # update exists so it needs to be projected
209
+ update = var.get_update()
210
+ projected_var.update = _project(update, current='update')
211
+ del update, var.update
212
+
213
+ if self._project_grad and projected_var.grad is None:
214
+ # projected_vars.grad may have been projected simultaneously with update
215
+ # but if that didn't happen, it is projected here
216
+ grad = var.get_grad()
217
+ projected_var.grad = _project(grad, current='grads')
124
218
 
125
- if self._project_grad and projected_vars.grad is None:
126
- grad = vars.get_grad()
127
- projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
128
219
 
129
220
  original_params = None
130
221
  if self._project_params:
131
- original_params = [p.clone() for p in vars.params]
132
- projected_params = self.project(vars.params, vars=vars, current='params')
222
+ original_params = [p.clone() for p in var.params]
223
+ projected_params = _project(var.params, current='params')
133
224
 
134
225
  else:
135
226
  # make fake params for correct shapes and state storage
136
227
  # they reuse update or grad storage for memory efficiency
137
- projected_params = projected_vars.update if projected_vars.update is not None else projected_vars.grad
228
+ projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
138
229
  assert projected_params is not None
139
230
 
140
231
  if self._projected_params is None:
@@ -146,99 +237,102 @@ class Projection(Module, ABC):
146
237
  for empty_p, new_p in zip(self._projected_params, projected_params):
147
238
  empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
148
239
 
240
+ projected_params = self._projected_params
241
+ # projected_settings = [self.settings[p] for p in projected_params]
242
+
243
+ def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
244
+ states = self._states.setdefault(current, [{} for _ in params])
245
+ return list(self.unproject(
246
+ projected_tensors=projected_tensors,
247
+ params=params,
248
+ grads=var.grad,
249
+ loss=var.loss,
250
+ states=states,
251
+ settings=settings,
252
+ current=current,
253
+ ))
254
+
149
255
  # project closure
150
256
  if self._project_params:
151
- closure = vars.closure; params = vars.params
152
- projected_vars.closure = _make_projected_closure(closure, vars=vars, projection=self, params=params,
153
- projected_params=self._projected_params)
257
+ projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
258
+ params=params, projected_params=projected_params)
259
+
260
+ elif closure is not None:
261
+ projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
262
+ params=params, fake_params=projected_params)
154
263
 
155
264
  else:
156
- projected_vars.closure = None
157
-
158
- # step
159
- projected_vars.params = self._projected_params
160
- projected_vars.get_grad = partial(
161
- _projected_get_grad_override,
162
- projection=self,
163
- unprojected_vars=vars,
164
- self=projected_vars,
165
- )
166
- projected_vars = self.children['modules'].step(projected_vars)
265
+ projected_var.closure = None
266
+
267
+ # ----------------------------------- step ----------------------------------- #
268
+ projected_var.params = projected_params
269
+ projected_var = self.children['modules'].step(projected_var)
167
270
 
168
271
  # empty fake params storage
169
272
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
170
273
  if not self._project_params:
171
274
  for p in self._projected_params:
172
- p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
275
+ set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
173
276
 
174
- # unproject
175
- unprojected_vars = projected_vars.clone(clone_update=False)
176
- unprojected_vars.closure = vars.closure
177
- unprojected_vars.params = vars.params
178
- unprojected_vars.grad = vars.grad
277
+ # --------------------------------- unproject -------------------------------- #
278
+ unprojected_var = projected_var.clone(clone_update=False)
279
+ unprojected_var.closure = var.closure
280
+ unprojected_var.params = var.params
281
+ unprojected_var.grad = var.grad
179
282
 
180
283
  if self._project_update:
181
- assert projected_vars.update is not None
182
- unprojected_vars.update = list(self.unproject(projected_vars.update, vars=vars, current='grads' if update_is_grad else 'update'))
183
- del projected_vars.update
284
+ assert projected_var.update is not None
285
+ unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
286
+ del projected_var.update
184
287
 
185
- # unprojecting grad doesn't make sense?
186
- # if self._project_grad:
187
- # assert projected_vars.grad is not None
188
- # unprojected_vars.grad = list(self.unproject(projected_vars.grad, vars=vars))
189
-
190
- del projected_vars
288
+ del projected_var
191
289
 
290
+ # original params are stored if params are projected
192
291
  if original_params is not None:
193
- for p, o in zip(unprojected_vars.params, original_params):
292
+ for p, o in zip(unprojected_var.params, original_params):
194
293
  p.set_(o) # pyright: ignore[reportArgumentType]
195
294
 
196
- return unprojected_vars
197
-
198
-
199
-
200
- class FlipConcatProjection(Projection):
201
- """
202
- for testing
203
- """
204
-
205
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
206
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
207
-
208
- @torch.no_grad
209
- def project(self, tensors, vars, current):
210
- return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
211
-
212
- @torch.no_grad
213
- def unproject(self, tensors, vars, current):
214
- return vec_to_tensors(vec=tensors[0].flip(0), reference=vars.params)
295
+ return unprojected_var
215
296
 
216
297
 
217
- class NoopProjection(Projection):
218
- """an example projection which doesn't do anything for testing"""
219
298
 
220
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
299
+ # basic examples
300
+ class VectorProjection(ProjectionBase):
301
+ """projection that concatenates all parameters into a vector"""
302
+ def __init__(
303
+ self,
304
+ modules: Chainable,
305
+ project_update=True,
306
+ project_params=True,
307
+ project_grad=True,
308
+ ):
221
309
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
222
310
 
223
311
  @torch.no_grad
224
- def project(self, tensors, vars, current):
225
- return tensors
312
+ def project(self, tensors, params, grads, loss, states, settings, current):
313
+ return [torch.cat([t.ravel() for t in tensors])]
226
314
 
227
315
  @torch.no_grad
228
- def unproject(self, tensors, vars, current):
229
- return tensors
316
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
317
+ return vec_to_tensors(vec=projected_tensors[0], reference=params)
230
318
 
231
- class MultipyProjection(Projection):
232
- """an example projection which multiplies everything by 2"""
233
319
 
234
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
320
+ class ScalarProjection(ProjectionBase):
321
+ """projetion that splits all parameters into individual scalars"""
322
+ def __init__(
323
+ self,
324
+ modules: Chainable,
325
+ project_update=True,
326
+ project_params=True,
327
+ project_grad=True,
328
+ ):
235
329
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
236
330
 
237
331
  @torch.no_grad
238
- def project(self, tensors, vars, current):
239
- return torch._foreach_mul(tensors, 2)
332
+ def project(self, tensors, params, grads, loss, states, settings, current):
333
+ return [s for t in tensors for s in t.ravel().unbind(0)]
240
334
 
241
335
  @torch.no_grad
242
- def unproject(self, tensors, vars, current):
243
- return torch._foreach_div(tensors, 2)
336
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
337
+ return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)
244
338
 
@@ -1,7 +1,46 @@
1
- from .cg import PolakRibiere, FletcherReeves, HestenesStiefel, DaiYuan, LiuStorey, ConjugateDescent, HagerZhang, HybridHS_DY
1
+ from .cg import (
2
+ ConjugateDescent,
3
+ DaiYuan,
4
+ FletcherReeves,
5
+ HagerZhang,
6
+ HestenesStiefel,
7
+ HybridHS_DY,
8
+ LiuStorey,
9
+ PolakRibiere,
10
+ ProjectedGradientMethod,
11
+ )
12
+ from .diagonal_quasi_newton import (
13
+ DNRTR,
14
+ DiagonalBFGS,
15
+ DiagonalQuasiCauchi,
16
+ DiagonalSR1,
17
+ DiagonalWeightedQuasiCauchi,
18
+ NewDQN,
19
+ )
2
20
  from .lbfgs import LBFGS
3
- from .olbfgs import OnlineLBFGS
4
- # from .experimental import ModularLBFGS
21
+ from .lsr1 import LSR1
22
+ # from .olbfgs import OnlineLBFGS
5
23
 
6
- from .quasi_newton import BFGS, SR1, DFP, BroydenGood, BroydenBad, Greenstadt1, Greenstadt2, ColumnUpdatingMethod, ThomasOptimalMethod, PSB, Pearson2, SSVM
7
- from .lsr1 import LSR1
24
+ # from .experimental import ModularLBFGS
25
+ from .quasi_newton import (
26
+ BFGS,
27
+ DFP,
28
+ ICUM,
29
+ PSB,
30
+ SR1,
31
+ SSVM,
32
+ BroydenBad,
33
+ BroydenGood,
34
+ FletcherVMM,
35
+ GradientCorrection,
36
+ Greenstadt1,
37
+ Greenstadt2,
38
+ Horisho,
39
+ McCormick,
40
+ NewSSM,
41
+ Pearson,
42
+ ProjectedNewtonRaphson,
43
+ ThomasOptimalMethod,
44
+ ShorR,
45
+ )
46
+ from .trust_region import CubicRegularization, TrustCG, TrustRegionBase