torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from typing import Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...core import Chainable, Module, TensorTransform, Transform, Objective
10
10
  from ...utils import (
11
11
  Distributions,
12
12
  Metrics,
@@ -19,15 +19,15 @@ from ...utils import (
19
19
  )
20
20
 
21
21
 
22
- class Previous(TensorwiseTransform):
22
+ class Previous(TensorTransform):
23
23
  """Maintains an update from n steps back, for example if n=1, returns previous update"""
24
- def __init__(self, n=1, target: Target = 'update'):
24
+ def __init__(self, n=1):
25
25
  defaults = dict(n=n)
26
- super().__init__(uses_grad=False, defaults=defaults, target=target)
26
+ super().__init__(defaults=defaults)
27
27
 
28
28
 
29
29
  @torch.no_grad
30
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
30
+ def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
31
31
  n = setting['n']
32
32
 
33
33
  if 'history' not in state:
@@ -38,13 +38,13 @@ class Previous(TensorwiseTransform):
38
38
  return state['history'][0]
39
39
 
40
40
 
41
- class LastDifference(Transform):
41
+ class LastDifference(TensorTransform):
42
42
  """Outputs difference between past two updates."""
43
- def __init__(self,target: Target = 'update'):
44
- super().__init__({}, target=target)
43
+ def __init__(self,):
44
+ super().__init__()
45
45
 
46
46
  @torch.no_grad
47
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
47
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
48
48
  prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
49
49
  difference = torch._foreach_sub(tensors, prev_tensors)
50
50
  for p, c in zip(prev_tensors, tensors): p.set_(c)
@@ -53,16 +53,16 @@ class LastDifference(Transform):
53
53
  class LastGradDifference(Module):
54
54
  """Outputs difference between past two gradients."""
55
55
  def __init__(self):
56
- super().__init__({})
56
+ super().__init__()
57
57
 
58
58
  @torch.no_grad
59
- def step(self, var):
60
- grad = var.get_grad()
61
- prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
59
+ def apply(self, objective):
60
+ grad = objective.get_grads()
61
+ prev_grad = self.get_state(objective.params, 'prev_grad') # initialized to 0
62
62
  difference = torch._foreach_sub(grad, prev_grad)
63
63
  for p, c in zip(prev_grad, grad): p.copy_(c)
64
- var.update = list(difference)
65
- return var
64
+ objective.updates = list(difference)
65
+ return objective
66
66
 
67
67
  class LastParamDifference(Module):
68
68
  """Outputs difference between past two parameters, which is the effective previous update."""
@@ -70,36 +70,36 @@ class LastParamDifference(Module):
70
70
  super().__init__({})
71
71
 
72
72
  @torch.no_grad
73
- def step(self, var):
74
- params = var.params
75
- prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
73
+ def apply(self, objective):
74
+ params = objective.params
75
+ prev_params = self.get_state(objective.params, 'prev_params') # initialized to 0
76
76
  difference = torch._foreach_sub(params, prev_params)
77
77
  for p, c in zip(prev_params, params): p.copy_(c)
78
- var.update = list(difference)
79
- return var
78
+ objective.updates = list(difference)
79
+ return objective
80
80
 
81
81
 
82
82
 
83
- class LastProduct(Transform):
83
+ class LastProduct(TensorTransform):
84
84
  """Outputs difference between past two updates."""
85
- def __init__(self,target: Target = 'update'):
86
- super().__init__({}, uses_grad=False, target=target)
85
+ def __init__(self):
86
+ super().__init__()
87
87
 
88
88
  @torch.no_grad
89
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
89
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
90
90
  prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
91
91
  prod = torch._foreach_mul(tensors, prev)
92
92
  for p, c in zip(prev, tensors): p.set_(c)
93
93
  return prod
94
94
 
95
- class LastRatio(Transform):
96
- """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
97
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
95
+ class LastRatio(TensorTransform):
96
+ """Outputs ratio between past two updates, the numerator is determined by ``numerator`` argument."""
97
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur'):
98
98
  defaults = dict(numerator=numerator)
99
- super().__init__(defaults, uses_grad=False, target=target)
99
+ super().__init__(defaults)
100
100
 
101
101
  @torch.no_grad
102
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
102
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
103
103
  prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
104
104
  numerator = settings[0]['numerator']
105
105
  if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
@@ -107,14 +107,14 @@ class LastRatio(Transform):
107
107
  for p, c in zip(prev, tensors): p.set_(c)
108
108
  return ratio
109
109
 
110
- class LastAbsoluteRatio(Transform):
111
- """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
112
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
110
+ class LastAbsoluteRatio(TensorTransform):
111
+ """Outputs ratio between absolute values of past two updates the numerator is determined by ``numerator`` argument."""
112
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8):
113
113
  defaults = dict(numerator=numerator, eps=eps)
