torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from abc import ABC, abstractmethod
1
2
  from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
2
3
  from typing import Any, Literal, TypeVar, overload
3
4
 
@@ -132,65 +133,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
132
133
  return values
133
134
 
134
135
 
135
-
136
- def loss_at_params(closure, params: Iterable[torch.Tensor],
137
- new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
138
- params = TensorList(params)
139
-
140
- old_params = params.clone() if restore else None
141
-
142
- if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
143
- # when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
144
- # when restoring params will be set back to old_params so its fine
145
- if restore: params.set_(new_params)
146
- else: params.copy_(new_params) # type:ignore
147
-
148
- else:
149
- new_params = totensor(new_params)
150
- params.from_vec_(new_params)
151
-
152
- if backward: loss = closure()
153
- else: loss = closure(False)
154
-
155
- if restore:
156
- assert old_params is not None
157
- params.set_(old_params)
158
-
159
- return tofloat(loss)
160
-
161
- def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
162
- params = TensorList(params)
163
- old_params = params.clone() if restore else None
164
- loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
165
- grad = params.ensure_grad_().grad
166
-
167
- if restore:
168
- assert old_params is not None
169
- params.set_(old_params)
170
-
171
- return loss, grad
172
-
173
- def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
174
- return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
175
-
176
- def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
177
- params = TensorList(params)
178
- old_params = params.clone() if restore else None
179
- loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
180
- grad = params.ensure_grad_().grad.to_vec()
181
-
182
- if restore:
183
- assert old_params is not None
184
- params.set_(old_params)
185
-
186
- return loss, grad
187
-
188
- def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
189
- return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
190
-
191
-
192
-
193
- class Optimizer(torch.optim.Optimizer):
136
+ class Optimizer(torch.optim.Optimizer, ABC):
194
137
  """subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
195
138
 
196
139
  Args:
@@ -251,21 +194,10 @@ class Optimizer(torch.optim.Optimizer):
251
194
 
252
195
  return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
253
196
 
254
- def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
255
- return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
256
-
257
- def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
258
- return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
259
-
260
- def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
261
- return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
262
-
263
- def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
264
- return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
265
-
266
- def grad_vec_at_params(self, closure, params: Any, restore=False):
267
- return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
268
197
 
198
+ # shut up pylance
199
+ @abstractmethod
200
+ def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
269
201
 
270
202
  def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
271
203
  if set_to_none:
@@ -281,4 +213,53 @@ def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
281
213
  else:
282
214
  grad.requires_grad_(False)
283
215
 
284
- torch._foreach_zero_(grads)
216
+ torch._foreach_zero_(grads)
217
+
218
+
219
+ @overload
220
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
221
+ key: str, *,
222
+ must_exist: bool = False, init: Init = torch.zeros_like,
223
+ cls: type[ListLike] = list) -> ListLike: ...
224
+ @overload
225
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
226
+ key: list[str] | tuple[str,...], *,
227
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
228
+ cls: type[ListLike] = list) -> list[ListLike]: ...
229
+ @overload
230
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
231
+ key: str, key2: str, *keys: str,
232
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
233
+ cls: type[ListLike] = list) -> list[ListLike]: ...
234
+
235
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
236
+ key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
237
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
238
+ cls: type[ListLike] = list) -> ListLike | list[ListLike]:
239
+
240
+ # single key, return single cls
241
+ if isinstance(key, str) and key2 is None:
242
+ values = cls()
243
+ for i,s in enumerate(states):
244
+ if key not in s:
245
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
246
+ s[key] = _make_initial_state_value(tensors[i], init, i)
247
+ values.append(s[key])
248
+ return values
249
+
250
+ # multiple keys
251
+ k1 = (key,) if isinstance(key, str) else tuple(key)
252
+ k2 = () if key2 is None else (key2,)
253
+ keys = k1 + k2 + keys
254
+
255
+ values = [cls() for _ in keys]
256
+ for i,s in enumerate(states):
257
+ for k_i, key in enumerate(keys):
258
+ if key not in s:
259
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
260
+ k_init = init[k_i] if isinstance(init, (list,tuple)) else init
261
+ s[key] = _make_initial_state_value(tensors[i], k_init, i)
262
+ values[k_i].append(s[key])
263
+
264
+ return values
265
+
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Any, TypeVar
4
- from collections.abc import Iterable, Callable
3
+ from typing import Any, TypeVar, overload
4
+ from collections.abc import Iterable, Callable, Mapping, MutableSequence
5
5
  from collections import UserDict
6
6
 
7
7
 
@@ -17,8 +17,8 @@ def flatten(iterable: Iterable) -> list[Any]:
17
17
  raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
18
18
 
19
19
  X = TypeVar("X")
20
- # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
21
- def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
20
+ # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
21
+ def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
22
22
  """Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
23
23
  return functools.reduce(operator.iconcat, x, [])
24
24
 
@@ -38,3 +38,16 @@ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
38
38
  if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
39
39
  return self.__class__(fn(i, other, *args, **kwargs) for i in self)
40
40
 
41
+ ListLike = TypeVar('ListLike', bound=MutableSequence)
42
+ @overload
43
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, *, cls:type[ListLike]=list) -> ListLike: ...
44
+ @overload
45
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str, *keys:str, cls:type[ListLike]=list) -> list[ListLike]: ...
46
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None = None, *keys:str, cls:type[ListLike]=list) -> ListLike | list[ListLike]:
47
+ k1 = (key,) if isinstance(key, str) else tuple(key)
48
+ k2 = () if key2 is None else (key2,)
49
+ keys = k1 + k2 + keys
50
+
51
+ values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
52
+ if len(values) == 1: return values[0]
53
+ return values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -157,13 +157,14 @@ for epoch in range(100):
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
159
  * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
