torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
  import random
4
4
 
5
- from ...core import Transform
5
+ from ...core import TensorTransform
6
6
  from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
7
 
8
8
  def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
@@ -12,24 +12,24 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
12
12
  return tensors * lr
13
13
  return tensors
14
14
 
15
- class LR(Transform):
15
+ class LR(TensorTransform):
16
16
  """Learning rate. Adding this module also adds support for LR schedulers."""
17
17
  def __init__(self, lr: float):
18
18
  defaults=dict(lr=lr)
19
- super().__init__(defaults, uses_grad=False)
19
+ super().__init__(defaults)
20
20
 
21
21
  @torch.no_grad
22
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
22
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
23
23
  return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
24
 
25
- class StepSize(Transform):
25
+ class StepSize(TensorTransform):
26
26
  """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
27
  def __init__(self, step_size: float, key = 'step_size'):
28
28
  defaults={"key": key, key: step_size}
29
- super().__init__(defaults, uses_grad=False)
29
+ super().__init__(defaults)
30
30
 
31
31
  @torch.no_grad
32
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
32
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
33
33
  return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
34
 
35
35
 
@@ -38,8 +38,8 @@ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberLi
38
38
  if step > steps: return end_lr
39
39
  return start_lr + (end_lr - start_lr) * (step / steps)
40
40
 
41
- class Warmup(Transform):
42
- """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
+ class Warmup(TensorTransform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from ``start_lr`` to ``end_lr`` over ``steps`` steps.
43
43
 
44
44
  Args:
45
45
  steps (int, optional): number of steps to perform warmup for. Defaults to 100.
@@ -51,7 +51,7 @@ class Warmup(Transform):
51
51
 
52
52
  .. code-block:: python
53
53
 
54
- opt = tz.Modular(
54
+ opt = tz.Optimizer(
55
55
  model.parameters(),
56
56
  tz.m.Adam(),
57
57
  tz.m.LR(1e-2),
@@ -64,7 +64,7 @@ class Warmup(Transform):
64
64
  super().__init__(defaults, uses_grad=False)
65
65
 
66
66
  @torch.no_grad
67
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
67
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
68
68
  start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
69
  num_steps = settings[0]['steps']
70
70
  step = self.global_state.get('step', 0)
@@ -77,7 +77,7 @@ class Warmup(Transform):
77
77
  self.global_state['step'] = step + 1
78
78
  return tensors
79
79
 
80
- class WarmupNormClip(Transform):
80
+ class WarmupNormClip(TensorTransform):
81
81
  """Warmup via clipping of the update norm.
82
82
 
83
83
  Args:
@@ -90,7 +90,7 @@ class WarmupNormClip(Transform):
90
90
 
91
91
  .. code-block:: python
92
92
 
93
- opt = tz.Modular(
93
+ opt = tz.Optimizer(
94
94
  model.parameters(),
95
95
  tz.m.Adam(),
96
96
  tz.m.WarmupNormClip(steps=1000)
@@ -102,7 +102,7 @@ class WarmupNormClip(Transform):
102
102
  super().__init__(defaults, uses_grad=False)
103
103
 
104
104
  @torch.no_grad
105
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
105
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
106
106
  start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
107
  num_steps = settings[0]['steps']
108
108
  step = self.global_state.get('step', 0)
@@ -118,8 +118,8 @@ class WarmupNormClip(Transform):
118
118
  return tensors
119
119
 
120
120
 
121
- class RandomStepSize(Transform):
122
- """Uses random global or layer-wise step size from `low` to `high`.
121
+ class RandomStepSize(TensorTransform):
122
+ """Uses random global or layer-wise step size from ``low`` to ``high``.
123
123
 
124
124
  Args:
125
125
  low (float, optional): minimum learning rate. Defaults to 0.
@@ -133,7 +133,7 @@ class RandomStepSize(Transform):
133
133
  super().__init__(defaults, uses_grad=False)
134
134
 
135
135
  @torch.no_grad
136
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
136
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
137
137
  s = settings[0]
138
138
  parameterwise = s['parameterwise']
139
139
 
@@ -1,11 +1,11 @@
1
1
  import time
2
2
  from abc import ABC, abstractmethod
3
3
  from collections.abc import Sequence
4
- from typing import cast
4
+ from typing import cast, final
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var
8
+ from ...core import Module, Objective
9
9
  from ...utils import Metrics, TensorList, safe_dict_update_, tofloat
10
10
 
11
11
 
@@ -16,14 +16,15 @@ class TerminationCriteriaBase(Module):
16
16
  super().__init__(defaults)
17
17
 
18
18
  @abstractmethod
19
- def termination_criteria(self, var: Var) -> bool:
19
+ def termination_criteria(self, objective: Objective) -> bool:
20
20
  ...
21
21
 
22
- def should_terminate(self, var: Var) -> bool:
22
+ @final
23
+ def should_terminate(self, objective: Objective) -> bool:
23
24
  n_bad = self.global_state.get('_n_bad', 0)
24
25
  n = self.defaults['_n']
25
26
 
26
- if self.termination_criteria(var):
27
+ if self.termination_criteria(objective):
27
28
  n_bad += 1
28
29
  if n_bad >= n:
29
30
  self.global_state['_n_bad'] = 0
@@ -36,12 +37,12 @@ class TerminationCriteriaBase(Module):
36
37
  return False
37
38
 
38
39
 
39
- def update(self, var):
40
- var.should_terminate = self.should_terminate(var)
41
- if var.should_terminate: self.global_state['_n_bad'] = 0
40
+ def update(self, objective):
41
+ objective.should_terminate = self.should_terminate(objective)
42
+ if objective.should_terminate: self.global_state['_n_bad'] = 0
42
43
 
43
- def apply(self, var):
44
- return var
44
+ def apply(self, objective):
45
+ return objective
45
46
 
46
47
 
47
48
  class TerminateAfterNSteps(TerminationCriteriaBase):
@@ -49,7 +50,7 @@ class TerminateAfterNSteps(TerminationCriteriaBase):
49
50
  defaults = dict(steps=steps)
50
51
  super().__init__(defaults)
51
52
 
52
- def termination_criteria(self, var):
53
+ def termination_criteria(self, objective):
53
54
  step = self.global_state.get('step', 0)
54
55
  self.global_state['step'] = step + 1
55
56
 
@@ -61,16 +62,17 @@ class TerminateAfterNEvaluations(TerminationCriteriaBase):
61
62
  defaults = dict(maxevals=maxevals)
62
63
  super().__init__(defaults)
63
64
 
64
- def termination_criteria(self, var):
65
+ def termination_criteria(self, objective):
65
66
  maxevals = self.defaults['maxevals']
66
- return var.modular.num_evaluations >= maxevals
67
+ assert objective.modular is not None
68
+ return objective.modular.num_evaluations >= maxevals
67
69
 
68
70
  class TerminateAfterNSeconds(TerminationCriteriaBase):
69
71
  def __init__(self, seconds:float, sec_fn = time.time):
70
72
  defaults = dict(seconds=seconds, sec_fn=sec_fn)
71
73
  super().__init__(defaults)
72
74
 
73
- def termination_criteria(self, var):
75
+ def termination_criteria(self, objective):
74
76
  max_seconds = self.defaults['seconds']
75
77
  sec_fn = self.defaults['sec_fn']
76
78
 
@@ -88,10 +90,10 @@ class TerminateByGradientNorm(TerminationCriteriaBase):
88
90
  defaults = dict(tol=tol, ord=ord)
89
91
  super().__init__(defaults, n=n)
90
92
 
91
- def termination_criteria(self, var):
93
+ def termination_criteria(self, objective):
92
94
  tol = self.defaults['tol']
93
95
  ord = self.defaults['ord']
94
- return TensorList(var.get_grad()).global_metric(ord) <= tol
96
+ return TensorList(objective.get_grads()).global_metric(ord) <= tol
95
97
 
96
98
 
97
99
  class TerminateByUpdateNorm(TerminationCriteriaBase):
@@ -100,20 +102,20 @@ class TerminateByUpdateNorm(TerminationCriteriaBase):
100
102
  defaults = dict(tol=tol, ord=ord)
101
103
  super().__init__(defaults, n=n)
102
104
 
103
- def termination_criteria(self, var):
105
+ def termination_criteria(self, objective):
104
106
  step = self.global_state.get('step', 0)
105
107
  self.global_state['step'] = step + 1
106
108
 
107
109
  tol = self.defaults['tol']
108
110
  ord = self.defaults['ord']
109
111
 
110
- p_prev = self.get_state(var.params, 'p_prev', cls=TensorList)
112
+ p_prev = self.get_state(objective.params, 'p_prev', cls=TensorList)
111
113
  if step == 0:
112
- p_prev.copy_(var.params)
114
+ p_prev.copy_(objective.params)
113
115
  return False
114
116
 
115
- should_terminate = (p_prev - var.params).global_metric(ord) <= tol
116
- p_prev.copy_(var.params)
117
+ should_terminate = (p_prev - objective.params).global_metric(ord) <= tol
118
+ p_prev.copy_(objective.params)
117
119
  return should_terminate
118
120
 
119
121
 
@@ -122,10 +124,10 @@ class TerminateOnNoImprovement(TerminationCriteriaBase):
122
124
  defaults = dict(tol=tol)
123
125
  super().__init__(defaults, n=n)
124
126
 
125
- def termination_criteria(self, var):
127
+ def termination_criteria(self, objective):
126
128
  tol = self.defaults['tol']
127
129
 
128
- f = tofloat(var.get_loss(False))
130
+ f = tofloat(objective.get_loss(False))
129
131
  if 'f_min' not in self.global_state:
130
132
  self.global_state['f_min'] = f
131
133
  return False
@@ -141,9 +143,9 @@ class TerminateOnLossReached(TerminationCriteriaBase):
141
143
  defaults = dict(value=value)
142
144
  super().__init__(defaults)
143
145
 
144
- def termination_criteria(self, var):
146
+ def termination_criteria(self, objective):
145
147
  value = self.defaults['value']
146
- return var.get_loss(False) <= value
148
+ return objective.get_loss(False) <= value
147
149
 
148
150
  class TerminateAny(TerminationCriteriaBase):
149
151
  def __init__(self, *criteria: TerminationCriteriaBase):
@@ -151,9 +153,9 @@ class TerminateAny(TerminationCriteriaBase):
151
153
 
152
154
  self.set_children_sequence(criteria)
153
155
 
154
- def termination_criteria(self, var: Var) -> bool:
156
+ def termination_criteria(self, objective: Objective) -> bool:
155
157
  for c in self.get_children_sequence():
156
- if cast(TerminationCriteriaBase, c).termination_criteria(var): return True
158
+ if cast(TerminationCriteriaBase, c).termination_criteria(objective): return True
157
159
 
158
160
  return False
159
161
 
@@ -163,9 +165,9 @@ class TerminateAll(TerminationCriteriaBase):
163
165
 
164
166
  self.set_children_sequence(criteria)
165
167
 
166
- def termination_criteria(self, var: Var) -> bool:
168
+ def termination_criteria(self, objective: Objective) -> bool:
167
169
  for c in self.get_children_sequence():
168
- if not cast(TerminationCriteriaBase, c).termination_criteria(var): return False
170
+ if not cast(TerminationCriteriaBase, c).termination_criteria(objective): return False
169
171
 
170
172
  return True
171
173
 
@@ -173,7 +175,7 @@ class TerminateNever(TerminationCriteriaBase):
173
175
  def __init__(self):
174
176
  super().__init__()
175
177
 
176
- def termination_criteria(self, var): return False
178
+ def termination_criteria(self, objective): return False
177
179
 
178
180
  def make_termination_criteria(
179
181
  ftol: float | None = None,
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable, Module
7
7
  from ...utils import TensorList, vec_to_tensors
8
- from ...utils.linalg.linear_operator import LinearOperator
8
+ from ...linalg.linear_operator import LinearOperator
9
9
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
10
10
 
11
11
 
@@ -58,7 +58,7 @@ def ls_cubic_solver(f, g:torch.Tensor, H:LinearOperator, M: float, loss_at_param
58
58
  for _ in range(it_max):
59
59
  r_try = (r_min + r_max) / 2
60
60
  lam = r_try * M
61
- s_lam = H.add_diagonal(lam).solve(g).neg()
61
+ s_lam = H.solve_plus_diag(g, lam).neg()
62
62
  # s_lam = -torch.linalg.solve(B + lam*id_matrix, g)
63
63
  solver_it += 1
64
64
  crit = conv_criterion(s_lam, r_try)
@@ -109,7 +109,7 @@ class CubicRegularization(TrustRegionBase):
109
109
 
110
110
  .. code-block:: python
111
111
 
112
- opt = tz.Modular(
112
+ opt = tz.Optimizer(
113
113
  model.parameters(),
114
114
  tz.m.CubicRegularization(tz.m.Newton()),
115
115
  )
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
 
4
4
  from ...core import Chainable, Module
5
- from ...utils.linalg import linear_operator
5
+ from ...linalg import linear_operator
6
6
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
7
7
 
8
8
 
@@ -32,38 +32,31 @@ class LevenbergMarquardt(TrustRegionBase):
32
32
  max_attempts (max_attempts, optional):
33
33
  maximum number of trust region size size reductions per step. A zero update vector is returned when
34
34
  this limit is exceeded. Defaults to 10.
35
+ adaptive (bool, optional):
36
+ if True, trust radius is multiplied by square root of gradient norm.
35
37
  fallback (bool, optional):
36
38
  if ``True``, when ``hess_module`` maintains hessian inverse which can't be inverted efficiently, it will
37
39
  be inverted anyway. When ``False`` (default), a ``RuntimeError`` will be raised instead.
38
40
  inner (Chainable | None, optional): preconditioning is applied to output of thise module. Defaults to None.
39
41
 
40
- Examples:
41
- Gauss-Newton with Levenberg-Marquardt trust-region
42
+ ### Examples:
42
43
 
43
- .. code-block:: python
44
+ Gauss-Newton with Levenberg-Marquardt trust-region
44
45
 
45
- opt = tz.Modular(
46
- model.parameters(),
47
- tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
48
- )
46
+ ```python
47
+ opt = tz.Optimizer(
48
+ model.parameters(),
49
+ tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
50
+ )
51
+ ```
49
52
 
50
- LM-SR1
51
-
52
- .. code-block:: python
53
-
54
- opt = tz.Modular(
55
- model.parameters(),
56
- tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
57
- )
58
-
59
- First order trust region (hessian is assumed to be identity)
60
-
61
- .. code-block:: python
62
-
63
- opt = tz.Modular(
64
- model.parameters(),
65
- tz.m.LevenbergMarquardt(tz.m.Identity()),
66
- )
53
+ LM-SR1
54
+ ```python
55
+ opt = tz.Optimizer(
56
+ model.parameters(),
57
+ tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
58
+ )
59
+ ```
67
60
 
68
61
  """
69
62
  def __init__(
@@ -78,11 +71,12 @@ class LevenbergMarquardt(TrustRegionBase):
78
71
  max_attempts: int = 10,
79
72
  radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
80
73
  y: float = 0,
74
+ adaptive: bool = False,
81
75
  fallback: bool = False,
82
76
  update_freq: int = 1,
83
77
  inner: Chainable | None = None,
84
78
  ):
85
- defaults = dict(y=y, fallback=fallback)
79
+ defaults = dict(y=y, fallback=fallback, adaptive=adaptive)
86
80
  super().__init__(
87
81
  defaults=defaults,
88
82
  hess_module=hess_module,
@@ -103,6 +97,7 @@ class LevenbergMarquardt(TrustRegionBase):
103
97
 
104
98
  def trust_solve(self, f, g, H, radius, params, closure, settings):
105
99
  y = settings['y']
100
+ adaptive = settings["adaptive"]
106
101
 
107
102
  if isinstance(H, linear_operator.DenseInverse):
108
103
  if settings['fallback']:
@@ -117,12 +112,14 @@ class LevenbergMarquardt(TrustRegionBase):
117
112
  )
118
113
 
119
114
  reg = 1/radius
115
+ if adaptive: reg = reg * torch.linalg.vector_norm(g).sqrt()
116
+
120
117
  if y == 0:
121
- return H.add_diagonal(reg).solve(g)
118
+ return H.solve_plus_diag(g, reg) # pyright:ignore[reportAttributeAccessIssue]
122
119
 
123
120
  diag = H.diagonal()
124
121
  diag = torch.where(diag < torch.finfo(diag.dtype).tiny * 2, 1, diag)
125
122
  if y != 1: diag = (diag*y) + (1-y)
126
- return H.add_diagonal(diag*reg).solve(g)
123
+ return H.solve_plus_diag(g, diag*reg)
127
124
 
128
125
 
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Chainable, Module
4
- from ...utils.linalg import cg, linear_operator
4
+ from ...linalg import cg, linear_operator
5
5
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
6
6
 
7
7
 
@@ -47,7 +47,7 @@ class TrustCG(TrustRegionBase):
47
47
 
48
48
  .. code-block:: python
49
49
 
50
- opt = tz.Modular(
50
+ opt = tz.Optimizer(
51
51
  model.parameters(),
52
52
  tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
53
53
  )
@@ -7,9 +7,16 @@ from typing import Any, Literal, Protocol, cast, final, overload
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Var, apply_transform
11
- from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
12
- from ...utils.linalg.linear_operator import LinearOperator
10
+ from ...core import Chainable, Module, Objective
11
+ from ...linalg.linear_operator import LinearOperator
12
+ from ...utils import (
13
+ TensorList,
14
+ generic_finfo,
15
+ generic_vector_norm,
16
+ safe_dict_update_,
17
+ tofloat,
18
+ vec_to_tensors,
19
+ )
13
20
 
14
21
 
15
22
  def _flatten_tensors(tensors: list[torch.Tensor]):
@@ -256,24 +263,24 @@ class TrustRegionBase(Module, ABC):
256
263
  """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
257
264
  ... # pylint:disable=unnecessary-ellipsis
258
265
 
259
- def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
266
+ def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
260
267
  """updates the state of this module after H or B have been updated, if necessary"""
261
268
 
262
- def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
263
- """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
269
+ def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
270
+ """Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
264
271
  assert H is not None
265
272
 
266
- params = TensorList(var.params)
273
+ params = TensorList(objective.params)
267
274
  settings = self.settings[params[0]]
268
275
  g = _flatten_tensors(tensors)
269
276
 
270
277
  max_attempts = settings['max_attempts']
271
278
 
272
279
  # loss at x_0
273
- loss = var.loss
274
- closure = var.closure
280
+ loss = objective.loss
281
+ closure = objective.closure
275
282
  if closure is None: raise RuntimeError("Trust region requires closure")
276
- if loss is None: loss = var.get_loss(False)
283
+ if loss is None: loss = objective.get_loss(False)
277
284
  loss = tofloat(loss)
278
285
 
279
286
  # trust region step and update
@@ -313,38 +320,36 @@ class TrustRegionBase(Module, ABC):
313
320
  )
314
321
 
315
322
  assert d is not None
316
- if success: var.update = vec_to_tensors(d, params)
317
- else: var.update = params.zeros_like()
323
+ if success: objective.updates = vec_to_tensors(d, params)
324
+ else: objective.updates = params.zeros_like()
318
325
 
319
- return var
326
+ return objective
320
327
 
321
328
 
322
329
  @final
323
330
  @torch.no_grad
324
- def update(self, var):
331
+ def update(self, objective):
325
332
  step = self.global_state.get('step', 0)
326
333
  self.global_state['step'] = step + 1
327
334
 
328
335
  if step % self.defaults["update_freq"] == 0:
329
336
 
330
337
  hessian_module = self.children['hess_module']
331
- hessian_module.update(var)
332
- H = hessian_module.get_H(var)
338
+ hessian_module.update(objective)
339
+ H = hessian_module.get_H(objective)
333
340
  self.global_state["H"] = H
334
341
 
335
- self.trust_region_update(var, H=H)
342
+ self.trust_region_update(objective, H=H)
336
343
 
337
344
 
338
345
  @final
339
346
  @torch.no_grad
340
- def apply(self, var):
347
+ def apply(self, objective):
341
348
  H = self.global_state.get('H', None)
342
349
 
343
350
  # -------------------------------- inner step -------------------------------- #
344
- update = var.get_update()
345
- if 'inner' in self.children:
346
- update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
351
+ objective = self.inner_step("inner", objective, must_exist=False)
347
352
 
348
353
  # ----------------------------------- apply ---------------------------------- #
349
- return self.trust_region_apply(var=var, tensors=update, H=H)
354
+ return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)
350
355
 
@@ -3,15 +3,16 @@ from functools import partial
3
3
 
4
4
  import torch
5
5
 
6
- from ...core.module import Module
6
+ from ...core import Module, Objective
7
7
  from ...utils import tofloat
8
8
 
9
9
 
10
- def _reset_except_self(optimizer, var, self: Module):
11
- for m in optimizer.unrolled_modules:
10
+ def _reset_except_self(objective: Objective, modules, self: Module):
11
+ for m in modules:
12
12
  if m is not self:
13
13
  m.reset()
14
14
 
15
+
15
16
  class SVRG(Module):
16
17
  """Stochastic variance reduced gradient method (SVRG).
17
18
 
@@ -43,7 +44,7 @@ class SVRG(Module):
43
44
  ## Examples:
44
45
  SVRG-LBFGS
45
46
  ```python
46
- opt = tz.Modular(
47
+ opt = tz.Optimizer(
47
48
  model.parameters(),
48
49
  tz.m.SVRG(len(dataloader)),
49
50
  tz.m.LBFGS(),
@@ -53,7 +54,7 @@ class SVRG(Module):
53
54
 
54
55
  For extra variance reduction one can use Online versions of algorithms, although it won't always help.
55
56
  ```python
56
- opt = tz.Modular(
57
+ opt = tz.Optimizer(
57
58
  model.parameters(),
58
59
  tz.m.SVRG(len(dataloader)),
59
60
  tz.m.Online(tz.m.LBFGS()),
@@ -62,7 +63,7 @@ class SVRG(Module):
62
63
 
63
64
  Variance reduction can also be applied to gradient estimators.
64
65
  ```python
65
- opt = tz.Modular(
66
+ opt = tz.Optimizer(
66
67
  model.parameters(),
67
68
  tz.m.SPSA(),
68
69
  tz.m.SVRG(100),
@@ -71,7 +72,7 @@ class SVRG(Module):
71
72
  ```
72
73
  ## Notes
73
74
 
74
- The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
75
+ The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
75
76
  - ``x`` is current parameters
76
77
  - ``x_0`` is initial parameters, where full gradient was computed
77
78
  - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
@@ -83,17 +84,18 @@ class SVRG(Module):
83
84
  defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
84
85
  super().__init__(defaults)
85
86
 
87
+
86
88
  @torch.no_grad
87
- def step(self, var):
88
- params = var.params
89
- closure = var.closure
89
+ def update(self, objective):
90
+ params = objective.params
91
+ closure = objective.closure
90
92
  assert closure is not None
91
93
 
92
94
  if "full_grad" not in self.global_state:
93
95
 
94
96
  # -------------------------- calculate full gradient ------------------------- #
95
- if "full_closure" in var.storage:
96
- full_closure = var.storage['full_closure']
97
+ if "full_closure" in objective.storage:
98
+ full_closure = objective.storage['full_closure']
97
99
  with torch.enable_grad():
98
100
  full_loss = full_closure()
99
101
  if all(p.grad is None for p in params):
@@ -116,12 +118,12 @@ class SVRG(Module):
116
118
 
117
119
  # accumulate grads
118
120
  accumulator = self.get_state(params, 'accumulator')
119
- grad = var.get_grad()
121
+ grad = objective.get_grads()
120
122
  torch._foreach_add_(accumulator, grad)
121
123
 
122
124
  # accumulate loss
123
125
  loss_accumulator = self.global_state.get('loss_accumulator', 0)
124
- loss_accumulator += tofloat(var.loss)
126
+ loss_accumulator += tofloat(objective.loss)
125
127
  self.global_state['loss_accumulator'] = loss_accumulator
126
128
 
127
129
  # on nth step, use the accumulated gradient
@@ -136,10 +138,10 @@ class SVRG(Module):
136
138
 
137
139
  # otherwise skip update until enough grads are accumulated
138
140
  else:
139
- var.update = None
140
- var.stop = True
141
- var.skip_update = True
142
- return var
141
+ objective.updates = None
142
+ objective.stop = True
143
+ objective.skip_update = True
144
+ return
143
145
 
144
146
 
145
147
  svrg_steps = self.defaults['svrg_steps']
@@ -194,7 +196,7 @@ class SVRG(Module):
194
196
 
195
197
  return closure(False)
196
198
 
197
- var.closure = svrg_closure
199
+ objective.closure = svrg_closure
198
200
 
199
201
  # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
200
202
  if current_svrg_step >= svrg_steps:
@@ -203,6 +205,6 @@ class SVRG(Module):
203
205
  del self.global_state['full_loss']
204
206
  del self.global_state['x_0']
205
207
  if self.defaults['reset_before_accum']:
206
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
208
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
207
209
 
208
- return var
210
+ def apply(self, objective): return objective
@@ -1 +1,2 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
2
+ from .reinit import RandomReinitialize