torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Chainable, Module
4
- from ...utils.linalg import cg, linear_operator
4
+ from ...linalg import cg, linear_operator
5
5
  from .trust_region import _RADIUS_KEYS, TrustRegionBase, _RadiusStrategy
6
6
 
7
7
 
@@ -7,9 +7,16 @@ from typing import Any, Literal, Protocol, cast, final, overload
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Var, apply_transform
11
- from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
12
- from ...utils.linalg.linear_operator import LinearOperator
10
+ from ...core import Chainable, Module, Objective
11
+ from ...linalg.linear_operator import LinearOperator
12
+ from ...utils import (
13
+ TensorList,
14
+ generic_finfo,
15
+ generic_vector_norm,
16
+ safe_dict_update_,
17
+ tofloat,
18
+ vec_to_tensors,
19
+ )
13
20
 
14
21
 
15
22
  def _flatten_tensors(tensors: list[torch.Tensor]):
@@ -256,24 +263,24 @@ class TrustRegionBase(Module, ABC):
256
263
  """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
257
264
  ... # pylint:disable=unnecessary-ellipsis
258
265
 
259
- def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
266
+ def trust_region_update(self, objective: Objective, H: LinearOperator | None) -> None:
260
267
  """updates the state of this module after H or B have been updated, if necessary"""
261
268
 
262
- def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
263
- """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
269
+ def trust_region_apply(self, objective: Objective, tensors:list[torch.Tensor], H: LinearOperator | None) -> Objective:
270
+ """Solves the trust region subproblem and outputs ``Objective`` with the solution direction."""
264
271
  assert H is not None
265
272
 
266
- params = TensorList(var.params)
273
+ params = TensorList(objective.params)
267
274
  settings = self.settings[params[0]]
268
275
  g = _flatten_tensors(tensors)
269
276
 
270
277
  max_attempts = settings['max_attempts']
271
278
 
272
279
  # loss at x_0
273
- loss = var.loss
274
- closure = var.closure
280
+ loss = objective.loss
281
+ closure = objective.closure
275
282
  if closure is None: raise RuntimeError("Trust region requires closure")
276
- if loss is None: loss = var.get_loss(False)
283
+ if loss is None: loss = objective.get_loss(False)
277
284
  loss = tofloat(loss)
278
285
 
279
286
  # trust region step and update
@@ -313,38 +320,36 @@ class TrustRegionBase(Module, ABC):
313
320
  )
314
321
 
315
322
  assert d is not None
316
- if success: var.update = vec_to_tensors(d, params)
317
- else: var.update = params.zeros_like()
323
+ if success: objective.updates = vec_to_tensors(d, params)
324
+ else: objective.updates = params.zeros_like()
318
325
 
319
- return var
326
+ return objective
320
327
 
321
328
 
322
329
  @final
323
330
  @torch.no_grad
324
- def update(self, var):
331
+ def update(self, objective):
325
332
  step = self.global_state.get('step', 0)
326
333
  self.global_state['step'] = step + 1
327
334
 
328
335
  if step % self.defaults["update_freq"] == 0:
329
336
 
330
337
  hessian_module = self.children['hess_module']
331
- hessian_module.update(var)
332
- H = hessian_module.get_H(var)
338
+ hessian_module.update(objective)
339
+ H = hessian_module.get_H(objective)
333
340
  self.global_state["H"] = H
334
341
 
335
- self.trust_region_update(var, H=H)
342
+ self.trust_region_update(objective, H=H)
336
343
 
337
344
 
338
345
  @final
339
346
  @torch.no_grad
340
- def apply(self, var):
347
+ def apply(self, objective):
341
348
  H = self.global_state.get('H', None)
342
349
 
343
350
  # -------------------------------- inner step -------------------------------- #
344
- update = var.get_update()
345
- if 'inner' in self.children:
346
- update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
351
+ objective = self.inner_step("inner", objective, must_exist=False)
347
352
 
348
353
  # ----------------------------------- apply ---------------------------------- #
