torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +43 -33
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +48 -52
- torchzero/core/module.py +130 -50
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +2 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +15 -12
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +1 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +32 -32
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +50 -48
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +10 -10
- torchzero/modules/quasi_newton/sg2.py +19 -19
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +49 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +57 -90
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +157 -177
- torchzero/modules/second_order/rsn.py +106 -96
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +10 -10
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +93 -69
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -2,49 +2,49 @@ from collections.abc import Iterable
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
-
from ...core import Chainable, Module,
|
|
5
|
+
from ...core import Chainable, Module, Objective
|
|
6
6
|
from ...utils import TensorList
|
|
7
7
|
|
|
8
|
-
def _sequential_step(self: Module,
|
|
9
|
-
params =
|
|
8
|
+
def _sequential_step(self: Module, objective: Objective, sequential: bool):
|
|
9
|
+
params = objective.params
|
|
10
10
|
steps = self.settings[params[0]]['steps']
|
|
11
11
|
|
|
12
|
-
if sequential: modules = self.get_children_sequence() * steps
|
|
12
|
+
if sequential: modules: list[Module] = self.get_children_sequence() * steps
|
|
13
13
|
else: modules = [self.children['module']] * steps
|
|
14
14
|
|
|
15
|
-
if
|
|
15
|
+
if objective.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
16
|
|
|
17
17
|
# store original params unless this is last module and can update params directly
|
|
18
18
|
params_before_steps = [p.clone() for p in params]
|
|
19
19
|
|
|
20
20
|
# first step - pass var as usual
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
objective = modules[0].step(objective)
|
|
22
|
+
new_objective = objective
|
|
23
23
|
|
|
24
24
|
# subsequent steps - update parameters and create new var
|
|
25
25
|
if len(modules) > 1:
|
|
26
26
|
for m in modules[1:]:
|
|
27
27
|
|
|
28
28
|
# update params
|
|
29
|
-
if (not
|
|
29
|
+
if (not new_objective.skip_update):
|
|
30
30
|
# if new_var.last_module_lrs is not None:
|
|
31
31
|
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
32
|
|
|
33
|
-
torch._foreach_sub_(params,
|
|
33
|
+
torch._foreach_sub_(params, new_objective.get_updates())
|
|
34
34
|
|
|
35
35
|
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
36
|
-
|
|
37
|
-
model=
|
|
36
|
+
new_objective = Objective(params=new_objective.params, closure=new_objective.closure,
|
|
37
|
+
model=new_objective.model, current_step=new_objective.current_step + 1)
|
|
38
38
|
|
|
39
39
|
# step
|
|
40
|
-
|
|
40
|
+
new_objective = m.step(new_objective)
|
|
41
41
|
|
|
42
42
|
# final parameter update
|
|
43
|
-
if (not
|
|
43
|
+
if (not new_objective.skip_update):
|
|
44
44
|
# if new_var.last_module_lrs is not None:
|
|
45
45
|
# torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
46
|
|
|
47
|
-
torch._foreach_sub_(params,
|
|
47
|
+
torch._foreach_sub_(params, new_objective.get_updates())
|
|
48
48
|
|
|
49
49
|
# if last module, update is applied so return new var
|
|
50
50
|
# if params_before_steps is None:
|
|
@@ -53,13 +53,13 @@ def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
|
53
53
|
# return new_var
|
|
54
54
|
|
|
55
55
|
# otherwise use parameter difference as update
|
|
56
|
-
|
|
56
|
+
objective.updates = list(torch._foreach_sub(params_before_steps, params))
|
|
57
57
|
for p, bef in zip(params, params_before_steps):
|
|
58
58
|
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
59
|
-
return
|
|
59
|
+
return objective
|
|
60
60
|
|
|
61
61
|
class Multistep(Module):
|
|
62
|
-
"""Performs
|
|
62
|
+
"""Performs ``steps`` inner steps with ``module`` per each step.
|
|
63
63
|
|
|
64
64
|
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
65
65
|
def __init__(self, module: Chainable, steps: int):
|
|
@@ -68,11 +68,11 @@ class Multistep(Module):
|
|
|
68
68
|
self.set_child('module', module)
|
|
69
69
|
|
|
70
70
|
@torch.no_grad
|
|
71
|
-
def
|
|
72
|
-
return _sequential_step(self,
|
|
71
|
+
def apply(self, objective):
|
|
72
|
+
return _sequential_step(self, objective, sequential=False)
|
|
73
73
|
|
|
74
74
|
class Sequential(Module):
|
|
75
|
-
"""On each step, this sequentially steps with
|
|
75
|
+
"""On each step, this sequentially steps with ``modules`` ``steps`` times.
|
|
76
76
|
|
|
77
77
|
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
78
78
|
def __init__(self, modules: Iterable[Chainable], steps: int=1):
|
|
@@ -81,28 +81,28 @@ class Sequential(Module):
|
|
|
81
81
|
self.set_children_sequence(modules)
|
|
82
82
|
|
|
83
83
|
@torch.no_grad
|
|
84
|
-
def
|
|
85
|
-
return _sequential_step(self,
|
|
84
|
+
def apply(self, objective):
|
|
85
|
+
return _sequential_step(self, objective, sequential=True)
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
class NegateOnLossIncrease(Module):
|
|
89
|
-
"""Uses an extra forward pass to evaluate loss at
|
|
90
|
-
if loss is larger than at
|
|
91
|
-
the update is set to 0 if
|
|
89
|
+
"""Uses an extra forward pass to evaluate loss at ``parameters+update``,
|
|
90
|
+
if loss is larger than at ``parameters``,
|
|
91
|
+
the update is set to 0 if ``backtrack=False`` and to ``-update`` otherwise"""
|
|
92
92
|
def __init__(self, backtrack=False):
|
|
93
93
|
defaults = dict(backtrack=backtrack)
|
|
94
94
|
super().__init__(defaults=defaults)
|
|
95
95
|
|
|
96
96
|
@torch.no_grad
|
|
97
|
-
def
|
|
98
|
-
closure =
|
|
97
|
+
def apply(self, objective):
|
|
98
|
+
closure = objective.closure
|
|
99
99
|
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
100
100
|
backtrack = self.defaults['backtrack']
|
|
101
101
|
|
|
102
|
-
update =
|
|
103
|
-
f_0 =
|
|
102
|
+
update = objective.get_updates()
|
|
103
|
+
f_0 = objective.get_loss(backward=False)
|
|
104
104
|
|
|
105
|
-
torch._foreach_sub_(
|
|
105
|
+
torch._foreach_sub_(objective.params, update)
|
|
106
106
|
f_1 = closure(False)
|
|
107
107
|
|
|
108
108
|
if f_1 <= f_0:
|
|
@@ -111,15 +111,15 @@ class NegateOnLossIncrease(Module):
|
|
|
111
111
|
# var.skip_update = True
|
|
112
112
|
# return var
|
|
113
113
|
|
|
114
|
-
torch._foreach_add_(
|
|
115
|
-
return
|
|
114
|
+
torch._foreach_add_(objective.params, update)
|
|
115
|
+
return objective
|
|
116
116
|
|
|
117
|
-
torch._foreach_add_(
|
|
117
|
+
torch._foreach_add_(objective.params, update)
|
|
118
118
|
if backtrack:
|
|
119
|
-
torch._foreach_neg_(
|
|
119
|
+
torch._foreach_neg_(objective.updates)
|
|
120
120
|
else:
|
|
121
|
-
torch._foreach_zero_(
|
|
122
|
-
return
|
|
121
|
+
torch._foreach_zero_(objective.updates)
|
|
122
|
+
return objective
|
|
123
123
|
|
|
124
124
|
|
|
125
125
|
class Online(Module):
|
|
@@ -147,48 +147,50 @@ class Online(Module):
|
|
|
147
147
|
"""
|
|
148
148
|
def __init__(self, *modules: Module,):
|
|
149
149
|
super().__init__()
|
|
150
|
+
if len(modules) == 0:
|
|
151
|
+
raise RuntimeError("Online got empty list of modules. To make a module online, wrap it in tz.m.Online, e.g. `tz.m.Online(tz.m.LBFGS())`")
|
|
150
152
|
|
|
151
153
|
self.set_child('module', modules)
|
|
152
154
|
|
|
153
155
|
@torch.no_grad
|
|
154
|
-
def update(self,
|
|
155
|
-
closure =
|
|
156
|
+
def update(self, objective):
|
|
157
|
+
closure = objective.closure
|
|
156
158
|
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
157
159
|
|
|
158
160
|
step = self.global_state.get('step', 0) + 1
|
|
159
161
|
self.global_state['step'] = step
|
|
160
162
|
|
|
161
|
-
params = TensorList(
|
|
163
|
+
params = TensorList(objective.params)
|
|
162
164
|
p_cur = params.clone()
|
|
163
165
|
p_prev = self.get_state(params, 'p_prev', cls=TensorList)
|
|
164
166
|
|
|
165
167
|
module = self.children['module']
|
|
166
|
-
var_c =
|
|
168
|
+
var_c = objective.clone(clone_updates=False)
|
|
167
169
|
|
|
168
170
|
# on 1st step just step and store previous params
|
|
169
171
|
if step == 1:
|
|
170
172
|
p_prev.copy_(params)
|
|
171
173
|
|
|
172
174
|
module.update(var_c)
|
|
173
|
-
|
|
175
|
+
objective.update_attrs_from_clone_(var_c)
|
|
174
176
|
return
|
|
175
177
|
|
|
176
178
|
# restore previous params and update
|
|
177
|
-
|
|
179
|
+
prev_objective = Objective(params=params, closure=closure, model=objective.model, current_step=objective.current_step)
|
|
178
180
|
params.set_(p_prev)
|
|
179
181
|
module.reset_for_online()
|
|
180
|
-
module.update(
|
|
182
|
+
module.update(prev_objective)
|
|
181
183
|
|
|
182
184
|
# restore current params and update
|
|
183
185
|
params.set_(p_cur)
|
|
184
186
|
p_prev.copy_(params)
|
|
185
187
|
module.update(var_c)
|
|
186
|
-
|
|
188
|
+
objective.update_attrs_from_clone_(var_c)
|
|
187
189
|
|
|
188
190
|
@torch.no_grad
|
|
189
|
-
def apply(self,
|
|
191
|
+
def apply(self, objective):
|
|
190
192
|
module = self.children['module']
|
|
191
|
-
return module.apply(
|
|
193
|
+
return module.apply(objective.clone(clone_updates=False))
|
|
192
194
|
|
|
193
|
-
def get_H(self,
|
|
194
|
-
return self.children['module'].get_H(
|
|
195
|
+
def get_H(self, objective):
|
|
196
|
+
return self.children['module'].get_H(objective)
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from ...core import Chainable, Module,
|
|
3
|
+
from ...core import Chainable, Module, Transform
|
|
4
4
|
from ...core.reformulation import Reformulation
|
|
5
|
-
from ...utils import Distributions, NumberList, TensorList
|
|
5
|
+
from ...utils import Distributions, Metrics, NumberList, TensorList, evaluate_metric
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class Dropout(Transform):
|
|
9
9
|
"""Applies dropout to the update.
|
|
10
10
|
|
|
11
|
-
For each weight the update to that weight has
|
|
11
|
+
For each weight the update to that weight has ``p`` probability to be set to 0.
|
|
12
12
|
This can be used to implement gradient dropout or update dropout depending on placement.
|
|
13
13
|
|
|
14
14
|
Args:
|
|
@@ -18,36 +18,37 @@ class Dropout(Transform):
|
|
|
18
18
|
target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
Examples:
|
|
22
|
-
Gradient dropout.
|
|
21
|
+
### Examples:
|
|
23
22
|
|
|
24
|
-
|
|
23
|
+
Gradient dropout.
|
|
25
24
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
25
|
+
```python
|
|
26
|
+
opt = tz.Modular(
|
|
27
|
+
model.parameters(),
|
|
28
|
+
tz.m.Dropout(0.5),
|
|
29
|
+
tz.m.Adam(),
|
|
30
|
+
tz.m.LR(1e-3)
|
|
31
|
+
)
|
|
32
|
+
```
|
|
32
33
|
|
|
33
|
-
|
|
34
|
+
Update dropout.
|
|
34
35
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
36
|
+
``python
|
|
37
|
+
opt = tz.Modular(
|
|
38
|
+
model.parameters(),
|
|
39
|
+
tz.m.Adam(),
|
|
40
|
+
tz.m.Dropout(0.5),
|
|
41
|
+
tz.m.LR(1e-3)
|
|
42
|
+
)
|
|
43
|
+
```
|
|
43
44
|
|
|
44
45
|
"""
|
|
45
|
-
def __init__(self, p: float = 0.5, graft: bool=False
|
|
46
|
+
def __init__(self, p: float = 0.5, graft: bool=False):
|
|
46
47
|
defaults = dict(p=p, graft=graft)
|
|
47
|
-
super().__init__(defaults
|
|
48
|
+
super().__init__(defaults)
|
|
48
49
|
|
|
49
50
|
@torch.no_grad
|
|
50
|
-
def
|
|
51
|
+
def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
|
|
51
52
|
tensors = TensorList(tensors)
|
|
52
53
|
p = NumberList(s['p'] for s in settings)
|
|
53
54
|
graft = settings[0]['graft']
|
|
@@ -67,32 +68,31 @@ class WeightDropout(Module):
|
|
|
67
68
|
"""
|
|
68
69
|
Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
|
|
69
70
|
|
|
70
|
-
Dropout can be disabled for a parameter by setting
|
|
71
|
+
Dropout can be disabled for a parameter by setting ``use_dropout=False`` in corresponding parameter group.
|
|
71
72
|
|
|
72
73
|
Args:
|
|
73
74
|
p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
|
|
74
|
-
graft (bool, optional):
|
|
75
|
-
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
76
75
|
"""
|
|
77
|
-
def __init__(self, p: float = 0.5
|
|
78
|
-
defaults = dict(p=p,
|
|
76
|
+
def __init__(self, p: float = 0.5):
|
|
77
|
+
defaults = dict(p=p, use_dropout=True)
|
|
79
78
|
super().__init__(defaults)
|
|
80
79
|
|
|
81
80
|
@torch.no_grad
|
|
82
|
-
def
|
|
83
|
-
closure =
|
|
81
|
+
def update(self, objective):
|
|
82
|
+
closure = objective.closure
|
|
84
83
|
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
85
|
-
params = TensorList(
|
|
84
|
+
params = TensorList(objective.params)
|
|
86
85
|
p = NumberList(self.settings[p]['p'] for p in params)
|
|
87
86
|
|
|
88
87
|
# create masks
|
|
89
88
|
mask = []
|
|
90
|
-
for p
|
|
89
|
+
for p in params:
|
|
91
90
|
prob = self.settings[p]['p']
|
|
92
91
|
use_dropout = self.settings[p]['use_dropout']
|
|
93
92
|
if use_dropout: mask.append(_bernoulli_like(p, prob))
|
|
94
93
|
else: mask.append(torch.ones_like(p))
|
|
95
94
|
|
|
95
|
+
# create a closure that evaluates masked parameters
|
|
96
96
|
@torch.no_grad
|
|
97
97
|
def dropout_closure(backward=True):
|
|
98
98
|
orig_params = params.clone()
|
|
@@ -104,15 +104,14 @@ class WeightDropout(Module):
|
|
|
104
104
|
params.copy_(orig_params)
|
|
105
105
|
return loss
|
|
106
106
|
|
|
107
|
-
|
|
108
|
-
return var
|
|
107
|
+
objective.closure = dropout_closure
|
|
109
108
|
|
|
110
109
|
|
|
111
110
|
class PerturbWeights(Module):
|
|
112
111
|
"""
|
|
113
112
|
Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
|
|
114
113
|
|
|
115
|
-
Can be disabled for a parameter by setting
|
|
114
|
+
Can be disabled for a parameter by setting ``perturb=False`` in corresponding parameter group.
|
|
116
115
|
|
|
117
116
|
Args:
|
|
118
117
|
alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
|
|
@@ -120,15 +119,22 @@ class PerturbWeights(Module):
|
|
|
120
119
|
distribution (bool, optional):
|
|
121
120
|
distribution of the random perturbation. Defaults to False.
|
|
122
121
|
"""
|
|
123
|
-
|
|
124
|
-
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
alpha: float = 0.1,
|
|
126
|
+
relative: bool = True,
|
|
127
|
+
distribution: Distributions = "normal",
|
|
128
|
+
metric: Metrics = "mad",
|
|
129
|
+
):
|
|
130
|
+
defaults = dict(alpha=alpha, relative=relative, distribution=distribution, metric=metric, perturb=True)
|
|
125
131
|
super().__init__(defaults)
|
|
126
132
|
|
|
127
133
|
@torch.no_grad
|
|
128
|
-
def
|
|
129
|
-
closure =
|
|
134
|
+
def update(self, objective):
|
|
135
|
+
closure = objective.closure
|
|
130
136
|
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
131
|
-
params = TensorList(
|
|
137
|
+
params = TensorList(objective.params)
|
|
132
138
|
|
|
133
139
|
# create perturbations
|
|
134
140
|
perts = []
|
|
@@ -140,7 +146,7 @@ class PerturbWeights(Module):
|
|
|
140
146
|
|
|
141
147
|
alpha = settings['alpha']
|
|
142
148
|
if settings['relative']:
|
|
143
|
-
alpha *= p
|
|
149
|
+
alpha *= evaluate_metric(p, settings["metric"])
|
|
144
150
|
|
|
145
151
|
distribution = self.settings[p]['distribution'].lower()
|
|
146
152
|
if distribution in ('normal', 'gaussian'):
|
|
@@ -163,5 +169,4 @@ class PerturbWeights(Module):
|
|
|
163
169
|
params.sub_(perts)
|
|
164
170
|
return loss
|
|
165
171
|
|
|
166
|
-
|
|
167
|
-
return var
|
|
172
|
+
objective.closure = perturbed_closure
|
torchzero/modules/misc/split.py
CHANGED
|
@@ -1,54 +1,53 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from collections.abc import Callable, Sequence, Iterable
|
|
3
2
|
from typing import cast
|
|
4
3
|
|
|
5
4
|
import torch
|
|
6
5
|
|
|
7
|
-
from ...core import Chainable, Module,
|
|
6
|
+
from ...core import Chainable, Module, Objective
|
|
8
7
|
|
|
9
8
|
|
|
10
9
|
def _split(
|
|
11
10
|
module: Module,
|
|
12
11
|
idxs,
|
|
13
12
|
params,
|
|
14
|
-
|
|
13
|
+
objective: Objective,
|
|
15
14
|
):
|
|
16
15
|
split_params = [p for i,p in enumerate(params) if i in idxs]
|
|
17
16
|
|
|
18
17
|
split_grad = None
|
|
19
|
-
if
|
|
20
|
-
split_grad = [g for i,g in enumerate(
|
|
18
|
+
if objective.grads is not None:
|
|
19
|
+
split_grad = [g for i,g in enumerate(objective.grads) if i in idxs]
|
|
21
20
|
|
|
22
21
|
split_update = None
|
|
23
|
-
if
|
|
24
|
-
split_update = [u for i,u in enumerate(
|
|
22
|
+
if objective.updates is not None:
|
|
23
|
+
split_update = [u for i,u in enumerate(objective.updates) if i in idxs]
|
|
25
24
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
split_obj = objective.clone(clone_updates=False, parent=objective)
|
|
26
|
+
split_obj.params = split_params
|
|
27
|
+
split_obj.grads = split_grad
|
|
28
|
+
split_obj.updates = split_update
|
|
30
29
|
|
|
31
|
-
|
|
30
|
+
split_obj = module.step(split_obj)
|
|
32
31
|
|
|
33
32
|
# those should be set due to var being parent
|
|
34
|
-
if
|
|
35
|
-
assert
|
|
33
|
+
if split_obj.grads is not None:
|
|
34
|
+
assert objective.grads is not None
|
|
36
35
|
|
|
37
|
-
if
|
|
38
|
-
assert
|
|
36
|
+
if split_obj.loss is not None:
|
|
37
|
+
assert objective.loss is not None
|
|
39
38
|
|
|
40
|
-
if
|
|
39
|
+
if split_obj.updates is not None:
|
|
41
40
|
|
|
42
41
|
# make sure update is set, it will be filled with ``true`` and ``false`` tensors
|
|
43
|
-
if
|
|
44
|
-
if
|
|
45
|
-
else:
|
|
42
|
+
if objective.updates is None:
|
|
43
|
+
if objective.grads is None: objective.updates = [cast(torch.Tensor, None) for _ in objective.params]
|
|
44
|
+
else: objective.updates = [g.clone() for g in objective.grads]
|
|
46
45
|
|
|
47
46
|
# set all tensors from this split
|
|
48
|
-
for idx, u in zip(idxs,
|
|
49
|
-
|
|
47
|
+
for idx, u in zip(idxs, split_obj.updates):
|
|
48
|
+
objective.updates[idx] = u
|
|
50
49
|
|
|
51
|
-
return
|
|
50
|
+
return objective
|
|
52
51
|
|
|
53
52
|
_SingleFilter = Callable[[torch.Tensor], bool] | torch.Tensor | Iterable[torch.Tensor] | torch.nn.Module | Iterable[torch.nn.Module]
|
|
54
53
|
Filter = _SingleFilter | Iterable[_SingleFilter]
|
|
@@ -101,9 +100,12 @@ class Split(Module):
|
|
|
101
100
|
if true is not None: self.set_child('true', true)
|
|
102
101
|
if false is not None: self.set_child('false', false)
|
|
103
102
|
|
|
104
|
-
def
|
|
103
|
+
def update(self, objective): raise RuntimeError
|
|
104
|
+
def apply(self, objective): raise RuntimeError
|
|
105
105
|
|
|
106
|
-
|
|
106
|
+
def step(self, objective):
|
|
107
|
+
|
|
108
|
+
params = objective.params
|
|
107
109
|
filter = _make_filter(self.settings[params[0]]['filter'])
|
|
108
110
|
|
|
109
111
|
true_idxs = []
|
|
@@ -114,10 +116,10 @@ class Split(Module):
|
|
|
114
116
|
|
|
115
117
|
if 'true' in self.children and len(true_idxs) > 0:
|
|
116
118
|
true = self.children['true']
|
|
117
|
-
|
|
119
|
+
objective = _split(true, idxs=true_idxs, params=params, objective=objective)
|
|
118
120
|
|
|
119
121
|
if 'false' in self.children and len(false_idxs) > 0:
|
|
120
122
|
false = self.children['false']
|
|
121
|
-
|
|
123
|
+
objective = _split(false, idxs=false_idxs, params=params, objective=objective)
|
|
122
124
|
|
|
123
|
-
return
|
|
125
|
+
return objective
|
torchzero/modules/misc/switch.py
CHANGED
|
@@ -14,20 +14,21 @@ class Alternate(Module):
|
|
|
14
14
|
Args:
|
|
15
15
|
steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
|
|
16
16
|
|
|
17
|
-
Examples:
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
17
|
+
### Examples:
|
|
18
|
+
Alternate between Adam, SignSGD and RMSprop
|
|
19
|
+
|
|
20
|
+
```python
|
|
21
|
+
|
|
22
|
+
opt = tz.Modular(
|
|
23
|
+
model.parameters(),
|
|
24
|
+
tz.m.Alternate(
|
|
25
|
+
tz.m.Adam(),
|
|
26
|
+
[tz.m.SignSGD(), tz.m.Mul(0.5)],
|
|
27
|
+
tz.m.RMSprop(),
|
|
28
|
+
),
|
|
29
|
+
tz.m.LR(1e-3),
|
|
30
|
+
)
|
|
31
|
+
```
|
|
31
32
|
"""
|
|
32
33
|
LOOP = True
|
|
33
34
|
def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
|
|
@@ -43,14 +44,17 @@ class Alternate(Module):
|
|
|
43
44
|
self.global_state['current_module_idx'] = 0
|
|
44
45
|
self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
|
|
45
46
|
|
|
47
|
+
def update(self, objective): raise RuntimeError
|
|
48
|
+
def apply(self, objective): raise RuntimeError
|
|
49
|
+
|
|
46
50
|
@torch.no_grad
|
|
47
|
-
def step(self,
|
|
51
|
+
def step(self, objective):
|
|
48
52
|
# get current module
|
|
49
53
|
current_module_idx = self.global_state.setdefault('current_module_idx', 0)
|
|
50
54
|
module = self.children[f'module_{current_module_idx}']
|
|
51
55
|
|
|
52
56
|
# step
|
|
53
|
-
|
|
57
|
+
objective = module.step(objective.clone(clone_updates=False))
|
|
54
58
|
|
|
55
59
|
# number of steps until next module
|
|
56
60
|
steps = self.defaults['steps']
|
|
@@ -72,28 +76,29 @@ class Alternate(Module):
|
|
|
72
76
|
|
|
73
77
|
self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
|
|
74
78
|
|
|
75
|
-
return
|
|
79
|
+
return objective
|
|
76
80
|
|
|
77
81
|
class Switch(Alternate):
|
|
78
|
-
"""After
|
|
82
|
+
"""After ``steps`` steps switches to the next module.
|
|
79
83
|
|
|
80
84
|
Args:
|
|
81
85
|
steps (int | Iterable[int]): Number of steps to perform with each module.
|
|
82
86
|
|
|
83
|
-
Examples:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
87
|
+
### Examples:
|
|
88
|
+
|
|
89
|
+
Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
opt = tz.Modular(
|
|
93
|
+
model.parameters(),
|
|
94
|
+
tz.m.Switch(
|
|
95
|
+
[tz.m.Adam(), tz.m.LR(1e-3)],
|
|
96
|
+
[tz.m.LBFGS(), tz.m.Backtracking()],
|
|
97
|
+
[tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
|
|
98
|
+
steps = (1000, 2000)
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
```
|
|
97
102
|
"""
|
|
98
103
|
|
|
99
104
|
LOOP = False
|