torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,10 @@ from typing import Any, cast
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, Target, Vars, maybe_chain
8
+ from ...core import Chainable, Module, Target, Var, maybe_chain
9
9
 
10
10
 
11
- class ReduceOperation(Module, ABC):
11
+ class ReduceOperationBase(Module, ABC):
12
12
  """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
13
13
  def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
14
14
  super().__init__(defaults=defaults)
@@ -26,33 +26,34 @@ class ReduceOperation(Module, ABC):
26
26
  raise ValueError('At least one operand must be a module')
27
27
 
28
28
  @abstractmethod
29
- def transform(self, vars: Vars, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
29
+ def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
30
30
  """applies the operation to operands"""
31
31
  raise NotImplementedError
32
32
 
33
33
  @torch.no_grad
34
- def step(self, vars: Vars) -> Vars:
34
+ def step(self, var: Var) -> Var:
35
35
  # pass cloned update to all module operands
36
36
  processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
37
37
 
38
38
  for i, v in enumerate(self.operands):
39
39
  if f'operand_{i}' in self.children:
40
40
  v: Module
41
- updated_vars = v.step(vars.clone(clone_update=True))
42
- processed_operands[i] = updated_vars.get_update()
43
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
41
+ updated_var = v.step(var.clone(clone_update=True))
42
+ processed_operands[i] = updated_var.get_update()
43
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
44
44
 
45
- transformed = self.transform(vars, *processed_operands)
46
- vars.update = transformed
47
- return vars
45
+ transformed = self.transform(var, *processed_operands)
46
+ var.update = transformed
47
+ return var
48
48
 
49
- class Sum(ReduceOperation):
49
+ class Sum(ReduceOperationBase):
50
+ """Outputs sum of :code:`inputs` that can be modules or numbers."""
50
51
  USE_MEAN = False
51
52
  def __init__(self, *inputs: Chainable | float):
52
53
  super().__init__({}, *inputs)
53
54
 
54
55
  @torch.no_grad
55
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
56
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
56
57
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
57
58
  sum = cast(list, sorted_inputs[0])
58
59
  if len(sorted_inputs) > 1:
@@ -63,12 +64,14 @@ class Sum(ReduceOperation):
63
64
  return sum
64
65
 
65
66
  class Mean(Sum):
67
+ """Outputs a mean of :code:`inputs` that can be modules or numbers."""
66
68
  USE_MEAN = True
67
69
 
68
70
 
69
- class WeightedSum(ReduceOperation):
71
+ class WeightedSum(ReduceOperationBase):
70
72
  USE_MEAN = False
71
73
  def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
74
+ """Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
72
75
  weights = list(weights)
73
76
  if len(inputs) != len(weights):
74
77
  raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
@@ -76,9 +79,9 @@ class WeightedSum(ReduceOperation):
76
79
  super().__init__(defaults=defaults, *inputs)
77
80
 
78
81
  @torch.no_grad
79
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
82
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
80
83
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
81
- weights = self.settings[vars.params[0]]['weights']
84
+ weights = self.settings[var.params[0]]['weights']
82
85
  sum = cast(list, sorted_inputs[0])
83
86
  torch._foreach_mul_(sum, weights[0])
84
87
  if len(sorted_inputs) > 1:
@@ -91,14 +94,16 @@ class WeightedSum(ReduceOperation):
91
94
 
92
95
 
93
96
  class WeightedMean(WeightedSum):
97
+ """Outputs weighted mean of :code:`inputs` that can be modules or numbers."""
94
98
  USE_MEAN = True
95
99
 
96
- class Median(ReduceOperation):
100
+ class Median(ReduceOperationBase):
101
+ """Outputs median of :code:`inputs` that can be modules or numbers."""
97
102
  def __init__(self, *inputs: Chainable | float):
98
103
  super().__init__({}, *inputs)
99
104
 
100
105
  @torch.no_grad
101
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
106
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
102
107
  res = []
103
108
  lists = [i for i in inputs if isinstance(i, list)]
104
109
  floats = [i for i in inputs if isinstance(i, (int,float))]
