torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,10 @@ from typing import Literal
6
6
  import torch
7
7
 
8
8
  from ...core import Chainable, Transform, HVPMethod
9
- from ...utils import vec_to_tensors
9
+ from ...utils import vec_to_tensors_
10
10
  from ...linalg.linear_operator import Sketched
11
11
 
12
- from .newton import _newton_step
12
+ from .newton import _newton_update_state_, _newton_solve
13
13
 
14
14
  def _qr_orthonormalize(A:torch.Tensor):
15
15
  m,n = A.shape
@@ -20,12 +20,10 @@ def _qr_orthonormalize(A:torch.Tensor):
20
20
  q, _ = torch.linalg.qr(A) # pylint:disable=not-callable
21
21
  return q
22
22
 
23
+
23
24
  def _orthonormal_sketch(m, n, dtype, device, generator):
24
25
  return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
25
26
 
26
- def _gaussian_sketch(m, n, dtype, device, generator):
27
- return torch.randn(m, n, dtype=dtype, device=device, generator=generator) / math.sqrt(m)
28
-
29
27
  def _rademacher_sketch(m, n, dtype, device, generator):
30
28
  rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
31
29
  return rademacher.mul_(1 / math.sqrt(m))
@@ -37,11 +35,10 @@ class SubspaceNewton(Transform):
37
35
  sketch_size (int):
38
36
  size of the random sketch. This many hessian-vector products will need to be evaluated each step.
39
37
  sketch_type (str, optional):
38
+ - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
40
39
  - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
41
- - "rademacher" - approximately orthonormal scaled random rademacher basis.
42
- - "gaussian" - random gaussian (not orthonormal) basis.
43
- - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt.
44
- - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction (default).
40
+ - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
41
+ - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
45
42
  damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
46
43
  hvp_method (str, optional):
47
44
  How to compute hessian-matrix product:
@@ -73,7 +70,7 @@ class SubspaceNewton(Transform):
73
70
 
74
71
  RSN with line search
75
72
  ```python
76
- opt = tz.Modular(
73
+ opt = tz.Optimizer(
77
74
  model.parameters(),
78
75
  tz.m.RSN(),
79
76
  tz.m.Backtracking()
@@ -82,7 +79,7 @@ class SubspaceNewton(Transform):
82
79
 
83
80
  RSN with trust region
84
81
  ```python
85
- opt = tz.Modular(
82
+ opt = tz.Optimizer(
86
83
  model.parameters(),
87
84
  tz.m.LevenbergMarquardt(tz.m.RSN()),
88
85
  )
@@ -97,14 +94,14 @@ class SubspaceNewton(Transform):
97
94
  def __init__(
98
95
  self,
99
96
  sketch_size: int,
100
- sketch_type: Literal["orthonormal", "gaussian", "common_directions", "mixed"] = "mixed",
97
+ sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
101
98
  damping:float=0,
99
+ eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
100
+ update_freq: int = 1,
101
+ precompute_inverse: bool = False,
102
+ use_lstsq: bool = True,
102
103
  hvp_method: HVPMethod = "batched_autograd",
103
104
  h: float = 1e-2,
104
- use_lstsq: bool = True,
105
- update_freq: int = 1,
106
- H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
107
- eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
108
105
  seed: int | None = None,
109
106
  inner: Chainable | None = None,
110
107
  ):
@@ -128,10 +125,7 @@ class SubspaceNewton(Transform):
128
125
  sketch_type = fs["sketch_type"]
129
126
  hvp_method = fs["hvp_method"]
130
127
 
131
- if sketch_type in ('normal', 'gaussian'):
132
- S = _gaussian_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
-
134
- elif sketch_type == "rademacher":
128
+ if sketch_type == "rademacher":
135
129
  S = _rademacher_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
136
130
 
137
131
  elif sketch_type == 'orthonormal':
@@ -187,7 +181,7 @@ class SubspaceNewton(Transform):
187
181
  # form and orthogonalize sketching matrix
188
182
  S = torch.stack([g, slow_ema, fast_ema, prev_dir], dim=1)
189
183
  if sketch_size > 4:
190
- S_random = _gaussian_sketch(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator)
184
+ S_random = torch.randn(ndim, sketch_size - 3, device=device, dtype=dtype, generator=generator) / math.sqrt(ndim)
191
185
  S = torch.cat([S, S_random], dim=1)
192
186
 
193
187
  S = _qr_orthonormalize(S)
@@ -200,38 +194,41 @@ class SubspaceNewton(Transform):
200
194
  hvp_method=fs["hvp_method"], h=fs["h"])
201
195
  H_sketched = S.T @ HS
202
196
 
