torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
1
1
  import time
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
- from typing import cast
4
+ from typing import cast, final
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Objective
9
9
  from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
10
10
 
11
11
 
@@ -16,14 +16,15 @@ class TerminationCriteriaBase(Module):
16
16
  super().__init__(defaults)
17
17
 
18
18
  @abstractmethod
19
- def termination_criteria(self, var: Var) -> bool:
19
+ def termination_criteria(self, objective: Objective) -> bool:
20
20
  ...
21
21
 
22
- def should_terminate(self, var: Var) -> bool:
22
+ @final
23
+ def should_terminate(self, objective: Objective) -> bool:
23
24
  n_bad = self.global_state.get('_n_bad', 0)
24
25
  n = self.defaults['_n']
25
26
 
26
- if self.termination_criteria(var):
27
+ if self.termination_criteria(objective):
27
28
  n_bad += 1
28
29
  if n_bad >= n:
29
30
  self.global_state['_n_bad'] = 0
@@ -36,12 +37,12 @@ class TerminationCriteriaBase(Module):
36
37
  return False
37
38
 
38
39
 
39
- def update(self, var):
40
- var.should_terminate = self.should_terminate(var)
41
- if var.should_terminate: self.global_state['_n_bad'] = 0
40
+ def update(self, objective):
41
+ objective.should_terminate = self.should_terminate(objective)
42
+ if objective.should_terminate: self.global_state['_n_bad'] = 0
42
43
 
43
- def apply(self, var):
44
- return var
44
+ def apply(self, objective):
45
+ return objective
45
46
 
46
47
 
47
48
  class TerminateAfterNSteps(TerminationCriteriaBase):
@@ -49,7 +50,7 @@ class TerminateAfterNSteps(TerminationCriteriaBase):
49
50
  defaults = dict(steps=steps)
50
51
  super().__init__(defaults)
51
52
 
52
- def termination_criteria(self, var):
53
+ def termination_criteria(self, objective):
53
54
  step = self.global_state.get('step', 0)
54
55
  self.global_state['step'] = step + 1
55
56
 
@@ -61,16 +62,17 @@ class TerminateAfterNEvaluations(TerminationCriteriaBase):
61
62
  defaults = dict(maxevals=maxevals)
62
63
  super().__init__(defaults)
63
64
 
64
- def termination_criteria(self, var):
65
+ def termination_criteria(self, objective):
65
66
  maxevals = self.defaults['maxevals']
66
- return var.modular.num_evaluations >= maxevals
67
+ assert objective.modular is not None
68
+ return objective.modular.num_evaluations >= maxevals
67
69
 
68
70
  class TerminateAfterNSeconds(TerminationCriteriaBase):
69
71
  def __init__(self, seconds:float, sec_fn = time.time):
70
72
  defaults = dict(seconds=seconds, sec_fn=sec_fn)
71
73
  super().__init__(defaults)
72
74
 
73
- def termination_criteria(self, var):
75
+ def termination_criteria(self, objective):
74
76
  max_seconds = self.defaults['seconds']
75
77
  sec_fn = self.defaults['sec_fn']
76
78
 
@@ -88,10 +90,10 @@ class TerminateByGradientNorm(TerminationCriteriaBase):
88
90
  defaults = dict(tol=tol, ord=ord)
89
91
  super().__init__(defaults, n=n)
90
92
 
91
- def termination_criteria(self, var):
93
+ def termination_criteria(self, objective):
92
94
  tol = self.defaults['tol']
93
95
  ord = self.defaults['ord']
94
- return TensorList(var.get_grad()).global_metric(ord) <= tol
96
+ return TensorList(objective.get_grads()).global_metric(ord) <= tol
95
97
 
96
98
 
97
99
  class TerminateByUpdateNorm(TerminationCriteriaBase):
@@ -100,20 +102,20 @@ class TerminateByUpdateNorm(TerminationCriteriaBase):
100
102
  defaults = dict(tol=tol, ord=ord)
