torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,29 +1,35 @@
1
1
  import math
2
- from functools import partial
2
+ import warnings
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Iterable
4
+ from collections import defaultdict, ChainMap
5
+ from collections.abc import Iterable, Mapping, Sequence
6
+ from functools import partial
5
7
  from typing import Any, Literal
6
- import warnings
8
+
7
9
  import torch
8
10
 
9
11
  from ...core import Chainable, Module, Var
10
- from ...utils import vec_to_tensors
12
+ from ...utils import vec_to_tensors, set_storage_
11
13
 
12
14
 
13
- def _make_projected_closure(closure, var: Var, projection: "Projection",
15
+ def _make_projected_closure(closure, project_fn, unproject_fn,
14
16
  params: list[torch.Tensor], projected_params: list[torch.Tensor]):
15
-
16
17
  def projected_closure(backward=True):
17
- unprojected_params = projection.unproject(projected_params, var, current='params')
18
+ # unproject projected params
19
+ unprojected_params = unproject_fn(projected_tensors=projected_params, current='params')
18
20
 
21
+ # set actual model parameters to suggested parameters
19
22
  with torch.no_grad():
20
23
  for p, new_p in zip(params, unprojected_params):
21
24
  p.set_(new_p) # pyright: ignore[reportArgumentType]
22
25
 
26
+ # evaluate closure with suggested parameters
23
27
  if backward:
24
28
  loss = closure()
25
29
  grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
26
- projected_grads = projection.project(grads, var, current='grads')
30
+
31
+ # project gradients on backward and set to projected parameter .grad attributes
32
+ projected_grads = project_fn(grads, current='grads')
27
33
  for p, g in zip(projected_params, projected_grads):
28
34
  p.grad = g
29
35
 
@@ -34,27 +40,44 @@ def _make_projected_closure(closure, var: Var, projection: "Projection",
34
40
 
35
41
  return projected_closure
36
42
 
37
- def _projected_get_grad_override(
38
- retain_graph: bool | None = None,
39
- create_graph: bool = False,
40
- projection: Any = ...,
41
- unprojected_var: Any = ...,
42
- self: Any = ...,
43
- ):
44
- assert isinstance(projection, Projection)
45
- assert isinstance(unprojected_var, Var)
46
- assert isinstance(self, Var)
47
-
48
- if self.grad is not None: return self.grad
49
- grads = unprojected_var.get_grad(retain_graph, create_graph)
50
- projected_grads = list(projection.project(grads, self, current='grads'))
51
- self.grad = projected_grads
52
- for p, g in zip(self.params, projected_grads):
53
- p.grad = g
54
- return self.grad
55
-
56
-
57
- class Projection(Module, ABC):
43
+ class _FakeProjectedClosure:
44
+ """This is used when project_params is False. Then the closure is meant to only be used to evaluate the initial gradient.
45
+ It should just evaluate original closure, project the gradients, and set them to fake params.
46
+
47
+ I made it into a class so that it can know and raise when it evaluates closure more than once.
48
+ """
49
+ __slots__ = ('closure', 'project_fn', 'params', 'fake_params', 'evaluated')
50
+ def __init__(self, closure, project_fn, params: list[torch.Tensor], fake_params: list[torch.Tensor]):
51
+ self.closure = closure
52
+ self.project_fn = project_fn
53
+ self.params = params
54
+ self.fake_params = fake_params
55
+ self.evaluated = False
56
+
57
+ def __call__(self, backward: bool = True):
58
+ if self.evaluated:
59
+ raise RuntimeError("set project_params to True if projected modules require closure.")
60
+ self.evaluated = True
61
+
62
+ # evaluate closure with suggested parameters
63
+ if backward:
64
+
65
+ loss = self.closure()
66
+ grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in self.params]
67
+
68
+ # project gradients on backward and set to projected parameter .grad attributes
69
+ projected_grads = self.project_fn(grads, current='grads')
70
+ for p, g in zip(self.fake_params, projected_grads):
71
+ p.grad = g
72
+
73
+ else:
74
+ loss = self.closure(False)
75
+
76
+ return loss
77
+
78
+
79
+
80
+ class ProjectionBase(Module, ABC):
58
81
  """
59
82
  Base class for projections.
60
83
  This is an abstract class, to use it, subclass it and override `project` and `unproject`.
@@ -84,52 +107,120 @@ class Projection(Module, ABC):
84
107
  self._project_grad = project_grad
85
108
  self._projected_params = None
86
109
 
110
+ self._states: dict[str, list[dict[str, Any]]] = {}
111
+ """per-parameter states for each projection target"""
112
+
87
113
  @abstractmethod
88
- def project(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
114
+ def project(
115
+ self,
116
+ tensors: list[torch.Tensor],
117
+ params: list[torch.Tensor],
118
+ grads: list[torch.Tensor] | None,
119
+ loss: torch.Tensor | None,
120
+ states: list[dict[str, Any]],
121
+ settings: list[ChainMap[str, Any]],
122
+ current: str,
123
+ ) -> Iterable[torch.Tensor]:
89
124
  """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
90
125
 
91
126
  @abstractmethod
92
- def unproject(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
93
- """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
127
+ def unproject(
128
+ self,
129
+ projected_tensors: list[torch.Tensor],
130
+ params: list[torch.Tensor],
131
+ grads: list[torch.Tensor] | None,
132
+ loss: torch.Tensor | None,
133
+ states: list[dict[str, Any]],
134
+ settings: list[ChainMap[str, Any]],
135
+ current: str,
136
+ ) -> Iterable[torch.Tensor]:
137
+ """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`.
138
+
139
+ Args:
140
+ projected_tensors (list[torch.Tensor]): projected tensors to unproject.
141
+ params (list[torch.Tensor]): original, unprojected parameters.
142
+ grads (list[torch.Tensor] | None): original, unprojected gradients
143
+ loss (torch.Tensor | None): loss at initial point.
144
+ states (list[dict[str, Any]]): list of state dictionaries per each UNPROJECTED tensor.
145
+ settings (list[ChainMap[str, Any]]): list of setting dictionaries per each UNPROJECTED tensor.
146
+ current (str): string representing what is being unprojected, e.g. "params", "grads" or "update".
147
+
148
+ Returns:
149
+ Iterable[torch.Tensor]: unprojected tensors of the same shape as params
150
+ """
94
151
 
95
152
  @torch.no_grad
96
153
  def step(self, var: Var):
154
+ params = var.params
155
+ settings = [self.settings[p] for p in params]
156
+
157
+ def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
158
+ states = self._states.setdefault(current, [{} for _ in params])
159
+ return list(self.project(
160
+ tensors=tensors,
161
+ params=params,
162
+ grads=var.grad,
163
+ loss=var.loss,
164
+ states=states,
165
+ settings=settings,
166
+ current=current,
167
+ ))
168
+
97
169
  projected_var = var.clone(clone_update=False)
170
+
171
+ closure = var.closure
172
+
173
+ # if this is True, update and grad were projected simultaneously under current="grads"
174
+ # so update will have to be unprojected with current="grads"
98
175
  update_is_grad = False
99
176
 
100
- # closure will calculate projected update and grad if needed
101
- if self._project_params and var.closure is not None:
102
- if self._project_update and var.update is not None: projected_var.update = list(self.project(var.update, var=var, current='update'))
177
+ # if closure is provided and project_params=True, make new closure that evaluates projected params
178
+ # that also means projected modules can evaluate grad/update at will, it shouldn't be computed here
179
+ # but if it has already been computed, it should be projected
180
+ if self._project_params and closure is not None:
181
+
182
+ if self._project_update and var.update is not None:
183
+ # project update only if it already exists
184
+ projected_var.update = _project(var.update, current='update')
185
+
103
186
  else:
187
+ # update will be set to gradients on var.get_grad()
188
+ # therefore projection will happen with current="grads"
104
189
  update_is_grad = True
105
- if self._project_grad and var.grad is not None: projected_var.grad = list(self.project(var.grad, var=var, current='grads'))
106
190
 
107
- # project update and grad, unprojected attributes are deleted
191
+ # project grad only if it already exists
192
+ if self._project_grad and var.grad is not None:
193
+ projected_var.grad = _project(var.grad, current='grads')
194
+
195
+ # otherwise update/grad needs to be calculated and projected here
108
196
  else:
109
197
  if self._project_update:
110
198
  if var.update is None:
111
199
  # update is None, meaning it will be set to `grad`.
112
200
  # we can project grad and use it for update
113
201
  grad = var.get_grad()
114
- projected_var.grad = list(self.project(grad, var=var, current='grads'))
115
- if self._project_grad: projected_var.update = [g.clone() for g in projected_var.grad]
116
- else: projected_var.update = projected_var.grad.copy() # don't clone because grad shouldn't be used
202
+ projected_var.grad = _project(grad, current='grads')
203
+ projected_var.update = [g.clone() for g in projected_var.grad]
117
204
  del var.update
118
205
  update_is_grad = True
119
206
 
120
207
  else:
208
+ # update exists so it needs to be projected
121
209
  update = var.get_update()
122
- projected_var.update = list(self.project(update, var=var, current='update'))
210
+ projected_var.update = _project(update, current='update')
123
211
  del update, var.update
124
212
 
125
213
  if self._project_grad and projected_var.grad is None:
214
+ # projected_vars.grad may have been projected simultaneously with update
215
+ # but if that didn't happen, it is projected here
126
216
  grad = var.get_grad()
127
- projected_var.grad = list(self.project(grad, var=var, current='grads'))
217
+ projected_var.grad = _project(grad, current='grads')
218
+
128
219
 
129
220
  original_params = None
130
221
  if self._project_params:
131
222
  original_params = [p.clone() for p in var.params]
132
- projected_params = self.project(var.params, var=var, current='params')
223
+ projected_params = _project(var.params, current='params')
133
224
 
134
225
  else:
135
226
  # make fake params for correct shapes and state storage
@@ -146,32 +237,44 @@ class Projection(Module, ABC):
146
237
  for empty_p, new_p in zip(self._projected_params, projected_params):
147
238
  empty_p.set_(new_p.view_as(new_p).requires_grad_()) # pyright: ignore[reportArgumentType]
148
239
 
240
+ projected_params = self._projected_params
241
+ # projected_settings = [self.settings[p] for p in projected_params]
242
+
243
+ def _unproject(projected_tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
244
+ states = self._states.setdefault(current, [{} for _ in params])
245
+ return list(self.unproject(
246
+ projected_tensors=projected_tensors,
247
+ params=params,
248
+ grads=var.grad,
249
+ loss=var.loss,
250
+ states=states,
251
+ settings=settings,
252
+ current=current,
253
+ ))
254
+
149
255
  # project closure
150
256
  if self._project_params:
151
- closure = var.closure; params = var.params
152
- projected_var.closure = _make_projected_closure(closure, var=var, projection=self, params=params,
153
- projected_params=self._projected_params)
257
+ projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
258
+ params=params, projected_params=projected_params)
259
+
260
+ elif closure is not None:
261
+ projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
262
+ params=params, fake_params=projected_params)
154
263
 
155
264
  else:
156
265
  projected_var.closure = None
157
266
 
158
- # step
159
- projected_var.params = self._projected_params
160
- projected_var.get_grad = partial(
161
- _projected_get_grad_override,
162
- projection=self,
163
- unprojected_var=var,
164
- self=projected_var,
165
- )
267
+ # ----------------------------------- step ----------------------------------- #
268
+ projected_var.params = projected_params
166
269
  projected_var = self.children['modules'].step(projected_var)
167
270
 
168
271
  # empty fake params storage
169
272
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
170
273
  if not self._project_params:
171
274
  for p in self._projected_params:
172
- p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
275
+ set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
173
276
 
174
- # unproject
277
+ # --------------------------------- unproject -------------------------------- #
175
278
  unprojected_var = projected_var.clone(clone_update=False)
176
279
  unprojected_var.closure = var.closure
177
280
  unprojected_var.params = var.params
@@ -179,16 +282,12 @@ class Projection(Module, ABC):
179
282
 
180
283
  if self._project_update:
181
284
  assert projected_var.update is not None
182
- unprojected_var.update = list(self.unproject(projected_var.update, var=var, current='grads' if update_is_grad else 'update'))
285
+ unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
183
286
  del projected_var.update
184
287
 
185
- # unprojecting grad doesn't make sense?
186
- # if self._project_grad:
187
- # assert projected_var.grad is not None
188
- # unprojected_var.grad = list(self.unproject(projected_var.grad, var=var))
189
-
190
288
  del projected_var
191
289
 
290
+ # original params are stored if params are projected
192
291
  if original_params is not None:
193
292
  for p, o in zip(unprojected_var.params, original_params):
194
293
  p.set_(o) # pyright: ignore[reportArgumentType]
@@ -197,48 +296,43 @@ class Projection(Module, ABC):
197
296
 
198
297
 
199
298
 
200
- class FlipConcatProjection(Projection):
201
- """
202
- for testing
203
- """
204
-
205
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
206
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
207
-
208
- @torch.no_grad
209
- def project(self, tensors, var, current):
210
- return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
211
-
212
- @torch.no_grad
213
- def unproject(self, tensors, var, current):
214
- return vec_to_tensors(vec=tensors[0].flip(0), reference=var.params)
215
-
216
-
217
- class NoopProjection(Projection):
218
- """an example projection which doesn't do anything for testing"""
219
-
220
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
299
+ # basic examples
300
+ class VectorProjection(ProjectionBase):
301
+ """projection that concatenates all parameters into a vector"""
302
+ def __init__(
303
+ self,
304
+ modules: Chainable,
305
+ project_update=True,
306
+ project_params=True,
307
+ project_grad=True,
308
+ ):
221
309
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
222
310
 
223
311
  @torch.no_grad
224
- def project(self, tensors, var, current):
225
- return tensors
312
+ def project(self, tensors, params, grads, loss, states, settings, current):
313
+ return [torch.cat([t.ravel() for t in tensors])]
226
314
 
227
315
  @torch.no_grad
228
- def unproject(self, tensors, var, current):
229
- return tensors
316
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
317
+ return vec_to_tensors(vec=projected_tensors[0], reference=params)
230
318
 
231
- class MultipyProjection(Projection):
232
- """an example projection which multiplies everything by 2"""
233
319
 
234
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
320
+ class ScalarProjection(ProjectionBase):
321
+ """projetion that splits all parameters into individual scalars"""
322
+ def __init__(
323
+ self,
324
+ modules: Chainable,
325
+ project_update=True,
326
+ project_params=True,
327
+ project_grad=True,
328
+ ):
235
329
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
236
330
 
237
331
  @torch.no_grad
238
- def project(self, tensors, var, current):
239
- return torch._foreach_mul(tensors, 2)
332
+ def project(self, tensors, params, grads, loss, states, settings, current):
333
+ return [s for t in tensors for s in t.ravel().unbind(0)]
240
334
 
241
335
  @torch.no_grad
242
- def unproject(self, tensors, var, current):
243
- return torch._foreach_div(tensors, 2)
336
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
337
+ return vec_to_tensors(vec=torch.stack(projected_tensors), reference=params)
244
338
 
@@ -9,20 +9,28 @@ from .cg import (
9
9
  PolakRibiere,
10
10
  ProjectedGradientMethod,
11
11
  )
12
+ from .diagonal_quasi_newton import (
13
+ DNRTR,
14
+ DiagonalBFGS,
15
+ DiagonalQuasiCauchi,
16
+ DiagonalSR1,
17
+ DiagonalWeightedQuasiCauchi,
18
+ NewDQN,
19
+ )
12
20
  from .lbfgs import LBFGS
13
21
  from .lsr1 import LSR1
14
- from .olbfgs import OnlineLBFGS
22
+ # from .olbfgs import OnlineLBFGS
15
23
 
16
24
  # from .experimental import ModularLBFGS
17
25
  from .quasi_newton import (
18
26
  BFGS,
19
27
  DFP,
28
+ ICUM,
20
29
  PSB,
21
30
  SR1,
22
31
  SSVM,
23
32
  BroydenBad,
24
33
  BroydenGood,
25
- ColumnUpdatingMethod,
26
34
  FletcherVMM,
27
35
  GradientCorrection,
28
36
  Greenstadt1,
@@ -33,4 +41,6 @@ from .quasi_newton import (
33
41
  Pearson,
34
42
  ProjectedNewtonRaphson,
35
43
  ThomasOptimalMethod,
44
+ ShorR,
36
45
  )
46
+ from .trust_region import CubicRegularization, TrustCG, TrustRegionBase