torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,33 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Target, Transform
4
- from ...utils import TensorList
4
+ from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
6
  class ReduceOutwardLR(Transform):
7
- """
8
- When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
7
+ """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
9
8
 
10
9
  This means updates that move weights towards zero have higher learning rates.
10
+
11
+ .. warning::
12
+ This sounded good but after testing turns out it sucks.
11
13
  """
12
14
  def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
13
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
14
16
  super().__init__(defaults, uses_grad=use_grad, target=target)
15
17
 
16
18
  @torch.no_grad
17
- def transform(self, tensors, params, grads, vars):
19
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
18
20
  params = TensorList(params)
19
21
  tensors = TensorList(tensors)
20
22
 
21
- mul = self.get_settings('mul', params=params)
22
- s = self.settings[params[0]]
23
+ mul = [s['mul'] for s in settings]
24
+ s = settings[0]
23
25
  use_grad = s['use_grad']
24
26
  invert = s['invert']
25
27
 
26
- if use_grad: cur = vars.get_grad()
28
+ if use_grad: cur = grads
27
29
  else: cur = tensors
30
+ assert cur is not None
28
31
 
29
32
  # mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
30
33
  if invert: mask = (params * cur) > 0
@@ -6,35 +6,18 @@ import torch
6
6
  from ...core import Chainable
7
7
  from ...utils import vec_to_tensors, TensorList
8
8
  from ..optimizers.shampoo import _merge_small_dims
9
- from .projection import Projection
9
+ from ..projections import ProjectionBase
10
10
 
11
11
 
12
- class VectorProjection(Projection):
13
- """
14
- flattens and concatenates all parameters into a vector
15
- """
16
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
17
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
18
12
 
19
- @torch.no_grad
20
- def project(self, tensors, vars, current):
21
- return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
22
-
23
- @torch.no_grad
24
- def unproject(self, tensors, vars, current):
25
- return vec_to_tensors(vec=tensors[0], reference=vars.params)
26
-
27
-
28
-
29
- class TensorizeProjection(Projection):
13
+ class TensorizeProjection(ProjectionBase):
30
14
  """flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
31
15
  def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
32
16
  defaults = dict(max_side=max_side)
33
17
  super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
34
18
 
35
19
  @torch.no_grad
36
- def project(self, tensors, vars, current):
37
- params = vars.params
20
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
21
  max_side = self.settings[params[0]]['max_side']
39
22
  num_elems = sum(t.numel() for t in tensors)
40
23
 
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
60
43
  return [vec.view(dims)]
61
44
 
62
45
  @torch.no_grad
63
- def unproject(self, tensors, vars, current):
46
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
64
47
  remainder = self.global_state['remainder']
65
48
  # warnings.warn(f'{tensors[0].shape = }')
66
- vec = tensors[0].view(-1)
49
+ vec = projected_tensors[0].view(-1)
67
50
  if remainder > 0: vec = vec[:-remainder]
68
- return vec_to_tensors(vec, vars.params)
51
+ return vec_to_tensors(vec, params)
69
52
 
70
- class BlockPartition(Projection):
53
+ class BlockPartition(ProjectionBase):
71
54
  """splits parameters into blocks (for now flatttens them and chunks)"""
72
55
  def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
73
56
  defaults = dict(max_size=max_size, batched=batched)
74
57
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
75
58
 
76
59
  @torch.no_grad
77
- def project(self, tensors, vars, current):
60
+ def project(self, tensors, params, grads, loss, states, settings, current):
78
61
  partitioned = []
79
- for p,t in zip(vars.params, tensors):
62
+ for p,t in zip(params, tensors):
80
63
  settings = self.settings[p]
81
64
  max_size = settings['max_size']
82
65
  n = t.numel()
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
101
84
  return partitioned
102
85
 
103
86
  @torch.no_grad
104
- def unproject(self, tensors, vars, current):
105
- ti = iter(tensors)
87
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
88
+ ti = iter(projected_tensors)
106
89
  unprojected = []
107
- for p in vars.params:
90
+ for p in params:
108
91
  settings = self.settings[p]
109
92
  n = p.numel()
110
93
 
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
124
107
 
125
108
  return unprojected
126
109
 
127
-
128
- class TensorNormsProjection(Projection):
129
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
130
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
131
-
132
- @torch.no_grad
133
- def project(self, tensors, vars, current):
134
- orig = self.get_state(f'{current}_orig', params=vars.params)
135
- torch._foreach_copy_(orig, tensors)
136
-
137
- norms = torch._foreach_norm(tensors)
138
- self.get_state(f'{current}_orig_norms', params=vars.params, init=norms, cls=TensorList).set_(norms)
139
-
140
- return [torch.stack(norms)]
141
-
142
- @torch.no_grad
143
- def unproject(self, tensors, vars, current):
144
- orig = self.get_state(f'{current}_orig', params=vars.params)
145
- orig_norms = torch.stack(self.get_state(f'{current}_orig_norms', params=vars.params))
146
- target_norms = tensors[0]
147
-
148
- orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
149
-
150
- torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
151
- return orig
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  # import torchzero as tz
7
7
 
