torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -2,78 +2,78 @@ from collections import deque
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform
5
+ from ...core import Module, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
- from ...utils.linalg.linear_operator import ScaledIdentity
7
+ from ...linalg.linear_operator import ScaledIdentity
8
8
 
9
9
  class Clone(Module):
10
10
  """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
11
11
  def __init__(self):
12
12
  super().__init__({})
13
13
  @torch.no_grad
14
- def step(self, var):
15
- var.update = [u.clone() for u in var.get_update()]
16
- return var
14
+ def apply(self, objective):
15
+ objective.updates = [u.clone() for u in objective.get_updates()]
16
+ return objective
17
17
 
18
18
  class Grad(Module):
19
19
  """Outputs the gradient"""
20
20
  def __init__(self):
21
21
  super().__init__({})
22
22
  @torch.no_grad
23
- def step(self, var):
24
- var.update = [g.clone() for g in var.get_grad()]
25
- return var
23
+ def apply(self, objective):
24
+ objective.updates = [g.clone() for g in objective.get_grads()]
25
+ return objective
26
26
 
27
27
  class Params(Module):
28
28
  """Outputs parameters"""
29
29
  def __init__(self):
30
30
  super().__init__({})
31
31
  @torch.no_grad
32
- def step(self, var):
33
- var.update = [p.clone() for p in var.params]
34
- return var
32
+ def apply(self, objective):
33
+ objective.updates = [p.clone() for p in objective.params]
34
+ return objective
35
35
 
36
36
  class Zeros(Module):
37
37
  """Outputs zeros"""
38
38
  def __init__(self):
39
39
  super().__init__({})
40
40
  @torch.no_grad
41
- def step(self, var):
42
- var.update = [torch.zeros_like(p) for p in var.params]
43
- return var
41
+ def apply(self, objective):
42
+ objective.updates = [torch.zeros_like(p) for p in objective.params]
43
+ return objective
44
44
 
45
45
  class Ones(Module):
46
46
  """Outputs ones"""
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
50
- def step(self, var):
51
- var.update = [torch.ones_like(p) for p in var.params]
52
- return var
50
+ def apply(self, objective):
51
+ objective.updates = [torch.ones_like(p) for p in objective.params]
52
+ return objective
53
53
 
54
54
  class Fill(Module):
55
- """Outputs tensors filled with :code:`value`"""
55
+ """Outputs tensors filled with ``value``"""
56
56
  def __init__(self, value: float):
57
57
  defaults = dict(value=value)
58
58
  super().__init__(defaults)
59
59
 
60
60
  @torch.no_grad
61
- def step(self, var):
62
- var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
63
- return var
61
+ def apply(self, objective):
62
+ objective.updates = [torch.full_like(p, self.settings[p]['value']) for p in objective.params]
63
+ return objective
64
64
 
65
65
  class RandomSample(Module):
66
- """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
66
+ """Outputs tensors filled with random numbers from distribution depending on value of ``distribution``."""
67
67
  def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
68
68
  defaults = dict(distribution=distribution, variance=variance)
69
69
  super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
- def step(self, var):
72
+ def apply(self, objective):
73
73
  distribution = self.defaults['distribution']
74
- variance = self.get_settings(var.params, 'variance')
75
- var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
76
- return var
74
+ variance = self.get_settings(objective.params, 'variance')
75
+ objective.updates = TensorList(objective.params).sample_like(distribution=distribution, variance=variance)
76
+ return objective
77
77
 
78
78
  class Randn(Module):
79
79
  """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
@@ -81,43 +81,44 @@ class Randn(Module):
81
81
  super().__init__({})
82
82
 
83
83
  @torch.no_grad
84
- def step(self, var):
85
- var.update = [torch.randn_like(p) for p in var.params]
86
- return var
84
+ def apply(self, objective):
85
+ objective.updates = [torch.randn_like(p) for p in objective.params]
86
+ return objective
87
87
 
88
88
  class Uniform(Module):
89
- """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
89
+ """Outputs tensors filled with random numbers from uniform distribution between ``low`` and ``high``."""
90
90
  def __init__(self, low: float, high: float):
91
91
  defaults = dict(low=low, high=high)
92
92
  super().__init__(defaults)
93
93
 
94
94
  @torch.no_grad
95
- def step(self, var):
96
- low,high = self.get_settings(var.params, 'low','high')
97
- var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
98
- return var
95
+ def apply(self, objective):
96
+ low,high = self.get_settings(objective.params, 'low','high')
97
+ objective.updates = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(objective.params, low, high)]
98
+ return objective
99
99
 
