torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,62 +1,14 @@
1
- from operator import itemgetter
2
1
  from typing import Literal
3
-
4
2
  import torch
3
+
5
4
  from ...core import (
6
5
  Chainable,
7
- Module,
8
- Target,
9
- TensorwiseTransform,
10
- Transform,
11
- Var,
12
- apply_transform,
6
+ TensorTransform,
13
7
  )
14
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
15
- from ...utils.linalg import matrix_power_eigh
16
- from ..functional import add_power_, lerp_power_, root, epsilon_step_size
17
- from ...utils.linalg.linear_operator import Dense
18
-
19
- def adagrad_(
20
- tensors_: TensorList,
21
- sq_sum_: TensorList,
22
- alpha: float | NumberList,
23
- lr_decay: float | NumberList,
24
- eps: float | NumberList,
25
- step: int,
26
- pow: float = 2,
27
- use_sqrt: bool = True,
28
- divide: bool = False,
29
-
30
- decay: float | None = None,
31
- beta: float | None = None,
32
-
33
- # inner args
34
- inner: Module | None = None,
35
- params: list[torch.Tensor] | None = None,
36
- grads: list[torch.Tensor] | None = None,
37
- ):
38
- """returns `tensors_`"""
39
- clr = alpha / (1 + step * lr_decay)
40
-
41
- if beta is None or step == 1: sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
42
- else: sq_sum_ = lerp_power_(tensors_, exp_avg_pow_=sq_sum_, beta=beta, pow=pow)
43
- if decay is not None:
44
- sq_sum_.mul_(1-decay)
45
-
46
- if inner is not None:
47
- assert params is not None
48
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
49
-
50
- if divide: sq_sum_ = sq_sum_ / max(step, 1)
51
-
52
- if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
53
- else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
8
+ from ...utils import NumberList, TensorList, unpack_dicts
9
+ from ...linalg.matrix_power import matrix_power as _matrix_power, MatrixPowerMethod
54
10
 
55
- return tensors_
56
-
57
-
58
-
59
- class Adagrad(Transform):
11
+ class Adagrad(TensorTransform):
60
12
  """Adagrad, divides by sum of past squares of gradients.
61
13
 
62
14
  This implementation is identical to ``torch.optim.Adagrad``.
@@ -72,103 +24,53 @@ class Adagrad(Transform):
72
24
  """
73
25
  def __init__(
74
26
  self,
27
+
28
+ # hyperparams
75
29
  lr_decay: float = 0,
76
30
  initial_accumulator_value: float = 0,
77
31
  eps: float = 1e-10,
78
32
  alpha: float = 1,
79
- pow: float = 2,
80
- use_sqrt: bool = True,
81
- divide: bool=False,
82
- beta:float | None = None,
83
- decay: float | None = None,
33
+
34
+ # tfms
84
35
  inner: Chainable | None = None,
36
+ accumulator_tfm: Chainable | None = None
85
37
  ):
86
- defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
87
- eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
88
- super().__init__(defaults=defaults, uses_grad=False)
38
+ defaults = locals().copy()
39
+ del defaults['self'], defaults['inner'], defaults["accumulator_tfm"]
40
+ super().__init__(defaults=defaults, inner=inner)
89
41
 
90
- if inner is not None:
91
- self.set_child('inner', inner)
42
+ self.set_child('accumulator', accumulator_tfm)
92
43
 
93
44
  @torch.no_grad
94
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
95
- tensors = TensorList(tensors)
96
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
45
+ def single_tensor_initialize(self, tensor, param, grad, loss, state, setting):
46
+ state["accumulator"] = torch.full_like(tensor, fill_value=setting["initial_accumulator_value"])
97
47
 