8
- from ...core import Transform, Chainable, apply
8
+ from ...core import Transform, Chainable, apply_transform
9
9
  from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
10
10
  from ...utils import TensorList, vec_to_tensors_
11
11
 
@@ -38,15 +38,20 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random slowly changing subspace"""
41
+ """Whitens in random slowly changing subspace.
42
+
43
+ .. warning::
44
+ Experimental and this is a barebones implementation.
45
+
46
+ """
42
47
  def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
48
  defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
49
  super().__init__(defaults, uses_grad=False)
45
50
 
46
51
  if inner is not None: self.set_child('inner', inner)
47
52
 
48
- def transform(self, tensors, params, grads, vars):
49
- settings = self.settings[params[0]]
53
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
+ settings = settings[0]
50
55
  g = torch.cat([t.view(-1) for t in tensors])
51
56
  k = settings['k']
52
57
  beta = settings['beta']
@@ -65,7 +70,7 @@ class RandomSubspacePreconditioning(Transform):
65
70
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
71
 
67
72
  if 'inner' in self.children:
68
- tensors = apply(self.children['inner'], tensors, params, grads, vars)
73
+ tensors = apply_transform(self.children['inner'], tensors, params, grads)
69
74
  g = torch.cat([t.view(-1) for t in tensors])
70
75
 
71
76
  try:
@@ -78,9 +83,14 @@ class RandomSubspacePreconditioning(Transform):
78
83
 
79
84
 
80
85
  class HistorySubspacePreconditioning(Transform):
81
- """full matrix rmsprop in subspace spanned by history of gradient differences
86
+ """Whitens in subspace spanned by history of gradient differences.
87
+
88
+ .. warning::
89
+ Experimental and this is a barebones implementation.
82
90
 
83
- basis_beta is how much basis is allowed to change, and beta is for preconditioner itself in the basis.
91
+ Args:
92
+ beta - for preconditioner itself in the basis.
93
+ basis_beta - how much basis is allowed to change.
84
94
  """
85
95
  def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
86
96
  defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
@@ -88,8 +98,8 @@ class HistorySubspacePreconditioning(Transform):
88
98
 
89
99
  if inner is not None: self.set_child('inner', inner)
90
100
 
91
- def transform(self, tensors, params, grads, vars):
92
- settings = self.settings[params[0]]
101
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
102
+ settings = settings[0]
93
103
 
94
104
  g = torch.cat([t.view(-1) for t in tensors])
95
105
  k = settings['k']
@@ -122,7 +132,7 @@ class HistorySubspacePreconditioning(Transform):
122
132
  update_subspace_preconditioner_(g, basis, accumulator, beta)
123
133
 
124
134
  if 'inner' in self.children:
125
- tensors = apply(self.children['inner'], tensors, params, grads, vars)
135
+ tensors = apply_transform(self.children['inner'], tensors, params, grads)
126
136
  g = torch.cat([t.view(-1) for t in tensors])
127
137
 
128
138
  try:
@@ -0,0 +1,42 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, TensorwiseTransform
6
+ from ...utils.linalg import matrix_power_eigh
7
+
8
+
9
+ class TensorAdagrad(TensorwiseTransform):
10
+ """3rd order whitening (maybe normalizes skewness, but don't quote me on it).
11
+
12
+ .. warning::
13
+ Experimental.
14
+ """
15
+ def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
16
+ defaults = dict(history_size=history_size, reg=reg)
17
+ super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
18
+
19
+ @torch.no_grad
20
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
21
+ reg = setting['reg']
22
+ if 'history' not in state:
23
+ state['history'] = deque(maxlen=setting['history_size'])
24
+
25
+ g = tensor.view(-1)
26
+ history = state['history']
27
+ history.append(g.clone())
28
+
29
+ I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
30
+ g_k = history[0]
31
+ outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
32
+ if len(history) > 1:
33
+ for g_k in list(history)[1:]:
34
+ outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
35
+
36
+ state['outer'] = outer.add_(I)
37
+
38
+ @torch.no_grad
39
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
40
+ outer = state['outer']
41
+ P = matrix_power_eigh(outer, -1/2)
42
+ return (P @ tensor.ravel()).view_as(tensor)
@@ -7,8 +7,9 @@ storage is always indicated in the docstring.
7
7
 