101
103
  super().__init__(defaults, n=n)
102
104
 
103
- def termination_criteria(self, var):
105
+ def termination_criteria(self, objective):
104
106
  step = self.global_state.get('step', 0)
105
107
  self.global_state['step'] = step + 1
106
108
 
107
109
  tol = self.defaults['tol']
108
110
  ord = self.defaults['ord']
109
111
 
110
- p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
112
+ p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
111
113
  if step == 0:
112
- p_prev.copy_(var.params)
114
+ p_prev.copy_(objective.params)
113
115
  return False
114
116
 
115
- should_terminate = (p_prev - var.params).global_metric(ord) <= tol
116
- p_prev.copy_(var.params)
117
+ should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
118
+ p_prev.copy_(objective.params)
117
119
  return should_terminate
118
120
 
119
121
 
@@ -122,10 +124,10 @@ class TerminateOnNoImprovement(TerminationCriteriaBase):
122
124
  defaults = dict(tol=tol)
123
125
  super().__init__(defaults, n=n)
124
126
 
125
- def termination_criteria(self, var):
127
+ def termination_criteria(self, objective):
126
128
  tol = self.defaults['tol']
127
129
 
128
- f = tofloat(var.get_loss(False))
130
+ f = tofloat(objective.get_loss(False))
129
131
  if 'f_min' not in self.global_state:
130
132
  self.global_state['f_min'] = f
131
133
  return False
@@ -141,9 +143,9 @@ class TerminateOnLossReached(TerminationCriteriaBase):
141
143
  defaults = dict(value=value)
142
144
  super().__init__(defaults)
143
145
 
144
- def termination_criteria(self, var):
146
+ def termination_criteria(self, objective):
145
147
  value = self.defaults['value']
146
- return var.get_loss(False) <= value
148
+ return objective.get_loss(False) <= value
147
149
 
148
150
  class TerminateAny(TerminationCriteriaBase):
149
151
  def __init__(self, *criteria: TerminationCriteriaBase):
@@ -151,9 +153,9 @@ class TerminateAny(TerminationCriteriaBase):
151
153
 
152
154
  self.set_children_sequence(criteria)
153
155
 
154
- def termination_criteria(self, var: Var) -> bool:
156
+ def termination_criteria(self, objective: Objective) -> bool:
155
157
  for c in self.get_children_sequence():
156
- if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
158
+ if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True
157
159
 
158
160
  return False
159
161
 
@@ -163,9 +165,9 @@ class TerminateAll(TerminationCriteriaBase):
163
165
 
164
166
  self.set_children_sequence(criteria)
165
167
 
166
- def termination_criteria(self, var: Var) -> bool:
168
+ def termination_criteria(self, objective: Objective) -> bool:
167
169
  for c in self.get_children_sequence():
168
- if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
170
+ if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False
169
171
 
170
172
  return True
171
173
 
@@ -173,7 +175,7 @@ class TerminateNever(TerminationCriteriaBase):
173
175
  def __init__(self):
174
176
  super().__init__()
175
177
 
176
- def termination_criteria(self, var): return False
178
+ def termination_criteria(self, objective): return False
177
179
 
