torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ....core import Chainable, HVPMethod, Transform
10
+ from ....utils import Distributions, TensorList, vec_to_tensors_
11
+ from ._psgd_utils import _initialize_lra_state_
12
+ from .psgd import (
13
+ lift2single,
14
+ update_precond_dense_eq,
15
+ update_precond_dense_q0p5eq1p5,
16
+ update_precond_dense_qep,
17
+ update_precond_dense_qeq,
18
+ update_precond_dense_quad,
19
+ update_precond_dense_quad4p,
20
+ )
21
+
22
+ # matches
23
+ class PSGDDenseNewton(Transform):
24
+ """Dense hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
25
+
26
+ Args:
27
+ init_scale (float | None, optional):
28
+ initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
29
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
30
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
31
+ damping (float, optional):
32
+ adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
33
+ grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
34
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
35
+ dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
36
+ hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
37
+ h (float, optional):
38
+ if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
39
+ Defaults to 1e-3.
40
+ distribution (Distributions, optional):
41
+ distribution for random vectors for hessian-vector products. Defaults to 'normal'.
42
+
43
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
44
+
45
+ ###Examples:
46
+
47
+ Pure Dense Newton PSGD:
48
+ ```py
49
+ optimizer = tz.Optimizer(
50
+ model.parameters(),
51
+ tz.m.DenseNewton(),
52
+ tz.m.LR(1e-3),
53
+ )
54
+ ```
55
+
56
+ Applying preconditioner to momentum:
57
+ ```py
58
+ optimizer = tz.Optimizer(
59
+ model.parameters(),
60
+ tz.m.DenseNewton(inner=tz.m.EMA(0.9)),
61
+ tz.m.LR(1e-3),
62
+ )
63
+ ```
64
+ """
65
+ def __init__(
66
+ self,
67
+ init_scale: float | None = None,
68
+ lr_preconditioner=0.1,
69
+ betaL=0.9,
70
+ damping=1e-9,
71
+ grad_clip_max_norm=float("inf"),
72
+ update_probability=1.0,
73
+ dQ: Literal["QUAD4P", "QUAD", "QEP", "EQ", "QEQ", "Q0p5EQ1p5", "Q0.5EQ1.5"] = "Q0.5EQ1.5",
74
+
75
+ hvp_method: HVPMethod = 'autograd',
76
+ h: float = 1e-3,
77
+ distribution: Distributions = 'normal',
78
+
79
+ inner: Chainable | None = None,
80
+ ):
81
+ defaults = locals().copy()
82
+ del defaults["inner"], defaults["self"]
83
+ super().__init__(defaults, inner=inner)
84
+
85
+
86
+ @torch.no_grad
87
+ def update_states(self, objective, states, settings):
88
+ fs = settings[0]
89
+
90
+ # -------------------------------- initialize -------------------------------- #
91
+ if "Q" not in self.global_state:
92
+
93
+ p = objective.params[0]
94
+ dQ = fs["dQ"]
95
+ init_scale = fs["init_scale"]
96
+
97
+ if init_scale is None:
98
+ self.global_state["Q"] = None
99
+
100
+ else:
101
+ n = sum(p.numel() for p in objective.params)
102
+ if dQ == "QUAD4P":
103
+ init_scale *= init_scale
104
+ self.global_state["Q"] = torch.eye(n, dtype=p.dtype, device=p.device) * init_scale
105
+
106
+ self.global_state["L"] = lift2single(torch.zeros([], dtype=p.dtype, device=p.device)) # Lipschitz smoothness constant estimation for the psgd criterion
107
+
108
+ if dQ == "QUAD4P":
109
+ self.global_state["update_precond"] = update_precond_dense_quad4p
110
+ self.global_state["precond_grad"] = lambda Q, g: Q @ g
111
+ assert torch.finfo(p.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
112
+
113
+ elif dQ == "QUAD":
114
+ self.global_state["update_precond"] = update_precond_dense_quad
115
+ self.global_state["precond_grad"] = lambda Q, g: Q @ (Q @ g) # Q is symmetric; just save one transpose
116
+
117
+ else:
118
+ self.global_state["precond_grad"] = lambda Q, g: Q.T @ (Q @ g)
119
+ if dQ == "QEP":
120
+ self.global_state["update_precond"] = update_precond_dense_qep
121
+ elif dQ == "EQ":
122
+ self.global_state["update_precond"] = update_precond_dense_eq
123
+ elif dQ == "QEQ":
124
+ self.global_state["update_precond"] = update_precond_dense_qeq
125
+ else:
126
+ assert (dQ == "Q0p5EQ1p5") or (dQ == "Q0.5EQ1.5"), f"Invalid choice for dQ: '{dQ}'"
127
+ self.global_state["update_precond"] = update_precond_dense_q0p5eq1p5
128
+
129
+ # ---------------------------------- update ---------------------------------- #
130
+ Q = self.global_state["Q"]
131
+ if (torch.rand([]) < fs["update_probability"]) or Q is None:
132
+
133
+ # hessian-vector product
134
+ vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
135
+ Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
136
+
137
+ v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
138
+ h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)
139
+
140
+ # initialize on the fly
141
+ if Q is None:
142
+ scale = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8)
143
+ if fs["dQ"] == "QUAD4P": # Q actually is P in this case
144
+ scale *= scale
145
+ Q = self.global_state["Q"] = torch.eye(len(v), dtype=v.dtype, device=v.device) * scale
146
+
147
+ # update preconditioner
148
+ self.global_state["update_precond"](
149
+ Q=Q,
150
+ L=self.global_state["L"],
151
+ v=v,
152
+ h=h,
153
+ lr=fs["lr_preconditioner"],
154
+ betaL=fs["betaL"],
155
+ damping=fs["damping"],
156
+ )
157
+
158
+ @torch.no_grad
159
+ def apply_states(self, objective, states, settings):
160
+ updates = objective.get_updates()
161
+
162
+ # cat grads
163
+ g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
164
+ pre_grad = self.global_state["precond_grad"](self.global_state["Q"], g)
165
+
166
+ # norm clipping
167
+ grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
168
+ if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
169
+ grad_norm = torch.linalg.vector_norm(pre_grad)
170
+ if grad_norm > grad_clip_max_norm:
171
+ pre_grad *= grad_clip_max_norm / grad_norm
172
+
173
+ vec_to_tensors_(pre_grad, updates)
174
+ return objective
@@ -0,0 +1,203 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ....core import Chainable, HVPMethod, Transform
10
+ from ....utils import NumberList, TensorList, Distributions
11
+ from .psgd import (
12
+ init_kron,
13
+ precond_grad_kron,
14
+ update_precond_kron_newton_eq,
15
+ update_precond_kron_newton_q0p5eq1p5,
16
+ update_precond_kron_newton_qep,
17
+ update_precond_kron_newton_qeq,
18
+ update_precond_kron_newton_quad,
19
+ update_precond_kron_newton_quad4p,
20
+ )
21
+
22
+ # matches
23
+ class PSGDKronNewton(Transform):
24
+ """Kron hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
25
+
26
+ Args:
27
+ max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
28
+ max_skew (float, optional):
29
+ if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
30
+ init_scale (float | None, optional):
31
+ initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
32
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
33
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
34
+ damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
35
+ grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
36
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
37
+ dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
38
+ balance_probability (float, optional):
39
+ probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
40
+ hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
41
+ h (float, optional):
42
+ if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
43
+ Defaults to 1e-3.
44
+ distribution (Distributions, optional):
45
+ distribution for random vectors for hessian-vector products. Defaults to 'normal'.
46
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
47
+
48
+
49
+ ###Examples:
50
+
51
+ Pure PSGD Kron Newton:
52
+ ```py
53
+ optimizer = tz.Optimizer(
54
+ model.parameters(),
55
+ tz.m.KronNewton(),
56
+ tz.m.LR(1e-3),
57
+ )
58
+ ```
59
+
60
+ Applying preconditioner to momentum:
61
+ ```py
62
+ optimizer = tz.Optimizer(
63
+ model.parameters(),
64
+ tz.m.KronNewton(inner=tz.m.EMA(0.9)),
65
+ tz.m.LR(1e-3),
66
+ )
67
+ ```
68
+ """
69
+ def __init__(
70
+ self,
71
+ max_dim: int = 10_000,
72
+ max_skew: float = 1.0,
73
+ init_scale: float | None = None,
74
+ lr_preconditioner: float = 0.1,
75
+ betaL: float = 0.9,
76
+ damping: float = 1e-9,
77
+ grad_clip_max_amp: float = float("inf"),
78
+ update_probability: float= 1.0,
79
+ dQ: Literal["QEP", "EQ", "QEQ", "QUAD", "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
80
+ balance_probability: float = 0.01,
81
+
82
+ hvp_method: HVPMethod = 'autograd',
83
+ h: float = 1e-3,
84
+ distribution: Distributions = 'normal',
85
+
86
+ inner: Chainable | None = None,
87
+ ):
88
+ defaults = locals().copy()
89
+ del defaults["inner"], defaults["self"]
90
+ super().__init__(defaults, inner=inner)
91
+
92
+
93
+ def _initialize_state(self, param, state, setting):
94
+ assert "initialized" not in state
95
+ state["initialized"] = True
96
+
97
+ # initialize preconditioners
98
+ if setting["init_scale"] is None:
99
+ warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
100
+ state["QLs_exprs"] = None
101
+ else:
102
+ state["QLs_exprs"] = init_kron(
103
+ param.squeeze(),
104
+ Scale=setting["init_scale"],
105
+ max_size=setting["max_dim"],
106
+ max_skew=setting["max_skew"],
107
+ dQ=setting["dQ"],
108
+ )
109
+
110
+ dQ = setting["dQ"]
111
+ if dQ == "QUAD4P":
112
+ assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
113
+ state["update_precond"] = update_precond_kron_newton_quad4p
114
+ state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
115
+
116
+ else:
117
+ state["precond_grad"] = precond_grad_kron
118
+ if dQ == "QEP":
119
+ state["update_precond"] = update_precond_kron_newton_quad
120
+ elif dQ == "EQ":
121
+ state["update_precond"] = update_precond_kron_newton_qep
122
+ elif dQ == "QEQ":
123
+ state["update_precond"] = update_precond_kron_newton_eq
124
+ elif dQ == "QUAD":
125
+ state["update_precond"] = update_precond_kron_newton_qeq
126
+ else:
127
+ assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
128
+ state["update_precond"] = update_precond_kron_newton_q0p5eq1p5
129
+
130
+ @torch.no_grad
131
+ def update_states(self, objective, states, settings):
132
+
133
+ # initialize states
134
+ for param, state, setting in zip(objective.params, states, settings):
135
+ if "initialized" not in state:
136
+ self._initialize_state(param, state, setting)
137
+
138
+ fs = settings[0]
139
+
140
+ uninitialized = any(state["QLs_exprs"] is None for state in states)
141
+ if (torch.rand([]) < fs["update_probability"]) or uninitialized:
142
+
143
+ # hessian-vector product
144
+ vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
145
+ Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
146
+
147
+ # initialize on the fly (why does it use vs?)
148
+ if uninitialized:
149
+
150
+ scale = (sum([torch.sum(torch.abs(v)**2) for v in vs])/sum([v.numel() for v in vs])) ** (1/4) # (mean(|v|^2))^(1/4)
151
+
152
+ scale = scale * (max([torch.mean((torch.abs(h))**4) for h in Hvs]) + fs["damping"]**4) ** (-1/8) # (mean(|v|^2))^(1/4) * (mean(|h|^4))^(-1/8)
153
+
154
+ for h, state, setting in zip(Hvs, states, settings):
155
+ if state["QLs_exprs"] is None:
156
+ state["QLs_exprs"] = init_kron(
157
+ h.squeeze(),
158
+ Scale=scale,
159
+ max_size=setting["max_dim"],
160
+ max_skew=setting["max_skew"],
161
+ dQ=setting["dQ"],
162
+ )
163
+
164
+ # update preconditioner
165
+ for v, h, state, setting in zip(vs, Hvs, states, settings):
166
+ state["update_precond"](
167
+ *state["QLs_exprs"],
168
+ v.squeeze(),
169
+ h.squeeze(),
170
+ lr=setting["lr_preconditioner"],
171
+ betaL=setting["betaL"],
172
+ damping=setting["damping"],
173
+ balance_prob=setting["balance_probability"]
174
+ )
175
+
176
+ @torch.no_grad
177
+ def apply_states(self, objective, states, settings):
178
+
179
+ params = objective.params
180
+ tensors = objective.get_updates()
181
+ pre_tensors = []
182
+
183
+ # precondition
184
+ for param, tensor, state in zip(params, tensors, states):
185
+ t = state["precond_grad"](
186
+ *state["QLs_exprs"],
187
+ tensor.squeeze(),
188
+ )
189
+ pre_tensors.append(t.view_as(param))
190
+
191
+ # norm clipping
192
+ grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
193
+ if grad_clip_max_amp < math.inf:
194
+ pre_tensors = TensorList(pre_tensors)
195
+ num_params = sum(t.numel() for t in pre_tensors)
196
+
197
+ avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()
198
+
199
+ if avg_amp > grad_clip_max_amp:
200
+ torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)
201
+
202
+ objective.updates = pre_tensors
203
+ return objective
@@ -0,0 +1,185 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ....core import Chainable, TensorTransform
10
+ from ....utils import NumberList, TensorList
11
+ from .psgd import (
12
+ init_kron,
13
+ precond_grad_kron,
14
+ update_precond_kron_whiten_eq,
15
+ update_precond_kron_whiten_q0p5eq1p5,
16
+ update_precond_kron_whiten_qep,
17
+ update_precond_kron_whiten_qeq,
18
+ update_precond_kron_whiten_quad,
19
+ update_precond_kron_whiten_quad4p,
20
+ )
21
+
22
+ # matches
23
+ class PSGDKronWhiten(TensorTransform):
24
+ """Kron whitening preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
25
+
26
+ Args:
27
+ max_dim (int, optional): dimensions with size larger than this use diagonal preconditioner. Defaults to 10_000.
28
+ max_skew (float, optional):
29
+ if memory used by full preconditioner (dim^2) is larger than total number of elements in a parameter times ``max_skew``, it uses a diagonal preconditioner. Defaults to 1.0.
30
+ init_scale (float | None, optional):
31
+ initial scale of the preconditioner. If None, determined from magnitude of the first gradient. Defaults to None.
32
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
33
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
34
+ damping (float, optional): adds small noise to gradient when updating the preconditioner. Defaults to 1e-9.
35
+ grad_clip_max_amp (float, optional): clips amplitude of the update. Defaults to float("inf").
36
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
37
+ dQ (str, optional): geometry for preconditioner update. Defaults to "Q0.5EQ1.5".
38
+ balance_probability (float, optional):
39
+ probablility of balancing the dynamic ranges of the factors of Q to avoid over/under-flow on each step. Defaults to 0.01.
40
+
41
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
42
+
43
+ ###Examples:
44
+
45
+ Pure PSGD Kron:
46
+ ```py
47
+ optimizer = tz.Optimizer(
48
+ model.parameters(),
49
+ tz.m.KronWhiten(),
50
+ tz.m.LR(1e-3),
51
+ )
52
+ ```
53
+
54
+ Momentum into preconditioner (whitens momentum):
55
+ ```py
56
+ optimizer = tz.Optimizer(
57
+ model.parameters(),
58
+ tz.m.EMA(0.9),
59
+ tz.m.KronWhiten(),
60
+ tz.m.LR(1e-3),
61
+ )
62
+ ```
63
+
64
+ Updating the preconditioner from gradients and applying it to momentum:
65
+ ```py
66
+ optimizer = tz.Optimizer(
67
+ model.parameters(),
68
+ tz.m.KronWhiten(inner=tz.m.EMA(0.9)),
69
+ tz.m.LR(1e-3),
70
+ )
71
+ ```
72
+
73
+ """
74
+ def __init__(
75
+ self,
76
+ max_dim: int = 10_000,
77
+ max_skew: float = 1.0,
78
+ init_scale: float | None = None,
79
+ lr_preconditioner: float = 0.1,
80
+ betaL: float = 0.9,
81
+ damping: float = 1e-9,
82
+ grad_clip_max_amp: float = float("inf"),
83
+ update_probability: float= 1.0,
84
+ dQ: Literal["QEP", "EQ", "QEQ", "QUAD", "Q0.5EQ1.5", "Q0p5EQ1p5", "QUAD4P"] = "Q0.5EQ1.5",
85
+ balance_probability: float = 0.01,
86
+
87
+ inner: Chainable | None = None,
88
+ ):
89
+ defaults = locals().copy()
90
+ del defaults["inner"], defaults["self"]
91
+ super().__init__(defaults, inner=inner)
92
+
93
+ @torch.no_grad
94
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
95
+ # initialize preconditioners
96
+ if setting["init_scale"] is None:
97
+ # warnings.warn("FYI: Will set the preconditioner initial scale on the fly. Recommend to set it manually.")
98
+ state["QLs_exprs"] = None
99
+ else:
100
+ state["QLs_exprs"] = init_kron(
101
+ param.squeeze(),
102
+ Scale=setting["init_scale"],
103
+ max_size=setting["max_dim"],
104
+ max_skew=setting["max_skew"],
105
+ dQ=setting["dQ"],
106
+ )
107
+
108
+ dQ = setting["dQ"]
109
+ if dQ == "QUAD4P":
110
+ assert torch.finfo(param.dtype).eps < 1e-6, "Directly fitting P needs at least single precision"
111
+ state["update_precond"] = update_precond_kron_whiten_quad4p
112
+ state["precond_grad"] = lambda QL, exprs, G: exprs[0](*QL[0], G) # it's exprA(*Q, G)
113
+
114
+ else:
115
+ state["precond_grad"] = precond_grad_kron
116
+ if dQ == "QEP":
117
+ state["update_precond"] = update_precond_kron_whiten_qep
118
+ elif dQ == "EQ":
119
+ state["update_precond"] = update_precond_kron_whiten_eq
120
+ elif dQ == "QEQ":
121
+ state["update_precond"] = update_precond_kron_whiten_qeq
122
+ elif dQ == "QUAD":
123
+ state["update_precond"] = update_precond_kron_whiten_quad
124
+ else:
125
+ assert (dQ == "Q0.5EQ1.5") or (dQ == "Q0p5EQ1p5"), f"Invalid choice for dQ: '{dQ}'"
126
+ state["update_precond"] = update_precond_kron_whiten_q0p5eq1p5
127
+
128
+ @torch.no_grad
129
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
130
+
131
+ # initialize on the fly if not initialized
132
+ if any(state["QLs_exprs"] is None for state in states):
133
+
134
+ scale = max([torch.mean((torch.abs(g))**4) for g in tensors])
135
+ scale = (scale + settings[0]["damping"]**4)**(-1/8)
136
+
137
+ for param, state, setting in zip(params, states, settings):
138
+ if state["QLs_exprs"] is None:
139
+ state["QLs_exprs"] = init_kron(
140
+ param.squeeze(),
141
+ Scale=scale,
142
+ max_size=setting["max_dim"],
143
+ max_skew=setting["max_skew"],
144
+ dQ=setting["dQ"],
145
+ )
146
+
147
+
148
+ # update preconditioners
149
+ # (could also try per-parameter probability)
150
+ if torch.rand([]) < settings[0]["update_probability"]: # update Q
151
+ for tensor, state, setting in zip(tensors, states, settings):
152
+ state["update_precond"](
153
+ *state["QLs_exprs"],
154
+ tensor.squeeze(),
155
+ lr=setting["lr_preconditioner"],
156
+ betaL=setting["betaL"],
157
+ damping=setting["damping"],
158
+ balance_prob=setting["balance_probability"]
159
+ )
160
+
161
+ @torch.no_grad
162
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
163
+
164
+ pre_tensors = []
165
+
166
+ # precondition
167
+ for param, tensor, state in zip(params, tensors, states):
168
+ t = state["precond_grad"](
169
+ *state["QLs_exprs"],
170
+ tensor.squeeze(),
171
+ )
172
+ pre_tensors.append(t.view_as(param))
173
+
174
+ # norm clipping
175
+ grad_clip_max_amp = settings[0]["grad_clip_max_amp"]
176
+ if grad_clip_max_amp < math.inf:
177
+ pre_tensors = TensorList(pre_tensors)
178
+ num_params = sum(t.numel() for t in pre_tensors)
179
+
180
+ avg_amp = pre_tensors.dot(pre_tensors.conj()).div(num_params).sqrt()
181
+
182
+ if avg_amp > grad_clip_max_amp:
183
+ torch._foreach_mul_(pre_tensors, grad_clip_max_amp / avg_amp)
184
+
185
+ return pre_tensors
@@ -0,0 +1,118 @@
1
+ # pylint:disable=not-callable
2
+ """all functions are from https://github.com/lixilinx/psgd_torch/blob/master/psgd.py"""
3
+ import math
4
+ import warnings
5
+
6
+ import torch
7
+
8
+ from ....core import Chainable, HVPMethod, Transform
9
+ from ....utils import Distributions, TensorList, vec_to_tensors_
10
+ from .psgd import lift2single, precond_grad_lra, update_precond_lra_newton
11
+ from ._psgd_utils import _initialize_lra_state_
12
+
13
+ # matches
14
+ class PSGDLRANewton(Transform):
15
+ """Low rank hessian preconditioner from Preconditioned Stochastic Gradient Descent (see https://github.com/lixilinx/psgd_torch)
16
+
17
+ Args:
18
+ rank (int, optional):
19
+ Preconditioner has a diagonal part and a low rank part, whose rank is decided by this setting. Defaults to 10.
20
+ init_scale (float | None, optional):
21
+ initial scale of the preconditioner. If None, determined based on a heuristic. Defaults to None.
22
+ lr_preconditioner (float, optional): learning rate of the preconditioner. Defaults to 0.1.
23
+ betaL (float, optional): EMA factor for the L-smoothness constant wrt Q. Defaults to 0.9.
24
+ damping (float, optional):
25
+ adds small noise to hessian-vector product when updating the preconditioner. Defaults to 1e-9.
26
+ grad_clip_max_norm (float, optional): clips norm of the update. Defaults to float("inf").
27
+ update_probability (float, optional): probability of updating preconditioner on each step. Defaults to 1.0.
28
+ hvp_method (HVPMethod, optional): how to compute hessian-vector products. Defaults to 'autograd'.
29
+ h (float, optional):
30
+ if ``hvp_method`` is ``"fd_central"`` or ``"fd_forward"``, controls finite difference step size.
31
+ Defaults to 1e-3.
32
+ distribution (Distributions, optional):
33
+ distribution for random vectors for hessian-vector products. Defaults to 'normal'.
34
+
35
+ inner (Chainable | None, optional): preconditioning will be applied to output of this module. Defaults to None.
36
+
37
+ ###Examples:
38
+
39
+ Pure LRA Newton PSGD:
40
+ ```py
41
+ optimizer = tz.Optimizer(
42
+ model.parameters(),
43
+ tz.m.LRANewton(),
44
+ tz.m.LR(1e-3),
45
+ )
46
+ ```
47
+
48
+ Applying preconditioner to momentum:
49
+ ```py
50
+ optimizer = tz.Optimizer(
51
+ model.parameters(),
52
+ tz.m.LRANewton(inner=tz.m.EMA(0.9)),
53
+ tz.m.LR(1e-3),
54
+ )
55
+ ```
56
+ """
57
+ def __init__(
58
+ self,
59
+ rank: int = 10,
60
+ init_scale: float | None = None,
61
+ lr_preconditioner=0.1,
62
+ betaL=0.9,
63
+ damping=1e-9,
64
+ grad_clip_max_norm=float("inf"),
65
+ update_probability=1.0,
66
+
67
+ hvp_method: HVPMethod = 'autograd',
68
+ h: float = 1e-3,
69
+ distribution: Distributions = 'normal',
70
+
71
+ inner: Chainable | None = None,
72
+ ):
73
+ defaults = locals().copy()
74
+ del defaults["inner"], defaults["self"]
75
+ super().__init__(defaults, inner=inner)
76
+
77
+ @torch.no_grad
78
+ def update_states(self, objective, states, settings):
79
+ fs = settings[0]
80
+
81
+ # initialize
82
+ if "UVd" not in self.global_state:
83
+ p = torch.cat([t.ravel() for t in objective.params])
84
+ _initialize_lra_state_(p, self.global_state, fs)
85
+
86
+ UVd = self.global_state["UVd"]
87
+ if (torch.rand([]) < fs["update_probability"]) or (UVd[2] is None):
88
+
89
+ # hessian-vector product
90
+ vs = TensorList(objective.params).sample_like(distribution=fs["distribution"])
91
+ Hvs, _ = objective.hessian_vector_product(z=vs, rgrad=None, at_x0=True, hvp_method=fs["hvp_method"], h=fs["h"])
92
+
93
+ v = torch.cat([t.ravel() for t in vs]).unsqueeze(1)
94
+ h = torch.cat([t.ravel() for t in Hvs]).unsqueeze(1)
95
+
96
+ if UVd[2] is None:
97
+ UVd[2] = (torch.mean(v*v))**(1/4) * (torch.mean(h**4) + fs["damping"]**4)**(-1/8) * torch.ones_like(v)
98
+
99
+ # update preconditioner
100
+ update_precond_lra_newton(UVd=UVd, Luvd=self.global_state["Luvd"], v=v, h=h, lr=fs["lr_preconditioner"], betaL=fs["betaL"], damping=fs["damping"])
101
+
102
+
103
+ @torch.no_grad
104
+ def apply_states(self, objective, states, settings):
105
+ updates = objective.get_updates()
106
+
107
+ g = torch.cat([t.ravel() for t in updates]).unsqueeze(1) # column vec
108
+ pre_grad = precond_grad_lra(UVd=self.global_state["UVd"], g=g)
109
+
110
+ # norm clipping
111
+ grad_clip_max_norm = settings[0]["grad_clip_max_norm"]
112
+ if grad_clip_max_norm < float("inf"): # clip preconditioned gradient
113
+ grad_norm = torch.linalg.vector_norm(pre_grad)
114
+ if grad_norm > grad_clip_max_norm:
115
+ pre_grad *= grad_clip_max_norm / grad_norm
116
+
117
+ vec_to_tensors_(pre_grad, updates)
118
+ return objective