torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,149 @@
1
+ """"""
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Iterable,Sequence
4
+ from typing import Any, cast
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Target, Vars, maybe_chain
9
+
10
+
11
+ class ReduceOperation(Module, ABC):
12
+ """Base class for reduction operations like Sum, Prod, Maximum. This is an abstract class, subclass it and override `transform` method to use it."""
13
+ def __init__(self, defaults: dict[str, Any] | None, *operands: Chainable | Any):
14
+ super().__init__(defaults=defaults)
15
+
16
+ self.operands = []
17
+ for i, v in enumerate(operands):
18
+
19
+ if isinstance(v, (Module, Sequence)):
20
+ self.set_child(f'operand_{i}', v)
21
+ self.operands.append(self.children[f'operand_{i}'])
22
+ else:
23
+ self.operands.append(v)
24
+
25
+ if not self.children:
26
+ raise ValueError('At least one operand must be a module')
27
+
28
+ @abstractmethod
29
+ def transform(self, vars: Vars, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
30
+ """applies the operation to operands"""
31
+ raise NotImplementedError
32
+
33
+ @torch.no_grad
34
+ def step(self, vars: Vars) -> Vars:
35
+ # pass cloned update to all module operands
36
+ processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
37
+
38
+ for i, v in enumerate(self.operands):
39
+ if f'operand_{i}' in self.children:
40
+ v: Module
41
+ updated_vars = v.step(vars.clone(clone_update=True))
42
+ processed_operands[i] = updated_vars.get_update()
43
+ vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
44
+
45
+ transformed = self.transform(vars, *processed_operands)
46
+ vars.update = transformed
47
+ return vars
48
+
49
+ class Sum(ReduceOperation):
50
+ USE_MEAN = False
51
+ def __init__(self, *inputs: Chainable | float):
52
+ super().__init__({}, *inputs)
53
+
54
+ @torch.no_grad
55
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
56
+ sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
57
+ sum = cast(list, sorted_inputs[0])
58
+ if len(sorted_inputs) > 1:
59
+ for v in sorted_inputs[1:]:
60
+ torch._foreach_add_(sum, v)
61
+
62
+ if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
63
+ return sum
64
+
65
+ class Mean(Sum):
66
+ USE_MEAN = True
67
+
68
+
69
+ class WeightedSum(ReduceOperation):
70
+ USE_MEAN = False
71
+ def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
72
+ weights = list(weights)
73
+ if len(inputs) != len(weights):
74
+ raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
75
+ defaults = dict(weights=weights)
76
+ super().__init__(defaults=defaults, *inputs)
77
+
78
+ @torch.no_grad
79
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
80
+ sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
81
+ weights = self.settings[vars.params[0]]['weights']
82
+ sum = cast(list, sorted_inputs[0])
83
+ torch._foreach_mul_(sum, weights[0])
84
+ if len(sorted_inputs) > 1:
85
+ for v, w in zip(sorted_inputs[1:], weights[1:]):
86
+ if isinstance(v, (int, float)): torch._foreach_add_(sum, v*w)
87
+ else: torch._foreach_add_(sum, v, alpha=w)
88
+
89
+ if self.USE_MEAN and len(sorted_inputs) > 1: torch._foreach_div_(sum, len(sorted_inputs))
90
+ return sum
91
+
92
+
93
+ class WeightedMean(WeightedSum):
94
+ USE_MEAN = True
95
+
96
+ class Median(ReduceOperation):
97
+ def __init__(self, *inputs: Chainable | float):
98
+ super().__init__({}, *inputs)
99
+
100
+ @torch.no_grad
101
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
102
+ res = []
103
+ lists = [i for i in inputs if isinstance(i, list)]
104
+ floats = [i for i in inputs if isinstance(i, (int,float))]
105
+ for tensors in zip(*lists):
106
+ res.append(torch.median(torch.stack(tensors + tuple(torch.full_like(tensors[0], f) for f in floats)), dim=0))
107
+ return res
108
+
109
+ class Prod(ReduceOperation):
110
+ def __init__(self, *inputs: Chainable | float):
111
+ super().__init__({}, *inputs)
112
+
113
+ @torch.no_grad
114
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
115
+ sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
116
+ prod = cast(list, sorted_inputs[0])
117
+ if len(sorted_inputs) > 1:
118
+ for v in sorted_inputs[1:]:
119
+ torch._foreach_mul_(prod, v)
120
+
121
+ return prod
122
+
123
+ class MaximumModules(ReduceOperation):
124
+ def __init__(self, *inputs: Chainable | float):
125
+ super().__init__({}, *inputs)
126
+
127
+ @torch.no_grad
128
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
129
+ sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
130
+ maximum = cast(list, sorted_inputs[0])
131
+ if len(sorted_inputs) > 1:
132
+ for v in sorted_inputs[1:]:
133
+ torch._foreach_maximum_(maximum, v)
134
+
135
+ return maximum
136
+
137
+ class MinimumModules(ReduceOperation):
138
+ def __init__(self, *inputs: Chainable | float):
139
+ super().__init__({}, *inputs)
140
+
141
+ @torch.no_grad
142
+ def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
143
+ sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
144
+ minimum = cast(list, sorted_inputs[0])
145
+ if len(sorted_inputs) > 1:
146
+ for v in sorted_inputs[1:]:
147
+ torch._foreach_minimum_(minimum, v)
148
+
149
+ return minimum
@@ -0,0 +1,75 @@
1
+ from collections.abc import Callable
2
+ from typing import cast
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, Vars
7
+
8
+
9
+ def _split(
10
+ module: Module,
11
+ idxs,
12
+ params,
13
+ vars: Vars,
14
+ ):
15
+ split_params = [p for i,p in enumerate(params) if i in idxs]
16
+
17
+ split_grad = None
18
+ if vars.grad is not None:
19
+ split_grad = [g for i,g in enumerate(vars.grad) if i in idxs]
20
+
21
+ split_update = None
22
+ if vars.update is not None:
23
+ split_update = [u for i,u in enumerate(vars.update) if i in idxs]
24
+
25
+ split_vars = vars.clone(clone_update=False)
26
+ split_vars.params = split_params
27
+ split_vars.grad = split_grad
28
+ split_vars.update = split_update
29
+
30
+ split_vars = module.step(split_vars)
31
+
32
+ if (vars.grad is None) and (split_vars.grad is not None):
33
+ vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
34
+
35
+ if split_vars.update is not None:
36
+
37
+ if vars.update is None:
38
+ if vars.grad is None: vars.update = [cast(torch.Tensor, None) for _ in vars.params]
39
+ else: vars.update = [g.clone() for g in vars.grad]
40
+
41
+ for idx, u in zip(idxs, split_vars.update):
42
+ vars.update[idx] = u
43
+
44
+ vars.update_attrs_from_clone_(split_vars)
45
+ return vars
46
+
47
+ class Split(Module):
48
+ """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
49
+ def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
50
+ defaults = dict(filter=filter)
51
+ super().__init__(defaults)
52
+
53
+ if true is not None: self.set_child('true', true)
54
+ if false is not None: self.set_child('false', false)
55
+
56
+ def step(self, vars):
57
+
58
+ params = vars.params
59
+ filter = self.settings[params[0]]['filter']
60
+
61
+ true_idxs = []
62
+ false_idxs = []
63
+ for i,p in enumerate(params):
64
+ if filter(p): true_idxs.append(i)
65
+ else: false_idxs.append(i)
66
+
67
+ if 'true' in self.children:
68
+ true = self.children['true']
69
+ vars = _split(true, idxs=true_idxs, params=params, vars=vars)
70
+
71
+ if 'false' in self.children:
72
+ false = self.children['false']
73
+ vars = _split(false, idxs=false_idxs, params=params, vars=vars)
74
+
75
+ return vars
@@ -0,0 +1,68 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from typing import Any
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module
7
+
8
+
9
+ class Alternate(Module):
10
+ """alternate between stepping with `modules`"""
11
+ LOOP = True
12
+ def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
13
+ if isinstance(steps, Iterable):
14
+ steps = list(steps)
15
+ if len(steps) != len(modules):
16
+ raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
17
+
18
+ defaults = dict(steps=steps)
19
+ super().__init__(defaults)
20
+
21
+ self.set_children_sequence(modules)
22
+ self.global_state['current_module_idx'] = 0
23
+ self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
24
+
25
+ @torch.no_grad
26
+ def step(self, vars):
27
+ # get current module
28
+ current_module_idx = self.global_state.setdefault('current_module_idx', 0)
29
+ module = self.children[f'module_{current_module_idx}']
30
+
31
+ # step
32
+ vars = module.step(vars.clone(clone_update=False))
33
+
34
+ # number of steps until next module
35
+ steps = self.settings[vars.params[0]]['steps']
36
+ if isinstance(steps, int): steps = [steps]*len(self.children)
37
+
38
+ if 'steps_to_next' not in self.global_state:
39
+ self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
40
+
41
+ self.global_state['steps_to_next'] -= 1
42
+
43
+ # switch to next module
44
+ if self.global_state['steps_to_next'] == 0:
45
+ self.global_state['current_module_idx'] += 1
46
+
47
+ # loop to first module (or keep using last module on Switch)
48
+ if self.global_state['current_module_idx'] > len(self.children) - 1:
49
+ if self.LOOP: self.global_state['current_module_idx'] = 0
50
+ else: self.global_state['current_module_idx'] = len(self.children) - 1
51
+
52
+ self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
53
+
54
+ return vars
55
+
56
+ class Switch(Alternate):
57
+ """switch to next module after some steps"""
58
+ LOOP = False
59
+ def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
60
+
61
+ if isinstance(steps, Iterable):
62
+ steps = list(steps)
63
+ if len(steps) != len(modules) - 1:
64
+ raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
65
+
66
+ steps.append(1)
67
+
68
+ super().__init__(*modules, steps=steps)
@@ -0,0 +1,115 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import TensorwiseTransform, Target, Transform
6
+ from ...utils import TensorList
7
+
8
+ class UnaryLambda(Transform):
9
+ def __init__(self, fn, target: "Target" = 'update'):
10
+ defaults = dict(fn=fn)
11
+ super().__init__(defaults=defaults, uses_grad=False, target=target)
12
+
13
+ @torch.no_grad
14
+ def transform(self, tensors, params, grads, vars):
15
+ return self.settings[params[0]]['fn'](tensors)
16
+
17
+ class UnaryParameterwiseLambda(TensorwiseTransform):
18
+ def __init__(self, fn, target: "Target" = 'update'):
19
+ defaults = dict(fn=fn)
20
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
21
+
22
+ @torch.no_grad
23
+ def transform(self, tensor, param, grad, vars):
24
+ return self.settings[param]['fn'](tensor)
25
+
26
+ class CustomUnaryOperation(Transform):
27
+ def __init__(self, name: str, target: "Target" = 'update'):
28
+ defaults = dict(name=name)
29
+ super().__init__(defaults=defaults, uses_grad=False, target=target)
30
+
31
+ @torch.no_grad
32
+ def transform(self, tensors, params, grads, vars):
33
+ return getattr(tensors, self.settings[params[0]]['name'])()
34
+
35
+
36
+ class Abs(Transform):
37
+ def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
38
+ @torch.no_grad
39
+ def transform(self, tensors, params, grads, vars):
40
+ torch._foreach_abs_(tensors)
41
+ return tensors
42
+
43
+ class Sign(Transform):
44
+ def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
45
+ @torch.no_grad
46
+ def transform(self, tensors, params, grads, vars):
47
+ torch._foreach_sign_(tensors)
48
+ return tensors
49
+
50
+ class Exp(Transform):
51
+ def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
52
+ @torch.no_grad
53
+ def transform(self, tensors, params, grads, vars):
54
+ torch._foreach_exp_(tensors)
55
+ return tensors
56
+
57
+ class Sqrt(Transform):
58
+ def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
59
+ @torch.no_grad
60
+ def transform(self, tensors, params, grads, vars):
61
+ torch._foreach_sqrt_(tensors)
62
+ return tensors
63
+
64
+ class Reciprocal(Transform):
65
+ def __init__(self, eps = 0, target: "Target" = 'update'):
66
+ defaults = dict(eps = eps)
67
+ super().__init__(defaults, uses_grad=False, target=target)
68
+ @torch.no_grad
69
+ def transform(self, tensors, params, grads, vars):
70
+ eps = self.get_settings('eps', params=params)
71
+ if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
72
+ torch._foreach_reciprocal_(tensors)
73
+ return tensors
74
+
75
+ class Negate(Transform):
76
+ def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
77
+ @torch.no_grad
78
+ def transform(self, tensors, params, grads, vars):
79
+ torch._foreach_neg_(tensors)
80
+ return tensors
81
+
82
+
83
+ class NanToNum(Transform):
84
+ """Convert `nan`, `inf` and `-inf` to numbers.
85
+
86
+ Args:
87
+ nan (optional): the value to replace NaNs with. Default is zero.
88
+ posinf (optional): if a Number, the value to replace positive infinity values with.
89
+ If None, positive infinity values are replaced with the greatest finite value
90
+ representable by input's dtype. Default is None.
91
+ neginf (optional): if a Number, the value to replace negative infinity values with.
92
+ If None, negative infinity values are replaced with the lowest finite value
93
+ representable by input's dtype. Default is None.
94
+ """
95
+ def __init__(self, nan=None, posinf=None, neginf=None, target: "Target" = 'update'):
96
+ defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
97
+ super().__init__(defaults, uses_grad=False, target=target)
98
+
99
+ @torch.no_grad
100
+ def transform(self, tensors, params, grads, vars):
101
+ nan, posinf, neginf = self.get_settings('nan', 'posinf', 'neginf', params=params)
102
+ return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
103
+
104
+ class Rescale(Transform):
105
+ """rescale update to (min, max) range"""
106
+ def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8, target: "Target" = 'update'):
107
+ defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
108
+ super().__init__(defaults, uses_grad=False, target=target)
109
+
110
+ @torch.no_grad
111
+ def transform(self, tensors, params, grads, vars):
112
+ min,max = self.get_settings('min','max', params=params)
113
+ tensorwise = self.settings[params[0]]['tensorwise']
114
+ dim = None if tensorwise else 'global'
115
+ return TensorList(tensors).rescale(min=min, max=max, eps=self.settings[params[0]]['eps'], dim=dim)
@@ -0,0 +1,112 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Module, Target, Transform
6
+ from ...utils.tensorlist import Distributions, TensorList
7
+
8
+
9
+ class Clone(Transform):
10
+ def __init__(self): super().__init__({}, uses_grad=False)
11
+ @torch.no_grad
12
+ def transform(self, tensors, params, grads, vars): return [t.clone() for t in tensors]
13
+
14
+ class Grad(Module):
15
+ def __init__(self):
16
+ super().__init__({})
17
+ @torch.no_grad
18
+ def step(self, vars):
19
+ vars.update = [g.clone() for g in vars.get_grad()]
20
+ return vars
21
+
22
+ class Params(Module):
23
+ def __init__(self):
24
+ super().__init__({})
25
+ @torch.no_grad
26
+ def step(self, vars):
27
+ vars.update = [p.clone() for p in vars.params]
28
+ return vars
29
+
30
+ class Update(Module):
31
+ def __init__(self):
32
+ super().__init__({})
33
+ @torch.no_grad
34
+ def step(self, vars):
35
+ vars.update = [u.clone() for u in vars.get_update()]
36
+ return vars
37
+
38
+ class Zeros(Module):
39
+ def __init__(self):
40
+ super().__init__({})
41
+ @torch.no_grad
42
+ def step(self, vars):
43
+ vars.update = [torch.zeros_like(p) for p in vars.params]
44
+ return vars
45
+
46
+ class Ones(Module):
47
+ def __init__(self):
48
+ super().__init__({})
49
+ @torch.no_grad
50
+ def step(self, vars):
51
+ vars.update = [torch.ones_like(p) for p in vars.params]
52
+ return vars
53
+
54
+ class Fill(Module):
55
+ def __init__(self, value: float):
56
+ defaults = dict(value=value)
57
+ super().__init__(defaults)
58
+
59
+ @torch.no_grad
60
+ def step(self, vars):
61
+ vars.update = [torch.full_like(p, self.settings[p]['value']) for p in vars.params]
62
+ return vars
63
+
64
+ class RandomSample(Module):
65
+ def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
66
+ defaults = dict(eps=eps, distribution=distribution)
67
+ super().__init__(defaults)
68
+
69
+ @torch.no_grad
70
+ def step(self, vars):
71
+ vars.update = TensorList(vars.params).sample_like(
72
+ eps=self.get_settings('eps',params=vars.params), distribution=self.settings[vars.params[0]]['distribution']
73
+ )
74
+ return vars
75
+
76
+ class Randn(Module):
77
+ def __init__(self):
78
+ super().__init__({})
79
+
80
+ @torch.no_grad
81
+ def step(self, vars):
82
+ vars.update = [torch.randn_like(p) for p in vars.params]
83
+ return vars
84
+
85
+ class Uniform(Module):
86
+ def __init__(self, low: float, high: float):
87
+ defaults = dict(low=low, high=high)
88
+ super().__init__(defaults)
89
+
90
+ @torch.no_grad
91
+ def step(self, vars):
92
+ low,high = self.get_settings('low','high', params=vars.params)
93
+ vars.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(vars.params, low, high)]
94
+ return vars
95
+
96
+ class GradToNone(Module):
97
+ def __init__(self): super().__init__()
98
+ def step(self, vars):
99
+ vars.grad = None
100
+ return vars
101
+
102
+ class UpdateToNone(Module):
103
+ def __init__(self): super().__init__()
104
+ def step(self, vars):
105
+ vars.update = None
106
+ return vars
107
+
108
+ class Identity(Module):
109
+ def __init__(self, *args, **kwargs): super().__init__()
110
+ def step(self, vars): return vars
111
+
112
+ NoOp = Identity
@@ -1,10 +1,18 @@
1
- r"""
2
- This include various optimizers as composable modules.
3
- """
4
- # from .adam import Adam
5
- from .sgd import SGD
6
- from .rprop import Rprop
7
- from .rmsprop import RMSProp
8
- from .adagrad import Adagrad
9
- from .adam import Adam
10
- from .lion import Lion
1
+ from .adagrad import Adagrad, FullMatrixAdagrad
2
+ from .adam import Adam
3
+ from .lion import Lion
4
+ from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
5
+ from .rmsprop import RMSprop
6
+ from .rprop import (
7
+ BacktrackOnSignChange,
8
+ Rprop,
9
+ ScaleLRBySignChange,
10
+ SignConsistencyLRs,
11
+ SignConsistencyMask,
12
+ )
13
+ from .shampoo import Shampoo
14
+ from .soap import SOAP
15
+ from .orthograd import OrthoGrad, orthograd_
16
+ from .sophia_h import SophiaH
17
+ # from .curveball import CurveBall
18
+ # from .spectral import SpectralPreconditioner