torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/modules/ops/multi.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import Any, Literal
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
|
-
from ...core import Chainable, Module,
|
|
11
|
-
from ...utils import TensorList,
|
|
10
|
+
from ...core import Chainable, Module, Objective
|
|
11
|
+
from ...utils import TensorList, Metrics
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class MultiOperationBase(Module, ABC):
|
|
@@ -29,36 +29,39 @@ class MultiOperationBase(Module, ABC):
|
|
|
29
29
|
raise ValueError('At least one operand must be a module')
|
|
30
30
|
|
|
31
31
|
@abstractmethod
|
|
32
|
-
def transform(self,
|
|
32
|
+
def transform(self, objective: Objective, **operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
33
33
|
"""applies the operation to operands"""
|
|
34
34
|
raise NotImplementedError
|
|
35
35
|
|
|
36
|
+
def update(self, objective): raise RuntimeError
|
|
37
|
+
def apply(self, objective): raise RuntimeError
|
|
38
|
+
|
|
36
39
|
@torch.no_grad
|
|
37
|
-
def step(self,
|
|
40
|
+
def step(self, objective: Objective) -> Objective:
|
|
38
41
|
# pass cloned update to all module operands
|
|
39
42
|
processed_operands: dict[str, Any | list[torch.Tensor]] = self.operands.copy()
|
|
40
43
|
|
|
41
44
|
for k,v in self.operands.items():
|
|
42
45
|
if k in self.children:
|
|
43
46
|
v: Module
|
|
44
|
-
|
|
45
|
-
processed_operands[k] =
|
|
46
|
-
|
|
47
|
+
updated_obj = v.step(objective.clone(clone_updates=True))
|
|
48
|
+
processed_operands[k] = updated_obj.get_updates()
|
|
49
|
+
objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
|
|
47
50
|
|
|
48
|
-
transformed = self.transform(
|
|
49
|
-
|
|
50
|
-
return
|
|
51
|
+
transformed = self.transform(objective, **processed_operands)
|
|
52
|
+
objective.updates = transformed
|
|
53
|
+
return objective
|
|
51
54
|
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
class SubModules(MultiOperationBase):
|
|
55
|
-
"""Calculates
|
|
58
|
+
"""Calculates ``input - other``. ``input`` and ``other`` can be numbers or modules."""
|
|
56
59
|
def __init__(self, input: Chainable | float, other: Chainable | float, alpha: float = 1):
|
|
57
60
|
defaults = dict(alpha=alpha)
|
|
58
61
|
super().__init__(defaults, input=input, other=other)
|
|
59
62
|
|
|
60
63
|
@torch.no_grad
|
|
61
|
-
def transform(self,
|
|
64
|
+
def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
62
65
|
alpha = self.defaults['alpha']
|
|
63
66
|
|
|
64
67
|
if isinstance(input, (int,float)):
|
|
@@ -70,14 +73,14 @@ class SubModules(MultiOperationBase):
|
|
|
70
73
|
return input
|
|
71
74
|
|
|
72
75
|
class DivModules(MultiOperationBase):
|
|
73
|
-
"""Calculates
|
|
76
|
+
"""Calculates ``input / other``. ``input`` and ``other`` can be numbers or modules."""
|
|
74
77
|
def __init__(self, input: Chainable | float, other: Chainable | float, other_first:bool=False):
|
|
75
78
|
defaults = {}
|
|
76
79
|
if other_first: super().__init__(defaults, other=other, input=input)
|
|
77
80
|
else: super().__init__(defaults, input=input, other=other)
|
|
78
81
|
|
|
79
82
|
@torch.no_grad
|
|
80
|
-
def transform(self,
|
|
83
|
+
def transform(self, objective: Objective, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
81
84
|
if isinstance(input, (int,float)):
|
|
82
85
|
assert isinstance(other, list)
|
|
83
86
|
return input / TensorList(other)
|
|
@@ -87,13 +90,13 @@ class DivModules(MultiOperationBase):
|
|
|
87
90
|
|
|
88
91
|
|
|
89
92
|
class PowModules(MultiOperationBase):
|
|
90
|
-
"""Calculates
|
|
93
|
+
"""Calculates ``input ** exponent``. ``input`` and ``other`` can be numbers or modules."""
|
|
91
94
|
def __init__(self, input: Chainable | float, exponent: Chainable | float):
|
|
92
95
|
defaults = {}
|
|
93
96
|
super().__init__(defaults, input=input, exponent=exponent)
|
|
94
97
|
|
|
95
98
|
@torch.no_grad
|
|
96
|
-
def transform(self,
|
|
99
|
+
def transform(self, objective: Objective, input: float | list[torch.Tensor], exponent: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
97
100
|
if isinstance(input, (int,float)):
|
|
98
101
|
assert isinstance(exponent, list)
|
|
99
102
|
return input ** TensorList(exponent)
|
|
@@ -102,32 +105,32 @@ class PowModules(MultiOperationBase):
|
|
|
102
105
|
return input
|
|
103
106
|
|
|
104
107
|
class LerpModules(MultiOperationBase):
|
|
105
|
-
"""Does a linear interpolation of
|
|
108
|
+
"""Does a linear interpolation of ``input(tensors)`` and ``end(tensors)`` based on a scalar ``weight``.
|
|
106
109
|
|
|
107
|
-
The output is given by
|
|
110
|
+
The output is given by ``output = input(tensors) + weight * (end(tensors) - input(tensors))``
|
|
108
111
|
"""
|
|
109
112
|
def __init__(self, input: Chainable, end: Chainable, weight: float):
|
|
110
113
|
defaults = dict(weight=weight)
|
|
111
114
|
super().__init__(defaults, input=input, end=end)
|
|
112
115
|
|
|
113
116
|
@torch.no_grad
|
|
114
|
-
def transform(self,
|
|
117
|
+
def transform(self, objective: Objective, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
115
118
|
torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
|
|
116
119
|
return input
|
|
117
120
|
|
|
118
121
|
class ClipModules(MultiOperationBase):
|
|
119
|
-
"""Calculates
|
|
122
|
+
"""Calculates ``input(tensors).clip(min, max)``. ``min`` and ``max`` can be numbers or modules."""
|
|
120
123
|
def __init__(self, input: Chainable, min: float | Chainable | None = None, max: float | Chainable | None = None):
|
|
121
124
|
defaults = {}
|
|
122
125
|
super().__init__(defaults, input=input, min=min, max=max)
|
|
123
126
|
|
|
124
127
|
@torch.no_grad
|
|
125
|
-
def transform(self,
|
|
128
|
+
def transform(self, objective: Objective, input: list[torch.Tensor], min: float | list[torch.Tensor], max: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
126
129
|
return TensorList(input).clamp_(min=min, max=max)
|
|
127
130
|
|
|
128
131
|
|
|
129
|
-
class
|
|
130
|
-
"""Outputs
|
|
132
|
+
class Graft(MultiOperationBase):
|
|
133
|
+
"""Outputs ``direction`` output rescaled to have the same norm as ``magnitude`` output.
|
|
131
134
|
|
|
132
135
|
Args:
|
|
133
136
|
direction (Chainable): module to use the direction from
|
|
@@ -137,40 +140,40 @@ class GraftModules(MultiOperationBase):
|
|
|
137
140
|
eps (float, optional): clips denominator to be no less than this value. Defaults to 1e-6.
|
|
138
141
|
strength (float, optional): strength of grafting. Defaults to 1.
|
|
139
142
|
|
|
140
|
-
Example:
|
|
141
|
-
Shampoo grafted to Adam
|
|
142
|
-
|
|
143
|
-
.. code-block:: python
|
|
143
|
+
### Example:
|
|
144
144
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
145
|
+
Shampoo grafted to Adam
|
|
146
|
+
```python
|
|
147
|
+
opt = tz.Optimizer(
|
|
148
|
+
model.parameters(),
|
|
149
|
+
tz.m.GraftModules(
|
|
150
|
+
direction = tz.m.Shampoo(),
|
|
151
|
+
magnitude = tz.m.Adam(),
|
|
152
|
+
),
|
|
153
|
+
tz.m.LR(1e-3)
|
|
154
|
+
)
|
|
155
|
+
```
|
|
153
156
|
|
|
154
157
|
Reference:
|
|
155
|
-
Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.
|
|
158
|
+
[Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803.](https://arxiv.org/pdf/2002.11803)
|
|
156
159
|
"""
|
|
157
160
|
def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
|
|
158
161
|
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
159
162
|
super().__init__(defaults, direction=direction, magnitude=magnitude)
|
|
160
163
|
|
|
161
164
|
@torch.no_grad
|
|
162
|
-
def transform(self,
|
|
165
|
+
def transform(self, objective, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
|
|
163
166
|
tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
|
|
164
167
|
return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
|
|
165
168
|
|
|
166
169
|
class MultiplyByModuleNorm(MultiOperationBase):
|
|
167
|
-
"""Outputs
|
|
170
|
+
"""Outputs ``input`` multiplied by norm of the ``norm`` output."""
|
|
168
171
|
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
169
172
|
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
170
173
|
super().__init__(defaults, input=input, norm=norm)
|
|
171
174
|
|
|
172
175
|
@torch.no_grad
|
|
173
|
-
def transform(self,
|
|
176
|
+
def transform(self, objective, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
174
177
|
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
175
178
|
if tensorwise:
|
|
176
179
|
n = TensorList(norm).metric(ord)
|
|
@@ -181,13 +184,13 @@ class MultiplyByModuleNorm(MultiOperationBase):
|
|
|
181
184
|
return input
|
|
182
185
|
|
|
183
186
|
class DivideByModuleNorm(MultiOperationBase):
|
|
184
|
-
"""Outputs
|
|
187
|
+
"""Outputs ``input`` divided by norm of the ``norm`` output."""
|
|
185
188
|
def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
|
|
186
189
|
defaults = dict(tensorwise=tensorwise, ord=ord)
|
|
187
190
|
super().__init__(defaults, input=input, norm=norm)
|
|
188
191
|
|
|
189
192
|
@torch.no_grad
|
|
190
|
-
def transform(self,
|
|
193
|
+
def transform(self, objective, input: list[torch.Tensor], norm:list[torch.Tensor]):
|
|
191
194
|
tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
|
|
192
195
|
if tensorwise:
|
|
193
196
|
n = TensorList(norm).metric(ord)
|
torchzero/modules/ops/reduce.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, cast
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ...core import Chainable, Module,
|
|
8
|
+
from ...core import Chainable, Module, Objective, maybe_chain
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class ReduceOperationBase(Module, ABC):
|
|
@@ -26,34 +26,37 @@ class ReduceOperationBase(Module, ABC):
|
|
|
26
26
|
raise ValueError('At least one operand must be a module')
|
|
27
27
|
|
|
28
28
|
@abstractmethod
|
|
29
|
-
def transform(self,
|
|
29
|
+
def transform(self, objective: Objective, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
30
30
|
"""applies the operation to operands"""
|
|
31
31
|
raise NotImplementedError
|
|
32
32
|
|
|
33
|
+
def update(self, objective): raise RuntimeError
|
|
34
|
+
def apply(self, objective): raise RuntimeError
|
|
35
|
+
|
|
33
36
|
@torch.no_grad
|
|
34
|
-
def step(self,
|
|
37
|
+
def step(self, objective: Objective) -> Objective:
|
|
35
38
|
# pass cloned update to all module operands
|
|
36
39
|
processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
|
|
37
40
|
|
|
38
41
|
for i, v in enumerate(self.operands):
|
|
39
42
|
if f'operand_{i}' in self.children:
|
|
40
43
|
v: Module
|
|
41
|
-
|
|
42
|
-
processed_operands[i] =
|
|
43
|
-
|
|
44
|
+
updated_obj = v.step(objective.clone(clone_updates=True))
|
|
45
|
+
processed_operands[i] = updated_obj.get_updates()
|
|
46
|
+
objective.update_attrs_from_clone_(updated_obj) # update loss, grad, etc if this module calculated them
|
|
44
47
|
|
|
45
|
-
transformed = self.transform(
|
|
46
|
-
|
|
47
|
-
return
|
|
48
|
+
transformed = self.transform(objective, *processed_operands)
|
|
49
|
+
objective.updates = transformed
|
|
50
|
+
return objective
|
|
48
51
|
|
|
49
52
|
class Sum(ReduceOperationBase):
|
|
50
|
-
"""Outputs sum of
|
|
53
|
+
"""Outputs sum of ``inputs`` that can be modules or numbers."""
|
|
51
54
|
USE_MEAN = False
|
|
52
55
|
def __init__(self, *inputs: Chainable | float):
|
|
53
56
|
super().__init__({}, *inputs)
|
|
54
57
|
|
|
55
58
|
@torch.no_grad
|
|
56
|
-
def transform(self,
|
|
59
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
57
60
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
58
61
|
sum = cast(list, sorted_inputs[0])
|
|
59
62
|
if len(sorted_inputs) > 1:
|
|
@@ -64,14 +67,14 @@ class Sum(ReduceOperationBase):
|
|
|
64
67
|
return sum
|
|
65
68
|
|
|
66
69
|
class Mean(Sum):
|
|
67
|
-
"""Outputs a mean of
|
|
70
|
+
"""Outputs a mean of ``inputs`` that can be modules or numbers."""
|
|
68
71
|
USE_MEAN = True
|
|
69
72
|
|
|
70
73
|
|
|
71
74
|
class WeightedSum(ReduceOperationBase):
|
|
75
|
+
"""Outputs a weighted sum of ``inputs`` that can be modules or numbers."""
|
|
72
76
|
USE_MEAN = False
|
|
73
77
|
def __init__(self, *inputs: Chainable | float, weights: Iterable[float]):
|
|
74
|
-
"""Outputs a weighted sum of :code:`inputs` that can be modules or numbers."""
|
|
75
78
|
weights = list(weights)
|
|
76
79
|
if len(inputs) != len(weights):
|
|
77
80
|
raise ValueError(f'Number of inputs {len(inputs)} must match number of weights {len(weights)}')
|
|
@@ -79,7 +82,7 @@ class WeightedSum(ReduceOperationBase):
|
|
|
79
82
|
super().__init__(defaults=defaults, *inputs)
|
|
80
83
|
|
|
81
84
|
@torch.no_grad
|
|
82
|
-
def transform(self,
|
|
85
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
83
86
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
84
87
|
weights = self.defaults['weights']
|
|
85
88
|
sum = cast(list, sorted_inputs[0])
|
|
@@ -94,16 +97,16 @@ class WeightedSum(ReduceOperationBase):
|
|
|
94
97
|
|
|
95
98
|
|
|
96
99
|
class WeightedMean(WeightedSum):
|
|
97
|
-
"""Outputs weighted mean of
|
|
100
|
+
"""Outputs weighted mean of ``inputs`` that can be modules or numbers."""
|
|
98
101
|
USE_MEAN = True
|
|
99
102
|
|
|
100
103
|
class Median(ReduceOperationBase):
|
|
101
|
-
"""Outputs median of
|
|
104
|
+
"""Outputs median of ``inputs`` that can be modules or numbers."""
|
|
102
105
|
def __init__(self, *inputs: Chainable | float):
|
|
103
106
|
super().__init__({}, *inputs)
|
|
104
107
|
|
|
105
108
|
@torch.no_grad
|
|
106
|
-
def transform(self,
|
|
109
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
107
110
|
res = []
|
|
108
111
|
lists = [i for i in inputs if isinstance(i, list)]
|
|
109
112
|
floats = [i for i in inputs if isinstance(i, (int,float))]
|
|
@@ -112,12 +115,12 @@ class Median(ReduceOperationBase):
|
|
|
112
115
|
return res
|
|
113
116
|
|
|
114
117
|
class Prod(ReduceOperationBase):
|
|
115
|
-
"""Outputs product of
|
|
118
|
+
"""Outputs product of ``inputs`` that can be modules or numbers."""
|
|
116
119
|
def __init__(self, *inputs: Chainable | float):
|
|
117
120
|
super().__init__({}, *inputs)
|
|
118
121
|
|
|
119
122
|
@torch.no_grad
|
|
120
|
-
def transform(self,
|
|
123
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
121
124
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
122
125
|
prod = cast(list, sorted_inputs[0])
|
|
123
126
|
if len(sorted_inputs) > 1:
|
|
@@ -127,12 +130,12 @@ class Prod(ReduceOperationBase):
|
|
|
127
130
|
return prod
|
|
128
131
|
|
|
129
132
|
class MaximumModules(ReduceOperationBase):
|
|
130
|
-
"""Outputs elementwise maximum of
|
|
133
|
+
"""Outputs elementwise maximum of ``inputs`` that can be modules or numbers."""
|
|
131
134
|
def __init__(self, *inputs: Chainable | float):
|
|
132
135
|
super().__init__({}, *inputs)
|
|
133
136
|
|
|
134
137
|
@torch.no_grad
|
|
135
|
-
def transform(self,
|
|
138
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
136
139
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
137
140
|
maximum = cast(list, sorted_inputs[0])
|
|
138
141
|
if len(sorted_inputs) > 1:
|
|
@@ -142,12 +145,12 @@ class MaximumModules(ReduceOperationBase):
|
|
|
142
145
|
return maximum
|
|
143
146
|
|
|
144
147
|
class MinimumModules(ReduceOperationBase):
|
|
145
|
-
"""Outputs elementwise minimum of
|
|
148
|
+
"""Outputs elementwise minimum of ``inputs`` that can be modules or numbers."""
|
|
146
149
|
def __init__(self, *inputs: Chainable | float):
|
|
147
150
|
super().__init__({}, *inputs)
|
|
148
151
|
|
|
149
152
|
@torch.no_grad
|
|
150
|
-
def transform(self,
|
|
153
|
+
def transform(self, objective: Objective, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
|
|
151
154
|
sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
|
|
152
155
|
minimum = cast(list, sorted_inputs[0])
|
|
153
156
|
if len(sorted_inputs) > 1:
|
torchzero/modules/ops/unary.py
CHANGED
|
@@ -2,102 +2,102 @@ from collections import deque
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import
|
|
5
|
+
from ...core import TensorTransform
|
|
6
6
|
from ...utils import TensorList, unpack_dicts,unpack_states
|
|
7
7
|
|
|
8
|
-
class UnaryLambda(
|
|
9
|
-
"""Applies
|
|
8
|
+
class UnaryLambda(TensorTransform):
|
|
9
|
+
"""Applies ``fn`` to input tensors.
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
``fn`` must accept and return a list of tensors.
|
|
12
12
|
"""
|
|
13
|
-
def __init__(self, fn
|
|
13
|
+
def __init__(self, fn):
|
|
14
14
|
defaults = dict(fn=fn)
|
|
15
|
-
super().__init__(defaults=defaults
|
|
15
|
+
super().__init__(defaults=defaults)
|
|
16
16
|
|
|
17
17
|
@torch.no_grad
|
|
18
|
-
def
|
|
18
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
19
19
|
return settings[0]['fn'](tensors)
|
|
20
20
|
|
|
21
|
-
class UnaryParameterwiseLambda(
|
|
22
|
-
"""Applies
|
|
21
|
+
class UnaryParameterwiseLambda(TensorTransform):
|
|
22
|
+
"""Applies ``fn`` to each input tensor.
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
``fn`` must accept and return a tensor.
|
|
25
25
|
"""
|
|
26
|
-
def __init__(self, fn
|
|
26
|
+
def __init__(self, fn):
|
|
27
27
|
defaults = dict(fn=fn)
|
|
28
|
-
super().__init__(
|
|
28
|
+
super().__init__(defaults=defaults)
|
|
29
29
|
|
|
30
30
|
@torch.no_grad
|
|
31
|
-
def
|
|
31
|
+
def single_tensor_apply(self, tensor, param, grad, loss, state, setting):
|
|
32
32
|
return setting['fn'](tensor)
|
|
33
33
|
|
|
34
|
-
class CustomUnaryOperation(
|
|
35
|
-
"""Applies
|
|
34
|
+
class CustomUnaryOperation(TensorTransform):
|
|
35
|
+
"""Applies ``getattr(tensor, name)`` to each tensor
|
|
36
36
|
"""
|
|
37
|
-
def __init__(self, name: str
|
|
37
|
+
def __init__(self, name: str):
|
|
38
38
|
defaults = dict(name=name)
|
|
39
|
-
super().__init__(defaults=defaults
|
|
39
|
+
super().__init__(defaults=defaults)
|
|
40
40
|
|
|
41
41
|
@torch.no_grad
|
|
42
|
-
def
|
|
42
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
43
43
|
return getattr(tensors, settings[0]['name'])()
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
class Abs(
|
|
47
|
-
"""Returns
|
|
48
|
-
def __init__(self
|
|
46
|
+
class Abs(TensorTransform):
|
|
47
|
+
"""Returns ``abs(input)``"""
|
|
48
|
+
def __init__(self): super().__init__()
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def
|
|
50
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
51
51
|
torch._foreach_abs_(tensors)
|
|
52
52
|
return tensors
|
|
53
53
|
|
|
54
|
-
class Sign(
|
|
55
|
-
"""Returns
|
|
56
|
-
def __init__(self
|
|
54
|
+
class Sign(TensorTransform):
|
|
55
|
+
"""Returns ``sign(input)``"""
|
|
56
|
+
def __init__(self): super().__init__()
|
|
57
57
|
@torch.no_grad
|
|
58
|
-
def
|
|
58
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
59
59
|
torch._foreach_sign_(tensors)
|
|
60
60
|
return tensors
|
|
61
61
|
|
|
62
|
-
class Exp(
|
|
63
|
-
"""Returns
|
|
64
|
-
def __init__(self
|
|
62
|
+
class Exp(TensorTransform):
|
|
63
|
+
"""Returns ``exp(input)``"""
|
|
64
|
+
def __init__(self): super().__init__()
|
|
65
65
|
@torch.no_grad
|
|
66
|
-
def
|
|
66
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
67
67
|
torch._foreach_exp_(tensors)
|
|
68
68
|
return tensors
|
|
69
69
|
|
|
70
|
-
class Sqrt(
|
|
71
|
-
"""Returns
|
|
72
|
-
def __init__(self
|
|
70
|
+
class Sqrt(TensorTransform):
|
|
71
|
+
"""Returns ``sqrt(input)``"""
|
|
72
|
+
def __init__(self): super().__init__()
|
|
73
73
|
@torch.no_grad
|
|
74
|
-
def
|
|
74
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
75
75
|
torch._foreach_sqrt_(tensors)
|
|
76
76
|
return tensors
|
|
77
77
|
|
|
78
|
-
class Reciprocal(
|
|
79
|
-
"""Returns
|
|
80
|
-
def __init__(self, eps = 0
|
|
78
|
+
class Reciprocal(TensorTransform):
|
|
79
|
+
"""Returns ``1 / input``"""
|
|
80
|
+
def __init__(self, eps = 0):
|
|
81
81
|
defaults = dict(eps = eps)
|
|
82
|
-
super().__init__(defaults
|
|
82
|
+
super().__init__(defaults)
|
|
83
83
|
@torch.no_grad
|
|
84
|
-
def
|
|
84
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
85
85
|
eps = [s['eps'] for s in settings]
|
|
86
86
|
if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
|
|
87
87
|
torch._foreach_reciprocal_(tensors)
|
|
88
88
|
return tensors
|
|
89
89
|
|
|
90
|
-
class Negate(
|
|
91
|
-
"""Returns
|
|
92
|
-
def __init__(self
|
|
90
|
+
class Negate(TensorTransform):
|
|
91
|
+
"""Returns ``- input``"""
|
|
92
|
+
def __init__(self): super().__init__()
|
|
93
93
|
@torch.no_grad
|
|
94
|
-
def
|
|
94
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
95
95
|
torch._foreach_neg_(tensors)
|
|
96
96
|
return tensors
|
|
97
97
|
|
|
98
98
|
|
|
99
|
-
class NanToNum(
|
|
100
|
-
"""Convert
|
|
99
|
+
class NanToNum(TensorTransform):
|
|
100
|
+
"""Convert ``nan``, ``inf`` and `-`inf`` to numbers.
|
|
101
101
|
|
|
102
102
|
Args:
|
|
103
103
|
nan (optional): the value to replace NaNs with. Default is zero.
|
|
@@ -108,23 +108,23 @@ class NanToNum(Transform):
|
|
|
108
108
|
If None, negative infinity values are replaced with the lowest finite value
|
|
109
109
|
representable by input's dtype. Default is None.
|
|
110
110
|
"""
|
|
111
|
-
def __init__(self, nan=None, posinf=None, neginf=None
|
|
111
|
+
def __init__(self, nan=None, posinf=None, neginf=None):
|
|
112
112
|
defaults = dict(nan=nan, posinf=posinf, neginf=neginf)
|
|
113
|
-
super().__init__(defaults
|
|
113
|
+
super().__init__(defaults)
|
|
114
114
|
|
|
115
115
|
@torch.no_grad
|
|
116
|
-
def
|
|
116
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
117
117
|
nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
|
|
118
118
|
return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
|
|
119
119
|
|
|
120
|
-
class Rescale(
|
|
121
|
-
"""Rescales input to
|
|
122
|
-
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8
|
|
120
|
+
class Rescale(TensorTransform):
|
|
121
|
+
"""Rescales input to ``(min, max)`` range"""
|
|
122
|
+
def __init__(self, min: float, max: float, tensorwise: bool = False, eps:float=1e-8):
|
|
123
123
|
defaults = dict(min=min, max=max, eps=eps, tensorwise=tensorwise)
|
|
124
|
-
super().__init__(defaults
|
|
124
|
+
super().__init__(defaults)
|
|
125
125
|
|
|
126
126
|
@torch.no_grad
|
|
127
|
-
def
|
|
127
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
128
128
|
min, max = unpack_dicts(settings, 'min','max')
|
|
129
129
|
tensorwise = settings[0]['tensorwise']
|
|
130
130
|
dim = None if tensorwise else 'global'
|