349
- return self.trust_region_apply(var=var, tensors=update, H=H)
354
+ return self.trust_region_apply(objective=objective, tensors=objective.get_updates(), H=H)
350
355
 
@@ -3,15 +3,17 @@ from functools import partial
3
3
 
4
4
  import torch
5
5
 
6
- from ...core.module import Module
6
+ from ...core import Module, Objective
7
7
  from ...utils import tofloat
8
8
 
9
9
 
10
- def _reset_except_self(optimizer, var, self: Module):
11
- for m in optimizer.unrolled_modules:
10
+ def _reset_except_self(objective: Objective, modules, self: Module):
11
+ assert objective.modular is not None
12
+ for m in objective.modular.flat_modules:
12
13
  if m is not self:
13
14
  m.reset()
14
15
 
16
+
15
17
  class SVRG(Module):
16
18
  """Stochastic variance reduced gradient method (SVRG).
17
19
 
@@ -71,7 +73,7 @@ class SVRG(Module):
71
73
  ```
72
74
  ## Notes
73
75
 
74
- The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
76
+ The SVRG gradient is computed as ``g_b(x) - alpha * (g_b(x_0) - g_f(x_0))``, where:
75
77
  - ``x`` is current parameters
76
78
  - ``x_0`` is initial parameters, where full gradient was computed
77
79
  - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
@@ -83,17 +85,18 @@ class SVRG(Module):
83
85
  defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
84
86
  super().__init__(defaults)
85
87
 
88
+
86
89
  @torch.no_grad
87
- def step(self, var):
88
- params = var.params
89
- closure = var.closure
90
+ def update(self, objective):
91
+ params = objective.params
92
+ closure = objective.closure
90
93
  assert closure is not None
91
94
 
92
95
  if "full_grad" not in self.global_state:
93
96
 
94
97
  # -------------------------- calculate full gradient ------------------------- #
95
- if "full_closure" in var.storage:
96
- full_closure = var.storage['full_closure']
98
+ if "full_closure" in objective.storage:
99
+ full_closure = objective.storage['full_closure']
97
100
  with torch.enable_grad():
98
101
  full_loss = full_closure()
99
102
  if all(p.grad is None for p in params):
@@ -116,12 +119,12 @@ class SVRG(Module):
116
119
 
117
120
  # accumulate grads
118
121
  accumulator = self.get_state(params, 'accumulator')
119
- grad = var.get_grad()
122
+ grad = objective.get_grads()
120
123
  torch._foreach_add_(accumulator, grad)
121
124
 
122
125
  # accumulate loss
123
126
  loss_accumulator = self.global_state.get('loss_accumulator', 0)
124
- loss_accumulator += tofloat(var.loss)
127
+ loss_accumulator += tofloat(objective.loss)
125
128
  self.global_state['loss_accumulator'] = loss_accumulator
126
129
 
127
130
  # on nth step, use the accumulated gradient
@@ -136,10 +139,10 @@ class SVRG(Module):
136
139
 
137
140
  # otherwise skip update until enough grads are accumulated
138
141
  else:
139
- var.update = None
140
- var.stop = True
141
- var.skip_update = True
142
- return var
142
+ objective.updates = None
143
+ objective.stop = True
144
+ objective.skip_update = True
145
+ return
143
146
 
144
147
 
145
148
  svrg_steps = self.defaults['svrg_steps']
@@ -194,7 +197,7 @@ class SVRG(Module):
194
197
 
195
198
  return closure(False)
196
199
 
197
- var.closure = svrg_closure
200
+ objective.closure = svrg_closure
198
201
 
199
202
  # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
200
203
  if current_svrg_step >= svrg_steps:
@@ -203,6 +206,6 @@ class SVRG(Module):
203
206
  del self.global_state['full_loss']
204
207
  del self.global_state['x_0']
205
208
  if self.defaults['reset_before_accum']:
206
- var.post_step_hooks.append(partial(_reset_except_self, self=self))
209
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
207
210
 
208
- return var
211
+ def apply(self, objective): return objective
@@ -1 +1,2 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
2
+ from .reinit import RandomReinitialize
@@ -0,0 +1,83 @@
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils import NumberList, TensorList
7
+
8
+
9
+ def _reset_except_self(optimizer, var, self: Module):
10
+ for m in optimizer.unrolled_modules:
11
+ if m is not self:
12
+ m.reset()
13
+
14
+ class RandomReinitialize(Module):
15
+ """On each step with probability ``p_reinit`` trigger reinitialization,
16
+ whereby ``p_weights`` weights are reset to their initial values.
17
+
18
+ This modifies the parameters directly. Place it as the first module.
19
+
20
+ Args:
21
+ p_reinit (float, optional): probability to trigger reinitialization on each step. Defaults to 0.01.
22
+ p_weights (float, optional): probability for each weight to be set to initial value when reinitialization is triggered. Defaults to 0.1.
23
+ store_every (int | None, optional): if set, stores new initial values every this many steps. Defaults to None.
24
+ beta (float, optional):
25
+ whenever ``store_every`` is triggered, uses linear interpolation with this beta.
26
+ If ``store_every=1``, this can be set to some value close to 1 such as 0.999
27
+ to reinitialize to slow parameter EMA. Defaults to 0.
28
+ reset (bool, optional): whether to reset states of other modules on reinitialization. Defaults to False.
29
+ seed (int | None, optional): random seed.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ p_reinit: float = 0.01,
35
+ p_weights: float = 0.1,
36
+ store_every: int | None = None,
37
+ beta: float = 0,
38
+ reset: bool = False,
39
+ seed: int | None = None,
40
+ ):
41
+ defaults = dict(p_weights=p_weights, p_reinit=p_reinit, store_every=store_every, beta=beta, reset=reset, seed=seed)
42
+ super().__init__(defaults)
43
+
44
+ def update(self, objective):
45
+ # this stores initial values to per-parameter states
46
+ p_init = self.get_state(objective.params, "p_init", init="params", cls=TensorList)
47
+
48
+ # store new params every store_every steps
49
+ step = self.global_state.get("step", 0)
50
+ self.global_state["step"] = step + 1
51
+
52
+ store_every = self.defaults["store_every"]
53
+ if (store_every is not None and step % store_every == 0):
54
+ beta = self.get_settings(objective.params, "beta", cls=NumberList)
55
+ p_init.lerp_(objective.params, weight=(1 - beta))
56
+
57
+ @torch.no_grad
58
+ def apply(self, objective):
59
+ p_reinit = self.defaults["p_reinit"]
60
+ device = objective.params[0].device
61
+ generator = self.get_generator(device, self.defaults["seed"])
62
+
63
+ # determine whether to trigger reinitialization
64
+ reinitialize = torch.rand(1, generator=generator, device=device) < p_reinit
65
+
66
+ # reinitialize
67
+ if reinitialize:
68
+ params = TensorList(objective.params)
69
+ p_init = self.get_state(params, "p_init", init=params)
70
+
71
+
72
+ # mask with p_weights entries being True
73
+ p_weights = self.get_settings(params, "p_weights")
74
+ mask = params.bernoulli_like(p_weights, generator=generator).as_bool()
75
+
76
+ # set weights at mask to their initialization
77
+ params.masked_set_(mask, p_init)
78
+
79
+ # reset
80
+ if self.defaults["reset"]:
81
+ objective.post_step_hooks.append(partial(_reset_except_self, self=self))
82
+
83
+ return objective
@@ -3,7 +3,7 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform
6
+ from ...core import Module, TensorTransform
7
7
  from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states, Metrics