114
- super().__init__(defaults, uses_grad=False, target=target)
114
+ super().__init__(defaults)
115
115
 
116
116
  @torch.no_grad
117
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
117
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
118
118
  prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
119
119
  numerator = settings[0]['numerator']
120
120
  eps = NumberList(s['eps'] for s in settings)
@@ -127,139 +127,139 @@ class LastAbsoluteRatio(Transform):
127
127
  for p, c in zip(prev, tensors): p.set_(c)
128
128
  return ratio
129
129
 
130
- class GradSign(Transform):
130
+ class GradSign(TensorTransform):
131
131
  """Copies gradient sign to update."""
132
- def __init__(self, target: Target = 'update'):
133
- super().__init__({}, uses_grad=True, target=target)
132
+ def __init__(self):
133
+ super().__init__(uses_grad=True)
134
134
 
135
135
  @torch.no_grad
136
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
136
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
137
137
  assert grads is not None
138
138
  return [t.copysign_(g) for t,g in zip(tensors, grads)]
139
139
 
140
- class UpdateSign(Transform):
140
+ class UpdateSign(TensorTransform):
141
141
  """Outputs gradient with sign copied from the update."""
142
- def __init__(self, target: Target = 'update'):
143
- super().__init__({}, uses_grad=True, target=target)
142
+ def __init__(self):
143
+ super().__init__(uses_grad=True)
144
144
 
145
145
  @torch.no_grad
146
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
146
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
147
147
  assert grads is not None
148
148
  return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
149
149
 
150
- class GraftToGrad(Transform):
150
+ class GraftToGrad(TensorTransform):
151
151
  """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
152
- def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
152
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
153
153
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
154
- super().__init__(defaults, uses_grad=True, target=target)
154
+ super().__init__(defaults, uses_grad=True)
155
155
 
156
156
  @torch.no_grad
157
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
157
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
158
158
  assert grads is not None
159
159
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
160
160
  return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
161
161
 
162
- class GraftGradToUpdate(Transform):
162
+ class GraftGradToUpdate(TensorTransform):
163
163
  """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
164
- def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6, target: Target = 'update'):
164
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-6):
165
165
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
166
- super().__init__(defaults, uses_grad=True, target=target)
166
+ super().__init__(defaults, uses_grad=True)
167
167
 
168
168
  @torch.no_grad
169
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
169
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
170
170
  assert grads is not None
171
171
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
172
172
  return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
173
173
 
174
174
 
175
- class GraftToParams(Transform):
176
- """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
177
- def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4, target: Target = 'update'):
175
+ class GraftToParams(TensorTransform):
176
+ """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than ``eps``."""
177
+ def __init__(self, tensorwise:bool=False, ord:Metrics=2, eps:float = 1e-4):
178
178
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
179
- super().__init__(defaults, uses_grad=False, target=target)
179
+ super().__init__(defaults)
180
180
 
181
181
  @torch.no_grad
182
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
182
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
183
183
  tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
184
184
  return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
185
185
 
186
- class Relative(Transform):
187
- """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
188
- def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
186
+ class Relative(TensorTransform):
187
+ """Multiplies update by absolute parameter values to make it relative to their magnitude, ``min_value`` is minimum allowed value to avoid getting stuck at 0."""
188
+ def __init__(self, min_value:float = 1e-4):
189
189
  defaults = dict(min_value=min_value)
190
- super().__init__(defaults, uses_grad=False, target=target)
190
+ super().__init__(defaults)
191
191
 
192
192
  @torch.no_grad
193
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
193
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
194
194
  mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
195
195
  torch._foreach_mul_(tensors, mul)
196
196
  return tensors
197
197
 
198
198
  class FillLoss(Module):