160
161
 
161
162
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
163
  * `LBFGS`: Limited-memory BFGS.
163
164
  * `LSR1`: Limited-memory SR1.
164
165
  * `OnlineLBFGS`: Online LBFGS.
165
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
166
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
166
+ * `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
167
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
167
168
 
168
169
  * **Line Search**:
169
170
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
@@ -312,20 +313,20 @@ not in the module itself. Also both per-parameter settings and state are stored
312
313
 
313
314
  ```python
314
315
  import torch
315
- from torchzero.core import Module, Vars
316
+ from torchzero.core import Module, Var
316
317
 
317
318
  class HeavyBall(Module):
318
319
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
319
320
  defaults = dict(momentum=momentum, dampening=dampening)
320
321
  super().__init__(defaults)
321
322
 
322
- def step(self, vars: Vars):
323
- # a module takes a Vars object, modifies it or creates a new one, and returns it
324
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
323
+ def step(self, var: Var):
324
+ # a module takes a Var object, modifies it or creates a new one, and returns it
325
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
325
326
  # for now we are only interested in update, and we will apply the heavyball rule to it.
326
327
 
327
- params = vars.params
328
- update = vars.get_update() # list of tensors
328
+ params = var.params
329
+ update = var.get_update() # list of tensors
329
330
 
330
331
  exp_avg_list = []
331
332
  for p, u in zip(params, update):
@@ -346,16 +347,15 @@ class HeavyBall(Module):
346
347
  # and it is part of self.state
347
348
  exp_avg_list.append(buf.clone())
348
349
 
