torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -10,8 +10,10 @@ from ...linalg.orthogonalize import orthogonalize as _orthogonalize, Orthogonali
10
10
  def reverse_dims(t:torch.Tensor):
11
11
  return t.permute(*reversed(range(t.ndim)))
12
12
 
13
- def _is_at_least_2d(p: torch.Tensor):
14
- if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
13
+ def _is_at_least_2d(p: torch.Tensor, channel_first:bool):
14
+ if p.ndim < 2: return False
15
+ if channel_first and (p.size(0) > 1) and (p.size(1) > 1): return True
16
+ if (not channel_first) and (p.size(-2) > 1) and (p.size(-1) > 1): return True
15
17
  return False
16
18
 
17
19
  def _orthogonalize_format(
@@ -19,6 +21,7 @@ def _orthogonalize_format(
19
21
  method: OrthogonalizeMethod,
20
22
  channel_first: bool,
21
23
  ):
24
+ """orthogonalize either 1st two dims if channel first or last two otherwise"""
22
25
  if channel_first:
23
26
  return reverse_dims(_orthogonalize(reverse_dims(tensor), method=method))
24
27
 
@@ -69,7 +72,7 @@ def orthogonalize_grads_(
69
72
  are considered batch dimensions.
70
73
  """
71
74
  for p in params:
72
- if (p.grad is not None) and _is_at_least_2d(p.grad):
75
+ if (p.grad is not None) and _is_at_least_2d(p.grad, channel_first=channel_first):
73
76
  X = _orthogonalize_format(p.grad, method=method, channel_first=channel_first)
74
77
  if dual_norm_correction: X = _dual_norm_correction(X, p.grad, channel_first=False)
75
78
  p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
@@ -100,7 +103,7 @@ class Orthogonalize(TensorTransform):
100
103
 
101
104
  standard Muon with Adam fallback
102
105
  ```py
103
- opt = tz.Modular(
106
+ opt = tz.Optimizer(
104
107
  model.head.parameters(),
105
108
  tz.m.Split(
106
109
  # apply muon only to 2D+ parameters
@@ -131,7 +134,7 @@ class Orthogonalize(TensorTransform):
131
134
 
132
135
  if not orthogonalize: return tensor
133
136
 
134
- if _is_at_least_2d(tensor):
137
+ if _is_at_least_2d(tensor, channel_first=channel_first):
135
138
 
136
139
  X = _orthogonalize_format(tensor, method, channel_first=channel_first)
137
140
 
@@ -173,7 +176,7 @@ class MuonAdjustLR(Transform):
173
176
  alphas = [s['alpha'] for s in settings]
174
177
  channel_first = [s["channel_first=channel_first"] for s in settings]
175
178
  tensors_alphas = [
176
- (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t)
179
+ (t, adjust_lr_for_muon(a, t.shape, cf)) for t, a, cf in zip(tensors, alphas, channel_first) if _is_at_least_2d(t, channel_first=cf)
177
180
  ]
178
181
  tensors = [i[0] for i in tensors_alphas]
179
182
  a = [i[1] for i in alphas]
@@ -4,7 +4,7 @@ from ...core import Transform
4
4
  from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
5
  from ...utils import vec_to_tensors
6
6
  from ...linalg import linear_operator
7
- from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
7
+ from .ggt import ggt_update
8
8
 
9
9
  class NaturalGradient(Transform):
10
10
  """Natural gradient approximated via empirical fisher information matrix.
@@ -41,7 +41,7 @@ class NaturalGradient(Transform):
41
41
  y = torch.randn(64, 10)
42
42
 
43
43
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
44
- opt = tz.Modular(
44
+ opt = tz.Optimizer(
45
45
  model.parameters(),
46
46
  tz.m.NaturalGradient(),
47
47
  tz.m.LR(3e-2)
@@ -61,7 +61,7 @@ class NaturalGradient(Transform):
61
61
  y = torch.randn(64, 10)
62
62
 
63
63
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
64
- opt = tz.Modular(
64
+ opt = tz.Optimizer(
65
65
  model.parameters(),
66
66
  tz.m.NaturalGradient(),
67
67
  tz.m.LR(3e-2)
@@ -84,7 +84,7 @@ class NaturalGradient(Transform):
84
84
  return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
85
85
 
86
86
  X = torch.tensor([-1.1, 2.5], requires_grad=True)
87
- opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
87
+ opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
88
88
 
89
89
  for iter in range(200):
90
90
  losses = rosenbrock(X)
@@ -99,18 +99,24 @@ class NaturalGradient(Transform):
99
99
  @torch.no_grad
100
100
  def update_states(self, objective, states, settings):
101
101
  params = objective.params
102
+ closure = objective.closure
102
103
  fs = settings[0]
103
104
  batched = fs['batched']
104
105
  gn_grad = fs['gn_grad']
105
106
 
106
- closure = objective.closure
107
- assert closure is not None
107
+ # compute per-sample losses
108
+ f = objective.loss
109
+ if f is None:
110
+ assert closure is not None
111
+ with torch.enable_grad():
112
+ f = objective.get_loss(backward=False) # n_out
113
+ assert isinstance(f, torch.Tensor)
108
114
 
115
+ # compute per-sample gradients
109
116
  with torch.enable_grad():
110
- f = objective.get_loss(backward=False) # n_out
111
- assert isinstance(f, torch.Tensor)
112
117
  G_list = jacobian_wrt([f.ravel()], params, batched=batched)
113
118
 
119
+ # set scalar loss and it's grad to objective
114
120
  objective.loss = f.sum()
115
121
  G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
116
122
 
@@ -123,8 +129,10 @@ class NaturalGradient(Transform):
123
129
  objective.grads = vec_to_tensors(g, params)
124
130
 
125
131
  # set closure to calculate scalar value for line searches etc
126
- if objective.closure is not None:
132
+ if closure is not None:
133
+
127
134
  def ngd_closure(backward=True):
135
+
128
136
  if backward:
129
137
  objective.zero_grad()
130
138
  with torch.enable_grad():
@@ -152,22 +160,31 @@ class NaturalGradient(Transform):
152
160
  if sqrt:
153
161
  # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
154
162
  # but it computes it through eigendecompotision
155
- U, L = lm_adagrad_update(G.H, reg, 0)
156
- if U is None or L is None: return objective
163
+ L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)
164
+
165
+ if U is None or L is None:
166
+
167
+ # fallback to element-wise
168
+ g = self.global_state["g"]
169
+ g /= G.square().mean(0).sqrt().add(reg)
170
+ objective.updates = vec_to_tensors(g, params)
171
+ return objective
157
172
 
158
- v = lm_adagrad_apply(self.global_state["g"], U, L)
173
+ # whiten
174
+ z = U.T @ self.global_state["g"]
175
+ v = (U * L.rsqrt()) @ z
159
176
  objective.updates = vec_to_tensors(v, params)
160
177
  return objective
161
178
 
162
179
  # we need (G^T G)v = g
163
180
  # where g = G^T
164
181
  # so we need to solve (G^T G)v = G^T
165
- GGT = G @ G.H # (n_samples, n_samples)
182
+ GGt = G @ G.H # (n_samples, n_samples)
166
183
 
167
184
  if reg != 0:
168
- GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
185
+ GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))
169
186
 
170
- z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
187
+ z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
171
188
  v = G.H @ z
172
189
 
173
190
  objective.updates = vec_to_tensors(v, params)
@@ -0,0 +1,5 @@
1
+ from .psgd_dense_newton import PSGDDenseNewton
2
+ from .psgd_kron_newton import PSGDKronNewton
3
+ from .psgd_kron_whiten import PSGDKronWhiten
4
+ from .psgd_lra_newton import PSGDLRANewton
5
+ from .psgd_lra_whiten import PSGDLRAWhiten
@@ -0,0 +1,37 @@
1
+ # pylint:disable=not-callable
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ from .psgd import lift2single
7
+
8
+
9
+ def _initialize_lra_state_(tensor: torch.Tensor, state, setting):
10
+ n = tensor.numel()
11
+ rank = max(min(setting["rank"], n-1), 1)
12
+ dtype=tensor.dtype
13
+ device=tensor.device
14
+
15
+ U = torch.randn((n, rank), dtype=dtype, device=device)
16
+ U *= 0.1**0.5 / torch.linalg.vector_norm(U)
17
+
18
+ V = torch.randn((n, rank), dtype=dtype, device=device)
19
+ V *= 0.1**0.5 / torch.linalg.vector_norm(V)
20
+
21
+ if setting["init_scale"] is None:
22
+ # warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
23
+ d = None
24
+ else:
25
+ d = torch.ones(n, 1, dtype=dtype, device=device) * setting["init_scale"]
26
+
27
+ state["UVd"] = [U, V, d]
28
+ state["Luvd"] = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
29
+
30
+
31
+
32
+ def _wrap_with_no_backward(opt):
33
+ """to use original psgd opts with visualbench"""
34
+ class _Wrapped:
35
+ def step(self, closure):
36
+ return opt.step(lambda: closure(False))
37
+ return _Wrapped()