torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -494
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.1.dist-info/METADATA +379 -0
- torchzero-0.3.1.dist-info/RECORD +128 -0
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
- {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.1.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -132
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.7.dist-info/METADATA +0 -120
- torchzero-0.1.7.dist-info/RECORD +0 -104
- torchzero-0.1.7.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from functools import partial
|
|
5
|
+
from operator import itemgetter
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ...core import Module, Target, Vars
|
|
12
|
+
from ...utils import tofloat
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MaxLineSearchItersReached(Exception): pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LineSearch(Module, ABC):
|
|
19
|
+
"""Base class for line searches.
|
|
20
|
+
This is an abstract class, to use it, subclass it and override `search`.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
defaults (dict[str, Any] | None): dictionary with defaults.
|
|
24
|
+
maxiter (int | None, optional):
|
|
25
|
+
if this is specified, the search method will terminate upon evaluating
|
|
26
|
+
the objective this many times, and step size with the lowest loss value will be used.
|
|
27
|
+
This is useful when passing `make_objective` to an external library which
|
|
28
|
+
doesn't have a maxiter option. Defaults to None.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
|
|
31
|
+
super().__init__(defaults)
|
|
32
|
+
self._maxiter = maxiter
|
|
33
|
+
self._reset()
|
|
34
|
+
|
|
35
|
+
def _reset(self):
|
|
36
|
+
self._current_step_size: float = 0
|
|
37
|
+
self._lowest_loss = float('inf')
|
|
38
|
+
self._best_step_size: float = 0
|
|
39
|
+
self._current_iter = 0
|
|
40
|
+
|
|
41
|
+
def set_step_size_(
|
|
42
|
+
self,
|
|
43
|
+
step_size: float,
|
|
44
|
+
params: list[torch.Tensor],
|
|
45
|
+
update: list[torch.Tensor],
|
|
46
|
+
):
|
|
47
|
+
if not math.isfinite(step_size): return
|
|
48
|
+
step_size = max(min(tofloat(step_size), 1e36), -1e36) # fixes overflow when backtracking keeps increasing alpha after converging
|
|
49
|
+
alpha = self._current_step_size - step_size
|
|
50
|
+
if alpha != 0:
|
|
51
|
+
torch._foreach_add_(params, update, alpha=alpha)
|
|
52
|
+
self._current_step_size = step_size
|
|
53
|
+
|
|
54
|
+
def _set_per_parameter_step_size_(
|
|
55
|
+
self,
|
|
56
|
+
step_size: Sequence[float],
|
|
57
|
+
params: list[torch.Tensor],
|
|
58
|
+
update: list[torch.Tensor],
|
|
59
|
+
):
|
|
60
|
+
if not np.isfinite(step_size): step_size = [0 for _ in step_size]
|
|
61
|
+
alpha = [self._current_step_size - s for s in step_size]
|
|
62
|
+
if any(a!=0 for a in alpha):
|
|
63
|
+
torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
64
|
+
|
|
65
|
+
def _loss(self, step_size: float, vars: Vars, closure, params: list[torch.Tensor],
|
|
66
|
+
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
67
|
+
|
|
68
|
+
# if step_size is 0, we might already know the loss
|
|
69
|
+
if (vars.loss is not None) and (step_size == 0):
|
|
70
|
+
return tofloat(vars.loss)
|
|
71
|
+
|
|
72
|
+
# check max iter
|
|
73
|
+
if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
|
|
74
|
+
self._current_iter += 1
|
|
75
|
+
|
|
76
|
+
# set new lr and evaluate loss with it
|
|
77
|
+
self.set_step_size_(step_size, params=params, update=update)
|
|
78
|
+
if backward:
|
|
79
|
+
with torch.enable_grad(): loss = closure()
|
|
80
|
+
else:
|
|
81
|
+
loss = closure(False)
|
|
82
|
+
|
|
83
|
+
# if it is the best so far, record it
|
|
84
|
+
if loss < self._lowest_loss:
|
|
85
|
+
self._lowest_loss = tofloat(loss)
|
|
86
|
+
self._best_step_size = step_size
|
|
87
|
+
|
|
88
|
+
# if evaluated loss at step size 0, set it to vars.loss
|
|
89
|
+
if step_size == 0:
|
|
90
|
+
vars.loss = loss
|
|
91
|
+
if backward: vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
92
|
+
|
|
93
|
+
return tofloat(loss)
|
|
94
|
+
|
|
95
|
+
def _loss_derivative(self, step_size: float, vars: Vars, closure,
|
|
96
|
+
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
97
|
+
# if step_size is 0, we might already know the derivative
|
|
98
|
+
if (vars.grad is not None) and (step_size == 0):
|
|
99
|
+
loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=False)
|
|
100
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul(vars.grad, update))
|
|
101
|
+
|
|
102
|
+
else:
|
|
103
|
+
# loss with a backward pass sets params.grad
|
|
104
|
+
loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=True)
|
|
105
|
+
|
|
106
|
+
# directional derivative
|
|
107
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
108
|
+
else torch.zeros_like(p) for p in params], update))
|
|
109
|
+
|
|
110
|
+
return loss, tofloat(derivative)
|
|
111
|
+
|
|
112
|
+
def evaluate_step_size(self, step_size: float, vars: Vars, backward:bool=False):
|
|
113
|
+
closure = vars.closure
|
|
114
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
115
|
+
return self._loss(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update(),backward=backward)
|
|
116
|
+
|
|
117
|
+
def evaluate_step_size_loss_and_derivative(self, step_size: float, vars: Vars):
|
|
118
|
+
closure = vars.closure
|
|
119
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
120
|
+
return self._loss_derivative(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update())
|
|
121
|
+
|
|
122
|
+
def make_objective(self, vars: Vars, backward:bool=False):
|
|
123
|
+
closure = vars.closure
|
|
124
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
125
|
+
return partial(self._loss, vars=vars, closure=closure, params=vars.params, update=vars.get_update(), backward=backward)
|
|
126
|
+
|
|
127
|
+
def make_objective_with_derivative(self, vars: Vars):
|
|
128
|
+
closure = vars.closure
|
|
129
|
+
if closure is None: raise RuntimeError('line search requires closure')
|
|
130
|
+
return partial(self._loss_derivative, vars=vars, closure=closure, params=vars.params, update=vars.get_update())
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def search(self, update: list[torch.Tensor], vars: Vars) -> float:
|
|
134
|
+
"""Finds the step size to use"""
|
|
135
|
+
|
|
136
|
+
@torch.no_grad
|
|
137
|
+
def step(self, vars: Vars) -> Vars:
|
|
138
|
+
self._reset()
|
|
139
|
+
params = vars.params
|
|
140
|
+
update = vars.get_update()
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
step_size = self.search(update=update, vars=vars)
|
|
144
|
+
except MaxLineSearchItersReached:
|
|
145
|
+
step_size = self._best_step_size
|
|
146
|
+
|
|
147
|
+
# set loss_approx
|
|
148
|
+
if vars.loss_approx is None: vars.loss_approx = self._lowest_loss
|
|
149
|
+
|
|
150
|
+
# this is last module - set step size to found step_size times lr
|
|
151
|
+
if vars.is_last:
|
|
152
|
+
|
|
153
|
+
if vars.last_module_lrs is None:
|
|
154
|
+
self.set_step_size_(step_size, params=params, update=update)
|
|
155
|
+
|
|
156
|
+
else:
|
|
157
|
+
self._set_per_parameter_step_size_([step_size*lr for lr in vars.last_module_lrs], params=params, update=update)
|
|
158
|
+
|
|
159
|
+
vars.stop = True; vars.skip_update = True
|
|
160
|
+
return vars
|
|
161
|
+
|
|
162
|
+
# revert parameters and multiply update by step size
|
|
163
|
+
self.set_step_size_(0, params=params, update=update)
|
|
164
|
+
torch._foreach_mul_(vars.update, step_size)
|
|
165
|
+
return vars
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class GridLineSearch(LineSearch):
|
|
169
|
+
"""Mostly for testing, this is not practical"""
|
|
170
|
+
def __init__(self, start, end, num):
|
|
171
|
+
defaults = dict(start=start,end=end,num=num)
|
|
172
|
+
super().__init__(defaults)
|
|
173
|
+
|
|
174
|
+
@torch.no_grad
|
|
175
|
+
def search(self, update, vars):
|
|
176
|
+
start,end,num=itemgetter('start','end','num')(self.settings[vars.params[0]])
|
|
177
|
+
|
|
178
|
+
for lr in torch.linspace(start,end,num):
|
|
179
|
+
self.evaluate_step_size(lr.item(), vars=vars, backward=False)
|
|
180
|
+
|
|
181
|
+
return self._best_step_size
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from operator import itemgetter
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .line_search import LineSearch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ScipyMinimizeScalar(LineSearch):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
method: str | None = None,
|
|
13
|
+
maxiter: int | None = None,
|
|
14
|
+
bracket=None,
|
|
15
|
+
bounds=None,
|
|
16
|
+
tol: float | None = None,
|
|
17
|
+
options=None,
|
|
18
|
+
):
|
|
19
|
+
defaults = dict(method=method,bracket=bracket,bounds=bounds,tol=tol,options=options,maxiter=maxiter)
|
|
20
|
+
super().__init__(defaults)
|
|
21
|
+
|
|
22
|
+
import scipy.optimize
|
|
23
|
+
self.scopt = scipy.optimize
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@torch.no_grad
|
|
27
|
+
def search(self, update, vars):
|
|
28
|
+
objective = self.make_objective(vars=vars)
|
|
29
|
+
method, bracket, bounds, tol, options, maxiter = itemgetter(
|
|
30
|
+
'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[vars.params[0]])
|
|
31
|
+
|
|
32
|
+
if maxiter is not None:
|
|
33
|
+
options = dict(options) if isinstance(options, Mapping) else {}
|
|
34
|
+
options['maxiter'] = maxiter
|
|
35
|
+
|
|
36
|
+
res = self.scopt.minimize_scalar(objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
|
|
37
|
+
return res.x
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import warnings
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.optim.lbfgs import _cubic_interpolate
|
|
7
|
+
|
|
8
|
+
from .line_search import LineSearch
|
|
9
|
+
from .backtracking import backtracking_line_search
|
|
10
|
+
from ...utils import totensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _zoom(f,
|
|
14
|
+
a_l, a_h,
|
|
15
|
+
f_l, g_l,
|
|
16
|
+
f_h, g_h,
|
|
17
|
+
f_0, g_0,
|
|
18
|
+
c1, c2,
|
|
19
|
+
maxzoom):
|
|
20
|
+
|
|
21
|
+
for i in range(maxzoom):
|
|
22
|
+
a_j = _cubic_interpolate(
|
|
23
|
+
*(totensor(i) for i in (a_l, f_l, g_l, a_h, f_h, g_h))
|
|
24
|
+
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# if interpolation fails or produces endpoint, bisect
|
|
28
|
+
delta = abs(a_h - a_l)
|
|
29
|
+
if a_j is None or a_j == a_l or a_j == a_h:
|
|
30
|
+
a_j = a_l + 0.5 * delta
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
f_j, g_j = f(a_j)
|
|
34
|
+
|
|
35
|
+
# check armijo
|
|
36
|
+
armijo = f_j <= f_0 + c1 * a_j * g_0
|
|
37
|
+
|
|
38
|
+
# check strong wolfe
|
|
39
|
+
wolfe = abs(g_j) <= c2 * abs(g_0)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# minimum between alpha_low and alpha_j
|
|
43
|
+
if not armijo or f_j >= f_l:
|
|
44
|
+
a_h = a_j
|
|
45
|
+
f_h = f_j
|
|
46
|
+
g_h = g_j
|
|
47
|
+
else:
|
|
48
|
+
# alpha_j satisfies armijo
|
|
49
|
+
if wolfe:
|
|
50
|
+
return a_j, f_j
|
|
51
|
+
|
|
52
|
+
# minimum between alpha_j and alpha_high
|
|
53
|
+
if g_j * (a_h - a_l) >= 0:
|
|
54
|
+
# between alpha_low and alpha_j
|
|
55
|
+
# a_h = a_l
|
|
56
|
+
# f_h = f_l
|
|
57
|
+
# g_h = g_l
|
|
58
|
+
a_h = a_j
|
|
59
|
+
f_h = f_j
|
|
60
|
+
g_h = g_j
|
|
61
|
+
|
|
62
|
+
# is this messing it up?
|
|
63
|
+
else:
|
|
64
|
+
a_l = a_j
|
|
65
|
+
f_l = f_j
|
|
66
|
+
g_l = g_j
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# check if interval too small
|
|
72
|
+
delta = abs(a_h - a_l)
|
|
73
|
+
if delta <= 1e-9 or delta <= 1e-6 * max(abs(a_l), abs(a_h)):
|
|
74
|
+
l_satisfies_wolfe = (f_l <= f_0 + c1 * a_l * g_0) and (abs(g_l) <= c2 * abs(g_0))
|
|
75
|
+
h_satisfies_wolfe = (f_h <= f_0 + c1 * a_h * g_0) and (abs(g_h) <= c2 * abs(g_0))
|
|
76
|
+
|
|
77
|
+
if l_satisfies_wolfe and h_satisfies_wolfe: return a_l if f_l <= f_h else a_h, f_h
|
|
78
|
+
if l_satisfies_wolfe: return a_l, f_l
|
|
79
|
+
if h_satisfies_wolfe: return a_h, f_h
|
|
80
|
+
if f_l <= f_0 + c1 * a_l * g_0: return a_l, f_l
|
|
81
|
+
return None,None
|
|
82
|
+
|
|
83
|
+
if a_j is None or a_j == a_l or a_j == a_h:
|
|
84
|
+
a_j = a_l + 0.5 * delta
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
return None,None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def strong_wolfe(
|
|
91
|
+
f,
|
|
92
|
+
f_0,
|
|
93
|
+
g_0,
|
|
94
|
+
init: float = 1.0,
|
|
95
|
+
c1: float = 1e-4,
|
|
96
|
+
c2: float = 0.9,
|
|
97
|
+
maxiter: int = 25,
|
|
98
|
+
maxzoom: int = 15,
|
|
99
|
+
# a_max: float = 1e30,
|
|
100
|
+
expand: float = 2.0, # Factor to increase alpha in bracketing
|
|
101
|
+
plus_minus: bool = False,
|
|
102
|
+
) -> tuple[float,float] | tuple[None,None]:
|
|
103
|
+
a_prev = 0.0
|
|
104
|
+
|
|
105
|
+
if g_0 == 0: return None,None
|
|
106
|
+
if g_0 > 0:
|
|
107
|
+
# if direction is not a descent direction, perform line search in opposite direction
|
|
108
|
+
if plus_minus:
|
|
109
|
+
def inverted_objective(alpha):
|
|
110
|
+
l, g = f(-alpha)
|
|
111
|
+
return l, -g
|
|
112
|
+
a, v = strong_wolfe(
|
|
113
|
+
inverted_objective,
|
|
114
|
+
init=init,
|
|
115
|
+
f_0=f_0,
|
|
116
|
+
g_0=-g_0,
|
|
117
|
+
c1=c1,
|
|
118
|
+
c2=c2,
|
|
119
|
+
maxiter=maxiter,
|
|
120
|
+
# a_max=a_max,
|
|
121
|
+
expand=expand,
|
|
122
|
+
plus_minus=False,
|
|
123
|
+
)
|
|
124
|
+
if a is not None and v is not None: return -a, v
|
|
125
|
+
return None, None
|
|
126
|
+
|
|
127
|
+
f_prev = f_0
|
|
128
|
+
g_prev = g_0
|
|
129
|
+
a_cur = init
|
|
130
|
+
|
|
131
|
+
# bracket
|
|
132
|
+
for i in range(maxiter):
|
|
133
|
+
|
|
134
|
+
f_cur, g_cur = f(a_cur)
|
|
135
|
+
|
|
136
|
+
# check armijo
|
|
137
|
+
armijo_violated = f_cur > f_0 + c1 * a_cur * g_0
|
|
138
|
+
func_increased = f_cur >= f_prev and i > 0
|
|
139
|
+
|
|
140
|
+
if armijo_violated or func_increased:
|
|
141
|
+
return _zoom(f,
|
|
142
|
+
a_prev, a_cur,
|
|
143
|
+
f_prev, g_prev,
|
|
144
|
+
f_cur, g_cur,
|
|
145
|
+
f_0, g_0,
|
|
146
|
+
c1, c2,
|
|
147
|
+
maxzoom=maxzoom,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# check strong wolfe
|
|
153
|
+
if abs(g_cur) <= c2 * abs(g_0):
|
|
154
|
+
return a_cur, f_cur
|
|
155
|
+
|
|
156
|
+
# minimum is bracketed
|
|
157
|
+
if g_cur >= 0:
|
|
158
|
+
return _zoom(f,
|
|
159
|
+
#alpha_curr, alpha_prev,
|
|
160
|
+
a_prev, a_cur,
|
|
161
|
+
#phi_curr, phi_prime_curr,
|
|
162
|
+
f_prev, g_prev,
|
|
163
|
+
f_cur, g_cur,
|
|
164
|
+
f_0, g_0,
|
|
165
|
+
c1, c2,
|
|
166
|
+
maxzoom=maxzoom,)
|
|
167
|
+
|
|
168
|
+
# otherwise continue bracketing
|
|
169
|
+
a_next = a_cur * expand
|
|
170
|
+
|
|
171
|
+
# update previous point and continue loop with increased step size
|
|
172
|
+
a_prev = a_cur
|
|
173
|
+
f_prev = f_cur
|
|
174
|
+
g_prev = g_cur
|
|
175
|
+
a_cur = a_next
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# max iters reached
|
|
179
|
+
return None, None
|
|
180
|
+
|
|
181
|
+
def _notfinite(x):
|
|
182
|
+
if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
|
|
183
|
+
return not math.isfinite(x)
|
|
184
|
+
|
|
185
|
+
class StrongWolfe(LineSearch):
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
init: float = 1.0,
|
|
189
|
+
c1: float = 1e-4,
|
|
190
|
+
c2: float = 0.9,
|
|
191
|
+
maxiter: int = 25,
|
|
192
|
+
maxzoom: int = 10,
|
|
193
|
+
# a_max: float = 1e10,
|
|
194
|
+
expand: float = 2.0,
|
|
195
|
+
adaptive = True,
|
|
196
|
+
fallback = False,
|
|
197
|
+
plus_minus = False,
|
|
198
|
+
):
|
|
199
|
+
defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
|
|
200
|
+
expand=expand, adaptive=adaptive, fallback=fallback, plus_minus=plus_minus)
|
|
201
|
+
super().__init__(defaults=defaults)
|
|
202
|
+
|
|
203
|
+
self.global_state['initial_scale'] = 1.0
|
|
204
|
+
self.global_state['beta_scale'] = 1.0
|
|
205
|
+
|
|
206
|
+
@torch.no_grad
|
|
207
|
+
def search(self, update, vars):
|
|
208
|
+
objective = self.make_objective_with_derivative(vars=vars)
|
|
209
|
+
|
|
210
|
+
init, c1, c2, maxiter, maxzoom, expand, adaptive, fallback, plus_minus = itemgetter(
|
|
211
|
+
'init', 'c1', 'c2', 'maxiter', 'maxzoom',
|
|
212
|
+
'expand', 'adaptive', 'fallback', 'plus_minus')(self.settings[vars.params[0]])
|
|
213
|
+
|
|
214
|
+
f_0, g_0 = objective(0)
|
|
215
|
+
|
|
216
|
+
step_size,f_a = strong_wolfe(
|
|
217
|
+
objective,
|
|
218
|
+
f_0=f_0, g_0=g_0,
|
|
219
|
+
init=init * self.global_state.setdefault("initial_scale", 1),
|
|
220
|
+
c1=c1,
|
|
221
|
+
c2=c2,
|
|
222
|
+
maxiter=maxiter,
|
|
223
|
+
maxzoom=maxzoom,
|
|
224
|
+
expand=expand,
|
|
225
|
+
plus_minus=plus_minus,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
|
|
229
|
+
if step_size is not None and step_size != 0 and not _notfinite(step_size):
|
|
230
|
+
self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
|
|
231
|
+
return step_size
|
|
232
|
+
|
|
233
|
+
# fallback to backtracking on fail
|
|
234
|
+
if adaptive: self.global_state['initial_scale'] *= 0.5
|
|
235
|
+
if not fallback: return 0
|
|
236
|
+
|
|
237
|
+
objective = self.make_objective(vars=vars)
|
|
238
|
+
|
|
239
|
+
# # directional derivative
|
|
240
|
+
g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
|
|
241
|
+
|
|
242
|
+
step_size = backtracking_line_search(
|
|
243
|
+
objective,
|
|
244
|
+
g_0,
|
|
245
|
+
init=init * self.global_state["initial_scale"],
|
|
246
|
+
beta=0.5 * self.global_state["beta_scale"],
|
|
247
|
+
c=c1,
|
|
248
|
+
maxiter=maxiter * 2,
|
|
249
|
+
a_min=None,
|
|
250
|
+
try_negative=plus_minus,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# found an alpha that reduces loss
|
|
254
|
+
if step_size is not None:
|
|
255
|
+
self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
|
|
256
|
+
return step_size
|
|
257
|
+
|
|
258
|
+
# on fail reduce beta scale value
|
|
259
|
+
self.global_state['beta_scale'] /= 1.5
|
|
260
|
+
return 0
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from operator import itemgetter
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .line_search import LineSearch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TrustRegion(LineSearch):
|
|
9
|
+
"""Basic first order trust region, re-evaluates closure with updated parameters and scales step size based on function value change"""
|
|
10
|
+
def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
|
|
11
|
+
defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
|
|
12
|
+
super().__init__(defaults)
|
|
13
|
+
|
|
14
|
+
@torch.no_grad
|
|
15
|
+
def search(self, update, vars):
|
|
16
|
+
|
|
17
|
+
nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[vars.params[0]])
|
|
18
|
+
step_size = self.global_state.setdefault('step_size', init)
|
|
19
|
+
previous_success = self.global_state.setdefault('previous_success', False)
|
|
20
|
+
nplus_mul = self.global_state.setdefault('nplus_mul', 1)
|
|
21
|
+
nminus_mul = self.global_state.setdefault('nminus_mul', 1)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
f_0 = self.evaluate_step_size(0, vars, backward=False)
|
|
25
|
+
|
|
26
|
+
# directional derivative (0 if c = 0 because it is not needed)
|
|
27
|
+
if c == 0: d = 0
|
|
28
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
|
|
29
|
+
|
|
30
|
+
# test step size
|
|
31
|
+
sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
|
|
32
|
+
|
|
33
|
+
f_1 = self.evaluate_step_size(step_size, vars, backward=False)
|
|
34
|
+
|
|
35
|
+
proposed = step_size
|
|
36
|
+
|
|
37
|
+
# very good step
|
|
38
|
+
if f_1 < sufficient_f:
|
|
39
|
+
self.global_state['step_size'] *= nplus * nplus_mul
|
|
40
|
+
|
|
41
|
+
# two very good steps in a row - increase nplus_mul
|
|
42
|
+
if adaptive:
|
|
43
|
+
if previous_success: self.global_state['nplus_mul'] *= nplus
|
|
44
|
+
else: self.global_state['nplus_mul'] = 1
|
|
45
|
+
|
|
46
|
+
# acceptable step step
|
|
47
|
+
#elif f_1 <= f_0: pass
|
|
48
|
+
|
|
49
|
+
# bad step
|
|
50
|
+
if f_1 >= f_0:
|
|
51
|
+
self.global_state['step_size'] *= nminus * nminus_mul
|
|
52
|
+
|
|
53
|
+
# two bad steps in a row - decrease nminus_mul
|
|
54
|
+
if adaptive:
|
|
55
|
+
if previous_success: self.global_state['nminus_mul'] *= nminus
|
|
56
|
+
else: self.global_state['nminus_mul'] = 1
|
|
57
|
+
|
|
58
|
+
if backtrack: proposed = 0
|
|
59
|
+
else: proposed *= nminus * nminus_mul
|
|
60
|
+
|
|
61
|
+
return proposed
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from ...core import Transform
|
|
4
|
+
from ...utils import NumberList, TensorList, generic_eq
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
|
|
8
|
+
"""multiplies by lr if lr is not 1"""
|
|
9
|
+
if generic_eq(lr, 1): return tensors
|
|
10
|
+
if inplace: return tensors.mul_(lr)
|
|
11
|
+
return tensors * lr
|
|
12
|
+
|
|
13
|
+
class LR(Transform):
|
|
14
|
+
def __init__(self, lr: float):
|
|
15
|
+
defaults=dict(lr=lr)
|
|
16
|
+
super().__init__(defaults, uses_grad=False)
|
|
17
|
+
|
|
18
|
+
@torch.no_grad
|
|
19
|
+
def transform(self, tensors, params, grads, vars):
|
|
20
|
+
return lazy_lr(TensorList(tensors), lr=self.get_settings('lr', params=params), inplace=True)
|
|
21
|
+
|
|
22
|
+
class StepSize(Transform):
|
|
23
|
+
"""this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
|
|
24
|
+
def __init__(self, step_size: float, key = 'step_size'):
|
|
25
|
+
defaults={"key": key, key: step_size}
|
|
26
|
+
super().__init__(defaults, uses_grad=False)
|
|
27
|
+
|
|
28
|
+
@torch.no_grad
|
|
29
|
+
def transform(self, tensors, params, grads, vars):
|
|
30
|
+
lrs = []
|
|
31
|
+
for p in params:
|
|
32
|
+
settings = self.settings[p]
|
|
33
|
+
lrs.append(settings[settings['key']])
|
|
34
|
+
return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def warmup(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
|
|
38
|
+
"""returns warm up lr scalar"""
|
|
39
|
+
if step > steps: return end_lr
|
|
40
|
+
return start_lr + (end_lr - start_lr) * (step / steps)
|
|
41
|
+
|
|
42
|
+
class Warmup(Transform):
|
|
43
|
+
def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
|
|
44
|
+
defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
|
|
45
|
+
super().__init__(defaults, uses_grad=False)
|
|
46
|
+
|
|
47
|
+
@torch.no_grad
|
|
48
|
+
def transform(self, tensors, params, grads, vars):
|
|
49
|
+
start_lr, end_lr = self.get_settings('start_lr', 'end_lr', params=params, cls = NumberList)
|
|
50
|
+
num_steps = self.settings[params[0]]['steps']
|
|
51
|
+
step = self.global_state.get('step', 0)
|
|
52
|
+
|
|
53
|
+
target = lazy_lr(
|
|
54
|
+
TensorList(tensors),
|
|
55
|
+
lr=warmup(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
|
|
56
|
+
inplace=True
|
|
57
|
+
)
|
|
58
|
+
self.global_state['step'] = step + 1
|
|
59
|
+
return target
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...core import Transform
|
|
7
|
+
from ...utils import TensorList, NumberList
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PolyakStepSize(Transform):
|
|
11
|
+
"""Polyak step-size.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
15
|
+
min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
16
|
+
use_grad (bool, optional):
|
|
17
|
+
if True, uses dot product of update and gradient to compute the step size.
|
|
18
|
+
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
19
|
+
Defaults to True.
|
|
20
|
+
parameterwise (bool, optional):
|
|
21
|
+
if True, calculate Polyak step-size for each parameter separately,
|
|
22
|
+
if False calculate one global step size for all parameters. Defaults to False.
|
|
23
|
+
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
26
|
+
|
|
27
|
+
defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
|
|
28
|
+
super().__init__(defaults, uses_grad=use_grad)
|
|
29
|
+
|
|
30
|
+
@torch.no_grad
|
|
31
|
+
def transform(self, tensors, params, grads, vars):
|
|
32
|
+
loss = vars.get_loss(False)
|
|
33
|
+
assert grads is not None
|
|
34
|
+
tensors = TensorList(tensors)
|
|
35
|
+
grads = TensorList(grads)
|
|
36
|
+
alpha = self.get_settings('alpha', params=params, cls=NumberList)
|
|
37
|
+
settings = self.settings[params[0]]
|
|
38
|
+
parameterwise = settings['parameterwise']
|
|
39
|
+
use_grad = settings['use_grad']
|
|
40
|
+
max = settings['max']
|
|
41
|
+
min_obj_value = settings['min_obj_value']
|
|
42
|
+
|
|
43
|
+
if parameterwise:
|
|
44
|
+
if use_grad: denom = (tensors * grads).sum()
|
|
45
|
+
else: denom = tensors.pow(2).sum()
|
|
46
|
+
polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
|
|
47
|
+
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
48
|
+
if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
if use_grad: denom = tensors.dot(grads)
|
|
52
|
+
else: denom = tensors.dot(tensors)
|
|
53
|
+
if denom == 0: polyak_step_size = 0 # we converged
|
|
54
|
+
else: polyak_step_size = (loss - min_obj_value) / denom
|
|
55
|
+
|
|
56
|
+
if max is not None:
|
|
57
|
+
if polyak_step_size > max: polyak_step_size = max
|
|
58
|
+
|
|
59
|
+
tensors.mul_(alpha * polyak_step_size)
|
|
60
|
+
return tensors
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RandomStepSize(Transform):
|
|
65
|
+
"""Uses random global step size from `low` to `high`.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
low (float, optional): minimum learning rate. Defaults to 0.
|
|
69
|
+
high (float, optional): maximum learning rate. Defaults to 1.
|
|
70
|
+
parameterwise (bool, optional):
|
|
71
|
+
if True, generate random step size for each parameter separately,
|
|
72
|
+
if False generate one global random step size. Defaults to False.
|
|
73
|
+
"""
|
|
74
|
+
def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
|
|
75
|
+
defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
|
|
76
|
+
super().__init__(defaults, uses_grad=False)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def transform(self, tensors, params, grads, vars):
|
|
80
|
+
settings = self.settings[params[0]]
|
|
81
|
+
parameterwise = settings['parameterwise']
|
|
82
|
+
|
|
83
|
+
seed = settings['seed']
|
|
84
|
+
if 'generator' not in self.global_state:
|
|
85
|
+
self.global_state['generator'] = random.Random(seed)
|
|
86
|
+
generator: random.Random = self.global_state['generator']
|
|
87
|
+
|
|
88
|
+
if parameterwise:
|
|
89
|
+
low, high = self.get_settings('low', 'high', params=params)
|
|
90
|
+
lr = [generator.uniform(l, h) for l, h in zip(low, high)]
|
|
91
|
+
else:
|
|
92
|
+
low = settings['low']
|
|
93
|
+
high = settings['high']
|
|
94
|
+
lr = generator.uniform(low, high)
|
|
95
|
+
|
|
96
|
+
torch._foreach_mul_(tensors, lr)
|
|
97
|
+
return tensors
|