torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -81,7 +81,7 @@ class Split(Module):
81
81
  Muon with Adam fallback using same hyperparams as https://github.com/KellerJordan/Muon
82
82
 
83
83
  ```python
84
- opt = tz.Modular(
84
+ opt = tz.Optimizer(
85
85
  model.parameters(),
86
86
  tz.m.NAG(0.95),
87
87
  tz.m.Split(
@@ -19,7 +19,7 @@ class Alternate(Module):
19
19
 
20
20
  ```python
21
21
 
22
- opt = tz.Modular(
22
+ opt = tz.Optimizer(
23
23
  model.parameters(),
24
24
  tz.m.Alternate(
25
25
  tz.m.Adam(),
@@ -89,7 +89,7 @@ class Switch(Alternate):
89
89
  Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
90
90
 
91
91
  ```python
92
- opt = tz.Modular(
92
+ opt = tz.Optimizer(
93
93
  model.parameters(),
94
94
  tz.m.Switch(
95
95
  [tz.m.Adam(), tz.m.LR(1e-3)],
@@ -57,7 +57,7 @@ class Cautious(TensorTransform):
57
57
  Cautious Adam
58
58
 
59
59
  ```python
60
- opt = tz.Modular(
60
+ opt = tz.Optimizer(
61
61
  bench.parameters(),
62
62
  tz.m.Adam(),
63
63
  tz.m.Cautious(),
@@ -173,7 +173,7 @@ class ScaleByGradCosineSimilarity(TensorTransform):
173
173
 
174
174
  Scaled Adam
175
175
  ```python
176
- opt = tz.Modular(
176
+ opt = tz.Optimizer(
177
177
  bench.parameters(),
178
178
  tz.m.Adam(),
179
179
  tz.m.ScaleByGradCosineSimilarity(),
@@ -211,7 +211,7 @@ class ScaleModulesByCosineSimilarity(Module):
211
211
 
212
212
  Adam scaled by similarity to RMSprop
213
213
  ```python
214
- opt = tz.Modular(
214
+ opt = tz.Optimizer(
215
215
  bench.parameters(),
216
216
  tz.m.ScaleModulesByCosineSimilarity(
217
217
  main = tz.m.Adam(),
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..functional import debias, ema_
9
+ from ..opt_utils import debias, ema_
10
10
 
11
11
 
12
12
  class EMA(TensorTransform):
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from ...core import TensorTransform
8
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
- from ..functional import (
9
+ from ..opt_utils import (
10
10
  centered_ema_sq_,
11
11
  debias,
12
12
  debias_second_momentum,
@@ -144,7 +144,7 @@ class Graft(MultiOperationBase):
144
144
 
145
145
  Shampoo grafted to Adam
146
146
  ```python
147
- opt = tz.Modular(
147
+ opt = tz.Optimizer(
148
148
  model.parameters(),
149
149
  tz.m.GraftModules(
150
150
  direction = tz.m.Shampoo(),
@@ -149,8 +149,11 @@ class ProjectionBase(Module, ABC):
149
149
  Iterable[torch.Tensor]: unprojected tensors of the same shape as params
150
150
  """
151
151
 
152
+ def update(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
153
+ def apply(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
154
+
152
155
  @torch.no_grad
153
- def apply(self, objective: Objective):
156
+ def step(self, objective: Objective):
154
157
  params = objective.params
155
158
  settings = [self.settings[p] for p in params]
156
159
 
@@ -266,7 +269,7 @@ class ProjectionBase(Module, ABC):
266
269
 
267
270
  # ----------------------------------- step ----------------------------------- #
268
271
  projected_obj.params = projected_params
269
- projected_obj = self.children['modules'].apply(projected_obj)
272
+ projected_obj = self.children['modules'].step(projected_obj)
270
273
 
271
274
  # empty fake params storage
272
275
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
@@ -30,4 +30,4 @@ from .quasi_newton import (
30
30
  ThomasOptimalMethod,
31
31
  )
32
32
 
33
- from .sg2 import SG2, SPSA2
33
+ from .sg2 import SG2
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...utils import TensorList
7
7
  from ...linalg.linear_operator import DenseInverse, LinearOperator
8
- from ..functional import safe_clip
8
+ from ..opt_utils import safe_clip
9
9
 
10
10
 
11
11
  class DampingStrategy(Protocol):
@@ -9,7 +9,7 @@ from .quasi_newton import (
9
9
  _InverseHessianUpdateStrategyDefaults,
10
10
  )
11
11
 
12
- from ..functional import safe_clip
12
+ from ..opt_utils import safe_clip
13
13
 
14
14
 
15
15
  def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
@@ -7,7 +7,7 @@ import torch
7
7
  from ...core import Chainable, TensorTransform
8
8
  from ...utils import TensorList, as_tensorlist, unpack_states
9
9
  from ...linalg.linear_operator import LinearOperator
10
- from ..functional import initial_step_size
10
+ from ..opt_utils import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
13
13
 
@@ -188,7 +188,7 @@ class LBFGS(TensorTransform):
188
188
 
189
189
  L-BFGS with line search
190
190
  ```python
191
- opt = tz.Modular(
191
+ opt = tz.Optimizer(
192
192
  model.parameters(),
193
193
  tz.m.LBFGS(100),
194
194
  tz.m.Backtracking()
@@ -197,7 +197,7 @@ class LBFGS(TensorTransform):
197
197
 
198
198
  L-BFGS with trust region
199
199
  ```python
200
- opt = tz.Modular(
200
+ opt = tz.Optimizer(
201
201
  model.parameters(),
202
202
  tz.m.TrustCG(tz.m.LBFGS())
203
203
  )
@@ -7,7 +7,7 @@ import torch
7
7
  from ...core import Chainable, Module, TensorTransform, Objective, step
8
8
  from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
9
9
  from ...linalg.linear_operator import LinearOperator
10
- from ..functional import initial_step_size
10
+ from ..opt_utils import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
13
13
 
@@ -110,7 +110,7 @@ class LSR1(TensorTransform):
110
110
 
111
111
  L-SR1 with line search
112
112
  ```python
113
- opt = tz.Modular(
113
+ opt = tz.Optimizer(
114
114
  model.parameters(),
115
115
  tz.m.SR1(),
116
116
  tz.m.StrongWolfe(c2=0.1, fallback=True)
@@ -119,7 +119,7 @@ class LSR1(TensorTransform):
119
119
 
120
120
  L-SR1 with trust region
121
121
  ```python
122
- opt = tz.Modular(
122
+ opt = tz.Optimizer(
123
123
  model.parameters(),
124
124
  tz.m.TrustCG(tz.m.LSR1())
125
125
  )
@@ -8,7 +8,7 @@ import torch
8
8
  from ...core import Chainable, Module, TensorTransform, Transform
9
9
  from ...utils import TensorList, set_storage_, unpack_states, safe_dict_update_
10
10
  from ...linalg import linear_operator
11
- from ..functional import initial_step_size, safe_clip
11
+ from ..opt_utils import initial_step_size, safe_clip
12
12
 
13
13
 
14
14
 
@@ -106,11 +106,12 @@ class HessianUpdateStrategy(TensorTransform, ABC):
106
106
  scale_first: bool = False,
107
107
  concat_params: bool = True,
108
108
  inverse: bool = True,
109
+ uses_loss: bool = False,
109
110
  inner: Chainable | None = None,
110
111
  ):
111
112
  if defaults is None: defaults = {}
112
113
  safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, ptol=ptol, ptol_restart=ptol_restart, gtol=gtol, inverse=inverse, beta=beta, restart_interval=restart_interval, scale_first=scale_first))
113
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
114
+ super().__init__(defaults, uses_loss=uses_loss, concat_params=concat_params, update_freq=update_freq, inner=inner)
114
115
 
115
116
  def reset_for_online(self):
116
117
  super().reset_for_online()
@@ -141,18 +142,22 @@ class HessianUpdateStrategy(TensorTransform, ABC):
141
142
  return H
142
143
 
143
144
  # ------------------------------ common methods ------------------------------ #
144
- def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
145
+ def auto_initial_scale(self, s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float | None:
145
146
  """returns multiplier to B on 2nd step if ``init_scale='auto'``. H should be divided by this!"""
146
147
  ys = y.dot(s)
147
148
  yy = y.dot(y)
148
- if ys != 0 and yy != 0: return yy/ys
149
- return 1
149
+ tiny = torch.finfo(ys.dtype).tiny * 2
150
+ if ys > tiny and yy > tiny: return yy/ys
151
+ return None
150
152
 
151
- def reset_P(self, P: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
153
+ def reset_P(self, P: torch.Tensor, s:torch.Tensor, y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]) -> None:
152
154
  """resets ``P`` which is either B or H"""
153
155
  set_storage_(P, self.initialize_P(s.numel(), device=P.device, dtype=P.dtype, is_inverse=inverse))
154
- if init_scale == 'auto': init_scale = self.auto_initial_scale(s,y)
155
- if init_scale >= 1:
156
+ if init_scale == 'auto':
157
+ init_scale = self.auto_initial_scale(s,y)
158
+ state["scaled"] = init_scale is not None
159
+
160
+ if init_scale is not None and init_scale != 1:
156
161
  if inverse: P /= init_scale
157
162
  else: P *= init_scale
158
163
 
@@ -182,6 +187,7 @@ class HessianUpdateStrategy(TensorTransform, ABC):
182
187
  state['f_prev'] = loss
183
188
  state['p_prev'] = p.clone()
184
189
  state['g_prev'] = g.clone()
190
+ state["scaled"] = False
185
191
  return
186
192
 
187
193
  state['f'] = loss
@@ -205,9 +211,13 @@ class HessianUpdateStrategy(TensorTransform, ABC):
205
211
  if gtol is not None and y.abs().max() <= gtol:
206
212
  return
207
213
 
208
- if step == 2 and init_scale == 'auto':
209
- if inverse: M /= self.auto_initial_scale(s,y)
210
- else: M *= self.auto_initial_scale(s,y)
214
+ # apply automatic initial scale if it hasn't been applied
215
+ if (not state["scaled"]) and (init_scale == 'auto'):
216
+ scale = self.auto_initial_scale(s,y)
217
+ if scale is not None:
218
+ state["scaled"] = True
219
+ if inverse: M /= self.auto_initial_scale(s,y)
220
+ else: M *= self.auto_initial_scale(s,y)
211
221
 
212
222
  beta = setting['beta']
213
223
  if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
@@ -367,22 +377,21 @@ def bfgs_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
367
377
  B += term1.sub_(term2)
368
378
  return B
369
379
 
370
- def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
380
+
381
+ def bfgs_H_(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol: float):
371
382
  sy = s.dot(y)
372
383
  if sy <= tol: return H
373
384
 
374
- sy_sq = safe_clip(sy**2)
375
-
376
- Hy = H@y
377
- scale1 = (sy + y.dot(Hy)) / sy_sq
378
- term1 = s.outer(s).mul_(scale1)
385
+ rho = 1.0 / sy
386
+ Hy = H @ y
379
387
 
380
- num2 = (Hy.outer(s)).add_(s.outer(y @ H))
381
- term2 = num2.div_(sy)
388
+ term1 = (s.outer(s)).mul_(rho * (1 + rho * y.dot(Hy)))
389
+ term2 = (Hy.outer(s) + s.outer(Hy)).mul_(rho)
382
390
 
383
- H += term1.sub_(term2)
391
+ H.add_(term1).sub_(term2)
384
392
  return H
385
393
 
394
+
386
395
  class BFGS(_InverseHessianUpdateStrategyDefaults):
387
396
  """Broyden–Fletcher–Goldfarb–Shanno Quasi-Newton method. This is usually the most stable quasi-newton method.
388
397
 
@@ -428,7 +437,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
428
437
  BFGS with backtracking line search:
429
438
 
430
439
  ```python
431
- opt = tz.Modular(
440
+ opt = tz.Optimizer(
432
441
  model.parameters(),
433
442
  tz.m.BFGS(),
434
443
  tz.m.Backtracking()
@@ -437,7 +446,7 @@ class BFGS(_InverseHessianUpdateStrategyDefaults):
437
446
 
438
447
  BFGS with trust region
439
448
  ```python
440
- opt = tz.Modular(
449
+ opt = tz.Optimizer(
441
450
  model.parameters(),
442
451
  tz.m.LevenbergMarquardt(tz.m.BFGS(inverse=False)),
443
452
  )
@@ -505,7 +514,7 @@ class SR1(_InverseHessianUpdateStrategyDefaults):
505
514
 
506
515
  SR1 with trust region
507
516
  ```python
508
- opt = tz.Modular(
517
+ opt = tz.Optimizer(
509
518
  model.parameters(),
510
519
  tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
511
520
  )
@@ -1015,7 +1024,7 @@ class GradientCorrection(TensorTransform):
1015
1024
  L-BFGS with gradient correction
1016
1025
 
1017
1026
  ```python
1018
- opt = tz.Modular(
1027
+ opt = tz.Optimizer(
1019
1028
  model.parameters(),
1020
1029
  tz.m.LBFGS(inner=tz.m.GradientCorrection()),
1021
1030
  tz.m.Backtracking()
@@ -1154,6 +1163,7 @@ class NewSSM(HessianUpdateStrategy):
1154
1163
  scale_first=scale_first,
1155
1164
  concat_params=concat_params,
1156
1165
  inverse=True,
1166
+ uses_loss=True,
1157
1167
  inner=inner,
1158
1168
  )
1159
1169
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
@@ -1171,13 +1181,18 @@ class NewSSM(HessianUpdateStrategy):
1171
1181
 
1172
1182
  # this is supposed to be equivalent (and it is)
1173
1183
  def shor_r_(H:torch.Tensor, y:torch.Tensor, alpha:float):
1174
- p = H@y
1175
- #(1-y)^2 (ppT)/(pTq)
1176
- #term = p.outer(p).div_(p.dot(y).clip(min=1e-32))
1177
- term = p.outer(p).div_(safe_clip(p.dot(y)))
1178
- H.sub_(term, alpha=1-alpha**2)
1184
+ Hy = H @ y
1185
+ yHy = safe_clip(y.dot(Hy))
1186
+ term = Hy.outer(Hy).div_(yHy)
1187
+ H.sub_(term, alpha=(1-alpha**2))
1179
1188
  return H
1180
1189
 
1190
+ # def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
1191
+ # Hy = H @ y
1192
+ # yHy = safe_clip(y.dot(Hy))
1193
+ # H -= (Hy.outer(y) @ H).div_(yHy)
1194
+ # return H
1195
+
1181
1196
  class ShorR(HessianUpdateStrategy):
1182
1197
  """Shor’s r-algorithm.
1183
1198