349
- # set new update to vars
350
- vars.update = exp_avg_list
351
- return vars
350
+ # set new update to var
351
+ var.update = exp_avg_list
352
+ return var
352
353
  ```
353
354
 
354
355
  There are a some specialized base modules that make it much easier to implement some specific things.
355
356
 
356
357
  * `GradApproximator` for gradient approximations
357
358
  * `LineSearch` for line searches
358
- * `Preconditioner` for preconditioners
359
359
  * `Projection` for projections like GaLore or into fourier domain.
360
360
  * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
361
361
  * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
@@ -376,4 +376,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
376
376
 
377
377
  They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
378
378
 
379
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
379
+ Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
@@ -0,0 +1,139 @@
1
+ docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
2
+ tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
3
+ tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
4
+ tests/test_opts.py,sha256=VSko5fUuACo_y6iab_akke0gMhCUEEUJ9ahpBqWBoM4,41715
5
+ tests/test_tensorlist.py,sha256=SwzLKLrs2ppMtm_7UrfTDTlD-ObZd7JQ_FNHbp059tc,72460
6
+ tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
7
+ tests/test_vars.py,sha256=MqCJXrbj-C75APm1heykzcEWewinihlSjekkYDx-TFk,6726
8
+ torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
9
+ torchzero/core/__init__.py,sha256=Zib_4is13LFAabp_7VU8QXZpQEEZGzsH94vgRI0HxAg,150
10
+ torchzero/core/module.py,sha256=Yfzn48dDbxYZJLpWnLYFIbqBb4sB3GekSZ7QGIplYAg,27525
11
+ torchzero/core/transform.py,sha256=yK1wYgp03THzRN9y_f9-5q2nonEZMa0CfDFAdOxnqEU,11778
12
+ torchzero/modules/__init__.py,sha256=8C73_dFzfWUWhii1UF86FUy8x75RPiAVLAm4sLTikBg,359
13
+ torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
14
+ torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
15
+ torchzero/modules/clipping/clipping.py,sha256=XKFKvzNgsvuYUmvHyulE6PkZv_aeLQjp0CgtFj0013s,12516
16
+ torchzero/modules/clipping/ema_clipping.py,sha256=MGouZEN0BorliHAZhue0afhC3AhZJ6wrnwBRzDTHjX4,5978
17
+ torchzero/modules/clipping/growth_clipping.py,sha256=50c1YOUPVL8eWzH6zJINnNP68oiZkDcq7rR6HnWjVFc,6674
18
+ torchzero/modules/experimental/__init__.py,sha256=zxxNKPZHnkVnx1ZjKNX_nkV4Wc_EdODM6qJGn7Pgb3w,766
19
+ torchzero/modules/experimental/absoap.py,sha256=-KwQXmI12hvHbMGPHM0APAxDQztlFhlSOG55KK6PvpI,9901
20
+ torchzero/modules/experimental/adadam.py,sha256=o0KPLaF4J7L_Ty71RNgsysk6IEuC4DRE5nGQkGIP_dA,4078
21
+ torchzero/modules/experimental/adamY.py,sha256=LZabWX_vccDaG6_UVZl9ALJ-3nCZu-NEygJQ_Bwzel8,4018
22
+ torchzero/modules/experimental/adasoap.py,sha256=XtxEvBWYdcqfWnQqOFa_-SrOwd_nXHzLftiw-YXDACQ,7408
23
+ torchzero/modules/experimental/curveball.py,sha256=JdgojuSYLNe9u3bmqcYrFm8brUD4kvKm9XYx78GzpKI,3257
24
+ torchzero/modules/experimental/diagonal_higher_order_newton.py,sha256=u4-a5qJ_97XiZUDlClE2cASkBsx_NTJNPk6BWWybiqE,7158
25
+ torchzero/modules/experimental/eigendescent.py,sha256=0cM1p4rYbrpwBNXgBEMblVyX0xBWTzojSC1EsUnXH6k,4707
26
+ torchzero/modules/experimental/etf.py,sha256=FsLOCmQf24PPoRf5wsRUjVqk32uW9uTzaf1ERjFxRK8,5744
27
+ torchzero/modules/experimental/gradmin.py,sha256=UixSLdca4ekYHOipEivdXfBAV-uEL9TZm5nCFXVaNco,3684
28
+ torchzero/modules/experimental/newton_solver.py,sha256=3dZ7FG-2vGxJKkFF9P2LCs-LI_epcvZbyNtJOtw47pg,3055
29
+ torchzero/modules/experimental/newtonnewton.py,sha256=QCGnY_CFo0i_NUB7D-6ezeNpG6wLkTD5lHBiakFIqbM,3033
30
+ torchzero/modules/experimental/reduce_outward_lr.py,sha256=VFjcTpmLwpfhUR8u_5rbzPgHVR6K3fvti7jVy1DnsYU,1300
31
+ torchzero/modules/experimental/soapy.py,sha256=7qsh9Y9U9oeQDwuDSVqnz71AD0nUYY3q0XN2XoMFWaw,6721
32
+ torchzero/modules/experimental/spectral.py,sha256=SN7tToIpmna0IZ1NgObvqEbO48NnVbwqRwKi8ROsb7s,7374
33
+ torchzero/modules/experimental/structured_newton.py,sha256=CWfVJ2LPZUuz1bMnlgOM6tlYPd2etjgLDIcyAfAG_y8,3464
34
+ torchzero/modules/experimental/subspace_preconditioners.py,sha256=9Tl1PCN9crFUvVn6343GHoI3kv6CVnUWP1dfhwUvAFU,5130
35
+ torchzero/modules/experimental/tada.py,sha256=84YcLhG34CbWq84L-AUj-A4uxpzdIVayaARHRm2f9b8,1564
36
+ torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
37
+ torchzero/modules/grad_approximation/fdm.py,sha256=cUgy98Bz0Br4q6ViNxn6EVOZX2jE0nDXVZLUGhxpDcA,3589
38
+ torchzero/modules/grad_approximation/forward_gradient.py,sha256=cNgx8kc8r0fWj8xdU2b85W3fenNDQZKuIsJLM3UzSig,3867
39
+ torchzero/modules/grad_approximation/grad_approximator.py,sha256=TODFUwBgTmjfbnO6Sc833fnvYzYaqqYTEba_13s-qOI,2906
40
+ torchzero/modules/grad_approximation/rfdm.py,sha256=VsRlf95JnG6HdlIsJANcfJjMk7c_B9a5-fH9dSTBA10,11328
41
+ torchzero/modules/higher_order/__init__.py,sha256=W94CY8K1NFxs9TPi415UssKVKz5MV_bH9adax1uZsYM,50
42
+ torchzero/modules/higher_order/higher_order_newton.py,sha256=BwiSlcGobam04SgWFcB1p_-TSuzu2rWgGVnmvP6Si9k,9751
43
+ torchzero/modules/line_search/__init__.py,sha256=nkOUPLe88wE91ICEhprl2pJsvaKtbI3KzYOdT83AGsg,253
44
+ torchzero/modules/line_search/backtracking.py,sha256=ZgeLAYqrw-6BeEGp8wmOgFoLtUKROF7w7LpAREe0xZU,7704
45
+ torchzero/modules/line_search/line_search.py,sha256=CfOENZgAPSdyv1wvSbhw6gdpfbQnXGdOnLsq29wjvzU,7229
46
+ torchzero/modules/line_search/scipy.py,sha256=SvDCZ1DPOLZcSeOFvf3tXAf1ty-9qRVfGFMWVF5q708,2293
47
+ torchzero/modules/line_search/strong_wolfe.py,sha256=xOU4XFekh4TIepm9ztJTYpcGucEMPwAeb_cDK4Rp0ho,7620
48
+ torchzero/modules/line_search/trust_region.py,sha256=xUZApOTW4uXFBk_Uq_YBktiXcoSAKdDc6O5vjTwquGw,3101
49
+ torchzero/modules/lr/__init__.py,sha256=kh2k_tma-oTOALR6AlD5XHdTPSMgU4A04Oa0hAqrEpI,89
50
+ torchzero/modules/lr/adaptive.py,sha256=6s06Gvu1UmoT89hrMkXvJWHkEOMNcy5mMiyxy3V9lQs,3904
51
+ torchzero/modules/lr/lr.py,sha256=1gU2QzMA5PV2KkzOkxxrZZKGcz-Kbjyp7WNurOM36ys,2655
52
+ torchzero/modules/momentum/__init__.py,sha256=pSD7vxu8PySrYOSHQMi3C9heYdcQr8y6WC_rwMybZm0,544
53
+ torchzero/modules/momentum/averaging.py,sha256=NmRodxsSekEDGIuFGDYOvJL-WkdMN3YF-naBdtfjxx8,3247
54
+ torchzero/modules/momentum/cautious.py,sha256=JuaFYfyf9S3nTcqeZz5ylXKepqi0eqglOAQ0uNG0eT8,7373
55
+ torchzero/modules/momentum/ema.py,sha256=qJV__nIbcD9e8qvwbvsATnYkQrdnmMiA91ju52IqSxw,10699
56
+ torchzero/modules/momentum/experimental.py,sha256=eYnP6NmBDegwX9XC_dYMJP3vquBpM1LyQc03v3vW6-8,6900
57
+ torchzero/modules/momentum/matrix_momentum.py,sha256=LR12UugXM8ocwTB8zBYpt03oZeZU0cb0UoFR6qO34V8,6818
58
+ torchzero/modules/momentum/momentum.py,sha256=4Pgk-3HM7Av_ILT6oXtvnM1CB1yit8AkFnYWLvnUAqs,2655
59
+ torchzero/modules/ops/__init__.py,sha256=hxMZFSXX7xvitXkuBiYykVGX3p03Xprm_QA2CMg4eW8,1601
60
+ torchzero/modules/ops/accumulate.py,sha256=yKNgw8ZsaVRPjuzPzLJOvALkjik0aWx30Eu91FefRoA,3741
61
+ torchzero/modules/ops/binary.py,sha256=98jyjkJ8BPuSH-mb4g2BnFi6UzvRZRf__Pt-jnD3pNU,9690
62
+ torchzero/modules/ops/debug.py,sha256=zueWyNVvaJmxRD8QG8m_ys9jc7zRfSr8GAuxqz5dDTI,851
63
+ torchzero/modules/ops/misc.py,sha256=GmnKDjMXaTUjPcC5e7Jftk6k2NQ0Ivv4ceUApxMckIQ,15978
64
+ torchzero/modules/ops/multi.py,sha256=T1aVaRr6bLWvjoj1cyxaDncROypT6rmmmji8mvBHczo,5713
65
+ torchzero/modules/ops/reduce.py,sha256=reGvusJyCzM8VdHbWyJRYFePPBXfVP0jZeXIEKGIJGc,5668
66
+ torchzero/modules/ops/split.py,sha256=eM4Qsz6pMNF22bk3NF2rtvyxSOt9U-EyYxMAyjvTrMQ,2265
67
+ torchzero/modules/ops/switch.py,sha256=ddsxq4bsH86iWW6mMdcQw3c0mU1s2FA-PRZpVOia7PY,2506
68
+ torchzero/modules/ops/unary.py,sha256=3ysDHXs6snsQNBj3c288BT8G6T30Nvo0QM3PcdfQ2ww,4888
69
+ torchzero/modules/ops/utility.py,sha256=8XFjQO4ghCmGD2H-lYTgaBzik_9pB3Uxt7xCxQrv5Ig,3080
70
+ torchzero/modules/optimizers/__init__.py,sha256=BbT2nhIt4p74t1cO8ziQgzqZHaLvyuleXQbccugd06M,554
71
+ torchzero/modules/optimizers/adagrad.py,sha256=NHpWcnIRM2LyPnNtDVTdluX4n1qqqWs9IHpFD8uYeLo,5500
72
+ torchzero/modules/optimizers/adam.py,sha256=u6ieXHn_lHZozmGiKhSA73pApI83eeTNIyOrxBTFL1o,4009
73
+ torchzero/modules/optimizers/lion.py,sha256=4yy6d0SLpGXndu8NCuYhdsNshMEYhONu_FPYXdupA_s,1119
74
+ torchzero/modules/optimizers/muon.py,sha256=exbp7wVpIryiOxmbf9RAfZ9a6XXuOWTUqdjn-i57Fq4,9628
75
+ torchzero/modules/optimizers/orthograd.py,sha256=cN5g7OusfeUlh0jn0kjkvpcVjqR01eGoi9WK1sSPnug,2021
76
+ torchzero/modules/optimizers/rmsprop.py,sha256=jM5ohfABYUljy2RrtG_bY9PMHNzSkROYjqFPxnlXE6o,4309
77
+ torchzero/modules/optimizers/rprop.py,sha256=d0R8UR-f9Pb64VrsJegrCyteLYa5TAmgObjgirqLaBo,11030
78
+ torchzero/modules/optimizers/shampoo.py,sha256=hmfgPghzmjmba3PH1vLzaz0lOvLiIX9rCKrT71YZb40,8420
79
+ torchzero/modules/optimizers/soap.py,sha256=7adybqncrkt31rNveQwXp8eeZKWf0LDhC5wt7GbmDcM,11052
80
+ torchzero/modules/optimizers/sophia_h.py,sha256=He9YrHeaQhiz4CJm-3H_d_M07MGTsP663v8wx4BlaZI,4273
81
+ torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
82
+ torchzero/modules/projections/dct.py,sha256=0tswjgta3mE5D5Yjw9mJWqPBDga0OIe3lKlwd1AXASc,2369
83
+ torchzero/modules/projections/fft.py,sha256=wNDZP5-3b2-bND3qH2yvX3SqFaljbLkPTQ1gUnlH5fU,2955
84
+ torchzero/modules/projections/galore.py,sha256=etaG2gxazxuDEu-e2r7lKIKMTPEGGS5Vi7HXccmD3kY,241
85
+ torchzero/modules/projections/projection.py,sha256=QUV_Gi6QlPiWEmcc7rwucr2yuYwYFGvSRUAT4uucqMY,10049
86
+ torchzero/modules/projections/structural.py,sha256=f8-72zViXJ6S2gxDagkrrul9IaOPsYXZmX8sFLYkxCc,5683
87
+ torchzero/modules/quasi_newton/__init__.py,sha256=Yc-NV__cJCiYLr2BZG4VsYa3VVq4gCxBMcirQEXSNIo,630
88
+ torchzero/modules/quasi_newton/cg.py,sha256=lvmwJNTR7AEcpDIvpcLnMrZrOLwNld8GFAC19CcTKoY,11661
89
+ torchzero/modules/quasi_newton/lbfgs.py,sha256=BDiv3f7qN8-Nhs8LqtWwk7Wwv88NtXXYle5WwKeekm4,9198
90
+ torchzero/modules/quasi_newton/lsr1.py,sha256=A0Pstikb6JrQbwM5RZjLw9WJEHiMRy3PsPF1_iLkrK4,6053
91
+ torchzero/modules/quasi_newton/olbfgs.py,sha256=Tz2eubiN7OXGN1mbXT4VKPd9kynpXzcLas7mrvBax-k,8333
92
+ torchzero/modules/quasi_newton/quasi_newton.py,sha256=4hRII9GFE5MzNtXkHH_T1hEJ1T8T4-Q4A4MXlhf64mc,25142
93
+ torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
94
+ torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=oLbJ96sl-2XBwLbJrrTZiLJIhKhTPOD6-wny7hbSno4,10767
95
+ torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
96
+ torchzero/modules/second_order/newton.py,sha256=ZYIcLpifcOHL_KRC6YoNs-MJQKM39urXUQzReWnWeXE,6583
97
+ torchzero/modules/second_order/newton_cg.py,sha256=YAEAD_8YU_H8Y-o6JI0Ywgk-kpAQOFBQm2Bjzaz9Bjs,2865
98
+ torchzero/modules/second_order/nystrom.py,sha256=aM6dlDv7znGYNXZgKN6B6AhZ1Tpp01JMs83B1hcXE3w,6061
99
+ torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
100
+ torchzero/modules/smoothing/gaussian.py,sha256=KbCgRXGntdPbt4-ojalrHkniYgYXk2294b-2C4MIFi8,6109
101
+ torchzero/modules/smoothing/laplacian.py,sha256=Vp2EnCQhyfGc3CbyOLc6_ZiVx_jvnOISf9vlHkIH4Jo,4998
102
+ torchzero/modules/weight_decay/__init__.py,sha256=j2Vq3DDxLYIPJmXWgAJ6dL-rXzcDEZxxvhJqRT3H0-U,95
103
+ torchzero/modules/weight_decay/weight_decay.py,sha256=UFL9W5w5nzTZGWvCwyGLe9UWBKN8FTClme1Klt7XZPw,3034
104
+ torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
105
+ torchzero/modules/wrappers/optim_wrapper.py,sha256=-wNI-fN8eaMSkvPIcPa34yxH0St5aLn7jaaLeh2DUsM,3569
106
+ torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
107
+ torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
108
+ torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
109
+ torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
110
+ torchzero/optim/wrappers/directsearch.py,sha256=Y2-7Sy4mYRPXPh0FTlsY_XOk5pCGjZsnbrlWCPZNp6A,10141
111
+ torchzero/optim/wrappers/fcmaes.py,sha256=TQvIktXV8ldy6smBX-S7ZcQEbSmSZyj567TuYShbvJg,3513
112
+ torchzero/optim/wrappers/mads.py,sha256=lC7edtrFS37PgmX7z9-eoqw6prl0k5BDB4NVBVQXJWE,2945
113
+ torchzero/optim/wrappers/nevergrad.py,sha256=qslMb-4_kfjU3Dd0UbbzE2SdLViil3Qjo2v0FtPE3Fg,4000
114
+ torchzero/optim/wrappers/nlopt.py,sha256=AaVEKfjbrt5DFION44_-g-jQAoVi4lCvBBPU5UDGO9Q,8151
115
+ torchzero/optim/wrappers/optuna.py,sha256=YN1I3rzsi20A9963pWNWd7W75FkxalVb5z5fCRQeWA0,2280
116
+ torchzero/optim/wrappers/scipy.py,sha256=pR26v8v0a-o2u0sbsKXpZ9JUKqXMaaI8gGLI8xYx3-s,19239
117
+ torchzero/utils/__init__.py,sha256=7beAjXvnmBQoy5hwYHY_PBUtrrbYb9Z7-KrYgfcFkPE,844
118
+ torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
119
+ torchzero/utils/derivatives.py,sha256=sAVd0Q1xmIPpo_AxRuoow66Hy_3goX_9o3lQK_1TyW0,16909
120
+ torchzero/utils/numberlist.py,sha256=cbG0UsSb9WCRxVhw8sd7Yf0bDy_gSqtghiJtkUxIO6U,6139
121
+ torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
122
+ torchzero/utils/optimizer.py,sha256=r52qu6pEcRH4lCXVlLxW5IweA6L-VrQj6RCMfdhzRpw,12466
123
+ torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
124
+ torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
125
+ torchzero/utils/python_tools.py,sha256=T5W7MpR7pNXiWSVw7gj-UuE9Ch0p9LRWuUZfg9Vtb-I,2794
126
+ torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
127
+ torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
128
+ torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
129
+ torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
130
+ torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
131
+ torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
132
+ torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
133
+ torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
134
+ torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
135
+ torchzero-0.3.10.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
136
+ torchzero-0.3.10.dist-info/METADATA,sha256=_J7AbrIa-nD6UWbuydCwxAnSpKcC9O1Vp_rM896ZkYQ,14081
137
+ torchzero-0.3.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
138
+ torchzero-0.3.10.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
139
+ torchzero-0.3.10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,138 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from collections import ChainMap, defaultdict
3
- from collections.abc import Mapping, Sequence
4
- from typing import Any, overload, final
5
-
6
- import torch
7
-
8
- from .module import Module, Chainable, Vars
9
- from .transform import apply, Transform, Target
10
- from ..utils import TensorList, vec_to_tensors
11
-
12
- class Preconditioner(Transform):
13
- """Abstract class for a preconditioner."""
14
- def __init__(
15
- self,
16
- defaults: dict | None,
17
- uses_grad: bool,
18
- concat_params: bool = False,
19
- update_freq: int = 1,
20
- scale_first: bool = False,
21
- inner: Chainable | None = None,
22
- target: Target = "update",
23
- ):
24
- if defaults is None: defaults = {}
25
- defaults.update(dict(__update_freq=update_freq, __concat_params=concat_params, __scale_first=scale_first))
26
- super().__init__(defaults, uses_grad=uses_grad, target=target)
27
-
28
- if inner is not None:
29
- self.set_child('inner', inner)
30
-
31
- @abstractmethod
32
- def update(self, tensors: list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]):
33
- """updates the preconditioner with `tensors`, any internal state should be stored using `keys`"""
34
-
35
- @abstractmethod
36
- def apply(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, states: list[dict[str, Any]], settings: Sequence[Mapping[str, Any]]) -> list[torch.Tensor]:
37
- """applies preconditioner to `tensors`, any internal state should be stored using `keys`"""
38
-
39
-
40
- def _tensor_wise_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
41
- step = self.global_state.get('__step', 0)
42
- states = [self.state[p] for p in params]
43
- settings = [self.settings[p] for p in params]
44
- global_settings = settings[0]
45
- update_freq = global_settings['__update_freq']
46
-
47
- scale_first = global_settings['__scale_first']
48
- scale_factor = 1
49
- if scale_first and step == 0:
50
- # initial step size guess from pytorch LBFGS
51
- scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
52
- scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
53
-
54
- # update preconditioner
55
- if step % update_freq == 0:
56
- self.update(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
57
-
58
- # step with inner
59
- if 'inner' in self.children:
60
- tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
61
-
62
- # apply preconditioner
63
- tensors = self.apply(tensors=tensors, params=params, grads=grads, states=states, settings=settings)
64
-
65
- # scale initial step, when preconditioner might not have been applied
66
- if scale_first and step == 0:
67
- torch._foreach_mul_(tensors, scale_factor)
68
-
69
- self.global_state['__step'] = step + 1
70
- return tensors
71
-
72
- def _concat_transform(self, tensors:list[torch.Tensor], params:list[torch.Tensor], grads:list[torch.Tensor] | None, vars:Vars) -> list[torch.Tensor]:
73
- step = self.global_state.get('__step', 0)
74
- tensors_vec = torch.cat([t.ravel() for t in tensors])
75
- params_vec = torch.cat([p.ravel() for p in params])
76
- grads_vec = [torch.cat([g.ravel() for g in grads])] if grads is not None else None
77
-
78
- states = [self.state[params[0]]]
79
- settings = [self.settings[params[0]]]
80
- global_settings = settings[0]
81
- update_freq = global_settings['__update_freq']
82
-
83
- scale_first = global_settings['__scale_first']
84
- scale_factor = 1
85
- if scale_first and step == 0:
86
- # initial step size guess from pytorch LBFGS
87
- scale_factor = 1 / tensors_vec.abs().sum().clip(min=1)
88
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_vec.dtype).eps)
89
-
90
- # update preconditioner
91
- if step % update_freq == 0:
92
- self.update(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)
93
-
94
- # step with inner
95
- if 'inner' in self.children:
96
- tensors = apply(self.children['inner'], tensors=tensors, params=params, grads=grads, vars=vars)
97
- tensors_vec = torch.cat([t.ravel() for t in tensors]) # have to recat
98
-
99
- # apply preconditioner
100
- tensors_vec = self.apply(tensors=[tensors_vec], params=[params_vec], grads=grads_vec, states=states, settings=settings)[0]
101
-
102
- # scale initial step, when preconditioner might not have been applied
103
- if scale_first and step == 0:
104
- tensors_vec *= scale_factor
105
-
106
- tensors = vec_to_tensors(vec=tensors_vec, reference=tensors)
107
- self.global_state['__step'] = step + 1
108
- return tensors
109
-
110
- @torch.no_grad
111
- def transform(self, tensors, params, grads, vars):
112
- concat_params = self.settings[params[0]]['__concat_params']
113
- if concat_params: return self._concat_transform(tensors, params, grads, vars)
114
- return self._tensor_wise_transform(tensors, params, grads, vars)
115
-
116
- class TensorwisePreconditioner(Preconditioner, ABC):
117
- @abstractmethod
118
- def update_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]):
119
- """update preconditioner with `tensor`"""
120
-
121
- @abstractmethod
122
- def apply_tensor(self, tensor: torch.Tensor, param:torch.Tensor, grad: torch.Tensor | None, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
123
- """apply preconditioner to `tensor`"""
124
-
125
- @final
126
- def update(self, tensors, params, grads, states, settings):
127
- if grads is None: grads = [None]*len(tensors)
128
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
129
- self.update_tensor(t, p, g, state, setting)
130
-
131
- @final
132
- def apply(self, tensors, params, grads, states, settings):
133
- preconditioned = []
134
- if grads is None: grads = [None]*len(tensors)
135
- for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
136
- preconditioned.append(self.apply_tensor(t, p, g, state, setting))
137
- return preconditioned
138
-