100
100
  class GradToNone(Module):
101
- """Sets :code:`grad` attribute to None on :code:`var`."""
101
+ """Sets ``grad`` attribute to None on ``objective``."""
102
102
  def __init__(self): super().__init__()
103
- def step(self, var):
104
- var.grad = None
105
- return var
103
+ def apply(self, objective):
104
+ objective.grads = None
105
+ return objective
106
106
 
107
107
  class UpdateToNone(Module):
108
- """Sets :code:`update` attribute to None on :code:`var`."""
108
+ """Sets ``update`` attribute to None on ``var``."""
109
109
  def __init__(self): super().__init__()
110
- def step(self, var):
111
- var.update = None
112
- return var
110
+ def apply(self, objective):
111
+ objective.updates = None
112
+ return objective
113
113
 
114
114
  class Identity(Module):
115
115
  """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
116
116
  def __init__(self, *args, **kwargs): super().__init__()
117
- def step(self, var): return var
118
- def get_H(self, var):
119
- n = sum(p.numel() for p in var.params)
120
- p = var.params[0]
117
+ def update(self, objective): pass
118
+ def apply(self, objective): return objective
119
+ def get_H(self, objective):
120
+ n = sum(p.numel() for p in objective.params)
121
+ p = objective.params[0]
121
122
  return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
122
123
 
123
124
  Noop = Identity
@@ -30,7 +30,7 @@ def debiased_step_size(
30
30
  pow: float = 2,
31
31
  alpha: float | NumberList = 1,
32
32
  ):
33
- """returns multiplier to step size"""
33
+ """returns multiplier to step size, step starts from 1"""
34
34
  if isinstance(beta1, NumberList): beta1 = beta1.fill_none(0)
35
35
  if isinstance(beta2, NumberList): beta2 = beta2.fill_none(0)
36
36
 
@@ -6,7 +6,7 @@ from typing import Any, Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Var
9
+ from ...core import Chainable, Module, Objective
10
10
  from .projection import ProjectionBase
11
11
 
12
12
  # TODO
@@ -8,7 +8,7 @@ from typing import Any, Literal
8
8
 
9
9
  import torch
10
10
 
11
- from ...core import Chainable, Module, Var
11
+ from ...core import Chainable, Module, Objective
12
12
  from ...utils import set_storage_, vec_to_tensors
13
13
 
14
14
 
@@ -80,7 +80,7 @@ class _FakeProjectedClosure:
80
80
  class ProjectionBase(Module, ABC):
81
81
  """
82
82
  Base class for projections.
83
- This is an abstract class, to use it, subclass it and override `project` and `unproject`.
83
+ This is an abstract class, to use it, subclass it and override ``project`` and ``unproject``.
84
84
 
85
85
  Args:
86
86
  modules (Chainable): modules that will be applied in the projected domain.
@@ -149,9 +149,12 @@ class ProjectionBase(Module, ABC):
149
149
  Iterable[torch.Tensor]: unprojected tensors of the same shape as params
