torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
torchzero/core/transform.py
CHANGED
|
@@ -25,7 +25,8 @@ class Transform(Module, ABC):
|
|
|
25
25
|
def __init__(
|
|
26
26
|
self,
|
|
27
27
|
defaults: dict[str,Any] | None,
|
|
28
|
-
uses_grad: bool,
|
|
28
|
+
uses_grad: bool = False,
|
|
29
|
+
uses_loss: bool = False,
|
|
29
30
|
concat_params: bool = False,
|
|
30
31
|
update_freq: int = 1,
|
|
31
32
|
scale_first: bool = False,
|
|
@@ -35,49 +36,48 @@ class Transform(Module, ABC):
|
|
|
35
36
|
super().__init__(defaults)
|
|
36
37
|
self._target: Target = target
|
|
37
38
|
self._uses_grad = uses_grad
|
|
39
|
+
self._uses_loss = uses_loss
|
|
38
40
|
self._concat_params = concat_params
|
|
39
41
|
self._update_freq = update_freq
|
|
40
42
|
self._scale_first = scale_first
|
|
41
43
|
self._inner = inner
|
|
42
44
|
|
|
43
|
-
def
|
|
45
|
+
def update_tensors(
|
|
44
46
|
self,
|
|
45
47
|
tensors: list[torch.Tensor],
|
|
46
48
|
params: list[torch.Tensor],
|
|
47
49
|
grads: list[torch.Tensor] | None,
|
|
48
|
-
loss: torch.Tensor | None,
|
|
50
|
+
loss: torch.Tensor | float | None,
|
|
49
51
|
states: list[dict[str, Any]],
|
|
50
52
|
settings: Sequence[Mapping[str, Any]],
|
|
51
53
|
) -> None:
|
|
52
|
-
"""
|
|
54
|
+
"""update function, this shouldn't be called directly. Updates this module."""
|
|
53
55
|
|
|
54
56
|
@abstractmethod
|
|
55
|
-
def
|
|
57
|
+
def apply_tensors(
|
|
56
58
|
self,
|
|
57
59
|
tensors: list[torch.Tensor],
|
|
58
60
|
params: list[torch.Tensor],
|
|
59
61
|
grads: list[torch.Tensor] | None,
|
|
60
|
-
loss: torch.Tensor | None,
|
|
62
|
+
loss: torch.Tensor | float | None,
|
|
61
63
|
states: list[dict[str, Any]],
|
|
62
64
|
settings: Sequence[Mapping[str, Any]],
|
|
63
65
|
) -> Sequence[torch.Tensor]:
|
|
64
|
-
"""Applies the update rule to `tensors
|
|
66
|
+
"""apply function, this shouldn't be called directly. Applies the update rule to `tensors` and returns them.
|
|
67
|
+
If possible, this shouldn't modify the internal state of this transform."""
|
|
65
68
|
|
|
66
69
|
@final
|
|
67
70
|
@torch.no_grad
|
|
68
|
-
def
|
|
71
|
+
def transform_update(
|
|
69
72
|
self,
|
|
70
73
|
tensors: list[torch.Tensor],
|
|
71
74
|
params: list[torch.Tensor],
|
|
72
75
|
grads: list[torch.Tensor] | None,
|
|
73
|
-
loss: torch.Tensor | None,
|
|
76
|
+
loss: torch.Tensor | float | None,
|
|
74
77
|
states: list[dict[str, Any]],
|
|
75
78
|
settings: Sequence[Mapping[str, Any]] | None,
|
|
76
|
-
) ->
|
|
77
|
-
"""
|
|
78
|
-
un_tensors = tensors
|
|
79
|
-
un_params = params
|
|
80
|
-
un_grads = grads
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Updates this transform from an arbitrary sequence of tensors."""
|
|
81
81
|
if self._concat_params:
|
|
82
82
|
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
83
83
|
params = [torch.cat([p.ravel() for p in params])]
|
|
@@ -86,24 +86,61 @@ class Transform(Module, ABC):
|
|
|
86
86
|
if settings is None:
|
|
87
87
|
settings = [self.defaults for _ in tensors]
|
|
88
88
|
|
|
89
|
-
step = self.global_state.get('__step', 0)
|
|
89
|
+
step = self.global_state.get('__step', 0) # that way it gets reset correctly
|
|
90
|
+
self.global_state['__step'] = step + 1
|
|
91
|
+
|
|
90
92
|
num = len(tensors)
|
|
91
93
|
states = states[:num]
|
|
92
94
|
settings = settings[:num]
|
|
93
95
|
|
|
94
|
-
update_freq = self._update_freq
|
|
95
|
-
scale_first = self._scale_first
|
|
96
96
|
scale_factor = 1
|
|
97
97
|
|
|
98
98
|
# scaling factor for 1st step
|
|
99
|
-
if
|
|
99
|
+
if self._scale_first and step == 0:
|
|
100
100
|
# initial step size guess from pytorch LBFGS
|
|
101
101
|
scale_factor = 1 / TensorList(tensors).abs().global_sum().clip(min=1)
|
|
102
102
|
scale_factor = scale_factor.clip(min=torch.finfo(tensors[0].dtype).eps)
|
|
103
103
|
|
|
104
104
|
# update transform
|
|
105
|
-
if step %
|
|
106
|
-
self.
|
|
105
|
+
if step % self._update_freq == 0:
|
|
106
|
+
self.update_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
107
|
+
|
|
108
|
+
# store for transform_apply
|
|
109
|
+
self.global_state["__tensors"] = tensors
|
|
110
|
+
self.global_state["__params"] = params
|
|
111
|
+
self.global_state["__grads"] = grads
|
|
112
|
+
self.global_state["__scale_factor"] = scale_factor
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@final
|
|
116
|
+
@torch.no_grad
|
|
117
|
+
def transform_apply(
|
|
118
|
+
self,
|
|
119
|
+
tensors: list[torch.Tensor],
|
|
120
|
+
params: list[torch.Tensor],
|
|
121
|
+
grads: list[torch.Tensor] | None,
|
|
122
|
+
loss: torch.Tensor | float | None,
|
|
123
|
+
states: list[dict[str, Any]],
|
|
124
|
+
settings: Sequence[Mapping[str, Any]] | None,
|
|
125
|
+
) -> list[torch.Tensor]:
|
|
126
|
+
"""Applies this transform to an arbitrary sequence of tensors.
|
|
127
|
+
This can be used after ``transform_update`` has been used at least once."""
|
|
128
|
+
|
|
129
|
+
if settings is None:
|
|
130
|
+
settings = [self.defaults for _ in tensors]
|
|
131
|
+
|
|
132
|
+
num = len(tensors)
|
|
133
|
+
states = states[:num]
|
|
134
|
+
settings = settings[:num]
|
|
135
|
+
|
|
136
|
+
un_tensors = tensors
|
|
137
|
+
un_params = params
|
|
138
|
+
un_grads = grads
|
|
139
|
+
|
|
140
|
+
tensors = self.global_state.pop("__tensors")
|
|
141
|
+
params = self.global_state.pop("__params")
|
|
142
|
+
grads = self.global_state.pop("__grads")
|
|
143
|
+
scale_factor = self.global_state.pop("__scale_factor")
|
|
107
144
|
|
|
108
145
|
# step with inner
|
|
109
146
|
if self._inner is not None:
|
|
@@ -112,27 +149,17 @@ class Transform(Module, ABC):
|
|
|
112
149
|
tensors = [torch.cat([t.ravel() for t in tensors])]
|
|
113
150
|
|
|
114
151
|
# apply transform
|
|
115
|
-
tensors = list(self.
|
|
152
|
+
tensors = list(self.apply_tensors(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings))
|
|
116
153
|
|
|
117
154
|
# scale initial step, when preconditioner might not have been applied
|
|
118
|
-
if
|
|
155
|
+
if self._scale_first and self.global_state['__step'] == 1:
|
|
119
156
|
torch._foreach_mul_(tensors, scale_factor)
|
|
120
157
|
|
|
121
|
-
self.global_state['__step'] = step + 1
|
|
122
158
|
if self._concat_params:
|
|
123
159
|
tensors = vec_to_tensors(vec=tensors[0], reference=un_tensors)
|
|
124
160
|
return tensors
|
|
125
161
|
|
|
126
|
-
|
|
127
|
-
@torch.no_grad
|
|
128
|
-
def keyed_transform(
|
|
129
|
-
self,
|
|
130
|
-
tensors: list[torch.Tensor],
|
|
131
|
-
params: list[torch.Tensor],
|
|
132
|
-
grads: list[torch.Tensor] | None,
|
|
133
|
-
loss: torch.Tensor | None,
|
|
134
|
-
):
|
|
135
|
-
"""Applies this transform to `tensors`, `params` will be used as keys and need to always point to same tensor objects."""
|
|
162
|
+
def _get_keyed_states_settings(self, params: list[torch.Tensor]):
|
|
136
163
|
if self._concat_params:
|
|
137
164
|
p = params[0]
|
|
138
165
|
states = [self.state[p]]
|
|
@@ -145,41 +172,116 @@ class Transform(Module, ABC):
|
|
|
145
172
|
states.append(self.state[p])
|
|
146
173
|
settings.append(self.settings[p])
|
|
147
174
|
|
|
148
|
-
return
|
|
175
|
+
return states, settings
|
|
176
|
+
|
|
177
|
+
@final
|
|
178
|
+
@torch.no_grad
|
|
179
|
+
def keyed_transform_update(
|
|
180
|
+
self,
|
|
181
|
+
tensors: list[torch.Tensor],
|
|
182
|
+
params: list[torch.Tensor],
|
|
183
|
+
grads: list[torch.Tensor] | None,
|
|
184
|
+
loss: torch.Tensor | float | None,
|
|
185
|
+
):
|
|
186
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
187
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
188
|
+
self.transform_update(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@final
|
|
192
|
+
@torch.no_grad
|
|
193
|
+
def keyed_transform_apply(
|
|
194
|
+
self,
|
|
195
|
+
tensors: list[torch.Tensor],
|
|
196
|
+
params: list[torch.Tensor],
|
|
197
|
+
grads: list[torch.Tensor] | None,
|
|
198
|
+
loss: torch.Tensor | float | None,
|
|
199
|
+
):
|
|
200
|
+
"""`params` will be used as keys and need to always point to same tensor objects.`"""
|
|
201
|
+
states, settings = self._get_keyed_states_settings(params)
|
|
202
|
+
return self.transform_apply(tensors=tensors, params=params, grads=grads, loss=loss, states=states, settings=settings)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def pre_step(self, var: Var) -> None:
|
|
206
|
+
"""Logic to run pre-transform, this way transform has access to Var."""
|
|
207
|
+
def post_step(self, var: Var) -> None:
|
|
208
|
+
"""Logic to run post-transform, this way transform has access to Var."""
|
|
209
|
+
|
|
210
|
+
def update(self, var: Var):
|
|
211
|
+
if self._target != 'update':
|
|
212
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
213
|
+
f"With {self._target = } only `step` method can be used.")
|
|
214
|
+
|
|
215
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
216
|
+
update = var.get_update() # this sets loss
|
|
217
|
+
if self._uses_grad: var.get_grad()
|
|
218
|
+
if self._uses_loss: var.get_loss(False)
|
|
219
|
+
params=var.params
|
|
220
|
+
self.pre_step(var)
|
|
221
|
+
|
|
222
|
+
# update
|
|
223
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
224
|
+
|
|
225
|
+
def apply(self, var: Var):
|
|
226
|
+
if self._target != 'update':
|
|
227
|
+
raise ValueError("Target must be 'update' to use `update` and `apply` methods. "
|
|
228
|
+
f"With {self._target = } only `step` method can be used.")
|
|
149
229
|
|
|
150
|
-
def step(self, var: Var) -> Var:
|
|
151
230
|
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
231
|
+
update = var.get_update() # this sets loss
|
|
152
232
|
if self._uses_grad: var.get_grad()
|
|
233
|
+
if self._uses_loss: var.get_loss(False)
|
|
234
|
+
params=var.params
|
|
235
|
+
|
|
236
|
+
# apply
|
|
237
|
+
var.update = self.keyed_transform_apply(update, params, var.grad, var.loss)
|
|
238
|
+
self.post_step(var)
|
|
239
|
+
return var
|
|
240
|
+
|
|
241
|
+
def step(self, var: Var) -> Var:
|
|
242
|
+
|
|
243
|
+
# var may change, therefore current params and grads have to be extracted and passed explicitly
|
|
244
|
+
if self._target in ('update', 'update_difference'): var.get_update() # this sets loss
|
|
245
|
+
if self._uses_grad or self._target == 'grad': var.get_grad()
|
|
246
|
+
if self._uses_loss: var.get_loss(False)
|
|
153
247
|
params=var.params
|
|
248
|
+
self.pre_step(var)
|
|
154
249
|
|
|
155
250
|
# ---------------------------------- update ---------------------------------- #
|
|
156
251
|
if self._target == 'update':
|
|
157
252
|
update = var.get_update()
|
|
158
|
-
|
|
253
|
+
self.keyed_transform_update(update, params, var.grad, var.loss)
|
|
254
|
+
var.update = list(self.keyed_transform_apply(update, params, var.grad, var.loss))
|
|
159
255
|
return var
|
|
160
256
|
|
|
161
257
|
# ----------------------------------- grad ----------------------------------- #
|
|
162
258
|
if self._target == 'grad':
|
|
163
259
|
grad = var.get_grad()
|
|
164
|
-
|
|
260
|
+
self.keyed_transform_update(grad, params, grad, var.loss)
|
|
261
|
+
var.grad = list(self.keyed_transform_apply(grad, params, grad, var.loss))
|
|
165
262
|
return var
|
|
166
263
|
|
|
167
264
|
# ------------------------------- params_direct ------------------------------ #
|
|
168
265
|
if self._target == 'params_direct':
|
|
169
|
-
|
|
266
|
+
self.keyed_transform_update(var.params, params, var.grad, var.loss)
|
|
267
|
+
new_params = self.keyed_transform_apply(var.params, params, var.grad, var.loss)
|
|
170
268
|
for p, new_p in zip(var.params, new_params): set_storage_(p, new_p)
|
|
171
269
|
return var
|
|
172
270
|
|
|
173
271
|
# ----------------------------- params_differnce ----------------------------- #
|
|
174
272
|
if self._target == 'params_difference':
|
|
175
|
-
|
|
273
|
+
p_clone = [p.clone() for p in var.params]
|
|
274
|
+
self.keyed_transform_update(p_clone, params, var.grad, var.loss)
|
|
275
|
+
new_params = tuple(self.keyed_transform_apply(p_clone, params, var.grad, var.loss))
|
|
176
276
|
var.update = list(torch._foreach_sub(var.params, new_params))
|
|
177
277
|
return var
|
|
178
278
|
|
|
179
279
|
# ----------------------------- update_difference ---------------------------- #
|
|
180
280
|
if self._target == 'update_difference':
|
|
181
281
|
update = var.get_update()
|
|
182
|
-
|
|
282
|
+
u_clone = [u.clone() for u in update]
|
|
283
|
+
self.keyed_transform_update(u_clone, params, var.grad, var.loss)
|
|
284
|
+
new_update = tuple(self.keyed_transform_apply(u_clone, params, var.grad, var.loss))
|
|
183
285
|
var.update = list(torch._foreach_sub(update, new_update))
|
|
184
286
|
return var
|
|
185
287
|
|
|
@@ -193,7 +295,8 @@ class Transform(Module, ABC):
|
|
|
193
295
|
if backward:
|
|
194
296
|
loss = original_closure()
|
|
195
297
|
current_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
196
|
-
|
|
298
|
+
self.keyed_transform_update(current_grad, params, var.grad, var.loss)
|
|
299
|
+
transformed_grad = list(self.keyed_transform_apply(current_grad, params, var.grad, var.loss))
|
|
197
300
|
for p, g in zip(params, transformed_grad):
|
|
198
301
|
p.grad = g
|
|
199
302
|
|
|
@@ -203,6 +306,7 @@ class Transform(Module, ABC):
|
|
|
203
306
|
return loss
|
|
204
307
|
|
|
205
308
|
var.closure = transformed_closure
|
|
309
|
+
self.post_step(var)
|
|
206
310
|
return var
|
|
207
311
|
|
|
208
312
|
# ---------------------------------- invalid --------------------------------- #
|
|
@@ -225,7 +329,8 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
225
329
|
def __init__(
|
|
226
330
|
self,
|
|
227
331
|
defaults: dict[str,Any] | None,
|
|
228
|
-
uses_grad: bool,
|
|
332
|
+
uses_grad: bool = False,
|
|
333
|
+
uses_loss: bool = False,
|
|
229
334
|
concat_params: bool = False,
|
|
230
335
|
update_freq: int = 1,
|
|
231
336
|
scale_first: bool = False,
|
|
@@ -238,6 +343,7 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
238
343
|
concat_params=concat_params,
|
|
239
344
|
update_freq=update_freq,
|
|
240
345
|
scale_first=scale_first,
|
|
346
|
+
uses_loss=uses_loss,
|
|
241
347
|
inner=inner,
|
|
242
348
|
target=target,
|
|
243
349
|
)
|
|
@@ -247,9 +353,9 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
247
353
|
tensor: torch.Tensor,
|
|
248
354
|
param: torch.Tensor,
|
|
249
355
|
grad: torch.Tensor | None,
|
|
250
|
-
loss: torch.Tensor | None,
|
|
356
|
+
loss: torch.Tensor | float | None,
|
|
251
357
|
state: dict[str, Any],
|
|
252
|
-
|
|
358
|
+
setting: Mapping[str, Any],
|
|
253
359
|
) -> None:
|
|
254
360
|
"""Updates this transform. By default does nothing - if logic is in `apply` method."""
|
|
255
361
|
|
|
@@ -259,20 +365,20 @@ class TensorwiseTransform(Transform, ABC):
|
|
|
259
365
|
tensor: torch.Tensor,
|
|
260
366
|
param: torch.Tensor,
|
|
261
367
|
grad: torch.Tensor | None,
|
|
262
|
-
loss: torch.Tensor | None,
|
|
368
|
+
loss: torch.Tensor | float | None,
|
|
263
369
|
state: dict[str, Any],
|
|
264
|
-
|
|
370
|
+
setting: Mapping[str, Any],
|
|
265
371
|
) -> torch.Tensor:
|
|
266
372
|
"""Applies the update rule to `tensor`."""
|
|
267
373
|
|
|
268
374
|
@final
|
|
269
|
-
def
|
|
375
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
270
376
|
if grads is None: grads = [None]*len(tensors)
|
|
271
377
|
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
272
378
|
self.update_tensor(t, p, g, loss, state, setting)
|
|
273
379
|
|
|
274
380
|
@final
|
|
275
|
-
def
|
|
381
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
276
382
|
applied = []
|
|
277
383
|
if grads is None: grads = [None]*len(tensors)
|
|
278
384
|
for t,p,g,state,setting in zip(tensors, params, grads, states, settings):
|
|
@@ -284,7 +390,7 @@ def apply_transform(
|
|
|
284
390
|
tensors: list[torch.Tensor],
|
|
285
391
|
params: list[torch.Tensor],
|
|
286
392
|
grads: list[torch.Tensor] | None,
|
|
287
|
-
loss: torch.Tensor | None = None,
|
|
393
|
+
loss: torch.Tensor | float | None = None,
|
|
288
394
|
var: Var | None = None,
|
|
289
395
|
current_step: int = 0,
|
|
290
396
|
):
|
|
@@ -292,9 +398,10 @@ def apply_transform(
|
|
|
292
398
|
var = Var(params=params, closure=None, model=None, current_step=current_step)
|
|
293
399
|
var.loss = loss
|
|
294
400
|
|
|
295
|
-
if isinstance(tfm, Transform):
|
|
401
|
+
if isinstance(tfm, Transform) and tfm._target == 'update':
|
|
296
402
|
if tfm._uses_grad and grads is None: grads = var.get_grad()
|
|
297
|
-
|
|
403
|
+
tfm.keyed_transform_update(tensors, params, grads, loss)
|
|
404
|
+
return list(tfm.keyed_transform_apply(tensors, params, grads, loss))
|
|
298
405
|
|
|
299
406
|
if isinstance(tfm, Chain): tfm = tfm.get_children_sequence() # pyright: ignore[reportAssignmentType]
|
|
300
407
|
if isinstance(tfm, Sequence):
|
torchzero/modules/__init__.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from .clipping import *
|
|
2
2
|
from .grad_approximation import *
|
|
3
3
|
from .line_search import *
|
|
4
|
-
from .
|
|
4
|
+
from .step_size import *
|
|
5
5
|
from .momentum import *
|
|
6
6
|
from .ops import *
|
|
7
7
|
from .optimizers import *
|
|
@@ -11,4 +11,5 @@ from .smoothing import *
|
|
|
11
11
|
from .weight_decay import *
|
|
12
12
|
from .wrappers import *
|
|
13
13
|
from .second_order import *
|
|
14
|
-
from .higher_order import *
|
|
14
|
+
from .higher_order import *
|
|
15
|
+
from .misc import *
|
|
@@ -5,7 +5,7 @@ import math
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
7
|
from ...core import Module, Target, Transform
|
|
8
|
-
from ...utils import NumberList, TensorList
|
|
8
|
+
from ...utils import NumberList, TensorList
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def clip_grad_value_(params: Iterable[torch.Tensor], value: float):
|
|
@@ -24,7 +24,7 @@ def _clip_norm_(
|
|
|
24
24
|
min: float | NumberList | None,
|
|
25
25
|
max: float | NumberList | None,
|
|
26
26
|
norm_value: float | NumberList | None,
|
|
27
|
-
ord: float,
|
|
27
|
+
ord: float | Literal['mean_abs'],
|
|
28
28
|
dim: int | Sequence[int] | Literal["global"] | None,
|
|
29
29
|
inverse_dims: bool,
|
|
30
30
|
min_size: int,
|
|
@@ -54,9 +54,13 @@ def _clip_norm_(
|
|
|
54
54
|
size = math.prod(tensor.size(d) for d in real_dim)
|
|
55
55
|
if size < min_size: continue
|
|
56
56
|
|
|
57
|
-
|
|
57
|
+
if ord == 'mean_abs':
|
|
58
|
+
norm = tensor.abs().mean(dim=real_dim, keepdim=True)
|
|
59
|
+
else:
|
|
60
|
+
norm: torch.Tensor = torch.linalg.vector_norm(tensor, ord=ord, dim=real_dim, keepdim=True) # pylint:disable=not-callable
|
|
61
|
+
|
|
58
62
|
if norm.numel() == 1 and norm == 0: continue
|
|
59
|
-
norm = torch.where(norm
|
|
63
|
+
norm = torch.where(norm <= 1e-12, 1, norm)
|
|
60
64
|
|
|
61
65
|
# normalize = True, perform normalization
|
|
62
66
|
norm_v = norm_value[i] if isinstance(norm_value, (list,tuple)) else norm_value
|
|
@@ -90,7 +94,7 @@ def _clip_norm_(
|
|
|
90
94
|
def clip_grad_norm_(
|
|
91
95
|
params: Iterable[torch.Tensor],
|
|
92
96
|
max_norm: float | None,
|
|
93
|
-
ord: float = 2,
|
|
97
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
94
98
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
95
99
|
inverse_dims: bool = False,
|
|
96
100
|
min_size: int = 2,
|
|
@@ -118,7 +122,7 @@ def clip_grad_norm_(
|
|
|
118
122
|
def normalize_grads_(
|
|
119
123
|
params: Iterable[torch.Tensor],
|
|
120
124
|
norm_value: float,
|
|
121
|
-
ord: float = 2,
|
|
125
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
122
126
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
123
127
|
inverse_dims: bool = False,
|
|
124
128
|
min_size: int = 1,
|
|
@@ -145,13 +149,43 @@ def normalize_grads_(
|
|
|
145
149
|
|
|
146
150
|
|
|
147
151
|
class ClipValue(Transform):
|
|
148
|
-
"""Clips update magnitude to be within `(-value, value)` range.
|
|
152
|
+
"""Clips update magnitude to be within `(-value, value)` range.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
value (float): value to clip to.
|
|
156
|
+
target (str): refer to :ref:`target argument` in documentation.
|
|
157
|
+
|
|
158
|
+
Examples:
|
|
159
|
+
|
|
160
|
+
Gradient clipping:
|
|
161
|
+
|
|
162
|
+
.. code-block:: python
|
|
163
|
+
|
|
164
|
+
opt = tz.Modular(
|
|
165
|
+
model.parameters(),
|
|
166
|
+
tz.m.ClipValue(1),
|
|
167
|
+
tz.m.Adam(),
|
|
168
|
+
tz.m.LR(1e-2),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
Update clipping:
|
|
172
|
+
|
|
173
|
+
.. code-block:: python
|
|
174
|
+
|
|
175
|
+
opt = tz.Modular(
|
|
176
|
+
model.parameters(),
|
|
177
|
+
tz.m.Adam(),
|
|
178
|
+
tz.m.ClipValue(1),
|
|
179
|
+
tz.m.LR(1e-2),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
"""
|
|
149
183
|
def __init__(self, value: float, target: Target = 'update'):
|
|
150
184
|
defaults = dict(value=value)
|
|
151
|
-
super().__init__(defaults,
|
|
185
|
+
super().__init__(defaults, target=target)
|
|
152
186
|
|
|
153
187
|
@torch.no_grad
|
|
154
|
-
def
|
|
188
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
155
189
|
value = [s['value'] for s in settings]
|
|
156
190
|
return TensorList(tensors).clip_([-v for v in value], value)
|
|
157
191
|
|
|
@@ -172,21 +206,45 @@ class ClipNorm(Transform):
|
|
|
172
206
|
minimal numer of elements in a parameter or slice to clip norm. Defaults to 1.
|
|
173
207
|
target (str, optional):
|
|
174
208
|
what this affects.
|
|
209
|
+
|
|
210
|
+
Examples:
|
|
211
|
+
|
|
212
|
+
Gradient norm clipping:
|
|
213
|
+
|
|
214
|
+
.. code-block:: python
|
|
215
|
+
|
|
216
|
+
opt = tz.Modular(
|
|
217
|
+
model.parameters(),
|
|
218
|
+
tz.m.ClipNorm(1),
|
|
219
|
+
tz.m.Adam(),
|
|
220
|
+
tz.m.LR(1e-2),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
Update norm clipping:
|
|
224
|
+
|
|
225
|
+
.. code-block:: python
|
|
226
|
+
|
|
227
|
+
opt = tz.Modular(
|
|
228
|
+
model.parameters(),
|
|
229
|
+
tz.m.Adam(),
|
|
230
|
+
tz.m.ClipNorm(1),
|
|
231
|
+
tz.m.LR(1e-2),
|
|
232
|
+
)
|
|
175
233
|
"""
|
|
176
234
|
def __init__(
|
|
177
235
|
self,
|
|
178
236
|
max_norm: float,
|
|
179
|
-
ord: float = 2,
|
|
237
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
180
238
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
181
239
|
inverse_dims: bool = False,
|
|
182
240
|
min_size: int = 1,
|
|
183
241
|
target: Target = "update",
|
|
184
242
|
):
|
|
185
243
|
defaults = dict(max_norm=max_norm,ord=ord,dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
186
|
-
super().__init__(defaults,
|
|
244
|
+
super().__init__(defaults, target=target)
|
|
187
245
|
|
|
188
246
|
@torch.no_grad
|
|
189
|
-
def
|
|
247
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
190
248
|
max_norm = NumberList(s['max_norm'] for s in settings)
|
|
191
249
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
192
250
|
_clip_norm_(
|
|
@@ -218,21 +276,45 @@ class Normalize(Transform):
|
|
|
218
276
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
219
277
|
target (str, optional):
|
|
220
278
|
what this affects.
|
|
279
|
+
|
|
280
|
+
Examples:
|
|
281
|
+
|
|
282
|
+
Gradient normalization:
|
|
283
|
+
|
|
284
|
+
.. code-block:: python
|
|
285
|
+
|
|
286
|
+
opt = tz.Modular(
|
|
287
|
+
model.parameters(),
|
|
288
|
+
tz.m.Normalize(1),
|
|
289
|
+
tz.m.Adam(),
|
|
290
|
+
tz.m.LR(1e-2),
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
Update normalization:
|
|
294
|
+
|
|
295
|
+
.. code-block:: python
|
|
296
|
+
|
|
297
|
+
opt = tz.Modular(
|
|
298
|
+
model.parameters(),
|
|
299
|
+
tz.m.Adam(),
|
|
300
|
+
tz.m.Normalize(1),
|
|
301
|
+
tz.m.LR(1e-2),
|
|
302
|
+
)
|
|
221
303
|
"""
|
|
222
304
|
def __init__(
|
|
223
305
|
self,
|
|
224
306
|
norm_value: float = 1,
|
|
225
|
-
ord: float = 2,
|
|
307
|
+
ord: float | Literal['mean_abs'] = 2,
|
|
226
308
|
dim: int | Sequence[int] | Literal["global"] | None = None,
|
|
227
309
|
inverse_dims: bool = False,
|
|
228
310
|
min_size: int = 1,
|
|
229
311
|
target: Target = "update",
|
|
230
312
|
):
|
|
231
313
|
defaults = dict(norm_value=norm_value,ord=ord,dim=dim,min_size=min_size, inverse_dims=inverse_dims)
|
|
232
|
-
super().__init__(defaults,
|
|
314
|
+
super().__init__(defaults, target=target)
|
|
233
315
|
|
|
234
316
|
@torch.no_grad
|
|
235
|
-
def
|
|
317
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
236
318
|
norm_value = NumberList(s['norm_value'] for s in settings)
|
|
237
319
|
ord, dim, min_size, inverse_dims = itemgetter('ord', 'dim', 'min_size', 'inverse_dims')(settings[0])
|
|
238
320
|
|
|
@@ -299,6 +381,21 @@ class Centralize(Transform):
|
|
|
299
381
|
if True, the `dims` argument is inverted, and all other dimensions are centralized.
|
|
300
382
|
min_size (int, optional):
|
|
301
383
|
minimal size of a dimension to normalize along it. Defaults to 1.
|
|
384
|
+
|
|
385
|
+
Examples:
|
|
386
|
+
|
|
387
|
+
Standard gradient centralization:
|
|
388
|
+
|
|
389
|
+
.. code-block:: python
|
|
390
|
+
|
|
391
|
+
opt = tz.Modular(
|
|
392
|
+
model.parameters(),
|
|
393
|
+
tz.m.Centralize(dim=0),
|
|
394
|
+
tz.m.LR(1e-2),
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
References:
|
|
398
|
+
- Yong, H., Huang, J., Hua, X., & Zhang, L. (2020). Gradient centralization: A new optimization technique for deep neural networks. In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part I 16 (pp. 635-652). Springer International Publishing. https://arxiv.org/abs/2004.01461
|
|
302
399
|
"""
|
|
303
400
|
def __init__(
|
|
304
401
|
self,
|
|
@@ -308,10 +405,10 @@ class Centralize(Transform):
|
|
|
308
405
|
target: Target = "update",
|
|
309
406
|
):
|
|
310
407
|
defaults = dict(dim=dim,min_size=min_size,inverse_dims=inverse_dims)
|
|
311
|
-
super().__init__(defaults,
|
|
408
|
+
super().__init__(defaults, target=target)
|
|
312
409
|
|
|
313
410
|
@torch.no_grad
|
|
314
|
-
def
|
|
411
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
315
412
|
dim, min_size, inverse_dims = itemgetter('dim', 'min_size', 'inverse_dims')(settings[0])
|
|
316
413
|
|
|
317
414
|
_centralize_(tensors_ = TensorList(tensors), dim=dim, inverse_dims=inverse_dims, min_size=min_size)
|