98
- lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
99
-
100
- pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
101
-
102
- sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
103
-
104
- # initialize accumulator on 1st step
105
- if step == 1:
106
- sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
107
-
108
- return adagrad_(
109
- tensors,
110
- sq_sum_=sq_sum,
111
- alpha=alpha,
112
- lr_decay=lr_decay,
113
- eps=eps,
114
- step=step,
115
- pow=pow,
116
- use_sqrt=use_sqrt,
117
- divide=divide,
118
-
119
- beta = self.defaults["beta"],
120
- decay = self.defaults["decay"],
121
- # inner args
122
- inner=self.children.get("inner", None),
123
- params=params,
124
- grads=grads,
125
- )
126
-
127
-
128
- def lerp(start, end, weight):
129
- return start + weight * (end - start)
130
-
131
- def adagrad_norm_(
132
- tensors_: TensorList,
133
- accumulator: float | torch.Tensor,
134
- alpha: float | NumberList,
135
- lr_decay: float | NumberList,
136
- eps: float | NumberList,
137
- step: int,
138
- use_sqrt: bool = True,
139
- divide: bool = False,
140
-
141
- decay: float | None = None,
142
- beta: float | None = None,
143
-
144
- # inner args
145
- inner: Module | None = None,
146
- params: list[torch.Tensor] | None = None,
147
- grads: list[torch.Tensor] | None = None,
148
- ):
149
- """returns `tensors_`"""
150
- clr = alpha / (1 + step * lr_decay)
48
+ @torch.no_grad
49
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
50
+ torch._foreach_addcmul_([state["accumulator"] for state in states], tensors, tensors)
51
+ self.increment_counter("step", start=0)
151
52
 
152
- gg = tensors_.dot(tensors_)
53
+ @torch.no_grad
54
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
55
+ tensors_ = TensorList(tensors)
56
+ step = self.global_state["step"] # 0 on first apply
57
+ eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
153
58
 
154
- if beta is None or step == 1: accumulator += gg
155
- else: accumulator = lerp(accumulator, gg, 1-beta)
59
+ accumulator = [state["accumulator"] for state in states]
60
+ accumulator = TensorList(self.inner_step_tensors(
61
+ "accumulator", tensors=accumulator, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
156
62
 
157
- if decay is not None:
158
- accumulator *= 1-decay
63
+ denom = accumulator.sqrt().add_(eps)
64
+ tensors_ /= denom
159
65
 
160
- if inner is not None:
161
- assert params is not None
162
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
66
+ clr = alpha / (1 + step * lr_decay)
67
+ tensors_.lazy_mul_(clr)
163
68
 
164
- if divide: accumulator = accumulator / max(step, 1)
69
+ return tensors_
165
70
 
166
- if use_sqrt: tensors_.div_(eps + accumulator.sqrt()).mul_(clr)
167
- else: tensors_.div_(eps + accumulator).mul_(clr)
168
71
 
169
- return tensors_, accumulator
170
72
 
171
- class AdagradNorm(Transform):
73
+ class AdagradNorm(TensorTransform):
172
74
  """Adagrad-Norm, divides by sum of past means of squares of gradients.
173
75
 
174
76
  Args:
@@ -176,7 +78,6 @@ class AdagradNorm(Transform):
176
78
  initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
177
79
  eps (float, optional): division epsilon. Defaults to 1e-10.
178
80
  alpha (float, optional): step size. Defaults to 1.
179
- pow (float, optional): power for gradients and accumulator root. Defaults to 2.
180
81
  use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
181
82
  inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
182
83
  """
@@ -185,71 +86,104 @@ class AdagradNorm(Transform):
185
86
  lr_decay: float = 0,
186
87
  initial_accumulator_value: float = 0,
187
88
  eps: float = 1e-10,
188
- alpha: float = 1,
189
- pow: float = 2,
190
- use_sqrt: bool = True,
191
- divide: bool=False,
192
89
  beta:float | None = None,
193
- decay: float | None = None,
90
+ beta_debias: bool = True,
91
+ layerwise: bool = True,
92
+ use_sqrt: bool = True,
93
+ alpha: float = 1,
194
94
  inner: Chainable | None = None,
195
95
  ):
196
- defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
197
- eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide, beta=beta, decay=decay)
198
- super().__init__(defaults=defaults, uses_grad=False)
96
+ defaults = locals().copy()
97
+ del defaults['self'], defaults['inner']
98
+ super().__init__(defaults=defaults, inner=inner)
99
+
100
+ @torch.no_grad
101
+ def multi_tensor_initialize(self, tensors, params, grads, loss, states, settings):
102
+
103
+ # layerwise initialize in each state
104
+ if settings[0]["layerwise"]:
105
+ for tensor, state, setting in zip(tensors, states, settings):
106
+
107
+ initial_accumulator_value = setting["initial_accumulator_value"]
108
+ state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)
109
+
110
+ # global initialize in global state
111
+ else:
112
+ initial_accumulator_value = settings[0]["initial_accumulator_value"]
113
+ tensor = tensors[0]
114
+ self.global_state["accumulator"] = torch.tensor(initial_accumulator_value, device=tensor.device, dtype=tensor.dtype)
115
+
116
+ def _get_accumulator(self, states, settings) -> torch.Tensor | TensorList:
117
+ layerwise = settings[0]["layerwise"]
118
+ if layerwise:
119
+ return TensorList(s["accumulator"] for s in states)
120
+
121
+ return self.global_state["accumulator"]
199
122
 