150
150
  """
151
151
 
152
+ def update(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
153
+ def apply(self, objective: Objective): raise RuntimeError("projections don't support update/apply")
154
+
152
155
  @torch.no_grad
153
- def step(self, var: Var):
154
- params = var.params
156
+ def step(self, objective: Objective):
157
+ params = objective.params
155
158
  settings = [self.settings[p] for p in params]
156
159
 
157
160
  def _project(tensors: list[torch.Tensor], current: Literal['params', 'grads', 'update']):
@@ -159,16 +162,16 @@ class ProjectionBase(Module, ABC):
159
162
  return list(self.project(
160
163
  tensors=tensors,
161
164
  params=params,
162
- grads=var.grad,
163
- loss=var.loss,
165
+ grads=objective.grads,
166
+ loss=objective.loss,
164
167
  states=states,
165
168
  settings=settings,
166
169
  current=current,
167
170
  ))
168
171
 
169
- projected_var = var.clone(clone_update=False, parent=var)
172
+ projected_obj = objective.clone(clone_updates=False, parent=objective)
170
173
 
171
- closure = var.closure
174
+ closure = objective.closure
172
175
 
173
176
  # if this is True, update and grad were projected simultaneously under current="grads"
174
177
  # so update will have to be unprojected with current="grads"
@@ -179,9 +182,9 @@ class ProjectionBase(Module, ABC):
179
182
  # but if it has already been computed, it should be projected
180
183
  if self._project_params and closure is not None:
181
184
 
182
- if self._project_update and var.update is not None:
185
+ if self._project_update and objective.updates is not None:
183
186
  # project update only if it already exists
184
- projected_var.update = _project(var.update, current='update')
187
+ projected_obj.updates = _project(objective.updates, current='update')
185
188
 
186
189
  else:
187
190
  # update will be set to gradients on var.get_grad()
@@ -189,43 +192,43 @@ class ProjectionBase(Module, ABC):
189
192
  update_is_grad = True
190
193
 
191
194
  # project grad only if it already exists
192
- if self._project_grad and var.grad is not None:
193
- projected_var.grad = _project(var.grad, current='grads')
195
+ if self._project_grad and objective.grads is not None:
196
+ projected_obj.grads = _project(objective.grads, current='grads')
194
197
 
195
198
  # otherwise update/grad needs to be calculated and projected here
196
199
  else:
197
200
  if self._project_update:
198
- if var.update is None:
201
+ if objective.updates is None:
199
202
  # update is None, meaning it will be set to `grad`.
200
203
  # we can project grad and use it for update
201
- grad = var.get_grad()
202
- projected_var.grad = _project(grad, current='grads')
203
- projected_var.update = [g.clone() for g in projected_var.grad]
204
- del var.update
204
+ grad = objective.get_grads()
205
+ projected_obj.grads = _project(grad, current='grads')
206
+ projected_obj.updates = [g.clone() for g in projected_obj.grads]
207
+ del objective.updates
205
208
  update_is_grad = True
206
209
 
207
210
  else:
208
211
  # update exists so it needs to be projected
209
- update = var.get_update()
210
- projected_var.update = _project(update, current='update')
211
- del update, var.update
212
+ update = objective.get_updates()
213
+ projected_obj.updates = _project(update, current='update')
214
+ del update, objective.updates
212
215
 
213
- if self._project_grad and projected_var.grad is None:
216
+ if self._project_grad and projected_obj.grads is None:
214
217
  # projected_vars.grad may have been projected simultaneously with update
215
218
  # but if that didn't happen, it is projected here
216
- grad = var.get_grad()
217
- projected_var.grad = _project(grad, current='grads')
219
+ grad = objective.get_grads()
220
+ projected_obj.grads = _project(grad, current='grads')
218
221
 
219
222
 
220
223
  original_params = None
221
224
  if self._project_params:
222
- original_params = [p.clone() for p in var.params]
223
- projected_params = _project(var.params, current='params')
225
+ original_params = [p.clone() for p in objective.params]
226
+ projected_params = _project(objective.params, current='params')
224
227
 
225
228
  else:
226
229
  # make fake params for correct shapes and state storage
227
230
  # they reuse update or grad storage for memory efficiency
228
- projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
231
+ projected_params = projected_obj.updates if projected_obj.updates is not None else projected_obj.grads
229
232
  assert projected_params is not None
230
233
 
231
234
  if self._projected_params is None:
@@ -245,8 +248,8 @@ class ProjectionBase(Module, ABC):
245
248
  return list(self.unproject(
246
249
  projected_tensors=projected_tensors,
247
250
  params=params,
248
- grads=var.grad,
249
- loss=var.loss,
251
+ grads=objective.grads,
252
+ loss=objective.loss,
250
253
  states=states,
251
254
  settings=settings,
252
255
  current=current,
@@ -254,19 +257,19 @@ class ProjectionBase(Module, ABC):
254
257
 
255
258
  # project closure
256
259
  if self._project_params:
257
- projected_var.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
260
+ projected_obj.closure = _make_projected_closure(closure, project_fn=_project, unproject_fn=_unproject,
258
261
  params=params, projected_params=projected_params)
259
262
 
260
263
  elif closure is not None:
261
- projected_var.closure = _FakeProjectedClosure(closure, project_fn=_project,
264
+ projected_obj.closure = _FakeProjectedClosure(closure, project_fn=_project,
262
265
  params=params, fake_params=projected_params)
263
266
 
264
267
  else:
265
- projected_var.closure = None
268
+ projected_obj.closure = None
266
269
 
267
270
  # ----------------------------------- step ----------------------------------- #
268
- projected_var.params = projected_params
269
- projected_var = self.children['modules'].step(projected_var)
271
+ projected_obj.params = projected_params
272
+ projected_obj = self.children['modules'].step(projected_obj)
270
273
 
271
274
  # empty fake params storage
272
275
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
@@ -275,24 +278,24 @@ class ProjectionBase(Module, ABC):
275
278
  set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
276
279
 
277
280
  # --------------------------------- unproject -------------------------------- #
278
- unprojected_var = projected_var.clone(clone_update=False)
279
- unprojected_var.closure = var.closure
280
- unprojected_var.params = var.params
281
- unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent
281
+ unprojected_obj = projected_obj.clone(clone_updates=False)
282
+ unprojected_obj.closure = objective.closure
283
+ unprojected_obj.params = objective.params
284
+ unprojected_obj.grads = objective.grads # this may also be set by projected_var since it has var as parent
282
285
 
283
286
  if self._project_update:
284
- assert projected_var.update is not None
285
- unprojected_var.update = _unproject(projected_var.update, current='grads' if update_is_grad else 'update')
286
- del projected_var.update
287
+ assert projected_obj.updates is not None
288
+ unprojected_obj.updates = _unproject(projected_obj.updates, current='grads' if update_is_grad else 'update')
289
+ del projected_obj.updates
287
290
 
288
- del projected_var
291
+ del projected_obj
289
292
 
290
293
  # original params are stored if params are projected
291
294
  if original_params is not None:
292
- for p, o in zip(unprojected_var.params, original_params):
295
+ for p, o in zip(unprojected_obj.params, original_params):
293
296
  p.set_(o) # pyright: ignore[reportArgumentType]
294
297
 
295
- return unprojected_var
298
+ return unprojected_obj
296
299
 
297
300
 
298
301
 
@@ -30,4 +30,4 @@ from .quasi_newton import (
30
30
  ThomasOptimalMethod,
31
31
  )
32
32
 
33
- from .sg2 import SG2, SPSA2
33
+ from .sg2 import SG2
@@ -4,8 +4,8 @@ from typing import Literal, Protocol, overload
4
4
  import torch
5
5
 
6
6
  from ...utils import TensorList
7
- from ...utils.linalg.linear_operator import DenseInverse, LinearOperator
8
- from ..functional import safe_clip
7
+ from ...linalg.linear_operator import DenseInverse, LinearOperator
8
+ from ..opt_utils import safe_clip
9
9
 
10
10
 
11
11
  class DampingStrategy(Protocol):
@@ -9,7 +9,7 @@ from .quasi_newton import (
9
9
  _InverseHessianUpdateStrategyDefaults,
10
10
  )
11
11
 
12
- from ..functional import safe_clip
12
+ from ..opt_utils import safe_clip
13
13
 
14
14
 
15
15
  def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
@@ -4,10 +4,10 @@ from typing import overload
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Transform
7
+ from ...core import Chainable, TensorTransform
8
8
  from ...utils import TensorList, as_tensorlist, unpack_states
9
- from ...utils.linalg.linear_operator import LinearOperator
10
- from ..functional import initial_step_size
9
+ from ...linalg.linear_operator import LinearOperator
10
+ from ..opt_utils import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
13
13
 
@@ -154,7 +154,7 @@ class LBFGSLinearOperator(LinearOperator):
154
154
  return (n, n)
155
155
 
156
156
 
157
- class LBFGS(Transform):
157
+ class LBFGS(TensorTransform):
158
158
  """Limited-memory BFGS algorithm. A line search or trust region is recommended.