203
- self.global_state["H_sketched"] = H_sketched
197
+ # update state
198
+ _newton_update_state_(
199
+ state = self.global_state,
200
+ H = H_sketched,
201
+ damping = fs["damping"],
202
+ eigval_fn = fs["eigval_fn"],
203
+ precompute_inverse = fs["precompute_inverse"],
204
+ use_lstsq = fs["use_lstsq"]
205
+
206
+ )
207
+
204
208
  self.global_state["S"] = S
205
209
 
206
210
  def apply_states(self, objective, states, settings):
207
- S: torch.Tensor = self.global_state["S"]
211
+ updates = objective.get_updates()
212
+ fs = settings[0]
208
213
 
209
- d_proj = _newton_step(
210
- objective=objective,
211
- H=self.global_state["H_sketched"],
212
- damping=self.defaults["damping"],
213
- H_tfm=self.defaults["H_tfm"],
214
- eigval_fn=self.defaults["eigval_fn"],
215
- use_lstsq=self.defaults["use_lstsq"],
216
- g_proj = lambda g: S.T @ g
217
- )
214
+ S = self.global_state["S"]
215
+ b = torch.cat([t.ravel() for t in updates])
216
+ b_proj = S.T @ b
217
+
218
+ d_proj = _newton_solve(b=b_proj, state=self.global_state, use_lstsq=fs["use_lstsq"])
218
219
 
219
220
  d = S @ d_proj
220
- objective.updates = vec_to_tensors(d, objective.params)
221
+ vec_to_tensors_(d, updates)
221
222
  return objective
222
223
 
223
224
  def get_H(self, objective=...):
224
- eigval_fn = self.defaults["eigval_fn"]
225
- H_sketched: torch.Tensor = self.global_state["H_sketched"]
226
- S: torch.Tensor = self.global_state["S"]
227
-
228
- if eigval_fn is not None:
229
- try:
230
- L, Q = torch.linalg.eigh(H_sketched) # pylint:disable=not-callable
231
- L: torch.Tensor = eigval_fn(L)
232
- H_sketched = Q @ L.diag_embed() @ Q.mH
225
+ if "H" in self.global_state:
226
+ H_sketched = self.global_state["H"]
233
227
 
234
- except torch.linalg.LinAlgError:
235
- pass
228
+ else:
229
+ L = self.global_state["L"]
230
+ Q = self.global_state["Q"]
231
+ H_sketched = Q @ L.diag_embed() @ Q.mH
236
232
 
233
+ S: torch.Tensor = self.global_state["S"]
237
234
  return Sketched(S, H_sketched)
@@ -74,7 +74,7 @@ class LaplacianSmoothing(TensorTransform):
74
74
 
75
75
  ```python
76
76
 
77
- opt = tz.Modular(
77
+ opt = tz.Optimizer(
78
78
  model.parameters(),
79
79
  tz.m.LaplacianSmoothing(),
80
80
  tz.m.LR(1e-2),
@@ -7,15 +7,14 @@ from typing import Literal, cast
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Modular, Module, Objective
10
+ from ...core import Chainable, Optimizer, Module, Objective
11
11
  from ...core.reformulation import Reformulation
12
12
  from ...utils import Distributions, NumberList, TensorList
13
13
  from ..termination import TerminationCriteriaBase, make_termination_criteria
14
14
 
15
15
 
16
16
  def _reset_except_self(objective: Objective, modules, self: Module):
17
- assert objective.modular is not None
18
- for m in objective.modular.flat_modules:
17
+ for m in modules:
19
18
  if m is not self:
20
19
  m.reset()
21
20
 
@@ -8,7 +8,7 @@ import torch
8
8
  from ...core import Chainable, TensorTransform
9
9
  from ...utils import NumberList, TensorList, tofloat, unpack_dicts, unpack_states
10
10
  from ...linalg.linear_operator import ScaledIdentity
11
- from ..functional import epsilon_step_size
11
+ from ..opt_utils import epsilon_step_size
12
12
 
13
13
  def _acceptable_alpha(alpha, param:torch.Tensor):
14
14
  finfo = torch.finfo(param.dtype)
@@ -16,7 +16,7 @@ def _acceptable_alpha(alpha, param:torch.Tensor):
16
16
  return False
17
17
  return True
18
18
 
19
- def _get_H(self: TensorTransform, var):
19
+ def _get_scaled_identity_H(self: TensorTransform, var):
20
20
  n = sum(p.numel() for p in var.params)
21
21
  p = var.params[0]
22
22
  alpha = self.global_state.get('alpha', 1)
@@ -87,7 +87,7 @@ class PolyakStepSize(TensorTransform):
87
87
  return tensors
88
88
 
89
89
  def get_H(self, objective):
90
- return _get_H(self, objective)
90
+ return _get_scaled_identity_H(self, objective)
91
91
 
92
92
 
93
93
  def _bb_short(s: TensorList, y: TensorList, sy, eps):
@@ -176,7 +176,7 @@ class BarzilaiBorwein(TensorTransform):
176
176
  prev_g.copy_(g)
177
177
 
178
178
  def get_H(self, objective):
179
- return _get_H(self, objective)
179
+ return _get_scaled_identity_H(self, objective)
180
180
 
181
181
  @torch.no_grad
182
182
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -288,7 +288,7 @@ class BBStab(TensorTransform):
288
288
  prev_g.copy_(g)
289
289
 
290
290
  def get_H(self, objective):
291
- return _get_H(self, objective)
291
+ return _get_scaled_identity_H(self, objective)
292
292
 
293
293
  @torch.no_grad
294
294
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -384,4 +384,4 @@ class AdGD(TensorTransform):
384
384
  return tensors
385
385
 
386
386
  def get_H(self, objective):
387
- return _get_H(self, objective)
387
+ return _get_scaled_identity_H(self, objective)
@@ -51,7 +51,7 @@ class Warmup(TensorTransform):
51
51
 
52
52
  .. code-block:: python
53
53
 
54
- opt = tz.Modular(
54
+ opt = tz.Optimizer(
55
55
  model.parameters(),
56
56
  tz.m.Adam(),
57
57
  tz.m.LR(1e-2),
@@ -90,7 +90,7 @@ class WarmupNormClip(TensorTransform):
90
90
 
91
91
  .. code-block:: python
92
92
 
93
- opt = tz.Modular(
93
+ opt = tz.Optimizer(
94
94
  model.parameters(),
95
95
  tz.m.Adam(),
96
96
  tz.m.WarmupNormClip(steps=1000)
@@ -109,7 +109,7 @@ class CubicRegularization(TrustRegionBase):
109
109
 
110
110
  .. code-block:: python
111
111
 
112
- opt = tz.Modular(
112
+ opt = tz.Optimizer(
113
113
  model.parameters(),
114
114
  tz.m.CubicRegularization(tz.m.Newton()),
115
115
  )
@@ -44,7 +44,7 @@ class LevenbergMarquardt(TrustRegionBase):
44
44
  Gauss-Newton with Levenberg-Marquardt trust-region
45
45
 
46
46
  ```python