@@ -106,12 +111,13 @@ class Median(ReduceOperation):
106
111
  res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
107
112
  return res
108
113
 
109
- class Prod(ReduceOperation):
114
+ class Prod(ReduceOperationBase):
115
+ """Outputs product of :code:`inputs` that can be modules or numbers."""
110
116
  def __init__(self, *inputs: Chainable | float):
111
117
  super().__init__({}, *inputs)
112
118
 
113
119
  @torch.no_grad
114
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
120
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
115
121
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
116
122
  prod = cast(list, sorted_inputs[0])
117
123
  if len(sorted_inputs) > 1:
@@ -120,12 +126,13 @@ class Prod(ReduceOperation):
120
126
 
121
127
  return prod
122
128
 
123
- class MaximumModules(ReduceOperation):
129
+ class MaximumModules(ReduceOperationBase):
130
+ """Outputs elementwise maximum of :code:`inputs` that can be modules or numbers."""
124
131
  def __init__(self, *inputs: Chainable | float):
125
132
  super().__init__({}, *inputs)
126
133
 
127
134
  @torch.no_grad
128
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
135
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
129
136
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
130
137
  maximum = cast(list, sorted_inputs[0])
131
138
  if len(sorted_inputs) > 1:
@@ -134,12 +141,13 @@ class MaximumModules(ReduceOperation):
134
141
 
135
142
  return maximum
136
143
 
137
- class MinimumModules(ReduceOperation):
144
+ class MinimumModules(ReduceOperationBase):
145
+ """Outputs elementwise minimum of :code:`inputs` that can be modules or numbers."""
138
146
  def __init__(self, *inputs: Chainable | float):
139
147
  super().__init__({}, *inputs)
140
148
 
141
149
  @torch.no_grad
142
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
150
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
143
151
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
144
152
  minimum = cast(list, sorted_inputs[0])
145
153
  if len(sorted_inputs) > 1:
@@ -3,79 +3,95 @@ from collections import deque
3
3
  import torch
4
4
 
5
5
  from ...core import TensorwiseTransform, Target, Transform
6
- from ...utils import TensorList
6
+ from ...utils import TensorList, unpack_dicts,unpack_states
7
7
 
8
8
  class UnaryLambda(Transform):
9
+ """Applies :code:`fn` to input tensors.
10
+
11
+ :code:`fn` must accept and return a list of tensors.
12
+ """
9
13
  def __init__(self, fn, target: "Target" = 'update'):
10
14
  defaults = dict(fn=fn)
11
15
  super().__init__(defaults=defaults, uses_grad=False, target=target)
12
16
 
13
17
  @torch.no_grad
14
- def transform(self, tensors, params, grads, vars):
15
- return self.settings[params[0]]['fn'](tensors)
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
19
+ return settings[0]['fn'](tensors)
16
20
 
17
21
  class UnaryParameterwiseLambda(TensorwiseTransform):
22
+ """Applies :code:`fn` to each input tensor.
23
+
24
+ :code:`fn` must accept and return a tensor.
25
+ """
18
26
  def __init__(self, fn, target: "Target" = 'update'):
19
27
  defaults = dict(fn=fn)
20
28
  super().__init__(uses_grad=False, defaults=defaults, target=target)
21
29
 
22
30
  @torch.no_grad
23
- def transform(self, tensor, param, grad, vars):
24
- return self.settings[param]['fn'](tensor)
31
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
32
+ return setting['fn'](tensor)
25
33
 
26
34
  class CustomUnaryOperation(Transform):
35
+ """Applies :code:`getattr(tensor, name)` to each tensor
36
+ """
27
37
  def __init__(self, name: str, target: "Target" = 'update'):
28
38
  defaults = dict(name=name)
29
39
  super().__init__(defaults=defaults, uses_grad=False, target=target)
30
40
 
31
41
  @torch.no_grad
32
- def transform(self, tensors, params, grads, vars):
33
- return getattr(tensors, self.settings[params[0]]['name'])()
42
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
43
+ return getattr(tensors, settings[0]['name'])()
34
44
 
35
45
 
36
46
  class Abs(Transform):