159
159
 
160
160
  Args:
@@ -188,7 +188,7 @@ class LBFGS(Transform):
188
188
 
189
189
  L-BFGS with line search
190
190
  ```python
191
- opt = tz.Modular(
191
+ opt = tz.Optimizer(
192
192
  model.parameters(),
193
193
  tz.m.LBFGS(100),
194
194
  tz.m.Backtracking()
@@ -197,7 +197,7 @@ class LBFGS(Transform):
197
197
 
198
198
  L-BFGS with trust region
199
199
  ```python
200
- opt = tz.Modular(
200
+ opt = tz.Optimizer(
201
201
  model.parameters(),
202
202
  tz.m.TrustCG(tz.m.LBFGS())
203
203
  )
@@ -226,7 +226,7 @@ class LBFGS(Transform):
226
226
  sy_tol=sy_tol,
227
227
  damping = damping,
228
228
  )
229
- super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
229
+ super().__init__(defaults, inner=inner, update_freq=update_freq)
230
230
 
231
231
  self.global_state['s_history'] = deque(maxlen=history_size)
232
232
  self.global_state['y_history'] = deque(maxlen=history_size)
@@ -249,7 +249,7 @@ class LBFGS(Transform):
249
249
  self.global_state.pop('step', None)
250
250
 