200
- if inner is not None:
201
- self.set_child('inner', inner)
123
+ @torch.no_grad
124
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
125
+ tensors = TensorList(tensors)
126
+ accumulator = self._get_accumulator(states, settings)
127
+ self.increment_counter("step", start=0)
128
+
129
+ # compute squared gradient norm (gg)
130
+ if isinstance(accumulator, TensorList): gg = tensors.tensorwise_dot(tensors)
131
+ else: gg = tensors.dot(tensors)
132
+
133
+ # update the accumulator
134
+ beta = settings[0]["beta"]
135
+ if beta is None: accumulator.add_(gg) # pyright:ignore[reportArgumentType]
136
+ else: accumulator.lerp_(gg, weight=1-beta) # pyright:ignore[reportArgumentType, reportCallIssue]
202
137
 
203
138
  @torch.no_grad
204
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
139
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
205
140
  tensors = TensorList(tensors)
206
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
207
- lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
141
+ accumulator = self._get_accumulator(states, settings)
142
+ eps, alpha, lr_decay = unpack_dicts(settings, "eps", "alpha", "lr_decay", cls=NumberList)
143
+ step = self.global_state["step"] # 0 on 1st step
144
+ fs = settings[0]
145
+ beta = fs["beta"]
208
146
 
209
- use_sqrt, divide, initial_accumulator_value = itemgetter('use_sqrt', 'divide', "initial_accumulator_value")(settings[0])
147
+ # ------------------------ debias if beta is not None ------------------------ #
148
+ if fs["beta_debias"] and beta is not None:
149
+ accumulator = accumulator / (1 - beta ** (step + 1))
210
150
 
211
- accumulator = self.global_state.get("accumulator", initial_accumulator_value)
212
151
 
213
- d, self.global_state["accumulator"] = adagrad_norm_(
214
- tensors,
215
- accumulator=accumulator,
216
- alpha=alpha,
217
- lr_decay=lr_decay,
218
- eps=eps,
219
- step=step,
220
- use_sqrt=use_sqrt,
221
- divide=divide,
152
+ # ---------------------------- compute denominator --------------------------- #
153
+ if fs["use_sqrt"]:
154
+ denom = accumulator.sqrt().add_(eps) # pyright:ignore[reportArgumentType]
155
+ else:
156
+ denom = accumulator + eps # pyright:ignore[reportOperatorIssue]
222
157
 
223
- beta = self.defaults["beta"],
224
- decay = self.defaults["decay"],
225
- # inner args
226
- inner=self.children.get("inner", None),
227
- params=params,
228
- grads=grads,
229
- )
230
158
 
231
- return d
159
+ # ---------------------------- compute the update ---------------------------- #
160
+ tensors /= denom
161
+ clr = alpha / (1 + step * lr_decay) # lr decay
162
+ tensors.lazy_mul_(clr)
232
163
 