8
8
  Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
9
9
  """
10
-
11
- from collections.abc import Callable, Sequence
10
+ from collections.abc import Callable
11
+ from typing import overload
12
+ import torch
12
13
 
13
14
  from ..utils import NumberList, TensorList
14
15
 
@@ -206,4 +207,13 @@ def sqrt_centered_ema_sq_(
206
207
  ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
207
208
  )
208
209
 
210
+ @overload
211
+ def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
212
+ @overload
213
+ def safe_scaling_(tensors_: TensorList) -> TensorList: ...
214
+ def safe_scaling_(tensors_: torch.Tensor | TensorList):
215
+ if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
216
+ else: scale = 1 / tensors_.abs().global_sum()
217
+ scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
218
+ return tensors_.mul_(scale)
209
219
 
@@ -77,8 +77,11 @@ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_
77
77
  return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
78
78
 
79
79
  _FD_FUNCS = {
80
+ "forward": _forward2,
80
81
  "forward2": _forward2,
82
+ "backward": _backward2,
81
83
  "backward2": _backward2,
84
+ "central": _central2,
82
85
  "central2": _central2,
83
86
  "central3": _central2, # they are the same
84
87
  "forward3": _forward3,
@@ -88,19 +91,43 @@ _FD_FUNCS = {
88
91
 
89
92
 
90
93
  class FDM(GradApproximator):
91
- """Approximate gradients via finite difference method
94
+ """Approximate gradients via finite difference method.
95
+
96
+ .. note::
97
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
98
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
92
99
 
93
100
  Args:
94
101
  h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
95
102
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
96
- target (GradTarget, optional): what to set on vars. Defaults to 'closure'.
103
+ target (GradTarget, optional): what to set on var. Defaults to 'closure'.
104
+
105
+ Examples:
106
+ plain FDM:
107
+
108
+ .. code-block:: python
109
+
110
+ fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
111
+
112
+ Any gradient-based method can use FDM-estimated gradients seamlessly.
113
+
114
+ .. code-block:: python
115
+
116
+ fdm_ncg = tz.Modular(
117
+ model.parameters(),
118
+ tz.m.FDM(),
119
+ # set hvp_method to "forward" so that it
120
+ # uses gradient difference instead of autograd
121
+ tz.m.NewtonCG(hvp_method="forward"),
122
+ tz.m.Backtracking()
123
+ )
97
124
  """
98
- def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
125
+ def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
99
126
  defaults = dict(h=h, formula=formula)
100
127
  super().__init__(defaults, target=target)
101
128
 
102
129
  @torch.no_grad
103
- def approximate(self, closure, params, loss, vars):
130
+ def approximate(self, closure, params, loss):
104
131
  grads = []
105
132
  loss_approx = None
106
133
 
@@ -4,26 +4,36 @@ from typing import Any, Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...utils import Distributions, NumberList, TensorList, generic_eq
7
+ from ...utils import Distributions, NumberList, TensorList
8
8
  from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
9
9
  from .grad_approximator import GradApproximator, GradTarget
10
10
  from .rfdm import RandomizedFDM
11
11
 
12
12
 
13
13
  class ForwardGradient(RandomizedFDM):
14
- """Forward gradient method, same as randomized finite difference but directional derivative is estimated via autograd (as jacobian vector product)
14
+ """Forward gradient method.
15
+
16
+ This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
17
+
18
+ .. note::
19
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
20
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
21
+
15
22
 
16
23
  Args:
17
24
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
18
25
  distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
19
26
  beta (float, optional):
20
- if not 0, acts as momentum on gradient samples, making the subspace spanned by them change slowly. Defaults to 0.
27
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
21
28
  pre_generate (bool, optional):
22
- whether to pre-generate gradient samples before each step. Defaults to True.
29
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
23
30
  jvp_method (str, optional):
24
- how to calculate jacobian vector product, note that with `forward` and 'central' this is identical to randomized finite difference. Defaults to 'autograd'.
31
+ how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
25
32
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
26
- target (GradTarget, optional): what to set on vars. Defaults to "closure".
33
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
34
+
35
+ References:
36
+ Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
27
37
  """
28
38
  PRE_MULTIPLY_BY_H = False
29
39
  def __init__(
@@ -41,7 +51,7 @@ class ForwardGradient(RandomizedFDM):
41
51
  self.defaults['jvp_method'] = jvp_method
42
52
 
43
53
  @torch.no_grad
44
- def approximate(self, closure, params, loss, vars):
54
+ def approximate(self, closure, params, loss):
45
55
  params = TensorList(params)
46
56
  loss_approx = None
47
57
 
@@ -5,7 +5,7 @@ from typing import Any, Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Vars
8
+ from ...core import Module, Var
9
9
 
10
10
  GradTarget = Literal['update', 'grad', 'closure']
11
11
  _Scalar = torch.Tensor | float
@@ -14,53 +14,98 @@ class GradApproximator(Module, ABC):
14
14
  """Base class for gradient approximations.