178
180
  def make_termination_criteria(
179
181
  ftol: float | None = None,
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable, Module
7
7
  from ...utils import TensorList, vec_to_tensors
8
- from ...utils.linalg.linear_operator import LinearOperator
8
+ from ...linalg.linear_operator import LinearOperator
9
9
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
10
10
 
11
11
 
@@ -58,7 +58,7 @@ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_param
58
58
  for _ in range(it_max):
59
59
  r_try = (r_min + r_max) / 2
60
60
  lam = r_try * M
61
- s_lam = H.add_diagonal(lam).solve(g).neg()
61
+ s_lam = H.solve_plus_diag(g, lam).neg()
62
62
  # s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
63
63
  solver_it += 1
64
64
  crit = conv_criterion(s_lam, r_try)
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
 
4
4
  from ...core import Chainable, Module
5
- from ...utils.linalg import linear_operator
5
+ from ...linalg import linear_operator
6
6
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
7
 
8
8
 
@@ -32,38 +32,31 @@ class LevenbergMarquardt(TrustRegionBase):
32
32
  max_attempts (max_attempts, optional):
33
33
  maximum number of trust region size size reductions per step. A zero update vector is returned when
34
34
  this limit is exceeded. Defaults to 10.
35
+ adaptive (bool, optional):
36
+ if True, trust radius is multiplied by square root of gradient norm.
35
37
  fallback (bool, optional):
36
38
  if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
37
39
  be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
38
40
  inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
39
41
 
40
- Examples:
41
- Gauss-Newton with Levenberg-Marquardt trust-region
42
+ ### Examples:
42
43
 
43
- .. code-block:: python
44
+ Gauss-Newton with Levenberg-Marquardt trust-region
44
45
 
45
- opt = tz.Modular(
46
- model.parameters(),
47
- tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
48
- )
46
+ ```python
47
+ opt = tz.Modular(
48
+ model.parameters(),
49
+ tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
50
+ )
51
+ ```
49
52
 
50
- LM-SR1
51
-
52
- .. code-block:: python
53
-
54
- opt = tz.Modular(
55
- model.parameters(),
56
- tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
57
- )
58
-
59
- First order trust region (hessian is assumed to be identity)
60
-
61
- .. code-block:: python
62
-
63
- opt = tz.Modular(
64
- model.parameters(),
65
- tz.m.LevenbergMarquardt(tz.m.Identity()),
66
- )
53
+ LM-SR1
54
+ ```python
55
+ opt = tz.Modular(
56
+ model.parameters(),
57
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
58
+ )
59
+ ```
67
60
 
68
61
  """
69
62
  def __init__(
@@ -78,11 +71,12 @@ class LevenbergMarquardt(TrustRegionBase):
78
71
  max_attempts: int = 10,
79
72
  radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
80
73
  y: float = 0,
74
+ adaptive: bool = False,
81
75
  fallback: bool = False,
82
76
  update_freq: int = 1,
83
77
  inner: Chainable | None = None,
84
78
  ):
85
- defaults = dict(y=y, fallback=fallback)
79
+ defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
86
80
  super().__init__(
87
81
  defaults=defaults,
88
82
  hess_module=hess_module,
@@ -103,6 +97,7 @@ class LevenbergMarquardt(TrustRegionBase):
103
97
 
104
98
  def trust_solve(self, f, g, H, radius, params, closure, settings):
105
99
  y = settings['y']
100
+ adaptive = settings["adaptive"]
106
101
 
107
102
  if isinstance(H, linear_operator.DenseInverse):
108
103
  if settings['fallback']:
@@ -117,12 +112,14 @@ class LevenbergMarquardt(TrustRegionBase):
117
112
  )
118
113
 
119
114
  reg = 1/radius
115
+ if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()
116
+
120
117
  if y == 0:
121
- return H.add_diagonal(reg).solve(g)
118
+ return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]
122
119
 
123
120
  diag = H.diagonal()
124
121
  diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
125
122
  if y != 1: diag = (diag*y) + (1-y)
126
- return H.add_diagonal(diag*reg).solve(g)
123
+ return H.solve_plus_diag(g, diag*reg)
127
124
 
128
125
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Chainable, Module
4
- from ...utils.linalg import cg, linear_operator
4
+ from ...linalg import cg, linear_operator
5
5
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
6
6
 
7
7
 
@@ -7,9 +7,16 @@ from typing import Any, Literal, Protocol, cast, final, overload
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Var, apply_transform
11
- from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
12
- from ...utils.linalg.linear_operator import LinearOperator
10
+ from ...core import Chainable, Module, Objective
11
+ from ...linalg.linear_operator import LinearOperator
12
+ from ...utils import (
13
+ TensorList,
14
+ generic_finfo,
15
+ generic_vector_norm,
16
+ safe_dict_update_,
17
+ tofloat,
18
+ vec_to_tensors,
19
+ )
13
20
 
14
21
 
15
22
  def _flatten_tensors(tensors: list[torch.Tensor]):
@@ -256,24 +263,24 @@ class TrustRegionBase(Module, ABC):
256
263
  """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
257
264
  ... # pylint:disable=unnecessary-ellipsis
258
265
 
259
- def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
266
+ def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
260
267
  """updates the state of this module after H or B have been updated, if necessary"""
261
268
 
262
- def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
263
- """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
269
+ def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
270
+ """Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
264
271
  assert H is not None
265
272
 
266
- params = TensorList(var.params)
273
+ params = TensorList(objective.params)
267
274
  settings = self.settings[params[0]]
268
275
  g = _flatten_tensors(tensors)
269
276
 
270
277
  max_attempts = settings['max_attempts']
271
278
 
272
279
  # loss at x_0
273
- loss = var.loss
274
- closure = var.closure
280
+ loss = objective.loss
281
+ closure = objective.closure
275
282
  if closure is None: raise RuntimeError("Trust region requires closure")
276
- if loss is None: loss = var.get_loss(False)
283
+ if loss is None: loss = objective.get_loss(False)
277
284
  loss = tofloat(loss)
278
285
 
279
286
  # trust region step and update
@@ -313,38 +320,36 @@ class TrustRegionBase(Module, ABC):
313
320
  )
