torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import torch
2
- from ...core import Module, Chainable, apply_transform
2
+ from ...core import Transform
3
3
 
4
4
  from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
- from ...utils import vec_to_tensors, TensorList
6
- from ...utils.linalg import linear_operator
7
- from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
5
+ from ...utils import vec_to_tensors
6
+ from ...linalg import linear_operator
7
+ from .ggt import ggt_update
8
8
 
9
- class NaturalGradient(Module):
9
+ class NaturalGradient(Transform):
10
10
  """Natural gradient approximated via empirical fisher information matrix.
11
11
 
12
12
  To use this, either pass vector of per-sample losses to the step method, or make sure
@@ -27,9 +27,9 @@ class NaturalGradient(Module):
27
27
  with a vector that isn't strictly per-sample gradients, but rather for example different losses.
28
28
  gn_grad (bool, optional):
29
29
  if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
30
- and is equivalent to squaring the values. This way you can solve least-squares
31
- objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
32
- This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
30
+ and is equivalent to squaring the values. That makes the kernel trick solver incorrect, but for
31
+ some reason it still works. If False, uses sum of per-sample gradients.
32
+ This has an effect when ``sqrt=False``, and affects the ``grad`` attribute.
33
33
  Defaults to False.
34
34
  batched (bool, optional): whether to use vmapping. Defaults to True.
35
35
 
@@ -41,7 +41,7 @@ class NaturalGradient(Module):
41
41
  y = torch.randn(64, 10)
42
42
 
43
43
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
44
- opt = tz.Modular(
44
+ opt = tz.Optimizer(
45
45
  model.parameters(),
46
46
  tz.m.NaturalGradient(),
47
47
  tz.m.LR(3e-2)
@@ -61,7 +61,7 @@ class NaturalGradient(Module):
61
61
  y = torch.randn(64, 10)
62
62
 
63
63
  model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
64
- opt = tz.Modular(
64
+ opt = tz.Optimizer(
65
65
  model.parameters(),
66
66
  tz.m.NaturalGradient(),
67
67
  tz.m.LR(3e-2)
@@ -84,7 +84,7 @@ class NaturalGradient(Module):
84
84
  return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
85
85
 
86
86
  X = torch.tensor([-1.1, 2.5], requires_grad=True)
87
- opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
87
+ opt = tz.Optimizer([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
88
88
 
89
89
  for iter in range(200):
90
90
  losses = rosenbrock(X)
@@ -97,20 +97,27 @@ class NaturalGradient(Module):
97
97
  super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
98
98
 
99
99
  @torch.no_grad
100
- def update(self, var):
101
- params = var.params
102
- batched = self.defaults['batched']
103
- gn_grad = self.defaults['gn_grad']
104
-
105
- closure = var.closure
106
- assert closure is not None
107
-
100
+ def update_states(self, objective, states, settings):
101
+ params = objective.params
102
+ closure = objective.closure
103
+ fs = settings[0]
104
+ batched = fs['batched']
105
+ gn_grad = fs['gn_grad']
106
+
107
+ # compute per-sample losses
108
+ f = objective.loss
109
+ if f is None:
110
+ assert closure is not None
111
+ with torch.enable_grad():
112
+ f = objective.get_loss(backward=False) # n_out
113
+ assert isinstance(f, torch.Tensor)
114
+
115
+ # compute per-sample gradients
108
116
  with torch.enable_grad():
109
- f = var.get_loss(backward=False) # n_out
110
- assert isinstance(f, torch.Tensor)
111
117
  G_list = jacobian_wrt([f.ravel()], params, batched=batched)
112
118
 
113
- var.loss = f.sum()
119
+ # set scalar loss and it's grad to objective
120
+ objective.loss = f.sum()
114
121
  G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
115
122
 
116
123
  if gn_grad:
@@ -119,13 +126,15 @@ class NaturalGradient(Module):
119
126
  else:
120
127
  g = self.global_state["g"] = G.sum(0)
121
128
 
122
- var.grad = vec_to_tensors(g, params)
129
+ objective.grads = vec_to_tensors(g, params)
123
130
 
124
131
  # set closure to calculate scalar value for line searches etc
125
- if var.closure is not None:
132
+ if closure is not None:
133
+
126
134
  def ngd_closure(backward=True):
135
+
127
136
  if backward:
128
- var.zero_grad()
137
+ objective.zero_grad()
129
138
  with torch.enable_grad():
130
139
  loss = closure(False)
131
140
  if gn_grad: loss = loss.pow(2)
@@ -137,39 +146,52 @@ class NaturalGradient(Module):
137
146
  if gn_grad: loss = loss.pow(2)
138
147
  return loss.sum()
139
148
 
140
- var.closure = ngd_closure
149
+ objective.closure = ngd_closure
141
150
 
142
151
  @torch.no_grad
143
- def apply(self, var):
144
- params = var.params
145
- reg = self.defaults['reg']
146
- sqrt = self.defaults['sqrt']
152
+ def apply_states(self, objective, states, settings):
153
+ params = objective.params
154
+ fs = settings[0]
155
+ reg = fs['reg']
156
+ sqrt = fs['sqrt']
147
157
 
148
158
  G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
149
159
 
150
160
  if sqrt:
151
161
  # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
152
162
  # but it computes it through eigendecompotision
153
- U, L = lm_adagrad_update(G.H, reg, 0)
154
- if U is None or L is None: return var
163
+ L, U = ggt_update(G.H, damping=reg, rdamping=1e-16, truncate=0, eig_tol=1e-12)
164
+
165
+ if U is None or L is None:
166
+
167
+ # fallback to element-wise
168
+ g = self.global_state["g"]
169
+ g /= G.square().mean(0).sqrt().add(reg)
170
+ objective.updates = vec_to_tensors(g, params)
171
+ return objective
155
172
 
156
- v = lm_adagrad_apply(self.global_state["g"], U, L)
157
- var.update = vec_to_tensors(v, params)
158
- return var
173
+ # whiten
174
+ z = U.T @ self.global_state["g"]
175
+ v = (U * L.rsqrt()) @ z
176
+ objective.updates = vec_to_tensors(v, params)
177
+ return objective
159
178
 
160
- GGT = G @ G.H # (n_samples, n_samples)
179
+ # we need (G^T G)v = g
180
+ # where g = G^T
181
+ # so we need to solve (G^T G)v = G^T
182
+ GGt = G @ G.H # (n_samples, n_samples)
161
183
 
162
184
  if reg != 0:
163
- GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
185
+ GGt.add_(torch.eye(GGt.size(0), device=GGt.device, dtype=GGt.dtype).mul_(reg))
164
186
 
165
- z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
187
+ z, _ = torch.linalg.solve_ex(GGt, torch.ones_like(GGt[0])) # pylint:disable=not-callable
166
188
  v = G.H @ z
167
189
 
168
- var.update = vec_to_tensors(v, params)
169
- return var
190
+ objective.updates = vec_to_tensors(v, params)
191
+ return objective
170
192
 
171
193
 
172
- def get_H(self, var):
194
+ def get_H(self, objective=...):
173
195
  if "G" not in self.global_state: return linear_operator.ScaledIdentity()
174
196
  G = self.global_state['G']
175
197
  return linear_operator.AtA(G)
@@ -1,13 +1,9 @@
1
- from operator import itemgetter
2
- import math
3
- import warnings
4
- from collections.abc import Iterable, Sequence
5
- from typing import Literal
1
+ from collections.abc import Iterable
6
2
 
7
3
  import torch
8
4
 
9
- from ...core import Target, Transform
10
- from ...utils import as_tensorlist
5
+ from ...core import TensorTransform
6
+ from ...utils import TensorList
11
7
 
12
8
  def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
13
9
  """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