47
+ """Returns :code:`abs(input)`"""
37
48
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
38
49
  @torch.no_grad
39
- def transform(self, tensors, params, grads, vars):
50
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
51
  torch._foreach_abs_(tensors)
41
52
  return tensors
42
53
 
43
54
  class Sign(Transform):
55
+ """Returns :code:`sign(input)`"""
44
56
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
45
57
  @torch.no_grad
46
- def transform(self, tensors, params, grads, vars):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
47
59
  torch._foreach_sign_(tensors)
48
60
  return tensors
49
61
 
50
62
  class Exp(Transform):
63
+ """Returns :code:`exp(input)`"""
51
64
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
52
65
  @torch.no_grad
53
- def transform(self, tensors, params, grads, vars):
66
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
67
  torch._foreach_exp_(tensors)
55
68
  return tensors
56
69
 
57
70
  class Sqrt(Transform):
71
+ """Returns :code:`sqrt(input)`"""
58
72
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
59
73
  @torch.no_grad
60
- def transform(self, tensors, params, grads, vars):
74
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
61
75
  torch._foreach_sqrt_(tensors)
62
76
  return tensors
63
77
 
64
78
  class Reciprocal(Transform):
79
+ """Returns :code:`1 / input`"""
65
80
  def __init__(self, eps = 0, target: "Target" = 'update'):
66
81
  defaults = dict(eps = eps)
67
82
  super().__init__(defaults, uses_grad=False, target=target)
68
83
  @torch.no_grad
69
- def transform(self, tensors, params, grads, vars):
70
- eps = self.get_settings('eps', params=params)
84
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
85
+ eps = [s['eps'] for s in settings]
71
86
  if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
72
87
  torch._foreach_reciprocal_(tensors)
73
88
  return tensors
74
89
 
75
90
  class Negate(Transform):
91
+ """Returns :code:`- input`"""
76
92
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
77
93
  @torch.no_grad
78
- def transform(self, tensors, params, grads, vars):
94
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
79
95
  torch._foreach_neg_(tensors)
80
96
  return tensors
81
97
 
@@ -97,19 +113,19 @@ class NanToNum(Transform):
97
113
  super().__init__(defaults, uses_grad=False, target=target)
98
114
 
99
115
  @torch.no_grad
100
- def transform(self, tensors, params, grads, vars):
101
- nan, posinf, neginf = self.get_settings('nan', 'posinf', 'neginf', params=params)
116
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
117
+ nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
102
118
  return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
103
119
 
104
120
  class Rescale(Transform):
105
- """rescale update to (min, max) range"""
121
+ """Rescales input to :code`(min, max)` range"""
106
122
  def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
107
123
  defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
108
124
  super().__init__(defaults, uses_grad=False, target=target)
109
125
 
110
126
  @torch.no_grad
111
- def transform(self, tensors, params, grads, vars):
112
- min,max = self.get_settings('min','max', params=params)
113
- tensorwise = self.settings[params[0]]['tensorwise']
127
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
128
+ min, max = unpack_dicts(settings, 'min','max')
129
+ tensorwise = settings[0]['tensorwise']
114
130
  dim = None if tensorwise else 'global'
115
- return TensorList(tensors).rescale(min=min, max=max, eps=self.settings[params[0]]['eps'], dim=dim)
131
+ return TensorList(tensors).rescale(min=min, max=max, eps=settings[0]['eps'], dim=dim)
@@ -6,107 +6,115 @@ from ...core import Module, Target, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
7
 
8
8
 