314
321
 
315
322
  assert d is not None
316
- if success: var.update = vec_to_tensors(d, params)
317
- else: var.update = params.zeros_like()
323
+ if success: objective.updates = vec_to_tensors(d, params)
324
+ else: objective.updates = params.zeros_like()
318
325
 
319
- return var
326
+ return objective
320
327
 
321
328
 
322
329
  @final
323
330
  @torch.no_grad
324
- def update(self, var):
331
+ def update(self, objective):
325
332
  step = self.global_state.get('step', 0)
326
333
  self.global_state['step'] = step + 1
327
334
 
328
335
  if step % self.defaults["update_freq"] == 0:
329
336
 
330
337
  hessian_module = self.children['hess_module']
331
- hessian_module.update(var)
332
- H = hessian_module.get_H(var)
338
+ hessian_module.update(objective)
339
+ H = hessian_module.get_H(objective)
333
340
  self.global_state["H"] = H
334
341
 
335
- self.trust_region_update(var, H=H)
342
+ self.trust_region_update(objective, H=H)
336
343
 
337
344
 
338
345
  @final
339
346
  @torch.no_grad
340
- def apply(self, var):
347
+ def apply(self, objective):
341
348
  H = self.global_state.get('H', None)
342
349
 
343
350
  # -------------------------------- inner step -------------------------------- #
344
- update = var.get_update()
345
- if 'inner' in self.children:
346
- update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
351
+ objective = self.inner_step("inner", objective, must_exist=False)
347
352
 
348
353
  # ----------------------------------- apply ---------------------------------- #
349
- return self.trust_region_apply(var=var, tensors=update, H=H)
354
+ return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)
350
355
 
@@ -3,15 +3,17 @@ from functools import partial
3
3
 
4
4
  import torch
5
5
 
6
- from ...core.module import Module
6
+ from ...core import Module, Objective
7
7
  from ...utils import tofloat
8
8
 
9
9
 
10
- def _reset_except_self(optimizer, var, self: Module):
11
- for m in optimizer.unrolled_modules:
10
+ def _reset_except_self(objective: Objective, modules, self: Module):
11
+ assert objective.modular is not None
12
+ for m in objective.modular.flat_modules:
12
13
  if m is not self:
13
14
  m.reset()
14
15
 
16
+
15
17
  class SVRG(Module):
