torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  #pyright: reportIncompatibleMethodOverride=false
2
- """"""
3
2
  from abc import ABC, abstractmethod
4
3
  from collections.abc import Iterable, Sequence
5
4
  from operator import itemgetter
@@ -7,11 +6,11 @@ from typing import Any
7
6
 
8
7
  import torch
9
8
 
10
- from ...core import Chainable, Module, Target, Vars, maybe_chain
9
+ from ...core import Chainable, Module, Target, Var, maybe_chain
11
10
  from ...utils import TensorList, tensorlist
12
11
 
13
12
 
14
- class BinaryOperation(Module, ABC):
13
+ class BinaryOperationBase(Module, ABC):
15
14
  """Base class for operations that use update as the first operand. This is an abstract class, subclass it and override `transform` method to use it."""
16
15
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
16
  super().__init__(defaults=defaults)
@@ -26,211 +25,258 @@ class BinaryOperation(Module, ABC):
26
25
  self.operands[k] = v
27
26
 
28
27
  @abstractmethod
29
- def transform(self, vars: Vars, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
28
+ def transform(self, var: Var, update: list[torch.Tensor], **operands: Any | list[torch.Tensor]) -> Iterable[torch.Tensor]:
30
29
  """applies the operation to operands"""
31
30
  raise NotImplementedError
32
31
 
33
32
  @torch.no_grad
34
- def step(self, vars: Vars) -> Vars:
33
+ def step(self, var: Var) -> Var:
35
34
  # pass cloned update to all module operands
36
35
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
37
36
 
38
37
  for k,v in self.operands.items():
39
38
  if k in self.children:
40
39
  v: Module
41
- updated_vars = v.step(vars.clone(clone_update=True))
42
- processed_operands[k] = updated_vars.get_update()
43
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
40
+ updated_var = v.step(var.clone(clone_update=True))
41
+ processed_operands[k] = updated_var.get_update()
42
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
44
43
 
45
- transformed = self.transform(vars, update=vars.get_update(), **processed_operands)
46
- vars.update = list(transformed)
47
- return vars
44
+ transformed = self.transform(var, update=var.get_update(), **processed_operands)
45
+ var.update = list(transformed)
46
+ return var
48
47
 
49
48
 
50
- class Add(BinaryOperation):
49
+ class Add(BinaryOperationBase):
50
+ """Add :code:`other` to tensors. :code:`other` can be a number or a module.
51
+
52
+ If :code:`other` is a module, this calculates :code:`tensors + other(tensors)`
53
+ """
51
54
  def __init__(self, other: Chainable | float, alpha: float = 1):
52
55
  defaults = dict(alpha=alpha)
53
56
  super().__init__(defaults, other=other)
54
57
 
55
58
  @torch.no_grad
56
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
57
- if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[vars.params[0]]['alpha'])
58
- else: torch._foreach_add_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
59
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
60
+ if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
61
+ else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
59
62
  return update
60
63
 
61
- class Sub(BinaryOperation):
64
+ class Sub(BinaryOperationBase):
65
+ """Subtract :code:`other` from tensors. :code:`other` can be a number or a module.
66
+
67
+ If :code:`other` is a module, this calculates :code:`tensors - other(tensors)`
68
+ """
62
69
  def __init__(self, other: Chainable | float, alpha: float = 1):
63
70
  defaults = dict(alpha=alpha)
64
71
  super().__init__(defaults, other=other)
65
72
 
66
73
  @torch.no_grad
67
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
68
- if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[vars.params[0]]['alpha'])
69
- else: torch._foreach_sub_(update, other, alpha=self.settings[vars.params[0]]['alpha'])
74
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
75
+ if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
76
+ else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
70
77
  return update
71
78
 
72
- class RSub(BinaryOperation):
79
+ class RSub(BinaryOperationBase):
80
+ """Subtract tensors from :code:`other`. :code:`other` can be a number or a module.
81
+
82
+ If :code:`other` is a module, this calculates :code:`other(tensors) - tensors`
83
+ """
73
84
  def __init__(self, other: Chainable | float):