15
15
  This is an abstract class, to use it, subclass it and override `approximate`.
16
16
 
17
+ GradientApproximator modifies the closure to evaluate the estimated gradients,
18
+ and further closure-based modules will use the modified closure.
19
+
17
20
  Args:
18
21
  defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
19
22
  target (str, optional):
20
- whether to set `vars.grad`, `vars.update` or 'vars.closure`. Defaults to 'closure'.
21
- """
23
+ whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
24
+
25
+ Example:
26
+
27
+ Basic SPSA method implementation.
28
+
29
+ .. code-block:: python
30
+
31
+ class SPSA(GradApproximator):
32
+ def __init__(self, h=1e-3):
33
+ defaults = dict(h=h)
34
+ super().__init__(defaults)
35
+
36
+ @torch.no_grad
37
+ def approximate(self, closure, params, loss):
38
+ perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
39
+
40
+ # evaluate params + perturbation
41
+ torch._foreach_add_(params, perturbation)
42
+ loss_plus = closure(False)
43
+
44
+ # evaluate params - perturbation
45
+ torch._foreach_sub_(params, perturbation)
46
+ torch._foreach_sub_(params, perturbation)
47
+ loss_minus = closure(False)
48
+
49
+ # restore original params
50
+ torch._foreach_add_(params, perturbation)
51
+
52
+ # calculate SPSA gradients
53
+ spsa_grads = []
54
+ for p, pert in zip(params, perturbation):
55
+ settings = self.settings[p]
56
+ h = settings['h']
57
+ d = (loss_plus - loss_minus) / (2*(h**2))
58
+ spsa_grads.append(pert * d)
59
+
60
+ # returns tuple: (grads, loss, loss_approx)
61
+ # loss must be with initial parameters
62
+ # since we only evaluated loss with perturbed parameters
63
+ # we only have loss_approx
64
+ return spsa_grads, None, loss_plus
65
+
66
+ """
22
67
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
23
68
  super().__init__(defaults)
24
69
  self._target: GradTarget = target
25
70
 
26
71
  @abstractmethod
27
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, vars: Vars) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
72
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
28
73
  """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
29
74
 
30
- def pre_step(self, vars: Vars) -> Vars | None:
75
+ def pre_step(self, var: Var) -> Var | None:
31
76
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
32
77
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
33
- return vars
78
+ return var
34
79
 
35
80
  @torch.no_grad
36
- def step(self, vars):
37
- ret = self.pre_step(vars)
38
- if isinstance(ret, Vars): vars = ret
81
+ def step(self, var):
82
+ ret = self.pre_step(var)
83
+ if isinstance(ret, Var): var = ret
39
84
 
40
- if vars.closure is None: raise RuntimeError("Gradient approximation requires closure")
41
- params, closure, loss = vars.params, vars.closure, vars.loss
85
+ if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
86
+ params, closure, loss = var.params, var.closure, var.loss
42
87
 
43
88
  if self._target == 'closure':
44
89
 
45
90
  def approx_closure(backward=True):
46
91
  if backward:
47
92
  # set loss to None because closure might be evaluated at different points
48
- grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, vars=vars)
93
+ grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
49
94
  for p, g in zip(params, grad): p.grad = g
50
- return l if l is not None else l_approx
95
+ return l if l is not None else closure(False)
51
96
  return closure(False)
52
97
 
53
- vars.closure = approx_closure
54
- return vars
98
+ var.closure = approx_closure
99
+ return var
55
100
 
56
- # if vars.grad is not None:
57
- # warnings.warn('Using grad approximator when `vars.grad` is already set.')
58
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, vars=vars)
59
- if loss_approx is not None: vars.loss_approx = loss_approx
60
- if loss is not None: vars.loss = vars.loss_approx = loss
61
- if self._target == 'grad': vars.grad = list(grad)
62
- elif self._target == 'update': vars.update = list(grad)
101
+ # if var.grad is not None:
102
+ # warnings.warn('Using grad approximator when `var.grad` is already set.')
103
+ grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
104
+ if loss_approx is not None: var.loss_approx = loss_approx
105
+ if loss is not None: var.loss = var.loss_approx = loss
106
+ if self._target == 'grad': var.grad = list(grad)
107
+ elif self._target == 'update': var.update = list(grad)
63
108
  else: raise ValueError(self._target)
64
- return vars
109
+ return var
65
110
 
66
- _FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
111
+ _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']