199
- """Outputs tensors filled with loss value times :code:`alpha`"""
199
+ """Outputs tensors filled with loss value times ``alpha``"""
200
200
  def __init__(self, alpha: float = 1, backward: bool = True):
201
201
  defaults = dict(alpha=alpha, backward=backward)
202
202
  super().__init__(defaults)
203
203
 
204
204
  @torch.no_grad
205
- def step(self, var):
206
- alpha = self.get_settings(var.params, 'alpha')
207
- loss = var.get_loss(backward=self.defaults['backward'])
208
- var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
209
- return var
210
-
211
- class MulByLoss(Module):
212
- """Multiplies update by loss times :code:`alpha`"""
213
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
205
+ def apply(self, objective):
206
+ alpha = self.get_settings(objective.params, 'alpha')
207
+ loss = objective.get_loss(backward=self.defaults['backward'])
208
+ objective.updates = [torch.full_like(p, loss*a) for p,a in zip(objective.params, alpha)]
209
+ return objective
210
+
211
+ class MulByLoss(TensorTransform):
212
+ """Multiplies update by loss times ``alpha``"""
213
+ def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
214
214
  defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
215
- super().__init__(defaults)
215
+ super().__init__(defaults, uses_loss=True)
216
216
 
217
217
  @torch.no_grad
218
- def step(self, var):
219
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
220
- loss = var.get_loss(backward=self.defaults['backward'])
218
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
219
+ assert loss is not None
220
+ alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
221
221
  mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
222
- torch._foreach_mul_(var.update, mul)
223
- return var
222
+ torch._foreach_mul_(tensors, mul)
223
+ return tensors
224
224
 
225
- class DivByLoss(Module):
226
- """Divides update by loss times :code:`alpha`"""
227
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
225
+ class DivByLoss(TensorTransform):
226
+ """Divides update by loss times ``alpha``"""
227
+ def __init__(self, alpha: float = 1, min_value:float = 1e-16, backward: bool = True):
228
228
  defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
229
- super().__init__(defaults)
229
+ super().__init__(defaults, uses_loss=True)
230
230
 
231
231
  @torch.no_grad
232
- def step(self, var):
233
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
234
- loss = var.get_loss(backward=self.defaults['backward'])
235
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
236
- torch._foreach_div_(var.update, mul)
237
- return var
232
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
233
+ assert loss is not None
234
+ alpha, min_value = unpack_dicts(settings, 'alpha', 'min_value')
235
+ denom = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
236
+ torch._foreach_div_(tensors, denom)
237
+ return tensors
238
238
 
239
239
 
240
- class NoiseSign(Transform):
240
+ class NoiseSign(TensorTransform):
241
241
  """Outputs random tensors with sign copied from the update."""
242
242
  def __init__(self, distribution:Distributions = 'normal', variance:float | None = None):
243
243
  defaults = dict(distribution=distribution, variance=variance)
244
- super().__init__(defaults, uses_grad=False)
244
+ super().__init__(defaults)
245
245
 
246
246
  @torch.no_grad
247
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
247
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
248
248
  variance = unpack_dicts(settings, 'variance')
249
249
  return TensorList(tensors).sample_like(settings[0]['distribution'], variance=variance).copysign_(tensors)
250
250
 