74
85
  super().__init__({}, other=other)
75
86
 
76
87
  @torch.no_grad
77
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
88
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
78
89
  return other - TensorList(update)
79
90
 
80
- class Mul(BinaryOperation):
91
+ class Mul(BinaryOperationBase):
92
+ """Multiply tensors by :code:`other`. :code:`other` can be a number or a module.
93
+
94
+ If :code:`other` is a module, this calculates :code:`tensors * other(tensors)`
95
+ """
81
96
  def __init__(self, other: Chainable | float):
82
97
  super().__init__({}, other=other)
83
98
 
84
99
  @torch.no_grad
85
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
100
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
86
101
  torch._foreach_mul_(update, other)
87
102
  return update
88
103
 
89
- class Div(BinaryOperation):
104
+ class Div(BinaryOperationBase):
105
+ """Divide tensors by :code:`other`. :code:`other` can be a number or a module.
106
+
107
+ If :code:`other` is a module, this calculates :code:`tensors / other(tensors)`
108
+ """
90
109
  def __init__(self, other: Chainable | float):
91
110
  super().__init__({}, other=other)
92
111
 
93
112
  @torch.no_grad
94
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
113
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
95
114
  torch._foreach_div_(update, other)
96
115
  return update
97
116
 
98
- class RDiv(BinaryOperation):
117
+ class RDiv(BinaryOperationBase):
118
+ """Divide :code:`other` by tensors. :code:`other` can be a number or a module.
119
+
120
+ If :code:`other` is a module, this calculates :code:`other(tensors) / tensors`
121
+ """
99
122
  def __init__(self, other: Chainable | float):
100
123
  super().__init__({}, other=other)
101
124
 
102
125
  @torch.no_grad
103
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
126
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
104
127
  return other / TensorList(update)
105
128
 
106
- class Pow(BinaryOperation):
129
+ class Pow(BinaryOperationBase):
130
+ """Take tensors to the power of :code:`exponent`. :code:`exponent` can be a number or a module.
131
+
132
+ If :code:`exponent` is a module, this calculates :code:`tensors ^ exponent(tensors)`
133
+ """
107
134
  def __init__(self, exponent: Chainable | float):
108
135
  super().__init__({}, exponent=exponent)
109
136
 
110
137
  @torch.no_grad