9
- class Clone(Transform):
10
- def __init__(self): super().__init__({}, uses_grad=False)
11
- @torch.no_grad
12
- def transform(self, tensors, params, grads, vars): return [t.clone() for t in tensors]
13
-
14
- class Grad(Module):
9
+ class Clone(Module):
10
+ """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
15
11
  def __init__(self):
16
12
  super().__init__({})
17
13
  @torch.no_grad
18
- def step(self, vars):
19
- vars.update = [g.clone() for g in vars.get_grad()]
20
- return vars
14
+ def step(self, var):
15
+ var.update = [u.clone() for u in var.get_update()]
16
+ return var
21
17
 
22
- class Params(Module):
18
+ class Grad(Module):
19
+ """Outputs the gradient"""
23
20
  def __init__(self):
24
21
  super().__init__({})
25
22
  @torch.no_grad
26
- def step(self, vars):
27
- vars.update = [p.clone() for p in vars.params]
28
- return vars
23
+ def step(self, var):
24
+ var.update = [g.clone() for g in var.get_grad()]
25
+ return var
29
26
 
30
- class Update(Module):
27
+ class Params(Module):
28
+ """Outputs parameters"""
31
29
  def __init__(self):
32
30
  super().__init__({})
33
31
  @torch.no_grad
34
- def step(self, vars):
35
- vars.update = [u.clone() for u in vars.get_update()]
36
- return vars
32
+ def step(self, var):
33
+ var.update = [p.clone() for p in var.params]
34
+ return var
37
35
 
38
36
  class Zeros(Module):
37
+ """Outputs zeros"""
39
38
  def __init__(self):
40
39
  super().__init__({})
41
40
  @torch.no_grad
42
- def step(self, vars):
43
- vars.update = [torch.zeros_like(p) for p in vars.params]
44
- return vars
41
+ def step(self, var):
42
+ var.update = [torch.zeros_like(p) for p in var.params]
43
+ return var
45
44
 
46
45
  class Ones(Module):
46
+ """Outputs ones"""
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
50
- def step(self, vars):
51
- vars.update = [torch.ones_like(p) for p in vars.params]
52
- return vars
50
+ def step(self, var):
51
+ var.update = [torch.ones_like(p) for p in var.params]
52
+ return var
53
53
 
54
54
  class Fill(Module):
55
+ """Outputs tensors filled with :code:`value`"""
55
56
  def __init__(self, value: float):
56
57
  defaults = dict(value=value)
57
58
  super().__init__(defaults)
58
59
 
59
60
  @torch.no_grad
60
- def step(self, vars):
61
- vars.update = [torch.full_like(p, self.settings[p]['value']) for p in vars.params]
62
- return vars
61
+ def step(self, var):
62
+ var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
63
+ return var
63
64
 
64
65
  class RandomSample(Module):
66
+ """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
65
67
  def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
66
68
  defaults = dict(eps=eps, distribution=distribution)
67
69
  super().__init__(defaults)
68
70
 
69
71
  @torch.no_grad
70
- def step(self, vars):
71
- vars.update = TensorList(vars.params).sample_like(
72
- eps=self.get_settings('eps',params=vars.params), distribution=self.settings[vars.params[0]]['distribution']
72
+ def step(self, var):
73
+ var.update = TensorList(var.params).sample_like(
74
+ eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
73
75
  )
74
- return vars
76
+ return var
75
77
 
76
78
  class Randn(Module):
79
+ """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
77
80
  def __init__(self):
78
81
  super().__init__({})
79
82
 
80
83
  @torch.no_grad
81
- def step(self, vars):
82
- vars.update = [torch.randn_like(p) for p in vars.params]
83
- return vars
84
+ def step(self, var):
85
+ var.update = [torch.randn_like(p) for p in var.params]
86
+ return var
84
87
 
85
88
  class Uniform(Module):
89
+ """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
86
90
  def __init__(self, low: float, high: float):
87
91
  defaults = dict(low=low, high=high)
88
92
  super().__init__(defaults)
89
93
 
90
94
  @torch.no_grad
91
- def step(self, vars):
92
- low,high = self.get_settings('low','high', params=vars.params)
93
- vars.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(vars.params, low, high)]
94
- return vars
95
+ def step(self, var):
96
+ low,high = self.get_settings(var.params, 'low','high')
97
+ var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
98
+ return var
95
99
 
96
100
  class GradToNone(Module):
101
+ """Sets :code:`grad` attribute to None on :code:`var`."""
97
102
  def __init__(self): super().__init__()
98
- def step(self, vars):
99
- vars.grad = None
100
- return vars
103
+ def step(self, var):
104
+ var.grad = None
105
+ return var
101
106
 