164
+ return tensors
233
165
 
234
- class FullMatrixAdagrad(TensorwiseTransform):
166
+
167
+
168
+ class FullMatrixAdagrad(TensorTransform):
235
169
  """Full-matrix version of Adagrad, can be customized to make RMSprop or Adam (see examples).
236
170
 
237
171
  Note:
238
172
  A more memory-efficient version equivalent to full matrix Adagrad on last n gradients is implemented in ``tz.m.LMAdagrad``.
239
173
 
240
174
  Args:
241
- beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
242
- decay (float | None, optional): decay for gradient outer product accumulators. Defaults to None.
243
- sqrt (bool, optional): whether to take the square root of the accumulator. Defaults to True.
244
- concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
175
+ reg (float, optional): regularization, scale of identity matrix added to accumulator. Defaults to 1e-12.
245
176
  precond_freq (int, optional): frequency of updating the inverse square root of the accumulator. Defaults to 1.
177
+ beta (float | None, optional): momentum for gradient outer product accumulators. if None, uses sum. Defaults to None.
178
+ beta_debias (bool, optional): whether to use debiasing, only has effect when ``beta`` is not ``None``. Defaults to True.
246
179
  init (Literal[str], optional):
247
180
  how to initialize the accumulator.
248
181
  - "identity" - with identity matrix (default).
249
182
  - "zeros" - with zero matrix.
250
183
  - "ones" - with matrix of ones.
251
184
  -"GGT" - with the first outer product
252
- divide (bool, optional): whether to divide the accumulator by number of gradients in it. Defaults to False.
185
+ matrix_power (float, optional): accumulator matrix power. Defaults to -1/2.
186
+ concat_params (bool, optional): if False, each parameter will have it's own accumulator. Defaults to True.
253
187
  inner (Chainable | None, optional): inner modules to apply preconditioning to. Defaults to None.
254
188
 
255
189
  ## Examples:
@@ -284,73 +218,89 @@ class FullMatrixAdagrad(TensorwiseTransform):
284
218
  """
285
219
  def __init__(
286
220
  self,
221
+ reg: float = 1e-12,
222
+ precond_freq: int = 1,
287
223
  beta: float | None = None,
288
- decay: float | None = None,
289
- sqrt: bool = True,
224
+ beta_debias: bool=True,
225
+ init: Literal["identity", "zeros", "GGT"] = "identity",
226
+ matrix_power: float = -1/2,
227
+ matrix_power_method: MatrixPowerMethod = "eigh_abs",
290
228
  concat_params=True,
291
- precond_freq: int = 1,
292
- init: Literal["identity", "zeros", "ones", "GGT"] = "identity",
293
- reg: float = 1e-12,
294
- divide: bool = False,
229
+
295
230
  inner: Chainable | None = None,
231
+ accumulator_tfm: Chainable | None = None
296
232
  ):
297
- defaults = dict(beta=beta, decay=decay, sqrt=sqrt, precond_freq=precond_freq, init=init, divide=divide, reg=reg)
298
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner,)
233
+ defaults = locals().copy()
234
+ del defaults['self'], defaults['inner'], defaults["concat_params"], defaults["accumulator_tfm"]
235
+ super().__init__(defaults=defaults, inner=inner, concat_params=concat_params)
236
+
237
+ self.set_child("accumulator", accumulator_tfm)
299
238
 
300
239
  @torch.no_grad
301
- def update_tensor(self, tensor, param, grad, loss, state, setting):
240
+ def single_tensor_update(self, tensor, param, grad, loss, state, setting):
241
+
302
242
  G = tensor.ravel()