47
- opt = tz.Modular(
47
+ opt = tz.Optimizer(
48
48
  model.parameters(),
49
49
  tz.m.LevenbergMarquardt(tz.m.GaussNewton()),
50
50
  )
@@ -52,7 +52,7 @@ class LevenbergMarquardt(TrustRegionBase):
52
52
 
53
53
  LM-SR1
54
54
  ```python
55
- opt = tz.Modular(
55
+ opt = tz.Optimizer(
56
56
  model.parameters(),
57
57
  tz.m.LevenbergMarquardt(tz.m.SR1(inverse=False)),
58
58
  )
@@ -47,7 +47,7 @@ class TrustCG(TrustRegionBase):
47
47
 
48
48
  .. code-block:: python
49
49
 
50
- opt = tz.Modular(
50
+ opt = tz.Optimizer(
51
51
  model.parameters(),
52
52
  tz.m.TrustCG(hess_module=tz.m.SR1(inverse=False)),
53
53
  )
@@ -8,8 +8,7 @@ from ...utils import tofloat
8
8
 
9
9
 
10
10
  def _reset_except_self(objective: Objective, modules, self: Module):
11
- assert objective.modular is not None
12
- for m in objective.modular.flat_modules:
11
+ for m in modules:
13
12
  if m is not self:
14
13
  m.reset()
15
14
 
@@ -45,7 +44,7 @@ class SVRG(Module):
45
44
  ## Examples:
46
45
  SVRG-LBFGS
47
46
  ```python
48
- opt = tz.Modular(
47
+ opt = tz.Optimizer(
49
48
  model.parameters(),
50
49
  tz.m.SVRG(len(dataloader)),
51
50
  tz.m.LBFGS(),
@@ -55,7 +54,7 @@ class SVRG(Module):
55
54
 
56
55
  For extra variance reduction one can use Online versions of algorithms, although it won't always help.
57
56
  ```python
58
- opt = tz.Modular(
57
+ opt = tz.Optimizer(
59
58
  model.parameters(),
60
59
  tz.m.SVRG(len(dataloader)),
61
60
  tz.m.Online(tz.m.LBFGS()),
@@ -64,7 +63,7 @@ class SVRG(Module):
64
63
 
65
64
  Variance reduction can also be applied to gradient estimators.
66
65
  ```python
67
- opt = tz.Modular(
66
+ opt = tz.Optimizer(
68
67
  model.parameters(),
69
68
  tz.m.SPSA(),
70
69
  tz.m.SVRG(100),
@@ -6,8 +6,8 @@ from ...core import Module
6
6
  from ...utils import NumberList, TensorList
7
7
 
8
8
 
9
- def _reset_except_self(optimizer, var, self: Module):
10
- for m in optimizer.unrolled_modules:
9
+ def _reset_except_self(objective, modules, self: Module):
10
+ for m in modules:
11
11
  if m is not self:
12
12
  m.reset()
13
13
 
@@ -33,7 +33,7 @@ class WeightDecay(TensorTransform):
33
33
 
34
34
  Adam with non-decoupled weight decay
35
35
  ```python
36
- opt = tz.Modular(
36
+ opt = tz.Optimizer(
37
37
  model.parameters(),
38
38
  tz.m.WeightDecay(1e-3),
39
39
  tz.m.Adam(),
@@ -44,7 +44,7 @@ class WeightDecay(TensorTransform):
44
44
  Adam with decoupled weight decay that still scales with learning rate
45
45
  ```python
46
46
 
47
- opt = tz.Modular(
47
+ opt = tz.Optimizer(
48
48
  model.parameters(),
49
49
  tz.m.Adam(),
50
50
  tz.m.WeightDecay(1e-3),
@@ -54,7 +54,7 @@ class WeightDecay(TensorTransform):
54
54
 
55
55
  Adam with fully decoupled weight decay that doesn't scale with learning rate
56
56
  ```python
57
- opt = tz.Modular(
57
+ opt = tz.Optimizer(
58
58
  model.parameters(),
59
59
  tz.m.Adam(),
60
60
  tz.m.LR(1e-3),
@@ -93,7 +93,7 @@ class RelativeWeightDecay(TensorTransform):
93
93
 
94
94
  Adam with non-decoupled relative weight decay
95
95
  ```python
96
- opt = tz.Modular(
96
+ opt = tz.Optimizer(
97
97
  model.parameters(),
98
98
  tz.m.RelativeWeightDecay(1e-1),
99
99
  tz.m.Adam(),
@@ -103,7 +103,7 @@ class RelativeWeightDecay(TensorTransform):
103
103
 
104
104
  Adam with decoupled relative weight decay
105
105
  ```python
106
- opt = tz.Modular(
106
+ opt = tz.Optimizer(
107
107
  model.parameters(),
108
108
  tz.m.Adam(),
109
109
  tz.m.RelativeWeightDecay(1e-1),
@@ -11,7 +11,7 @@ class Wrap(Module):
11
11
  Wraps a pytorch optimizer to use it as a module.
12
12
 
13
13
  Note:
14
- Custom param groups are supported only by ``set_param_groups``, settings passed to Modular will be applied to all parameters.
14
+ Custom param groups are supported only by ``set_param_groups``, settings passed to Optimizer will be applied to all parameters.
15
15
 
16
16
  Args:
17
17
  opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
@@ -21,7 +21,7 @@ class Wrap(Module):
21
21
  **kwargs:
22
22
  Extra args to be passed to opt_fn. The function is called as ``opt_fn(parameters, *args, **kwargs)``.
23
23
  use_param_groups:
24
- Whether to pass settings passed to Modular to the wrapped optimizer.
24
+ Whether to pass settings passed to Optimizer to the wrapped optimizer.
25
25
 
26
26
  Note that settings to the first parameter are used for all parameters,
27
27
  so if you specified per-parameter settings, they will be ignored.
@@ -32,7 +32,7 @@ class Wrap(Module):
32
32
  ```python
33
33
 
34
34
  from pytorch_optimizer import StableAdamW
35
- opt = tz.Modular(
35
+ opt = tz.Optimizer(
36
36
  model.parameters(),
37
37
  tz.m.Wrap(StableAdamW, lr=1),
38
38
  tz.m.Cautious(),
@@ -83,7 +83,7 @@ class Wrap(Module):
83
83
 
84
84
  # settings passed in `set_param_groups` are the highest priority
85
85
  # schedulers will override defaults but not settings passed in `set_param_groups`
86
- # this is consistent with how Modular does it.
86
+ # this is consistent with how Optimizer does it.
87
87
  if self._custom_param_groups is not None:
88
88
  setting = {k:v for k,v in setting if k not in self._custom_param_groups[0]}
89
89
 
@@ -29,7 +29,7 @@ class CD(Module):
29
29
  whether to use three points (three function evaluatins) to determine descent direction.
30
30
  if False, uses two points, but then ``adaptive`` can't be used. Defaults to True.
31
31
  """
32
- def __init__(self, h:float=1e-3, grad:bool=True, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
32
+ def __init__(self, h:float=1e-3, grad:bool=False, adaptive:bool=True, index:Literal['cyclic', 'cyclic2', 'random']="cyclic2", threepoint:bool=True,):
33
33
  defaults = dict(h=h, grad=grad, adaptive=adaptive, index=index, threepoint=threepoint)
34
34
  super().__init__(defaults)
35
35