16
18
  """Stochastic variance reduced gradient method (SVRG).
17
19
 
@@ -71,7 +73,7 @@ class SVRG(Module):
71
73
  ```
72
74
  ## Notes
73
75
 
74
- The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
76
+ The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
75
77
  - ``x`` is current parameters
76
78
  - ``x_0`` is initial parameters, where full gradient was computed
77
79
  - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
@@ -83,17 +85,18 @@ class SVRG(Module):
83
85
  defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
84
86
  super().__init__(defaults)
85
87
 
88
+
86
89
  @torch.no_grad
87
- def step(self, var):
88
- params = var.params
89
- closure = var.closure
90
+ def update(self, objective):
91
+ params = objective.params
92
+ closure = objective.closure
90
93
  assert closure is not None
91
94
 
92
95
  if "full_grad" not in self.global_state:
93
96
 
94
97
  # -------------------------- calculate full gradient ------------------------- #
95
- if "full_closure" in var.storage:
96
- full_closure = var.storage['full_closure']
98
+ if "full_closure" in objective.storage:
99
+ full_closure = objective.storage['full_closure']
97
100
  with torch.enable_grad():
98
101
  full_loss = full_closure()
99
102
  if all(p.grad is None for p in params):
@@ -116,12 +119,12 @@ class SVRG(Module):
116
119
 
117
120
  # accumulate grads
118
121
  accumulator = self.get_state(params, 'accumulator')
119
- grad = var.get_grad()
122
+ grad = objective.get_grads()
120
123
  torch._foreach_add_(accumulator, grad)
121
124
 
122
125
  # accumulate loss
123
126
  loss_accumulator = self.global_state.get('loss_accumulator', 0)
124
- loss_accumulator += tofloat(var.loss)
127
+ loss_accumulator += tofloat(objective.loss)
125
128
  self.global_state['loss_accumulator'] = loss_accumulator
126
129
 
127
130
  # on nth step, use the accumulated gradient
@@ -136,10 +139,10 @@ class SVRG(Module):
136
139
 
137
140
  # otherwise skip update until enough grads are accumulated
138
141
  else:
139
- var.update = None
140
- var.stop = True
141
- var.skip_update = True
142
- return var
142
+ objective.updates = None
143
+ objective.stop = True
144
+ objective.skip_update = True
145
+ return
143
146
 
144
147
 
145
148
  svrg_steps = self.defaults['svrg_steps']
@@ -194,7 +197,7 @@ class SVRG(Module):
194
197
 
195
198
  return closure(False)
196
199
 
197
- var.closure = svrg_closure
200
+ objective.closure = svrg_closure
198
201
 
199
202
  # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
200
203
  if current_svrg_step >= svrg_steps:
@@ -203,6 +206,6 @@ class SVRG(Module):
203
206
  del self.global_state['full_loss']
204
207
  del self.global_state['x_0']
205
208
  if self.defaults['reset_before_accum']:
206
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
209
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
207
210
 
208
- return var
211
+ def apply(self, objective): return objective
@@ -1 +1,2 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
2
+ from .reinit import RandomReinitialize
@@ -0,0 +1,83 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils import NumberList, TensorList
7
+
8
+
9
+ def _reset_except_self(optimizer, var, self: Module):
10
+ for m in optimizer.unrolled_modules:
11
+ if m is not self:
12
+ m.reset()
13
+
14
+ class RandomReinitialize(Module):
15
+ """On each step with probability ``p_reinit`` trigger reinitialization,
16
+ whereby ``p_weights`` weights are reset to their initial values.
17
+
18
+ This modifies the parameters directly. Place it as the first module.
19
+
20
+ Args:
21
+ p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
22
+ p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
23
+ store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
24
+ beta (float, optional):
25
+ whenever ``store_every`` is triggered, uses linear interpolation with this beta.
26
+ If ``store_every=1``, this can be set to some value close to 1 such as 0.999
27
+ to reinitialize to slow parameter EMA. Defaults to 0.
28
+ reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
29
+ seed (int | None, optional): random seed.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ p_reinit: float = 0.01,
35
+ p_weights: float = 0.1,
36
+ store_every: int | None = None,
37
+ beta: float = 0,
38
+ reset: bool = False,
39
+ seed: int | None = None,
40
+ ):
41
+ defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
42
+ super().__init__(defaults)
43
+
44
+ def update(self, objective):
45
+ # this stores initial values to per-parameter states
46
+ p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)
47
+
48
+ # store new params every store_every steps
49
+ step = self.global_state.get("step", 0)
50
+ self.global_state["step"] = step + 1
51
+
52
+ store_every = self.defaults["store_every"]
53
+ if (store_every is not None and step % store_every == 0):
54
+ beta = self.get_settings(objective.params, "beta", cls=NumberList)
55
+ p_init.lerp_(objective.params, weight=(1 - beta))
56
+
57
+ @torch.no_grad
58
+ def apply(self, objective):
59
+ p_reinit = self.defaults["p_reinit"]
60
+ device = objective.params[0].device
61
+ generator = self.get_generator(device, self.defaults["seed"])
62
+
63
+ # determine whether to trigger reinitialization
64
+ reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit
65
+
66
+ # reinitialize
67
+ if reinitialize:
68
+ params = TensorList(objective.params)
69
+ p_init = self.get_state(params, "p_init", init=params)
70
+
71
+
72
+ # mask with p_weights entries being True
73
+ p_weights = self.get_settings(params, "p_weights")
74
+ mask = params.bernoulli_like(p_weights, generator=generator).as_bool()
75
+
76
+ # set weights at mask to their initialization
77
+ params.masked_set_(mask, p_init)
78
+
79
+ # reset
80
+ if self.defaults["reset"]:
81
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
82
+
83
+ return objective
@@ -3,7 +3,7 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform
6
+ from ...core import Module, TensorTransform
7
7
  from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