102
107
  class UpdateToNone(Module):
108
+ """Sets :code:`update` attribute to None on :code:`var`."""
103
109
  def __init__(self): super().__init__()
104
- def step(self, vars):
105
- vars.update = None
106
- return vars
110
+ def step(self, var):
111
+ var.update = None
112
+ return var
107
113
 
108
114
  class Identity(Module):
115
+ """A placeholder identity operator that is argument-insensitive."""
109
116
  def __init__(self, *args, **kwargs): super().__init__()
110
- def step(self, vars): return vars
117
+ def step(self, var): return var
111
118
 
112
- NoOp = Identity
119
+ NoOp = Identity
120
+ """A placeholder identity operator that is argument-insensitive."""
@@ -1,7 +1,18 @@
1
1
  from .adagrad import Adagrad, FullMatrixAdagrad
2
+
3
+ # from .curveball import CurveBall
4
+ # from .spectral import SpectralPreconditioner
5
+ from .adahessian import AdaHessian
2
6
  from .adam import Adam
7
+ from .adan import Adan
8
+ from .adaptive_heavyball import AdaptiveHeavyBall
9
+ from .esgd import ESGD
10
+ from .ladagrad import LMAdagrad
3
11
  from .lion import Lion
12
+ from .mars import MARSCorrection
13
+ from .msam import MSAM, MSAMObjective
4
14
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
15
+ from .orthograd import OrthoGrad, orthograd_
5
16
  from .rmsprop import RMSprop
6
17
  from .rprop import (
7
18
  BacktrackOnSignChange,
@@ -10,9 +21,7 @@ from .rprop import (
10
21
  SignConsistencyLRs,
11
22
  SignConsistencyMask,
12
23
  )
24
+ from .sam import ASAM, SAM
13
25
  from .shampoo import Shampoo
14
26
  from .soap import SOAP
15
- from .orthograd import OrthoGrad, orthograd_
16
27
  from .sophia_h import SophiaH
17
- # from .curveball import CurveBall
18
- # from .spectral import SpectralPreconditioner
@@ -1,18 +1,17 @@
1
1
  from operator import itemgetter
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
-
5
5
  from ...core import (
6
6
  Chainable,
7
7
  Module,
8
- Preconditioner,
9
8
  Target,
10
- TensorwisePreconditioner,
9
+ TensorwiseTransform,
11
10
  Transform,
12
- Vars,
13
- apply,
11
+ Var,
12
+ apply_transform,
14
13
  )
15
- from ...utils import NumberList, TensorList
14
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
16
15
  from ...utils.linalg import matrix_power_eigh
17
16
  from ..functional import add_power_, lerp_power_, root
18
17
 
@@ -26,12 +25,12 @@ def adagrad_(
26
25
  step: int,
27
26
  pow: float = 2,
28
27
  use_sqrt: bool = True,
28
+ divide: bool = False,
29
29
 
30
30
  # inner args
31
31
  inner: Module | None = None,
32
32
  params: list[torch.Tensor] | None = None,
33
33
  grads: list[torch.Tensor] | None = None,
34
- vars: Vars | None = None,
35
34
  ):
36
35
  """returns `tensors_`"""
37
36
  clr = alpha / (1 + step * lr_decay)
@@ -40,7 +39,9 @@ def adagrad_(
40
39
 
41
40
  if inner is not None:
42
41
  assert params is not None
43
- tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
42
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
43
+
44
+ if divide: sq_sum_ = sq_sum_ / max(step, 1)
44
45
 
45
46
  if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
46
47
  else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
@@ -50,7 +51,9 @@ def adagrad_(
50
51
 
51
52
 
52
53
  class Adagrad(Transform):
53
- """Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
54
+ """Adagrad, divides by sum of past squares of gradients.
55
+
56
+ This implementation is identical to :code:`torch.optim.Adagrad`.
54
57
 
55
58
  Args:
56
59
  lr_decay (float, optional): learning rate decay. Defaults to 0.
@@ -69,29 +72,30 @@ class Adagrad(Transform):
69
72
  alpha: float = 1,
70
73
  pow: float = 2,
71
74
  use_sqrt: bool = True,
75
+ divide: bool=False,
72
76
  inner: Chainable | None = None,
73
77
  ):
74
78
  defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
75
- eps = eps, pow=pow, use_sqrt = use_sqrt)
79
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
76
80
  super().__init__(defaults=defaults, uses_grad=False)
77
81
 
78
82
  if inner is not None:
79
83
  self.set_child('inner', inner)
80
84
 
81
85
  @torch.no_grad
82
- def transform(self, tensors, params, grads, vars):
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
83
87
  tensors = TensorList(tensors)
84
88
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
85
89
 
86
- lr_decay,alpha,eps = self.get_settings('lr_decay', 'alpha', 'eps', params=params, cls=NumberList)
90
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
87
91
 
88
- pow, use_sqrt = itemgetter('pow', 'use_sqrt')(self.settings[params[0]])
92
+ pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
89
93
 
90
- sq_sum = self.get_state('sq_sum', params=params, cls=TensorList)
94
+ sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
91
95
 
92
96
  # initialize accumulator on 1st step
93
97
  if step == 1:
94
- sq_sum.set_(tensors.full_like(self.get_settings('initial_accumulator_value', params=params)))
98
+ sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
95
99
 
96
100
  return adagrad_(
97
101
  tensors,
@@ -102,45 +106,60 @@ class Adagrad(Transform):
102
106
  step=self.global_state["step"],
103
107
  pow=pow,
104
108
  use_sqrt=use_sqrt,
109
+ divide=divide,
105
110
 
106
111
  # inner args
107
112
  inner=self.children.get("inner", None),
108
113
  params=params,
109
114
  grads=grads,
110
- vars=vars,
111
115
  )
112
116
 
113
117
 
114
118
 
115
- class FullMatrixAdagrad(TensorwisePreconditioner):
116
- def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=False, update_freq=1, inner: Chainable | None = None):
117
- defaults = dict(beta=beta, decay=decay)
118
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
119
+ class FullMatrixAdagrad(TensorwiseTransform):
120
+ def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
121
+ defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
122
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
119
123
 