@@ -19,29 +15,29 @@ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
19
15
  reference
20
16
  https://arxiv.org/abs/2501.04697
21
17
  """
22
- params = as_tensorlist(params).with_grad()
18
+ params = TensorList(params).with_grad()
23
19
  grad = params.grad
24
20
  grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
25
21
 
26
22
 
27
- class OrthoGrad(Transform):
23
+ class OrthoGrad(TensorTransform):
28
24
  """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
29
25
 
30
26
  Args:
31
27
  eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
32
28
  renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
33
- target (Target, optional): what to set on var. Defaults to 'update'.
34
29
  """
35
- def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
30
+ def __init__(self, eps: float = 1e-8, renormalize=True):
36
31
  defaults = dict(eps=eps, renormalize=renormalize)
37
- super().__init__(defaults, uses_grad=False, target=target)
32
+ super().__init__(defaults)
38
33
 
39
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
34
+ @torch.no_grad
35
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
40
36
  eps = settings[0]['eps']
41
37
  renormalize = settings[0]['renormalize']
42
38
 
43
- params = as_tensorlist(params)
44
- target = as_tensorlist(tensors)
39
+ params = TensorList(params)
40
+ target = TensorList(tensors)
45
41
 
46
42
  scale = params.dot(target)/(params.dot(params) + eps)
47
43
  if renormalize:
@@ -0,0 +1,5 @@
1
+ from .psgd_dense_newton import PSGDDenseNewton
2
+ from .psgd_kron_newton import PSGDKronNewton
3
+ from .psgd_kron_whiten import PSGDKronWhiten
4
+ from .psgd_lra_newton import PSGDLRANewton
5
+ from .psgd_lra_whiten import PSGDLRAWhiten
@@ -0,0 +1,37 @@
1
+ # pylint:disable=not-callable
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ from .psgd import lift2single
7
+
8
+
9
+ def _initialize_lra_state_(tensor: torch.Tensor, state, setting):
10
+ n = tensor.numel()
11
+ rank = max(min(setting["rank"], n-1), 1)
12
+ dtype=tensor.dtype
13
+ device=tensor.device
14
+
15
+ U = torch.randn((n, rank), dtype=dtype, device=device)
16
+ U *= 0.1**0.5 / torch.linalg.vector_norm(U)
17
+
18
+ V = torch.randn((n, rank), dtype=dtype, device=device)
19
+ V *= 0.1**0.5 / torch.linalg.vector_norm(V)
20
+
21
+ if setting["init_scale"] is None:
22
+ # warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
23
+ d = None
24
+ else:
25
+ d = torch.ones(n, 1, dtype=dtype, device=device) * setting["init_scale"]
26
+
27
+ state["UVd"] = [U, V, d]
28
+ state["Luvd"] = [lift2single(torch.zeros([], dtype=dtype, device=device)) for _ in range(3)]
29
+
30
+
31
+
32
+ def _wrap_with_no_backward(opt):
33
+ """to use original psgd opts with visualbench"""
34
+ class _Wrapped:
35
+ def step(self, closure):
36
+ return opt.step(lambda: closure(False))
37
+ return _Wrapped()