251
251
  @torch.no_grad
252
- def update_tensors(self, tensors, params, grads, loss, states, settings):
252
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
253
253
  p = as_tensorlist(params)
254
254
  g = as_tensorlist(tensors)
255
255
  step = self.global_state.get('step', 0)
@@ -311,14 +311,14 @@ class LBFGS(Transform):
311
311
  y_history.append(y)
312
312
  sy_history.append(sy)
313
313
 
314
- def get_H(self, var=...):
314
+ def get_H(self, objective=...):
315
315
  s_history = [tl.to_vec() for tl in self.global_state['s_history']]
316
316
  y_history = [tl.to_vec() for tl in self.global_state['y_history']]
317
317
  sy_history = self.global_state['sy_history']
318
318
  return LBFGSLinearOperator(s_history, y_history, sy_history)
319
319
 
320
320
  @torch.no_grad
321
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
321
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
322
322
  scale_first = self.defaults['scale_first']
323
323
 
324
324
  tensors = as_tensorlist(tensors)
@@ -4,10 +4,10 @@ from operator import itemgetter
4
4
 
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Module, Transform, Var, apply_transform
7
+ from ...core import Chainable, Module, TensorTransform, Objective, step
8
8
  from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
9
- from ...utils.linalg.linear_operator import LinearOperator
10
- from ..functional import initial_step_size
9
+ from ...linalg.linear_operator import LinearOperator
10
+ from ..opt_utils import initial_step_size
11
11
  from .damping import DampingStrategyType, apply_damping
12
12
 
13
13
 
@@ -76,7 +76,7 @@ class LSR1LinearOperator(LinearOperator):
76
76
  return (n, n)
77
77
 
78
78
 
79
- class LSR1(Transform):
79
+ class LSR1(TensorTransform):
80
80
  """Limited-memory SR1 algorithm. A line search or trust region is recommended.
81
81
 
82
82
  Args:
@@ -110,7 +110,7 @@ class LSR1(Transform):
110
110
 
111
111
  L-SR1 with line search
112
112
  ```python
113
- opt = tz.Modular(
113
+ opt = tz.Optimizer(
114
114
  model.parameters(),
115
115
  tz.m.SR1(),
116
116
  tz.m.StrongWolfe(c2=0.1, fallback=True)
@@ -119,7 +119,7 @@ class LSR1(Transform):
119
119
 
120
120
  L-SR1 with trust region
121
121
  ```python
122
- opt = tz.Modular(
122
+ opt = tz.Optimizer(
123
123
  model.parameters(),
124
124
  tz.m.TrustCG(tz.m.LSR1())
125
125
  )
@@ -146,7 +146,7 @@ class LSR1(Transform):
146
146
  gtol_restart=gtol_restart,
147
147
  damping = damping,
148
148
  )
149
- super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
149
+ super().__init__(defaults, inner=inner, update_freq=update_freq)
150
150
 
151
151
  self.global_state['s_history'] = deque(maxlen=history_size)
152
152
  self.global_state['y_history'] = deque(maxlen=history_size)
@@ -167,7 +167,7 @@ class LSR1(Transform):
167
167
  self.global_state.pop('step', None)
168
168
 
169
169
  @torch.no_grad
170
- def update_tensors(self, tensors, params, grads, loss, states, settings):
170
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
171
171
  p = as_tensorlist(params)
172
172
  g = as_tensorlist(tensors)
173
173
  step = self.global_state.get('step', 0)
@@ -225,13 +225,13 @@ class LSR1(Transform):
225
225
  s_history.append(s)
226
226
  y_history.append(y)
227
227
 
228
- def get_H(self, var=...):
228
+ def get_H(self, objective=...):
229
229
  s_history = [tl.to_vec() for tl in self.global_state['s_history']]
230
230
  y_history = [tl.to_vec() for tl in self.global_state['y_history']]
231
231
  return LSR1LinearOperator(s_history, y_history)
232
232
 
233
233
  @torch.no_grad
234
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
234
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
235
235
  scale_first = self.defaults['scale_first']
236
236
 
237
237
  tensors = as_tensorlist(tensors)