120
124
  @torch.no_grad
121
- def update_tensor(self, tensor, param, grad, state, settings):
125
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
122
126
  G = tensor.ravel()
123
127
  GG = torch.outer(G, G)
124
- decay = settings['decay']
125
- beta = settings['beta']
126
-
127
- if 'GG' not in state: state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
128
+ decay = setting['decay']
129
+ beta = setting['beta']
130
+ init = setting['init']
131
+
132
+ if 'GG' not in state:
133
+ if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
134
+ elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
135
+ elif init == 'ones': state['GG'] = torch.ones_like(GG)
136
+ elif init == 'GGT': state['GG'] = GG.clone()
137
+ else: raise ValueError(init)
128
138
  if decay is not None: state['GG'].mul_(decay)
129
139
 
130
140
  if beta is not None: state['GG'].lerp_(GG, 1-beta)
131
141
  else: state['GG'].add_(GG)
142
+ state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
132
143
 
133
144
  @torch.no_grad
134
- def apply_tensor(self, tensor, param, grad, state, settings):
145
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
135
146
  GG = state['GG']
147
+ sqrt = setting['sqrt']
148
+ divide = setting['divide']
149
+ if divide: GG = GG/state.get('i', 1)
136
150
 
137
151
  if tensor.numel() == 1:
138
- return tensor / (GG**(1/2)).squeeze()
152
+ GG = GG.squeeze()
153
+ if sqrt: return tensor / GG.sqrt()
154
+ return tensor / GG
139
155
 
140
156
  try:
141
- B = matrix_power_eigh(GG, -1/2)
157
+ if sqrt: B = matrix_power_eigh(GG, -1/2)
158
+ else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
159
+
142
160
  except torch.linalg.LinAlgError:
143
- return tensor.div_(tensor.abs().max()) # conservative scaling
161
+ scale = 1 / tensor.abs().max()
162
+ return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
144
163
 
145
164
  return (B @ tensor.ravel()).view_as(tensor)
146
165