torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,154 @@
1
+ """Learning rate"""
2
+ import torch
3
+ import random
4
+
5
+ from ...core import Transform
6
+ from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
+
8
+ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
9
+ """multiplies by lr if lr is not 1"""
10
+ if generic_ne(lr, 1):
11
+ if inplace: return tensors.mul_(lr)
12
+ return tensors * lr
13
+ return tensors
14
+
15
+ class LR(Transform):
16
+ """Learning rate. Adding this module also adds support for LR schedulers."""
17
+ def __init__(self, lr: float):
18
+ defaults=dict(lr=lr)
19
+ super().__init__(defaults, uses_grad=False)
20
+
21
+ @torch.no_grad
22
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
23
+ return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
+
25
+ class StepSize(Transform):
26
+ """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
+ def __init__(self, step_size: float, key = 'step_size'):
28
+ defaults={"key": key, key: step_size}
29
+ super().__init__(defaults, uses_grad=False)
30
+
31
+ @torch.no_grad
32
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
33
+ return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
+
35
+
36
+ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
37
+ """returns warm up lr scalar"""
38
+ if step > steps: return end_lr
39
+ return start_lr + (end_lr - start_lr) * (step / steps)
40
+
41
+ class Warmup(Transform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
43
+
44
+ Args:
45
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
+ start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
47
+ end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
48
+
49
+ Example:
50
+ Adam with 1000 steps warmup
51
+
52
+ .. code-block:: python
53
+
54
+ opt = tz.Modular(
55
+ model.parameters(),
56
+ tz.m.Adam(),
57
+ tz.m.LR(1e-2),
58
+ tz.m.Warmup(steps=1000)
59
+ )
60
+
61
+ """
62
+ def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
63
+ defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
64
+ super().__init__(defaults, uses_grad=False)
65
+
66
+ @torch.no_grad
67
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
68
+ start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
+ num_steps = settings[0]['steps']
70
+ step = self.global_state.get('step', 0)
71
+
72
+ tensors = lazy_lr(
73
+ TensorList(tensors),
74
+ lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
75
+ inplace=True
76
+ )
77
+ self.global_state['step'] = step + 1
78
+ return tensors
79
+
80
+ class WarmupNormClip(Transform):
81
+ """Warmup via clipping of the update norm.
82
+
83
+ Args:
84
+ start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
85
+ end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
86
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
87
+
88
+ Example:
89
+ Adam with 1000 steps norm clip warmup
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.Adam(),
96
+ tz.m.WarmupNormClip(steps=1000)
97
+ tz.m.LR(1e-2),
98
+ )
99
+ """
100
+ def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
101
+ defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
102
+ super().__init__(defaults, uses_grad=False)
103
+
104
+ @torch.no_grad
105
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
106
+ start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
+ num_steps = settings[0]['steps']
108
+ step = self.global_state.get('step', 0)
109
+ if step > num_steps: return tensors
110
+
111
+ tensors = TensorList(tensors)
112
+ norm = tensors.global_vector_norm()
113
+ current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
114
+ if norm > current_max_norm:
115
+ tensors.mul_(current_max_norm / norm)
116
+
117
+ self.global_state['step'] = step + 1
118
+ return tensors
119
+
120
+
121
+ class RandomStepSize(Transform):
122
+ """Uses random global or layer-wise step size from `low` to `high`.
123
+
124
+ Args:
125
+ low (float, optional): minimum learning rate. Defaults to 0.
126
+ high (float, optional): maximum learning rate. Defaults to 1.
127
+ parameterwise (bool, optional):
128
+ if True, generate random step size for each parameter separately,
129
+ if False generate one global random step size. Defaults to False.
130
+ """
131
+ def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
132
+ defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
133
+ super().__init__(defaults, uses_grad=False)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ s = settings[0]
138
+ parameterwise = s['parameterwise']
139
+
140
+ seed = s['seed']
141
+ if 'generator' not in self.global_state:
142
+ self.global_state['generator'] = random.Random(seed)
143
+ generator: random.Random = self.global_state['generator']
144
+
145
+ if parameterwise:
146
+ low, high = unpack_dicts(settings, 'low', 'high')
147
+ lr = [generator.uniform(l, h) for l, h in zip(low, high)]
148
+ else:
149
+ low = s['low']
150
+ high = s['high']
151
+ lr = generator.uniform(low, high)
152
+
153
+ torch._foreach_mul_(tensors, lr)
154
+ return tensors
@@ -0,0 +1,14 @@
1
+ from .termination import (
2
+ TerminateAfterNEvaluations,
3
+ TerminateAfterNSeconds,
4
+ TerminateAfterNSteps,
5
+ TerminateAll,
6
+ TerminateAny,
7
+ TerminateByGradientNorm,
8
+ TerminateByUpdateNorm,
9
+ TerminateOnLossReached,
10
+ TerminateOnNoImprovement,
11
+ TerminationCriteriaBase,
12
+ TerminateNever,
13
+ make_termination_criteria
14
+ )
@@ -0,0 +1,207 @@
1
+ import time
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Sequence
4
+ from typing import cast
5
+
6
+ import torch
7
+
8
+ from ...core import Module, Var
9
+ from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
10
+
11
+
12
+ class TerminationCriteriaBase(Module):
13
+ def __init__(self, defaults:dict | None = None, n: int = 1):
14
+ if defaults is None: defaults = {}
15
+ safe_dict_update_(defaults, {"_n": n})
16
+ super().__init__(defaults)
17
+
18
+ @abstractmethod
19
+ def termination_criteria(self, var: Var) -> bool:
20
+ ...
21
+
22
+ def should_terminate(self, var: Var) -> bool:
23
+ n_bad = self.global_state.get('_n_bad', 0)
24
+ n = self.defaults['_n']
25
+
26
+ if self.termination_criteria(var):
27
+ n_bad += 1
28
+ if n_bad >= n:
29
+ self.global_state['_n_bad'] = 0
30
+ return True
31
+
32
+ else:
33
+ n_bad = 0
34
+
35
+ self.global_state['_n_bad'] = n_bad
36
+ return False
37
+
38
+
39
+ def update(self, var):
40
+ var.should_terminate = self.should_terminate(var)
41
+ if var.should_terminate: self.global_state['_n_bad'] = 0
42
+
43
+ def apply(self, var):
44
+ return var
45
+
46
+
47
+ class TerminateAfterNSteps(TerminationCriteriaBase):
48
+ def __init__(self, steps:int):
49
+ defaults = dict(steps=steps)
50
+ super().__init__(defaults)
51
+
52
+ def termination_criteria(self, var):
53
+ step = self.global_state.get('step', 0)
54
+ self.global_state['step'] = step + 1
55
+
56
+ max_steps = self.defaults['steps']
57
+ return step >= max_steps
58
+
59
+ class TerminateAfterNEvaluations(TerminationCriteriaBase):
60
+ def __init__(self, maxevals:int):
61
+ defaults = dict(maxevals=maxevals)
62
+ super().__init__(defaults)
63
+
64
+ def termination_criteria(self, var):
65
+ maxevals = self.defaults['maxevals']
66
+ return var.modular.num_evaluations >= maxevals
67
+
68
+ class TerminateAfterNSeconds(TerminationCriteriaBase):
69
+ def __init__(self, seconds:float, sec_fn = time.time):
70
+ defaults = dict(seconds=seconds, sec_fn=sec_fn)
71
+ super().__init__(defaults)
72
+
73
+ def termination_criteria(self, var):
74
+ max_seconds = self.defaults['seconds']
75
+ sec_fn = self.defaults['sec_fn']
76
+
77
+ if 'start' not in self.global_state:
78
+ self.global_state['start'] = sec_fn()
79
+ return False
80
+
81
+ seconds_passed = sec_fn() - self.global_state['start']
82
+ return seconds_passed >= max_seconds
83
+
84
+
85
+
86
+ class TerminateByGradientNorm(TerminationCriteriaBase):
87
+ def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
88
+ defaults = dict(tol=tol, ord=ord)
89
+ super().__init__(defaults, n=n)
90
+
91
+ def termination_criteria(self, var):
92
+ tol = self.defaults['tol']
93
+ ord = self.defaults['ord']
94
+ return TensorList(var.get_grad()).global_metric(ord) <= tol
95
+
96
+
97
+ class TerminateByUpdateNorm(TerminationCriteriaBase):
98
+ """update is calculated as parameter difference"""
99
+ def __init__(self, tol:float = 1e-8, n: int = 3, ord: Metrics = 2):
100
+ defaults = dict(tol=tol, ord=ord)
101
+ super().__init__(defaults, n=n)
102
+
103
+ def termination_criteria(self, var):
104
+ step = self.global_state.get('step', 0)
105
+ self.global_state['step'] = step + 1
106
+
107
+ tol = self.defaults['tol']
108
+ ord = self.defaults['ord']
109
+
110
+ p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
111
+ if step == 0:
112
+ p_prev.copy_(var.params)
113
+ return False
114
+
115
+ should_terminate = (p_prev - var.params).global_metric(ord) <= tol
116
+ p_prev.copy_(var.params)
117
+ return should_terminate
118
+
119
+
120
+ class TerminateOnNoImprovement(TerminationCriteriaBase):
121
+ def __init__(self, tol:float = 1e-8, n: int = 10):
122
+ defaults = dict(tol=tol)
123
+ super().__init__(defaults, n=n)
124
+
125
+ def termination_criteria(self, var):
126
+ tol = self.defaults['tol']
127
+
128
+ f = tofloat(var.get_loss(False))
129
+ if 'f_min' not in self.global_state:
130
+ self.global_state['f_min'] = f
131
+ return False
132
+
133
+ f_min = self.global_state['f_min']
134
+ d = f_min - f
135
+ should_terminate = d <= tol
136
+ self.global_state['f_min'] = min(f, f_min)
137
+ return should_terminate
138
+
139
+ class TerminateOnLossReached(TerminationCriteriaBase):
140
+ def __init__(self, value: float):
141
+ defaults = dict(value=value)
142
+ super().__init__(defaults)
143
+
144
+ def termination_criteria(self, var):
145
+ value = self.defaults['value']
146
+ return var.get_loss(False) <= value
147
+
148
+ class TerminateAny(TerminationCriteriaBase):
149
+ def __init__(self, *criteria: TerminationCriteriaBase):
150
+ super().__init__()
151
+
152
+ self.set_children_sequence(criteria)
153
+
154
+ def termination_criteria(self, var: Var) -> bool:
155
+ for c in self.get_children_sequence():
156
+ if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
157
+
158
+ return False
159
+
160
+ class TerminateAll(TerminationCriteriaBase):
161
+ def __init__(self, *criteria: TerminationCriteriaBase):
162
+ super().__init__()
163
+
164
+ self.set_children_sequence(criteria)
165
+
166
+ def termination_criteria(self, var: Var) -> bool:
167
+ for c in self.get_children_sequence():
168
+ if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
169
+
170
+ return True
171
+
172
+ class TerminateNever(TerminationCriteriaBase):
173
+ def __init__(self):
174
+ super().__init__()
175
+
176
+ def termination_criteria(self, var): return False
177
+
178
+ def make_termination_criteria(
179
+ ftol: float | None = None,
180
+ gtol: float | None = None,
181
+ stol: float | None = None,
182
+ maxiter: int | None = None,
183
+ maxeval: int | None = None,
184
+ maxsec: float | None = None,
185
+ target_loss: float | None = None,
186
+ extra: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
187
+ n: int = 3,
188
+ ):
189
+ criteria: list[TerminationCriteriaBase] = []
190
+
191
+ if ftol is not None: criteria.append(TerminateOnNoImprovement(ftol, n=n))
192
+ if gtol is not None: criteria.append(TerminateByGradientNorm(gtol, n=n))
193
+ if stol is not None: criteria.append(TerminateByUpdateNorm(stol, n=n))
194
+
195
+ if maxiter is not None: criteria.append(TerminateAfterNSteps(maxiter))
196
+ if maxeval is not None: criteria.append(TerminateAfterNEvaluations(maxeval))
197
+ if maxsec is not None: criteria.append(TerminateAfterNSeconds(maxsec))
198
+
199
+ if target_loss is not None: criteria.append(TerminateOnLossReached(target_loss))
200
+
201
+ if extra is not None:
202
+ if isinstance(extra, TerminationCriteriaBase): criteria.append(extra)
203
+ else: criteria.extend(extra)
204
+
205
+ if len(criteria) == 0: return TerminateNever()
206
+ if len(criteria) == 1: return criteria[0]
207
+ return TerminateAny(*criteria)
@@ -0,0 +1,5 @@
1
+ from .trust_region import TrustRegionBase
2
+ from .cubic_regularization import CubicRegularization
3
+ from .trust_cg import TrustCG
4
+ from .levenberg_marquardt import LevenbergMarquardt
5
+ from .dogleg import Dogleg
@@ -0,0 +1,170 @@
1
+ # pylint:disable=not-callable
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module
7
+ from ...utils import TensorList, vec_to_tensors
8
+ from ...utils.linalg.linear_operator import LinearOperator
9
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
10
+
11
+
12
+ # code from https://github.com/konstmish/opt_methods/blob/master/optmethods/second_order/cubic.py
13
+ # ported to pytorch and linear operator
14
+ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_params_plus_x_fn: Callable | None, it_max=100, epsilon=1e-8, ):
15
+ """
16
+ Solve min_z <g, z-x> + 1/2<z-x, H(z-x)> + M/3 ||z-x||^3
17
+
18
+ For explanation of Cauchy point, see "Gradient Descent
19
+ Efficiently Finds the Cubic-Regularized Non-Convex Newton Step"
20
+ https://arxiv.org/pdf/1612.00547.pdf
21
+ Other potential implementations can be found in paper
22
+ "Adaptive cubic regularisation methods"
23
+ https://people.maths.ox.ac.uk/cartis/papers/ARCpI.pdf
24
+ """
25
+ solver_it = 1
26
+ newton_step = H.solve(g).neg_()
27
+ if M == 0:
28
+ return newton_step, solver_it
29
+
30
+ def cauchy_point(g, H:LinearOperator, M):
31
+ if torch.linalg.vector_norm(g) == 0 or M == 0:
32
+ return 0 * g
33
+ g_dir = g / torch.linalg.vector_norm(g)
34
+ H_g_g = H.matvec(g_dir) @ g_dir
35
+ R = -H_g_g / (2*M) + torch.sqrt((H_g_g/M)**2/4 + torch.linalg.vector_norm(g)/M)
36
+ return -R * g_dir
37
+
38
+ def conv_criterion(s, r):
39
+ """
40
+ The convergence criterion is an increasing and concave function in r
41
+ and it is equal to 0 only if r is the solution to the cubic problem
42
+ """
43
+ s_norm = torch.linalg.vector_norm(s)
44
+ return 1/s_norm - 1/r
45
+
46
+ # Solution s satisfies ||s|| >= Cauchy_radius
47
+ r_min = torch.linalg.vector_norm(cauchy_point(g, H, M))
48
+
49
+ if (loss_at_params_plus_x_fn is not None) and (f > loss_at_params_plus_x_fn(newton_step)):
50
+ return newton_step, solver_it
51
+
52
+ r_max = torch.linalg.vector_norm(newton_step)
53
+ if r_max - r_min < epsilon:
54
+ return newton_step, solver_it
55
+
56
+ # id_matrix = torch.eye(g.size(0), device=g.device, dtype=g.dtype)
57
+ s_lam = None
58
+ for _ in range(it_max):
59
+ r_try = (r_min + r_max) / 2
60
+ lam = r_try * M
61
+ s_lam = H.add_diagonal(lam).solve(g).neg()
62
+ # s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
63
+ solver_it += 1
64
+ crit = conv_criterion(s_lam, r_try)
65
+ if torch.abs(crit) < epsilon:
66
+ return s_lam, solver_it
67
+ if crit < 0:
68
+ r_min = r_try
69
+ else:
70
+ r_max = r_try
71
+ if r_max - r_min < epsilon:
72
+ break
73
+ assert s_lam is not None
74
+ return s_lam, solver_it
75
+
76
+
77
+ class CubicRegularization(TrustRegionBase):
78
+ """Cubic regularization.
79
+
80
+ Args:
81
+ hess_module (Module | None, optional):
82
+ A module that maintains a hessian approximation (not hessian inverse!).
83
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
84
+ When using quasi-newton methods, set `inverse=False` when constructing them.
85
+ eta (float, optional):
86
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
87
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
88
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
89
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
90
+ rho_good (float, optional):
91
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
92
+ rho_bad (float, optional):
93
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
94
+ init (float, optional): Initial trust region value. Defaults to 1.
95
+ maxiter (float, optional): maximum iterations when solving cubic subproblem, defaults to 1e-7.
96
+ eps (float, optional): epsilon for the solver, defaults to 1e-8.
97
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
98
+ max_attempts (max_attempts, optional):
99
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
100
+ this limit is exceeded. Defaults to 10.
101
+ fallback (bool, optional):
102
+ if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
103
+ be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
104
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
105
+
106
+
107
+ Examples:
108
+ Cubic regularized newton
109
+
110
+ .. code-block:: python
111
+
112
+ opt = tz.Modular(
113
+ model.parameters(),
114
+ tz.m.CubicRegularization(tz.m.Newton()),
115
+ )
116
+
117
+ """
118
+ def __init__(
119
+ self,
120
+ hess_module: Chainable,
121
+ eta: float= 0.0,
122
+ nplus: float = 3.5,
123
+ nminus: float = 0.25,
124
+ rho_good: float = 0.99,
125
+ rho_bad: float = 1e-4,
126
+ init: float = 1,
127
+ max_attempts: int = 10,
128
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
129
+ maxiter: int = 100,
130
+ eps: float = 1e-8,
131
+ check_decrease:bool=False,
132
+ update_freq: int = 1,
133
+ inner: Chainable | None = None,
134
+ ):
135
+ defaults = dict(maxiter=maxiter, eps=eps, check_decrease=check_decrease)
136
+ super().__init__(
137
+ defaults=defaults,
138
+ hess_module=hess_module,
139
+ eta=eta,
140
+ nplus=nplus,
141
+ nminus=nminus,
142
+ rho_good=rho_good,
143
+ rho_bad=rho_bad,
144
+ init=init,
145
+ max_attempts=max_attempts,
146
+ radius_strategy=radius_strategy,
147
+ update_freq=update_freq,
148
+ inner=inner,
149
+
150
+ boundary_tol=None,
151
+ radius_fn=None,
152
+ )
153
+
154
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
155
+ params = TensorList(params)
156
+
157
+ loss_at_params_plus_x_fn = None
158
+ if settings['check_decrease']:
159
+ def closure_plus_x(x):
160
+ x_unflat = vec_to_tensors(x, params)
161
+ params.add_(x_unflat)
162
+ loss_x = closure(False)
163
+ params.sub_(x_unflat)
164
+ return loss_x
165
+ loss_at_params_plus_x_fn = closure_plus_x
166
+
167
+
168
+ d, _ = ls_cubic_solver(f=f, g=g, H=H, M=1/radius, loss_at_params_plus_x_fn=loss_at_params_plus_x_fn,
169
+ it_max=settings['maxiter'], epsilon=settings['eps'])
170
+ return d.neg_()
@@ -0,0 +1,92 @@
1
+ # pylint:disable=not-callable
2
+ import torch
3
+
4
+ from ...core import Chainable, Module
5
+ from ...utils import TensorList, vec_to_tensors
6
+ from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
+
8
+ class Dogleg(TrustRegionBase):
9
+ """Dogleg trust region algorithm.
10
+
11
+
12
+ Args:
13
+ hess_module (Module | None, optional):
14
+ A module that maintains a hessian approximation (not hessian inverse!).
15
+ This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
16
+ When using quasi-newton methods, set `inverse=False` when constructing them.
17
+ eta (float, optional):
18
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
19
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
20
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
21
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
22
+ rho_good (float, optional):
23
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
24
+ rho_bad (float, optional):
25
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
26
+ init (float, optional): Initial trust region value. Defaults to 1.
27
+ update_freq (int, optional): frequency of updating the hessian. Defaults to 1.
28
+ max_attempts (max_attempts, optional):
29
+ maximum number of trust region size size reductions per step. A zero update vector is returned when
30
+ this limit is exceeded. Defaults to 10.
31
+ inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
32
+
33
+ """
34
+ def __init__(
35
+ self,
36
+ hess_module: Chainable,
37
+ eta: float= 0.0,
38
+ nplus: float = 2,
39
+ nminus: float = 0.25,
40
+ rho_good: float = 0.75,
41
+ rho_bad: float = 0.25,
42
+ boundary_tol: float | None = None,
43
+ init: float = 1,
44
+ max_attempts: int = 10,
45
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
46
+ update_freq: int = 1,
47
+ inner: Chainable | None = None,
48
+ ):
49
+ defaults = dict()
50
+ super().__init__(
51
+ defaults=defaults,
52
+ hess_module=hess_module,
53
+ eta=eta,
54
+ nplus=nplus,
55
+ nminus=nminus,
56
+ rho_good=rho_good,
57
+ rho_bad=rho_bad,
58
+ boundary_tol=boundary_tol,
59
+ init=init,
60
+ max_attempts=max_attempts,
61
+ radius_strategy=radius_strategy,
62
+ update_freq=update_freq,
63
+ inner=inner,
64
+
65
+ radius_fn=torch.linalg.vector_norm,
66
+ )
67
+
68
+ def trust_solve(self, f, g, H, radius, params, closure, settings):
69
+ if radius > 2: radius = self.global_state['radius'] = 2
70
+ eps = torch.finfo(g.dtype).tiny * 2
71
+
72
+ gHg = g.dot(H.matvec(g))
73
+ if gHg <= eps:
74
+ return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
75
+
76
+ p_cauchy = (g.dot(g) / gHg) * g
77
+ p_newton = H.solve(g)
78
+
79
+ a = p_newton - p_cauchy
80
+ b = p_cauchy
81
+
82
+ aa = a.dot(a)
83
+ if aa < eps:
84
+ return (radius / torch.linalg.vector_norm(g)) * g # pylint:disable=not-callable
85
+
86
+ ab = a.dot(b)
87
+ bb = b.dot(b)
88
+ c = bb - radius**2
89
+ discriminant = (2*ab)**2 - 4*aa*c
90
+ beta = (-2*ab + torch.sqrt(discriminant.clip(min=0))) / (2 * aa)
91
+ return p_cauchy + beta * (p_newton - p_cauchy)
92
+