251
- class HpuEstimate(Transform):
251
+ class HpuEstimate(TensorTransform):
252
252
  """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
253
253
  def __init__(self):
254
254
  defaults = dict()
255
- super().__init__(defaults, uses_grad=False)
255
+ super().__init__(defaults)
256
256
 
257
257
  def reset_for_online(self):
258
258
  super().reset_for_online()
259
259
  self.clear_state_keys('prev_params', 'prev_update')
260
260
 
261
261
  @torch.no_grad
262
- def update_tensors(self, tensors, params, grads, loss, states, settings):
262
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
263
263
  prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
264
264
  s = torch._foreach_sub(params, prev_params)
265
265
  y = torch._foreach_sub(tensors, prev_update)
@@ -269,50 +269,48 @@ class HpuEstimate(Transform):
269
269
  self.store(params, 'y', y)
270
270
 
271
271
  @torch.no_grad
272
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
272
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
273
273
  return [self.state[p]['y'] for p in params]
274
274
 
275
275
  class RandomHvp(Module):
276
- """Returns a hessian-vector product with a random vector"""
276
+ """Returns a hessian-vector product with a random vector, optionally times vector"""
277
277
 
278
278
  def __init__(
279
279
  self,
280
280
  n_samples: int = 1,
281
281
  distribution: Distributions = "normal",
282
282
  update_freq: int = 1,
283
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
283
+ zHz: bool = False,
284
+ hvp_method: Literal["autograd", "fd_forward", "central"] = "autograd",
284
285
  h=1e-3,
286
+ seed: int | None = None
285
287
  ):
286
- defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
288
+ defaults = locals().copy()
289
+ del defaults['self']
287
290
  super().__init__(defaults)
288
291
 
289
292
  @torch.no_grad
290
- def step(self, var):
291
- params = TensorList(var.params)
292
- settings = self.settings[params[0]]
293
- n_samples = settings['n_samples']
294
- distribution = settings['distribution']
295
- hvp_method = settings['hvp_method']
296
- h = settings['h']
297
- update_freq = settings['update_freq']
293
+ def apply(self, objective):
294
+ params = TensorList(objective.params)
298
295
 
299
296
  step = self.global_state.get('step', 0)
300
297
  self.global_state['step'] = step + 1
301
298
 
302
299
  D = None
300
+ update_freq = self.defaults['update_freq']
303
301
  if step % update_freq == 0:
304
302
 
305
- rgrad = None
306
- for i in range(n_samples):
307
- u = params.sample_like(distribution=distribution, variance=1)
308
-
309
- Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
310
- h=h, normalize=True, retain_graph=i < n_samples-1)
311
-
312
- if D is None: D = Hvp
313
- else: torch._foreach_add_(D, Hvp)
303
+ D, _ = objective.hutchinson_hessian(
304
+ rgrad = None,
305
+ at_x0 = True,
306
+ n_samples = self.defaults['n_samples'],
307
+ distribution = self.defaults['distribution'],
308
+ hvp_method = self.defaults['hvp_method'],
309
+ h = self.defaults['h'],
310
+ zHz = self.defaults["zHz"],
311
+ generator = self.get_generator(params[0].device, self.defaults["seed"]),
312
+ )
314
313
 
315
- if n_samples > 1: torch._foreach_div_(D, n_samples)
316
314
  if update_freq != 1:
317
315
  assert D is not None
318
316
  D_buf = self.get_state(params, "D", cls=TensorList)
@@ -321,8 +319,8 @@ class RandomHvp(Module):
321
319
  if D is None:
322
320
  D = self.get_state(params, "D", cls=TensorList)
323
321
 
324
- var.update = list(D)
325
- return var
322
+ objective.updates = list(D)
323
+ return objective
326
324
 
327
325
  @torch.no_grad
328
326
  def _load_best_parameters(params: Sequence[torch.Tensor], best_params: Sequence[torch.Tensor]):
@@ -344,7 +342,7 @@ class SaveBest(Module):
344
342
  return (1 - x)**2 + (100 * (y - x**2))**2
345
343
 
346
344
  xy = torch.tensor((-1.1, 2.5), requires_grad=True)
347
- opt = tz.Modular(
345
+ opt = tz.Optimizer(
348
346
  [xy],
349
347
  tz.m.NAG(0.999),
350
348
  tz.m.LR(1e-6),
@@ -370,14 +368,14 @@ class SaveBest(Module):
370
368
  super().__init__()
371
369
 
372
370
  @torch.no_grad
373
- def step(self, var):
374
- loss = tofloat(var.get_loss(False))
371
+ def apply(self, objective):
372
+ loss = tofloat(objective.get_loss(False))
375
373
  lowest_loss = self.global_state.get('lowest_loss', float("inf"))
376
374
 
377
375
  if loss < lowest_loss:
378
376
  self.global_state['lowest_loss'] = loss
379
- best_params = var.attrs['best_params'] = [p.clone() for p in var.params]
380
- var.attrs['best_loss'] = loss
381
- var.attrs['load_best_params'] = partial(_load_best_parameters, params=var.params, best_params=best_params)
377
+ best_params = objective.attrs['best_params'] = [p.clone() for p in objective.params]
378
+ objective.attrs['best_loss'] = loss
379
+ objective.attrs['load_best_params'] = partial(_load_best_parameters, params=objective.params, best_params=best_params)
382
380
 
383
- return var
381
+ return objective