111
- def transform(self, vars, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
138
+ def transform(self, var, update: list[torch.Tensor], exponent: float | list[torch.Tensor]):
112
139
  torch._foreach_pow_(update, exponent)
113
140
  return update
114
141
 
115
- class RPow(BinaryOperation):
142
+ class RPow(BinaryOperationBase):
143
+ """Take :code:`other` to the power of tensors. :code:`other` can be a number or a module.
144
+
145
+ If :code:`other` is a module, this calculates :code:`other(tensors) ^ tensors`
146
+ """
116
147
  def __init__(self, other: Chainable | float):
117
148
  super().__init__({}, other=other)
118
149
 
119
150
  @torch.no_grad
120
- def transform(self, vars, update: list[torch.Tensor], other: float | list[torch.Tensor]):
151
+ def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
121
152
  if isinstance(other, (int, float)): return torch._foreach_pow(other, update) # no in-place
122
153
  torch._foreach_pow_(other, update)
123
154
  return other
124
155
 
125
- class Lerp(BinaryOperation):
156
+ class Lerp(BinaryOperationBase):
157
+ """Does a linear interpolation of tensors and :code:`end` module based on a scalar :code:`weight`.
158
+
159
+ The output is given by :code:`output = tensors + weight * (end(tensors) - tensors)`
160
+ """
126
161
  def __init__(self, end: Chainable, weight: float):
127
162
  defaults = dict(weight=weight)
128
163
  super().__init__(defaults, end=end)
129
164
 
130
165
  @torch.no_grad
131
- def transform(self, vars, update: list[torch.Tensor], end: list[torch.Tensor]):
132
- torch._foreach_lerp_(update, end, weight=self.get_settings('weight',params=vars))
166
+ def transform(self, var, update: list[torch.Tensor], end: list[torch.Tensor]):
167
+ torch._foreach_lerp_(update, end, weight=self.get_settings(var.params, 'weight'))
133
168
  return update
134
169
 
135
- class CopySign(BinaryOperation):
170
+ class CopySign(BinaryOperationBase):
171
+ """Returns tensors with sign copied from :code:`other(tensors)`."""
136
172
  def __init__(self, other: Chainable):
137
173
  super().__init__({}, other=other)
138
174
 
139
175
  @torch.no_grad
140
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
176
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
141
177
  return [u.copysign_(o) for u, o in zip(update, other)]
142
178
 
143
- class RCopySign(BinaryOperation):
179
+ class RCopySign(BinaryOperationBase):
180
+ """Returns :code:`other(tensors)` with sign copied from tensors."""
144
181
  def __init__(self, other: Chainable):
145
182
  super().__init__({}, other=other)
146
183
 
147
184
  @torch.no_grad
148
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
185
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
149
186
  return [o.copysign_(u) for u, o in zip(update, other)]
150
187
  CopyMagnitude = RCopySign
151
188
 
152
- class Clip(BinaryOperation):
189
+ class Clip(BinaryOperationBase):
190
+ """clip tensors to be in :code:`(min, max)` range. :code:`min` and :code:`max: can be None, numbers or modules.
191
+
192
+ If code:`min` and :code:`max`: are modules, this calculates :code:`tensors.clip(min(tensors), max(tensors))`.
193
+ """
153
194
  def __init__(self, min: float | Chainable | None = None, max: float | Chainable | None = None):
154
195
  super().__init__({}, min=min, max=max)
155
196
 
156
197
  @torch.no_grad
157
- def transform(self, vars, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
198
+ def transform(self, var, update: list[torch.Tensor], min: float | list[torch.Tensor] | None, max: float | list[torch.Tensor] | None):
158
199
  return TensorList(update).clamp_(min=min, max=max)
159
200
 
160
- class MirroredClip(BinaryOperation):
161
- """clip by -value, value"""
201
+ class MirroredClip(BinaryOperationBase):
202
+ """clip tensors to be in :code:`(-value, value)` range. :code:`value` can be a number or a module.
203
+
204
+ If :code:`value` is a module, this calculates :code:`tensors.clip(-value(tensors), value(tensors))`
205
+ """
162
206
  def __init__(self, value: float | Chainable):
163
207
  super().__init__({}, value=value)
164
208
 
165
209
  @torch.no_grad
166
- def transform(self, vars, update: list[torch.Tensor], value: float | list[torch.Tensor]):
210
+ def transform(self, var, update: list[torch.Tensor], value: float | list[torch.Tensor]):
167
211
  min = -value if isinstance(value, (int,float)) else [-v for v in value]
168
212
  return TensorList(update).clamp_(min=min, max=value)
169
213
 
170
- class Graft(BinaryOperation):
171
- """use direction from update and magnitude from `magnitude` module"""
214
+ class Graft(BinaryOperationBase):
215
+ """Outputs tensors rescaled to have the same norm as :code:`magnitude(tensors)`."""
172
216
  def __init__(self, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
173
217
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
174
218
  super().__init__(defaults, magnitude=magnitude)
175
219
 
176
220
  @torch.no_grad
177
- def transform(self, vars, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
178
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
221
+ def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
222
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
179
223
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
180
224
 
181
- class RGraft(BinaryOperation):
182
- """use direction from `direction` module and magnitude from update"""
225
+ class RGraft(BinaryOperationBase):
226
+ """Outputs :code:`magnitude(tensors)` rescaled to have the same norm as tensors"""
183
227
 
184
228
  def __init__(self, direction: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6):
185
229
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
186
230
  super().__init__(defaults, direction=direction)
187
231
 
188
232
  @torch.no_grad
189
- def transform(self, vars, update: list[torch.Tensor], direction: list[torch.Tensor]):
190
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[vars.params[0]])
233
+ def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
234
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
191
235
  return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
192
236
 
193
237
  GraftToUpdate = RGraft
194
238
 
195
- class Maximum(BinaryOperation):
239
+ class Maximum(BinaryOperationBase):
240
+ """Outputs :code:`maximum(tensors, other(tensors))`"""
196
241
  def __init__(self, other: Chainable):
197
242
  super().__init__({}, other=other)
198
243
 
199
244
  @torch.no_grad
200
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
245
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
201
246
  torch._foreach_maximum_(update, other)
202
247
  return update
203
248
 
204
- class Minimum(BinaryOperation):
249
+ class Minimum(BinaryOperationBase):
250
+ """Outputs :code:`minimum(tensors, other(tensors))`"""
205
251
  def __init__(self, other: Chainable):
206
252
  super().__init__({}, other=other)
207
253
 
208
254
  @torch.no_grad
209
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
255
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
210
256
  torch._foreach_minimum_(update, other)
211
257
  return update
212
258
 
213
259
 
214
- class GramSchimdt(BinaryOperation):
215
- """makes update orthonormal to `other`"""
260
+ class GramSchimdt(BinaryOperationBase):
261
+ """outputs tensors made orthogonal to `other(tensors)` via Gram-Schmidt."""
216
262
  def __init__(self, other: Chainable):
217
263
  super().__init__({}, other=other)
218
264
 
219
265
  @torch.no_grad
220
- def transform(self, vars, update: list[torch.Tensor], other: list[torch.Tensor]):
266
+ def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
221
267
  update = TensorList(update); other = TensorList(other)
222
268
  return update - (other*update) / ((other*other) + 1e-8)
223
269
 
224
270
 
225
- class Threshold(BinaryOperation):
226
- """update above/below threshold, value at and below"""
271
+ class Threshold(BinaryOperationBase):
272
+ """Outputs tensors thresholded such that values above :code:`threshold` are set to :code:`value`."""
227
273
  def __init__(self, threshold: Chainable | float, value: Chainable | float, update_above: bool):
228
274
  defaults = dict(update_above=update_above)
229
275
  super().__init__(defaults, threshold=threshold, value=value)
230
276
 
231
277
  @torch.no_grad
232
- def transform(self, vars, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
233
- update_above = self.settings[vars.params[0]]['update_above']
278
+ def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
279
+ update_above = self.settings[var.params[0]]['update_above']
234
280
  update = TensorList(update)
235
281
  if update_above:
236
282
  if isinstance(value, list): return update.where_(update>threshold, value)
@@ -3,15 +3,15 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from collections.abc import Iterable, Sequence
5
5
  from operator import itemgetter
6
- from typing import Any
6
+ from typing import Any, Literal
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, Target, Vars, maybe_chain
10
+ from ...core import Chainable, Module, Target, Var, maybe_chain
11
11
  from ...utils import TensorList, tensorlist
12
12
 
13
13
 
14
- class MultiOperation(Module, ABC):
14
+ class MultiOperationBase(Module, ABC):
15
15
  """Base class for operations that use operands. This is an abstract class, subclass it and override `transform` method to use it."""
16
16
  def __init__(self, defaults: dict[str, Any] | None, **operands: Chainable | Any):
17
17
  super().__init__(defaults=defaults)
@@ -29,36 +29,37 @@ class MultiOperation(Module, ABC):
29
29
  raise ValueError('At least one operand must be a module')
30
30
 
31
31
  @abstractmethod
32
- def transform(self, vars: Vars, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
32
+ def transform(self, var: Var, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
33
33
  """applies the operation to operands"""
34
34
  raise NotImplementedError
35
35
 
36
36
  @torch.no_grad
37
- def step(self, vars: Vars) -> Vars:
37
+ def step(self, var: Var) -> Var:
38
38
  # pass cloned update to all module operands
39
39
  processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
40
40
 
41
41
  for k,v in self.operands.items():
42
42
  if k in self.children:
43
43
  v: Module
44
- updated_vars = v.step(vars.clone(clone_update=True))
45
- processed_operands[k] = updated_vars.get_update()
46
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
44
+ updated_var = v.step(var.clone(clone_update=True))
45
+ processed_operands[k] = updated_var.get_update()
46
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
47
47
 
48
- transformed = self.transform(vars, **processed_operands)
49
- vars.update = transformed
50
- return vars
48
+ transformed = self.transform(var, **processed_operands)
49
+ var.update = transformed
50
+ return var
51
51
 
52
52
 
53
53
 
54
- class SubModules(MultiOperation):
54
+ class SubModules(MultiOperationBase):
55
+ """Calculates :code:`input - other`. :code:`input` and :code:`other` can be numbers or modules."""
55
56
  def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
56
57
  defaults = dict(alpha=alpha)
57
58
  super().__init__(defaults, input=input, other=other)
58
59
 
59
60
  @torch.no_grad
60
- def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
61
- alpha = self.settings[vars.params[0]]['alpha']
61
+ def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
62
+ alpha = self.settings[var.params[0]]['alpha']
62
63
 
63
64
  if isinstance(input, (int,float)):
64
65
  assert isinstance(other, list)
@@ -68,13 +69,15 @@ class SubModules(MultiOperation):
68
69
  else: torch._foreach_sub_(input, other, alpha=alpha)
69
70
  return input
70
71
 
71
- class DivModules(MultiOperation):
72
- def __init__(self, input: Chainable | float, other: Chainable | float):
72
+ class DivModules(MultiOperationBase):
73
+ """Calculates :code:`input / other`. :code:`input` and :code:`other` can be numbers or modules."""
74
+ def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
73
75
  defaults = {}
74
- super().__init__(defaults, input=input, other=other)
76
+ if other_first: super().__init__(defaults, other=other, input=input)
77
+ else: super().__init__(defaults, input=input, other=other)
75
78
 
76
79
  @torch.no_grad
77
- def transform(self, vars: Vars, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
80
+ def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
78
81
  if isinstance(input, (int,float)):
79
82
  assert isinstance(other, list)
80
83
  return input / TensorList(other)
@@ -82,13 +85,15 @@ class DivModules(MultiOperation):
82
85
  torch._foreach_div_(input, other)
83
86
  return input
84
87
 
85
- class PowModules(MultiOperation):
88
+
89
+ class PowModules(MultiOperationBase):
90
+ """Calculates :code:`input ** exponent`. :code:`input` and :code:`other` can be numbers or modules."""
86
91
  def __init__(self, input: Chainable | float, exponent: Chainable | float):
87
92
  defaults = {}
88
93
  super().__init__(defaults, input=input, exponent=exponent)
89
94
 
90
95
  @torch.no_grad
91
- def transform(self, vars: Vars, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
96
+ def transform(self, var: Var, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
92
97
  if isinstance(input, (int,float)):
93
98
  assert isinstance(exponent, list)
94
99
  return input ** TensorList(exponent)
@@ -96,42 +101,98 @@ class PowModules(MultiOperation):
96
101
  torch._foreach_div_(input, exponent)
97
102
  return input
98
103
 
99
- class LerpModules(MultiOperation):
104
+ class LerpModules(MultiOperationBase):
105
+ """Does a linear interpolation of :code:`input(tensors)` and :code:`end(tensors)` based on a scalar :code:`weight`.
106
+
107
+ The output is given by :code:`output = input(tensors) + weight * (end(tensors) - input(tensors))`
108
+ """
100
109
  def __init__(self, input: Chainable, end: Chainable, weight: float):
101
110
  defaults = dict(weight=weight)
102
111
  super().__init__(defaults, input=input, end=end)
103
112
 
104
113
  @torch.no_grad
105
- def transform(self, vars: Vars, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
106
- torch._foreach_lerp_(input, end, weight=self.settings[vars.params[0]]['weight'])
114
+ def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
115
+ torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
107
116
  return input
108
117
 
109
- class ClipModules(MultiOperation):
118
+ class ClipModules(MultiOperationBase):
119
+ """Calculates :code:`input(tensors).clip(min, max)`. :code:`min` and :code:`max` can be numbers or modules."""
110
120
  def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
111
121
  defaults = {}
112
122
  super().__init__(defaults, input=input, min=min, max=max)
113
123
 
114
124
  @torch.no_grad
115
- def transform(self, vars: Vars, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
125
+ def transform(self, var: Var, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
116
126
  return TensorList(input).clamp_(min=min, max=max)
117
127
 
118
128
 
119
- class GraftModules(MultiOperation):
129
+ class GraftModules(MultiOperationBase):
130
+ """Outputs :code:`direction` output rescaled to have the same norm as :code:`magnitude` output.
131
+
132
+ Args:
133
+ direction (Chainable): module to use the direction from
134
+ magnitude (Chainable): module to use the magnitude from
135
+ tensorwise (bool, optional): whether to calculate norm per-tensor or globally. Defaults to True.
136
+ ord (float, optional): norm order. Defaults to 2.
137
+ eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
138
+ strength (float, optional): strength of grafting. Defaults to 1.
139
+
140
+ Example:
141
+ Shampoo grafted to Adam
142
+
143
+ .. code-block:: python
144
+
145
+ opt = tz.Modular(
146
+ model.parameters(),
147
+ tz.m.GraftModules(
148
+ direction = tz.m.Shampoo(),
149
+ magnitude = tz.m.Adam(),
150
+ ),
151
+ tz.m.LR(1e-3)
152
+ )
153
+
154
+ Reference:
155
+ Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
156
+ """
120
157
  def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
121
158
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
122
159
  super().__init__(defaults, direction=direction, magnitude=magnitude)
123
160
 
124
161
  @torch.no_grad
125
- def transform(self, vars, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
126
- tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[vars.params[0]])
162
+ def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
163
+ tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
127
164
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
128
165
 
129
-
130
- class Where(MultiOperation):
131
- def __init__(self, condition: Chainable, input: Chainable | float, other: Chainable | float):
132
- super().__init__({}, condition=condition, input=input, other=other)
166
+ class MultiplyByModuleNorm(MultiOperationBase):
167
+ """Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
168
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
169
+ defaults = dict(tensorwise=tensorwise, ord=ord)
170
+ super().__init__(defaults, input=input, norm=norm)
133
171
 
134
172
  @torch.no_grad
135
- def transform(self, vars, condition: list[torch.Tensor], input: list[torch.Tensor] | float, other: list[torch.Tensor] | float):
136
- return tensorlist.where(TensorList(condition).as_bool(), input, other)
173
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
174
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
175
+ if tensorwise:
176
+ if ord == 'mean_abs': n = [t.mean() for t in torch._foreach_abs(norm)]
177
+ else: n = torch._foreach_norm(norm, ord)
178
+ else: n = TensorList(norm).global_vector_norm(ord)
179
+
180
+ torch._foreach_mul_(input, n)
181
+ return input
182
+
183
+ class DivideByModuleNorm(MultiOperationBase):
184
+ """Outputs :code:`input` divided by norm of the :code:`norm` output."""
185
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
186
+ defaults = dict(tensorwise=tensorwise, ord=ord)
187
+ super().__init__(defaults, input=input, norm=norm)
137
188
 
189
+ @torch.no_grad
190
+ def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
191
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
192
+ if tensorwise:
193
+ if ord == 'mean_abs': n = [t.mean().clip(min=1e-8) for t in torch._foreach_abs(norm)]
194
+ else: n = torch._foreach_clamp_min(torch._foreach_norm(norm, ord), 1e-8)
195
+ else: n = TensorList(norm).global_vector_norm(ord).clip(min=1e-8)
196
+
197
+ torch._foreach_div_(input, n)
198
+ return input