8
8
 
9
9
 
@@ -21,7 +21,7 @@ def weight_decay_(
21
21
  return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
22
22
 
23
23
 
24
- class WeightDecay(Transform):
24
+ class WeightDecay(TensorTransform):
25
25
  """Weight decay.
26
26
 
27
27
  Args:
@@ -63,19 +63,19 @@ class WeightDecay(Transform):
63
63
  ```
64
64
 
65
65
  """
66
- def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
66
+ def __init__(self, weight_decay: float, ord: int = 2):
67
67
 
68
68
  defaults = dict(weight_decay=weight_decay, ord=ord)
69
- super().__init__(defaults, uses_grad=False, target=target)
69
+ super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
72
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
73
73
  weight_decay = NumberList(s['weight_decay'] for s in settings)
74
74
  ord = settings[0]['ord']
75
75
 
76
76
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
77
77
 
78
- class RelativeWeightDecay(Transform):
78
+ class RelativeWeightDecay(TensorTransform):
79
79
  """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
80
80
 
81
81
  Args:
@@ -117,13 +117,12 @@ class RelativeWeightDecay(Transform):
117
117
  ord: int = 2,
118
118
  norm_input: Literal["update", "grad", "params"] = "update",
119
119
  metric: Metrics = 'mad',
120
- target: Target = "update",
121
120
  ):
122
121
  defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
123
- super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
122
+ super().__init__(defaults, uses_grad=norm_input == 'grad')
124
123
 
125
124
  @torch.no_grad
126
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
127
126
  weight_decay = NumberList(s['weight_decay'] for s in settings)
128
127
 
129
128
  ord = settings[0]['ord']
@@ -161,9 +160,9 @@ class DirectWeightDecay(Module):
161
160
  super().__init__(defaults)
162
161
 
163
162
  @torch.no_grad
164
- def step(self, var):
165
- weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
163
+ def apply(self, objective):
164
+ weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
166
165
  ord = self.defaults['ord']
167
166
 
168
- decay_weights_(var.params, weight_decay, ord)
169
- return var
167
+ decay_weights_(objective.params, weight_decay, ord)
168
+ return objective