8
8
 
9
9
 
@@ -21,7 +21,7 @@ def weight_decay_(
21
21
  return grad_.add_(params.pow(ord-1).copysign_(params).mul_(weight_decay))
22
22
 
23
23
 
24
- class WeightDecay(Transform):
24
+ class WeightDecay(TensorTransform):
25
25
  """Weight decay.
26
26
 
27
27
  Args:
@@ -63,19 +63,19 @@ class WeightDecay(Transform):
63
63
  ```
64
64
 
65
65
  """
66
- def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
66
+ def __init__(self, weight_decay: float, ord: int = 2):
67
67
 
68
68
  defaults = dict(weight_decay=weight_decay, ord=ord)
69
- super().__init__(defaults, uses_grad=False, target=target)
69
+ super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
72
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
73
73
  weight_decay = NumberList(s['weight_decay'] for s in settings)
74
74
  ord = settings[0]['ord']
75
75
 
76
76
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
77
77
 
78
- class RelativeWeightDecay(Transform):
78
+ class RelativeWeightDecay(TensorTransform):
79
79
  """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of ``norm_input`` argument.
80
80
 
81
81
  Args:
@@ -117,13 +117,12 @@ class RelativeWeightDecay(Transform):
117
117
  ord: int = 2,
118
118
  norm_input: Literal["update", "grad", "params"] = "update",
119
119
  metric: Metrics = 'mad',
120
- target: Target = "update",
121
120
  ):
122
121
  defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input, metric=metric)
123
- super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
122
+ super().__init__(defaults, uses_grad=norm_input == 'grad')
124
123
 
125
124
  @torch.no_grad
126
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
125
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
127
126
  weight_decay = NumberList(s['weight_decay'] for s in settings)
128
127
 
129
128
  ord = settings[0]['ord']
@@ -161,9 +160,9 @@ class DirectWeightDecay(Module):
161
160
  super().__init__(defaults)
162
161
 
163
162
  @torch.no_grad
164
- def step(self, var):
165
- weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
163
+ def apply(self, objective):
164
+ weight_decay = self.get_settings(objective.params, 'weight_decay', cls=NumberList)
166
165
  ord = self.defaults['ord']
167
166
 
168
- decay_weights_(var.params, weight_decay, ord)
169
- return var
167
+ decay_weights_(objective.params, weight_decay, ord)
168
+ return objective
@@ -3,41 +3,55 @@ from typing import Any
3
3
  import torch
4
4
 
5
5
  from ...core.module import Module
6
- from ...utils import Params, _copy_param_groups, _make_param_groups
6
+ from ...utils.params import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
10
10
  """
11
11
  Wraps a pytorch optimizer to use it as a module.
12
12
 
13
- .. note::
14
- Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
13
+ Note:
14
+ Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
15
15
 
16
16
  Args:
17
17
  opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
- function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
- or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
18
+ function that takes in parameters and returns the optimizer, for example ``torch.optim.Adam``
19
+ or ``lambda parameters: torch.optim.Adam(parameters, lr=1e-3)``
20
20
  *args:
21
21
  **kwargs:
22
- Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
22
+ Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
23
+ use_param_groups:
24
+ Whether to pass settings passed to Modular to the wrapped optimizer.
23
25
 
24
- Example:
25
- wrapping pytorch_optimizer.StableAdamW
26
+ Note that settings to the first parameter are used for all parameters,
27
+ so if you specified per-parameter settings, they will be ignored.
26
28
 
27
- .. code-block:: py
29
+ ### Example:
30
+ wrapping pytorch_optimizer.StableAdamW
28
31
 
29
- from pytorch_optimizer import StableAdamW
30
- opt = tz.Modular(
31
- model.parameters(),
32
- tz.m.Wrap(StableAdamW, lr=1),
33
- tz.m.Cautious(),
34
- tz.m.LR(1e-2)
35
- )
32
+ ```python
36
33
 
34
+ from pytorch_optimizer import StableAdamW
35
+ opt = tz.Modular(
36
+ model.parameters(),
37
+ tz.m.Wrap(StableAdamW, lr=1),
38
+ tz.m.Cautious(),
39
+ tz.m.LR(1e-2)
40
+ )
41
+ ```
37
42
 
38
43
  """
39
- def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
40
- super().__init__()
44
+
45
+ def __init__(
46
+ self,
47
+ opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer,
48
+ *args,
49
+ use_param_groups: bool = True,
50
+ **kwargs,
51
+ ):
52
+ defaults = dict(use_param_groups=use_param_groups)
53
+ super().__init__(defaults=defaults)
54
+
41
55
  self._opt_fn = opt_fn
42
56
  self._opt_args = args
43
57
  self._opt_kwargs = kwargs
@@ -48,12 +62,12 @@ class Wrap(Module):
48
62
  self.optimizer = self._opt_fn
49
63
 
50
64
  def set_param_groups(self, param_groups):
51
- self._custom_param_groups = param_groups
65
+ self._custom_param_groups = _make_param_groups(param_groups, differentiable=False)
52
66
  return super().set_param_groups(param_groups)
53
67
 
54
68
  @torch.no_grad
55
- def step(self, var):
56
- params = var.params
69
+ def apply(self, objective):
70
+ params = objective.params
57
71
 
58
72
  # initialize opt on 1st step
59
73
  if self.optimizer is None:
@@ -61,54 +75,47 @@ class Wrap(Module):
61
75
  param_groups = params if self._custom_param_groups is None else self._custom_param_groups
62
76
  self.optimizer = self._opt_fn(param_groups, *self._opt_args, **self._opt_kwargs)
63
77
 
78
+ # set optimizer per-parameter settings
79
+ if self.defaults["use_param_groups"] and objective.modular is not None:
80
+ for group in self.optimizer.param_groups:
81
+ first_param = group['params'][0]
82
+ setting = self.settings[first_param]
83
+
84
+ # settings passed in `set_param_groups` are the highest priority
85
+ # schedulers will override defaults but not settings passed in `set_param_groups`
86
+ # this is consistent with how Modular does it.
87
+ if self._custom_param_groups is not None:
88
+ setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
89
+
90
+ group.update(setting)
91
+
64
92
  # set grad to update
65
93
  orig_grad = [p.grad for p in params]
66
- for p, u in zip(params, var.get_update()):
94
+ for p, u in zip(params, objective.get_updates()):
67
95
  p.grad = u
68
96
 
69
- # if this module is last, can step with _opt directly
70
- # direct step can't be applied if next module is LR but _opt doesn't support lr,
71
- # and if there are multiple different per-parameter lrs (would be annoying to support)
72
- if var.is_last and (
73
- (var.last_module_lrs is None)
74
- or
75
- (('lr' in self.optimizer.defaults) and (len(set(var.last_module_lrs)) == 1))
76
- ):
77
- lr = 1 if var.last_module_lrs is None else var.last_module_lrs[0]
78
-
79
- # update optimizer lr with desired lr
80
- if lr != 1:
81
- self.optimizer.defaults['__original_lr__'] = self.optimizer.defaults['lr']
82
- for g in self.optimizer.param_groups:
83
- g['__original_lr__'] = g['lr']
84
- g['lr'] = g['lr'] * lr
85
-
86
- # step
97
+ # if this is last module, simply use optimizer to update parameters
98
+ if objective.modular is not None and self is objective.modular.modules[-1]:
87
99
  self.optimizer.step()
88
100
 
89
- # restore original lr
90
- if lr != 1:
91
- self.optimizer.defaults['lr'] = self.optimizer.defaults.pop('__original_lr__')
92
- for g in self.optimizer.param_groups:
93
- g['lr'] = g.pop('__original_lr__')
94
-
95
101
  # restore grad
96
102
  for p, g in zip(params, orig_grad):
97
103
  p.grad = g
98
104
 
99
- var.stop = True; var.skip_update = True
100
- return var
105
+ objective.stop = True; objective.skip_update = True
106
+ return objective
101
107
 
102
108
  # this is not the last module, meaning update is difference in parameters
109
+ # and passed to next module
103
110
  params_before_step = [p.clone() for p in params]
104
111
  self.optimizer.step() # step and update params
105
112
  for p, g in zip(params, orig_grad):
106
113
  p.grad = g
107
- var.update = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
114
+ objective.updates = list(torch._foreach_sub(params_before_step, params)) # set update to difference between params
108
115
  for p, o in zip(params, params_before_step):
109
116
  p.set_(o) # pyright: ignore[reportArgumentType]
110
117
 
111
- return var
118
+ return objective
112
119
 
113
120
  def reset(self):
114
121
  super().reset()
@@ -33,13 +33,16 @@ class CD(Module):
33
33
  defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
34
34
  super().__init__(defaults)
35
35
 
36
+ def update(self, objective): raise RuntimeError
37
+ def apply(self, objective): raise RuntimeError
38
+
36
39
  @torch.no_grad
37
- def step(self, var):
38
- closure = var.closure
40
+ def step(self, objective):
41
+ closure = objective.closure
39
42
  if closure is None:
40
43
  raise RuntimeError("CD requires closure")
41
44
 
42
- params = TensorList(var.params)
45
+ params = TensorList(objective.params)
43
46
  ndim = params.global_numel()
44
47
 
45
48
  grad_step_size = self.defaults['grad']
@@ -79,7 +82,7 @@ class CD(Module):
79
82
  else:
80
83
  warnings.warn("CD adaptive=True only works with threepoint=True")
81
84
 
82
- f_0 = var.get_loss(False)
85
+ f_0 = objective.get_loss(False)
83
86
  params.flat_set_lambda_(idx, lambda x: x + h)
84
87
  f_p = closure(False)
85
88
 
@@ -117,6 +120,6 @@ class CD(Module):
117
120
  # ----------------------------- create the update ---------------------------- #
118
121
  update = params.zeros_like()
119
122
  update.flat_set_(idx, alpha)
120
- var.update = update
121
- return var
123
+ objective.updates = update
124
+ return objective
122
125
 
torchzero/optim/root.py CHANGED
@@ -3,7 +3,7 @@ from collections.abc import Callable
3
3
 
4
4
  from abc import abstractmethod
5
5
  import torch
6
- from ..modules.higher_order.multipoint import sixth_order_im1, sixth_order_p6, _solve
6
+ from ..modules.second_order.multipoint import sixth_order_3p, sixth_order_5p, two_point_newton, sixth_order_3pm2, _solve
7
7
 
8
8
  def make_evaluate(f: Callable[[torch.Tensor], torch.Tensor]):
9
9
  def evaluate(x, order) -> tuple[torch.Tensor, ...]:
@@ -53,7 +53,7 @@ class Newton(RootBase):
53
53
  def one_iteration(self, x, evaluate): return newton(x, evaluate, self.lstsq)
54
54
 
55
55
 
56
- class SixthOrderP6(RootBase):
56
+ class SixthOrder3P(RootBase):
57
57
  """sixth-order iterative method
58
58
 
59
59
  Abro, Hameer Akhtar, and Muhammad Mujtaba Shaikh. "A new time-efficient and convergent nonlinear solver." Applied Mathematics and Computation 355 (2019): 516-536.
@@ -62,4 +62,4 @@ class SixthOrderP6(RootBase):
62
62
  def one_iteration(self, x, evaluate):
63
63
  def f(x): return evaluate(x, 0)[0]
64
64
  def f_j(x): return evaluate(x, 1)
65
- return sixth_order_p6(x, f, f_j, self.lstsq)
65
+ return sixth_order_3p(x, f, f_j, self.lstsq)
@@ -3,7 +3,8 @@ from collections.abc import Callable, Iterable
3
3
 
4
4
  import torch
5
5
 
6
- from ...utils import flatten, get_params
6
+ from ...utils import flatten
7
+ from ...utils.optimizer import get_params
7
8
 
8
9
  class Split(torch.optim.Optimizer):
9
10
  """Steps will all `optimizers`, also has a check that they have no duplicate parameters.