303
- GG = torch.outer(G, G)
304
- decay = setting['decay']
243
+ GG = torch.outer(G, G)
244
+
245
+ # initialize
246
+ if "accumulator" not in state:
247
+ init = setting['init']
248
+ if init == 'identity': state['accumulator'] = torch.eye(GGᵀ.size(0), device=GGᵀ.device, dtype=GGᵀ.dtype)
249
+ elif init == 'zeros': state['accumulator'] = torch.zeros_like(GGᵀ)
250
+ elif init == 'GGT': state['accumulator'] = GGᵀ.clone()
251
+ else: raise ValueError(init)
252
+
253
+ # update
305
254
  beta = setting['beta']
306
- init = setting['init']
255
+ accumulator: torch.Tensor = state["accumulator"]
307
256
 
308
- if 'GG' not in state:
309
- if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
310
- elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
311
- elif init == 'ones': state['GG'] = torch.ones_like(GG)
312
- elif init == 'GGT': state['GG'] = GG.clone()
313
- else: raise ValueError(init)
314
- if decay is not None: state['GG'].mul_(decay)
257
+ if beta is None: accumulator.add_(GGᵀ)
258
+ else: accumulator.lerp_(GGᵀ, 1-beta)
315
259
 
316
- if beta is not None: state['GG'].lerp_(GG, 1-beta)
317
- else: state['GG'].add_(GG)
318
- state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
260
+ # update number of GG in accumulator for divide
261
+ state['num_GGTs'] = state.get('num_GGTs', 0) + 1
319
262
 
320
263
  @torch.no_grad
321
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
264
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
322
265
  step = state.get('step', 0)
323
266
  state['step'] = step + 1
324
267
 
325
- GG: torch.Tensor = state['GG']
326
- sqrt = setting['sqrt']
327
- divide = setting['divide']
268
+ accumulator: torch.Tensor = state['accumulator']
269
+ accumulator = self.inner_step_tensors("accumulator", [accumulator], clone=True, must_exist=False)[0]
270
+
328
271
  precond_freq = setting['precond_freq']
329
272
  reg = setting['reg']
273
+ beta = setting["beta"]
330
274
 
331
- if divide: GG = GG/state.get('i', 1)
332
-
275
+ # add regularizer
333
276
  if reg != 0:
334
- GG = GG + torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype).mul_(reg)
277
+ device = accumulator.device; dtype = accumulator.dtype
278
+ accumulator = accumulator + torch.eye(accumulator.size(0), device=device, dtype=dtype).mul_(reg)
335
279
 
280
+ # for single value use sqrt
336
281
  if tensor.numel() == 1:
337
- GG = GG.squeeze()
338
- if sqrt: return tensor / GG.sqrt()
339
- return tensor / GG
282
+ dir = tensor.mul_(accumulator.squeeze() ** setting["matrix_power"])
340
283
 
341
- try:
342
- if sqrt:
284
+ # otherwise use matrix inverse square root
285
+ else:
286
+
287
+ # compute inverse square root and store to state
288
+ try:
343
289
  if "B" not in state or step % precond_freq == 0:
344
- B = state["B"] = matrix_power_eigh(GG, -1/2)
290
+ B = state["B"] = _matrix_power(accumulator, setting["matrix_power"], method=setting["matrix_power_method"])
345
291
  else:
346
292
  B = state["B"]
347
293
 
348
- else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
294
+ dir = (B @ tensor.ravel()).view_as(tensor)
295
+
296
+ # fallback to diagonal Adagrad on fail
297
+ except torch.linalg.LinAlgError:
298
+ dir = tensor.mul_(accumulator.diagonal() ** setting["matrix_power"])
349
299
 
350
- except torch.linalg.LinAlgError:
351
- # fallback to diagonal AdaGrad
352
- denom = GG.diagonal()
353
- if sqrt: denom = denom.sqrt()
354
- return tensor.div_(denom + max(reg, 1e-12))
300
+ # debias
301
+ if setting["beta_debias"] and beta is not None:
302
+ num_GGTs = state.get('num_GGTs', 1)
303
+ bias_correction = 1 - beta ** num_GGTs
304
+ dir *= bias_correction ** 0.5
355
305
 
356
- return (B @ tensor.ravel()).view_as(tensor)
306
+ return dir