torchzero 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/__init__.py +4 -0
- torchzero/core/__init__.py +13 -0
- torchzero/core/module.py +471 -0
- torchzero/core/tensorlist_optimizer.py +219 -0
- torchzero/modules/__init__.py +21 -0
- torchzero/modules/adaptive/__init__.py +4 -0
- torchzero/modules/adaptive/adaptive.py +192 -0
- torchzero/modules/experimental/__init__.py +19 -0
- torchzero/modules/experimental/experimental.py +294 -0
- torchzero/modules/experimental/quad_interp.py +104 -0
- torchzero/modules/experimental/subspace.py +259 -0
- torchzero/modules/gradient_approximation/__init__.py +7 -0
- torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
- torchzero/modules/gradient_approximation/base_approximator.py +110 -0
- torchzero/modules/gradient_approximation/fdm.py +125 -0
- torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
- torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
- torchzero/modules/gradient_approximation/rfdm.py +125 -0
- torchzero/modules/line_search/__init__.py +30 -0
- torchzero/modules/line_search/armijo.py +56 -0
- torchzero/modules/line_search/base_ls.py +139 -0
- torchzero/modules/line_search/directional_newton.py +217 -0
- torchzero/modules/line_search/grid_ls.py +158 -0
- torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
- torchzero/modules/meta/__init__.py +12 -0
- torchzero/modules/meta/alternate.py +65 -0
- torchzero/modules/meta/grafting.py +195 -0
- torchzero/modules/meta/optimizer_wrapper.py +173 -0
- torchzero/modules/meta/return_overrides.py +46 -0
- torchzero/modules/misc/__init__.py +10 -0
- torchzero/modules/misc/accumulate.py +43 -0
- torchzero/modules/misc/basic.py +115 -0
- torchzero/modules/misc/lr.py +96 -0
- torchzero/modules/misc/multistep.py +51 -0
- torchzero/modules/misc/on_increase.py +53 -0
- torchzero/modules/momentum/__init__.py +4 -0
- torchzero/modules/momentum/momentum.py +106 -0
- torchzero/modules/operations/__init__.py +29 -0
- torchzero/modules/operations/multi.py +298 -0
- torchzero/modules/operations/reduction.py +134 -0
- torchzero/modules/operations/singular.py +113 -0
- torchzero/modules/optimizers/__init__.py +10 -0
- torchzero/modules/optimizers/adagrad.py +49 -0
- torchzero/modules/optimizers/adam.py +118 -0
- torchzero/modules/optimizers/lion.py +28 -0
- torchzero/modules/optimizers/rmsprop.py +51 -0
- torchzero/modules/optimizers/rprop.py +99 -0
- torchzero/modules/optimizers/sgd.py +54 -0
- torchzero/modules/orthogonalization/__init__.py +2 -0
- torchzero/modules/orthogonalization/newtonschulz.py +159 -0
- torchzero/modules/orthogonalization/svd.py +86 -0
- torchzero/modules/quasi_newton/__init__.py +4 -0
- torchzero/modules/regularization/__init__.py +22 -0
- torchzero/modules/regularization/dropout.py +34 -0
- torchzero/modules/regularization/noise.py +77 -0
- torchzero/modules/regularization/normalization.py +328 -0
- torchzero/modules/regularization/ortho_grad.py +78 -0
- torchzero/modules/regularization/weight_decay.py +92 -0
- torchzero/modules/scheduling/__init__.py +2 -0
- torchzero/modules/scheduling/lr_schedulers.py +131 -0
- torchzero/modules/scheduling/step_size.py +80 -0
- torchzero/modules/second_order/__init__.py +4 -0
- torchzero/modules/second_order/newton.py +165 -0
- torchzero/modules/smoothing/__init__.py +5 -0
- torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
- torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
- torchzero/modules/weight_averaging/__init__.py +2 -0
- torchzero/modules/weight_averaging/ema.py +72 -0
- torchzero/modules/weight_averaging/swa.py +171 -0
- torchzero/optim/__init__.py +10 -0
- torchzero/optim/experimental/__init__.py +20 -0
- torchzero/optim/experimental/experimental.py +343 -0
- torchzero/optim/experimental/ray_search.py +83 -0
- torchzero/optim/first_order/__init__.py +18 -0
- torchzero/optim/first_order/cautious.py +158 -0
- torchzero/optim/first_order/forward_gradient.py +70 -0
- torchzero/optim/first_order/optimizers.py +570 -0
- torchzero/optim/modular.py +132 -0
- torchzero/optim/quasi_newton/__init__.py +1 -0
- torchzero/optim/quasi_newton/directional_newton.py +58 -0
- torchzero/optim/second_order/__init__.py +1 -0
- torchzero/optim/second_order/newton.py +94 -0
- torchzero/optim/wrappers/__init__.py +0 -0
- torchzero/optim/wrappers/nevergrad.py +113 -0
- torchzero/optim/wrappers/nlopt.py +165 -0
- torchzero/optim/wrappers/scipy.py +439 -0
- torchzero/optim/zeroth_order/__init__.py +4 -0
- torchzero/optim/zeroth_order/fdm.py +87 -0
- torchzero/optim/zeroth_order/newton_fdm.py +146 -0
- torchzero/optim/zeroth_order/rfdm.py +217 -0
- torchzero/optim/zeroth_order/rs.py +85 -0
- torchzero/random/__init__.py +1 -0
- torchzero/random/random.py +46 -0
- torchzero/tensorlist.py +819 -0
- torchzero/utils/__init__.py +0 -0
- torchzero/utils/compile.py +39 -0
- torchzero/utils/derivatives.py +99 -0
- torchzero/utils/python_tools.py +25 -0
- torchzero/utils/torch_tools.py +92 -0
- torchzero-0.0.1.dist-info/LICENSE +21 -0
- torchzero-0.0.1.dist-info/METADATA +118 -0
- torchzero-0.0.1.dist-info/RECORD +104 -0
- torchzero-0.0.1.dist-info/WHEEL +5 -0
- torchzero-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, overload, TYPE_CHECKING
|
|
4
|
+
import random
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from ...core import OptimizerModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ...optim import Modular
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# LR SCHEDULING MOVED TO LR MODULE
|
|
15
|
+
|
|
16
|
+
# def _set_momentum_hook(optimizer, state, momentum):
|
|
17
|
+
# for module in optimizer.unrolled_modules:
|
|
18
|
+
# if 'momentum' in module.defaults:
|
|
19
|
+
# for g in module.param_groups:
|
|
20
|
+
# g['momentum'] = momentum
|
|
21
|
+
# if 'beta1' in module.defaults:
|
|
22
|
+
# for g in module.param_groups:
|
|
23
|
+
# g['beta1'] = momentum
|
|
24
|
+
|
|
25
|
+
# def _add_scheduler_hook(opt: "Modular", scheduler_cls, id):
|
|
26
|
+
# """post-init hook that sets `scheduler_step_fn` to the scheduler step."""
|
|
27
|
+
# # get LR module
|
|
28
|
+
# lr_module = opt.get_lr_module()
|
|
29
|
+
|
|
30
|
+
# # get current LRScheduler module
|
|
31
|
+
# scheds = [i for i in opt.unrolled_modules if isinstance(i, LRScheduler)]
|
|
32
|
+
# scheds = [i for i in scheds if i.id == id]
|
|
33
|
+
# if len(scheds) != 1:
|
|
34
|
+
# raise RuntimeError(f"more than 1 module with id {id}: {scheds}")
|
|
35
|
+
|
|
36
|
+
# sch_module = scheds[0]
|
|
37
|
+
|
|
38
|
+
# # make a scheduler and save the step function
|
|
39
|
+
# scheduler = scheduler_cls(lr_module)
|
|
40
|
+
# sch_module.scheduler_step_fn = scheduler.step
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# class LRScheduler(OptimizerModule):
|
|
44
|
+
# """Use any pytorch lr scheduler.
|
|
45
|
+
|
|
46
|
+
# Important - the lr is applied multiplicatively and multiplies with learning rate of other modules,
|
|
47
|
+
# so usually base learning rate of the lr scheduler, such as `max_lr` for OneCycleLR, should be set to 1.
|
|
48
|
+
|
|
49
|
+
# Args:
|
|
50
|
+
# lr_scheduler (Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any]):
|
|
51
|
+
# something like:
|
|
52
|
+
# .. code:: py
|
|
53
|
+
# lambda opt: OneCycleLR(opt, max_lr = 1, total_steps = 60000)
|
|
54
|
+
# update_every (int, optional):
|
|
55
|
+
# call `step` every n steps, useful for schedulers that only step once per epoch. Defaults to 1.
|
|
56
|
+
# cycle_momentum (bool, optional):
|
|
57
|
+
# enables support for cycling momentum with schedulers that support it, such as `OneCycleLR`.
|
|
58
|
+
# Unlike lr, momentum is not applied multiplicatively, but set to all other modules with
|
|
59
|
+
# `momentum` or `beta` settings. Has no effect if there are no modules that support momentum. Defaults to False.
|
|
60
|
+
# init_lr (float, optional):
|
|
61
|
+
# initial lr, I believe most lr schedulers ignore this. Defaults to 1.
|
|
62
|
+
# init_momentum (float, optional):
|
|
63
|
+
# initial init_momentum, I believe most lr schedulers ignore this.
|
|
64
|
+
# Has no effect if `cycle_momentum` is False or there are no modules that support momentum. Defaults to 0.
|
|
65
|
+
# """
|
|
66
|
+
# def __init__(
|
|
67
|
+
# self,
|
|
68
|
+
# lr_scheduler: Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any],
|
|
69
|
+
# step_every: int = 1,
|
|
70
|
+
# cycle_momentum: bool = True,
|
|
71
|
+
# ):
|
|
72
|
+
# super().__init__({})
|
|
73
|
+
# scheduler = lr_scheduler(self.dummy_opt)
|
|
74
|
+
# self.update_every = step_every
|
|
75
|
+
# self.cycle_momentum = cycle_momentum
|
|
76
|
+
|
|
77
|
+
# self.scheduler_step_fn = scheduler.step
|
|
78
|
+
# self.cur = 0
|
|
79
|
+
# self.cur_lr = init_lr
|
|
80
|
+
# self.cur_momentum = init_momentum
|
|
81
|
+
|
|
82
|
+
# self.id = random.random()
|
|
83
|
+
|
|
84
|
+
# def step(self, state):
|
|
85
|
+
# if self.cur % self.update_every == 0:
|
|
86
|
+
# self.scheduler_step_fn()
|
|
87
|
+
# self.cur_lr = self.dummy_opt.first_param_group['lr']
|
|
88
|
+
# self.cur_momentum = self.dummy_opt.first_param_group['momentum']
|
|
89
|
+
|
|
90
|
+
# params = self.get_params()
|
|
91
|
+
# ascent = state.maybe_use_grad_(params)
|
|
92
|
+
# ascent *= self.cur_lr
|
|
93
|
+
|
|
94
|
+
# if self.cycle_momentum:
|
|
95
|
+
# state.add_post_step_hook(partial(_set_momentum_hook, momentum = self.cur_momentum))
|
|
96
|
+
|
|
97
|
+
class LRWarmup(OptimizerModule):
|
|
98
|
+
"""Linear learning rate warmup.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
n_steps (int): number of warmup steps.
|
|
102
|
+
start_lr (float, optional): initial lr. Defaults to 1e-8.
|
|
103
|
+
end_lr (float, optional): final lr. Defaults to 1.
|
|
104
|
+
delay_steps (int, optional): number of `start_lr` steps before starting the warmup. Defaults to 0.
|
|
105
|
+
"""
|
|
106
|
+
def __init__(self, n_steps: int, start_lr: float = 1e-8, end_lr: float = 1, delay_steps: int = 0):
|
|
107
|
+
|
|
108
|
+
super().__init__({})
|
|
109
|
+
self.n_steps = n_steps
|
|
110
|
+
self.start_lr = start_lr
|
|
111
|
+
self.end_lr = end_lr
|
|
112
|
+
self.delay_steps = delay_steps
|
|
113
|
+
|
|
114
|
+
self.cur = 0
|
|
115
|
+
|
|
116
|
+
def _update(self, state, ascent):
|
|
117
|
+
if self.cur < self.delay_steps:
|
|
118
|
+
if self.start_lr != 1: ascent *= self.start_lr
|
|
119
|
+
|
|
120
|
+
elif self.cur >= self.n_steps + self.delay_steps:
|
|
121
|
+
if self.end_lr != 1: ascent *= self.end_lr
|
|
122
|
+
|
|
123
|
+
else:
|
|
124
|
+
remaining = (self.n_steps - (self.cur-self.delay_steps)) / self.n_steps
|
|
125
|
+
lr = (self.start_lr * remaining) + self.end_lr * (1 - remaining)
|
|
126
|
+
ascent *= lr
|
|
127
|
+
|
|
128
|
+
self.cur += 1
|
|
129
|
+
return ascent
|
|
130
|
+
|
|
131
|
+
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from ...core import OptimizerModule
|
|
5
|
+
from ...tensorlist import TensorList
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolyakStepSize(OptimizerModule):
|
|
9
|
+
"""Polyak step-size. Meant to be used at the beginning when ascent is the gradient but other placements may work.
|
|
10
|
+
This can also work with SGD as SPS (Stochastic Polyak Step-Size) seems to use the same formula.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
max (float | None, optional): maximum possible step size. Defaults to None.
|
|
14
|
+
min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
|
|
15
|
+
use_grad (bool, optional):
|
|
16
|
+
if True, uses dot product of update and gradient to compute the step size.
|
|
17
|
+
Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
|
|
18
|
+
Defaults to True.
|
|
19
|
+
parameterwise (bool, optional):
|
|
20
|
+
if True, calculate Polyak step-size for each parameter separately,
|
|
21
|
+
if False calculate one global step size for all parameters. Defaults to False.
|
|
22
|
+
alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
|
|
25
|
+
|
|
26
|
+
defaults = dict(alpha = alpha)
|
|
27
|
+
super().__init__(defaults)
|
|
28
|
+
self.max = max
|
|
29
|
+
self.min_obj_value = min_obj_value
|
|
30
|
+
self.use_grad = use_grad
|
|
31
|
+
self.parameterwise = parameterwise
|
|
32
|
+
|
|
33
|
+
def _update(self, state, ascent):
|
|
34
|
+
if state.closure is None: raise ValueError("PolyakStepSize requires closure")
|
|
35
|
+
if state.fx0 is None: state.fx0 = state.closure(False) # can only happen when placed after SPSA
|
|
36
|
+
|
|
37
|
+
alpha = self.get_group_key('alpha')
|
|
38
|
+
|
|
39
|
+
if self.parameterwise:
|
|
40
|
+
if self.use_grad: denom = (ascent*state.maybe_compute_grad_(self.get_params())).mean()
|
|
41
|
+
else: denom = ascent.pow(2).mean()
|
|
42
|
+
polyak_step_size: TensorList | Any = (state.fx0 - self.min_obj_value) / denom.where(denom!=0, 1) # type:ignore
|
|
43
|
+
polyak_step_size = polyak_step_size.where(denom != 0, 0)
|
|
44
|
+
if self.max is not None: polyak_step_size = polyak_step_size.clamp_max(self.max)
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
if self.use_grad: denom = (ascent*state.maybe_compute_grad_(self.get_params())).total_mean()
|
|
48
|
+
else: denom = ascent.pow(2).total_mean()
|
|
49
|
+
if denom == 0: polyak_step_size = 0 # we converged
|
|
50
|
+
else: polyak_step_size = (state.fx0 - self.min_obj_value) / denom
|
|
51
|
+
|
|
52
|
+
if self.max is not None:
|
|
53
|
+
if polyak_step_size > self.max: polyak_step_size = self.max
|
|
54
|
+
|
|
55
|
+
ascent.mul_(alpha * polyak_step_size)
|
|
56
|
+
return ascent
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RandomStepSize(OptimizerModule):
|
|
61
|
+
"""Uses random global step size from `low` to `high`.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
low (float, optional): minimum learning rate. Defaults to 0.
|
|
65
|
+
high (float, optional): maximum learning rate. Defaults to 1.
|
|
66
|
+
parameterwise (bool, optional):
|
|
67
|
+
if True, generate random step size for each parameter separately,
|
|
68
|
+
if False generate one global random step size. Defaults to False.
|
|
69
|
+
"""
|
|
70
|
+
def __init__(self, low: float = 0, high: float = 1, parameterwise=False):
|
|
71
|
+
super().__init__({})
|
|
72
|
+
self.low = low; self.high = high
|
|
73
|
+
self.parameterwise = parameterwise
|
|
74
|
+
|
|
75
|
+
def _update(self, state, ascent):
|
|
76
|
+
if self.parameterwise:
|
|
77
|
+
lr = [random.uniform(self.low, self.high) for _ in range(len(ascent))]
|
|
78
|
+
else:
|
|
79
|
+
lr = random.uniform(self.low, self.high)
|
|
80
|
+
return ascent.mul_(lr) # type:ignore
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections import abc
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...utils.derivatives import hessian_list_to_mat, jacobian_and_hessian
|
|
7
|
+
from ...tensorlist import TensorList
|
|
8
|
+
from ...core import OptimizerModule
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _cholesky_solve(hessian: torch.Tensor, grad: torch.Tensor):
|
|
12
|
+
cholesky, info = torch.linalg.cholesky_ex(hessian) # pylint:disable=not-callable
|
|
13
|
+
if info == 0:
|
|
14
|
+
grad.unsqueeze_(1)
|
|
15
|
+
return torch.cholesky_solve(grad, cholesky), True
|
|
16
|
+
return None, False
|
|
17
|
+
|
|
18
|
+
def _lu_solve(hessian: torch.Tensor, grad: torch.Tensor):
|
|
19
|
+
try:
|
|
20
|
+
newton_step, info = torch.linalg.solve_ex(hessian, grad) # pylint:disable=not-callable
|
|
21
|
+
if info == 0: return newton_step, True
|
|
22
|
+
return None, False
|
|
23
|
+
except torch.linalg.LinAlgError:
|
|
24
|
+
return None, False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _cholesky_fallback_lu(hessian: torch.Tensor, grad: torch.Tensor):
|
|
28
|
+
step, success = _cholesky_solve(hessian, grad)
|
|
29
|
+
if not success:
|
|
30
|
+
step, success = _lu_solve(hessian, grad)
|
|
31
|
+
return step, success
|
|
32
|
+
|
|
33
|
+
def _least_squares_solve(hessian: torch.Tensor, grad: torch.Tensor):
|
|
34
|
+
return torch.linalg.lstsq(hessian, grad)[0], True # pylint:disable=not-callable
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _fallback_gd(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
|
|
38
|
+
return grad.mul_(1e-2), True
|
|
39
|
+
|
|
40
|
+
def _fallback_safe_diag(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
|
|
41
|
+
diag = hessian.diag().reciprocal_().nan_to_num_(1,1,1)
|
|
42
|
+
if torch.all(diag == 1): # fallback to gd
|
|
43
|
+
return _fallback_gd(hessian, grad, lr)
|
|
44
|
+
return grad.mul_(diag * lr), True
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def regularize_hessian_(hessian: torch.Tensor, value: float | Literal['eig']):
|
|
48
|
+
"""regularize hessian matrix in-place"""
|
|
49
|
+
if value == 'eig':
|
|
50
|
+
value = torch.linalg.eigvalsh(hessian).min().clamp_(max=0).neg_() # pylint:disable=not-callable
|
|
51
|
+
elif value != 0:
|
|
52
|
+
hessian.add_(torch.eye(hessian.shape[0], device=hessian.device,dtype=hessian.dtype), alpha = value)
|
|
53
|
+
|
|
54
|
+
LinearSystemSolvers = Literal['cholesky', 'lu', 'cholesky_lu', 'lstsq']
|
|
55
|
+
FallbackLinearSystemSolvers = Literal['lstsq', 'safe_diag', 'gd']
|
|
56
|
+
|
|
57
|
+
LINEAR_SYSTEM_SOLVERS = {
|
|
58
|
+
"cholesky": _cholesky_solve,
|
|
59
|
+
"lu": _lu_solve,
|
|
60
|
+
"cholesky_lu": _cholesky_fallback_lu,
|
|
61
|
+
"lstsq": _least_squares_solve,
|
|
62
|
+
"safe_diag": _fallback_safe_diag,
|
|
63
|
+
"gd": _fallback_gd
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
class ExactNewton(OptimizerModule):
|
|
67
|
+
"""Peforms an exact Newton step using batched autograd.
|
|
68
|
+
|
|
69
|
+
Note that this doesn't support per-group settings.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
tikhonov (float, optional):
|
|
73
|
+
tikhonov regularization (constant value added to the diagonal of the hessian).
|
|
74
|
+
Also known as Levenberg-Marquardt regularization. Can be set to 'eig', so it will be set
|
|
75
|
+
to the smallest eigenvalue of the hessian if that value is negative. Defaults to 0.
|
|
76
|
+
solver (Solvers, optional):
|
|
77
|
+
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
|
|
78
|
+
fallback (Solvers, optional):
|
|
79
|
+
what to do if solver fails. Defaults to "safe_diag"
|
|
80
|
+
(takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
|
|
81
|
+
validate (bool, optional):
|
|
82
|
+
validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
|
|
83
|
+
If not, undo the step and perform a gradient descent step.
|
|
84
|
+
tol (float, optional):
|
|
85
|
+
only has effect if `validate` is enabled.
|
|
86
|
+
If loss increased by `loss * tol`, perform gradient descent step.
|
|
87
|
+
Set this to 0 to guarantee that loss always decreases. Defaults to 1.
|
|
88
|
+
gd_lr (float, optional):
|
|
89
|
+
only has effect if `validate` is enabled.
|
|
90
|
+
Gradient descent step learning rate. Defaults to 1e-2.
|
|
91
|
+
batched_hessian (bool, optional):
|
|
92
|
+
whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
|
|
93
|
+
should be faster, but this feature being experimental, there may be performance cliffs.
|
|
94
|
+
Defaults to True.
|
|
95
|
+
diag (False, optional):
|
|
96
|
+
only use the diagonal of the hessian. This will still calculate the full hessian!
|
|
97
|
+
This is mainly useful for benchmarking.
|
|
98
|
+
"""
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
tikhonov: float | Literal['eig'] = 0.0,
|
|
102
|
+
solver: LinearSystemSolvers = "cholesky_lu",
|
|
103
|
+
fallback: FallbackLinearSystemSolvers = "safe_diag",
|
|
104
|
+
validate=False,
|
|
105
|
+
tol: float = 1,
|
|
106
|
+
gd_lr = 1e-2,
|
|
107
|
+
batched_hessian=True,
|
|
108
|
+
diag: bool = False,
|
|
109
|
+
):
|
|
110
|
+
super().__init__({})
|
|
111
|
+
self.tikhonov: float | Literal['eig'] = tikhonov
|
|
112
|
+
self.batched_hessian = batched_hessian
|
|
113
|
+
|
|
114
|
+
self.solver: abc.Callable = LINEAR_SYSTEM_SOLVERS[solver]
|
|
115
|
+
self.fallback: abc.Callable = LINEAR_SYSTEM_SOLVERS[fallback]
|
|
116
|
+
|
|
117
|
+
self.validate = validate
|
|
118
|
+
self.gd_lr = gd_lr
|
|
119
|
+
self.tol = tol
|
|
120
|
+
|
|
121
|
+
self.diag = diag
|
|
122
|
+
|
|
123
|
+
@torch.no_grad
|
|
124
|
+
def step(self, state):
|
|
125
|
+
if state.closure is None: raise ValueError("Newton requires a closure to compute the gradient.")
|
|
126
|
+
|
|
127
|
+
params = self.get_params()
|
|
128
|
+
|
|
129
|
+
# exact hessian via autograd
|
|
130
|
+
with torch.enable_grad():
|
|
131
|
+
state.fx0 = state.closure(False)
|
|
132
|
+
grads, hessian = jacobian_and_hessian([state.fx0], params) # type:ignore
|
|
133
|
+
state.grad = grads = TensorList(grads).squeeze_(0)
|
|
134
|
+
gvec = grads.to_vec()
|
|
135
|
+
hessian = hessian_list_to_mat(hessian)
|
|
136
|
+
|
|
137
|
+
# tikhonov regularization
|
|
138
|
+
regularize_hessian_(hessian, self.tikhonov)
|
|
139
|
+
|
|
140
|
+
# calculate newton step
|
|
141
|
+
if self.diag:
|
|
142
|
+
newton_step = gvec / hessian.diag()
|
|
143
|
+
else:
|
|
144
|
+
newton_step, success = self.solver(hessian, gvec)
|
|
145
|
+
if not success:
|
|
146
|
+
newton_step, success = self.fallback(hessian, gvec)
|
|
147
|
+
if not success:
|
|
148
|
+
newton_step, success = _fallback_gd(hessian, gvec)
|
|
149
|
+
|
|
150
|
+
# apply the `_update` method
|
|
151
|
+
state.ascent = grads.from_vec(newton_step.squeeze_().nan_to_num_(0,0,0))
|
|
152
|
+
|
|
153
|
+
# validate if newton step decreased loss
|
|
154
|
+
if self.validate:
|
|
155
|
+
|
|
156
|
+
params.sub_(state.ascent)
|
|
157
|
+
fx1 = state.closure(False)
|
|
158
|
+
params.add_(state.ascent)
|
|
159
|
+
|
|
160
|
+
# if loss increases, set ascent direction to grad times lr
|
|
161
|
+
if (not fx1.isfinite()) or fx1 - state.fx0 > state.fx0 * self.tol: # type:ignore
|
|
162
|
+
state.ascent = grads.div_(grads.total_vector_norm(2) / self.gd_lr)
|
|
163
|
+
|
|
164
|
+
# peform an update with the ascent direction, or pass it to the child.
|
|
165
|
+
return self._update_params_or_step_with_next(state, params=params)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...tensorlist import TensorList, Distributions, mean as tlmean
|
|
7
|
+
from ...utils.python_tools import _ScalarLoss
|
|
8
|
+
from ...core import _ClosureType, OptimizationState, OptimizerModule, _maybe_pass_backward
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _numpy_or_torch_mean(losses: list):
|
|
12
|
+
"""Returns the mean of a list of losses, which can be either numpy arrays or torch tensors."""
|
|
13
|
+
if isinstance(losses[0], torch.Tensor):
|
|
14
|
+
return torch.mean(torch.stack(losses))
|
|
15
|
+
return np.mean(losses).item()
|
|
16
|
+
|
|
17
|
+
class GaussianSmoothing(OptimizerModule):
|
|
18
|
+
"""Samples and averages value and gradients in multiple random points around current position.
|
|
19
|
+
This effectively applies smoothing to the function.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
n_samples (int, optional): number of gradient samples from around current position. Defaults to 4.
|
|
23
|
+
sigma (float, optional): how far from current position to sample from. Defaults to 0.1.
|
|
24
|
+
distribution (tl.Distributions, optional): distribution for random positions. Defaults to "normal".
|
|
25
|
+
sample_x0 (bool, optional): 1st sample will be x0. Defaults to False.
|
|
26
|
+
randomize_every (int | None, optional): randomizes the points every n steps. Defaults to 1.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
n_samples: int = 4,
|
|
31
|
+
sigma: float = 0.1,
|
|
32
|
+
distribution: Distributions = "normal",
|
|
33
|
+
sample_x0 = False,
|
|
34
|
+
randomize_every: int | None = 1,
|
|
35
|
+
):
|
|
36
|
+
defaults = dict(sigma = sigma)
|
|
37
|
+
super().__init__(defaults)
|
|
38
|
+
self.n_samples = n_samples
|
|
39
|
+
self.distribution: Distributions = distribution
|
|
40
|
+
self.randomize_every = randomize_every
|
|
41
|
+
self.current_step = 0
|
|
42
|
+
self.perturbations = None
|
|
43
|
+
self.sample_x0 = sample_x0
|
|
44
|
+
|
|
45
|
+
@torch.no_grad()
|
|
46
|
+
def step(self, state: OptimizationState):
|
|
47
|
+
if state.closure is None: raise ValueError('GaussianSmoothing requires closure.')
|
|
48
|
+
closure = state.closure
|
|
49
|
+
params = self.get_params()
|
|
50
|
+
sigmas = self.get_group_key('sigma')
|
|
51
|
+
|
|
52
|
+
# generate random perturbations
|
|
53
|
+
if self.perturbations is None or (self.randomize_every is not None and self.current_step % self.randomize_every == 0):
|
|
54
|
+
if self.sample_x0:
|
|
55
|
+
self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples-1)]
|
|
56
|
+
else:
|
|
57
|
+
self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples)]
|
|
58
|
+
|
|
59
|
+
@torch.no_grad
|
|
60
|
+
def smooth_closure(backward = True):
|
|
61
|
+
losses = []
|
|
62
|
+
grads = []
|
|
63
|
+
|
|
64
|
+
# sample gradient and loss at x0
|
|
65
|
+
if self.sample_x0:
|
|
66
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
67
|
+
losses.append(closure())
|
|
68
|
+
if backward: grads.append(params.grad.clone())
|
|
69
|
+
|
|
70
|
+
# sample gradients from points around current params
|
|
71
|
+
# and average them
|
|
72
|
+
if self.perturbations is None: raise ValueError('who set perturbations to None???')
|
|
73
|
+
for p in self.perturbations:
|
|
74
|
+
params.add_(p)
|
|
75
|
+
with torch.enable_grad() if backward else nullcontext():
|
|
76
|
+
losses.append(_maybe_pass_backward(closure, backward))
|
|
77
|
+
if backward: grads.append(params.grad.clone())
|
|
78
|
+
params.sub_(p)
|
|
79
|
+
|
|
80
|
+
# set the new averaged grads and return average loss
|
|
81
|
+
if backward: params.set_grad_(tlmean(grads))
|
|
82
|
+
return _numpy_or_torch_mean(losses)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
self.current_step += 1
|
|
86
|
+
state.closure = smooth_closure
|
|
87
|
+
return self._update_params_or_step_with_next(state)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# todo single loop gaussian homotopy?
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ...tensorlist import TensorList
|
|
7
|
+
from ...core import OptimizerModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
|
|
11
|
+
"""Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
|
|
12
|
+
vec = input.view(-1)
|
|
13
|
+
v = torch.zeros_like(vec)
|
|
14
|
+
v[0] = -2
|
|
15
|
+
v[1] = 1
|
|
16
|
+
v[-1] = 1
|
|
17
|
+
numerator = torch.fft.fft(vec) # pylint: disable = not-callable
|
|
18
|
+
denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
19
|
+
return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
|
|
20
|
+
|
|
21
|
+
def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
|
|
22
|
+
"""Applies laplacian smoothing to gradients of an iterable of parameters.
|
|
23
|
+
|
|
24
|
+
This updates gradients in-place.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
|
|
28
|
+
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
29
|
+
layerwise (bool, optional):
|
|
30
|
+
If True, applies smoothing to each parameter's gradient separately,
|
|
31
|
+
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
32
|
+
min_numel (int, optional):
|
|
33
|
+
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
34
|
+
Only has effect if `layerwise` is True. Defaults to 4.
|
|
35
|
+
|
|
36
|
+
Reference:
|
|
37
|
+
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
38
|
+
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
39
|
+
"""
|
|
40
|
+
grads = TensorList(params).get_existing_grads()
|
|
41
|
+
if layerwise:
|
|
42
|
+
for g in grads:
|
|
43
|
+
if g.numel() >= min_numel:
|
|
44
|
+
g.set_(vector_laplacian_smoothing(g, sigma).reshape(g.shape)) # type:ignore
|
|
45
|
+
else:
|
|
46
|
+
vec = grads.to_vec()
|
|
47
|
+
grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
|
|
51
|
+
"""Denominator will always be the same and depends on the size of the vector and the sigma."""
|
|
52
|
+
v = torch.zeros_like(tensor.view(-1))
|
|
53
|
+
v[0] = -2
|
|
54
|
+
v[1] = 1
|
|
55
|
+
v[-1] = 1
|
|
56
|
+
return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
|
|
57
|
+
|
|
58
|
+
class LaplacianSmoothing(OptimizerModule):
|
|
59
|
+
"""Applies laplacian smoothing via a fast Fourier transform solver.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
sigma (float, optional): controls the amount of smoothing. Defaults to 1.
|
|
63
|
+
layerwise (bool, optional):
|
|
64
|
+
If True, applies smoothing to each parameter's gradient separately,
|
|
65
|
+
Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
|
|
66
|
+
min_numel (int, optional):
|
|
67
|
+
minimum number of elements in a parameter to apply laplacian smoothing to.
|
|
68
|
+
Only has effect if `layerwise` is True. Defaults to 4.
|
|
69
|
+
target (str, optional):
|
|
70
|
+
determines what this module updates.
|
|
71
|
+
|
|
72
|
+
"ascent" - it updates the ascent (default).
|
|
73
|
+
|
|
74
|
+
"grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
|
|
75
|
+
|
|
76
|
+
"closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
|
|
77
|
+
|
|
78
|
+
Reference:
|
|
79
|
+
*Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
|
|
80
|
+
Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
|
|
81
|
+
|
|
82
|
+
"""
|
|
83
|
+
def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Literal['ascent', 'grad', 'closure',] = 'ascent'):
|
|
84
|
+
# sigma from defaults is used in layerwise case
|
|
85
|
+
# otherwise self.sigma is used
|
|
86
|
+
defaults = dict(sigma = sigma)
|
|
87
|
+
self.sigma = 1
|
|
88
|
+
super().__init__(defaults, target=target)
|
|
89
|
+
self.layerwise = layerwise
|
|
90
|
+
self.min_numel = min_numel
|
|
91
|
+
|
|
92
|
+
# precomputed denominator for when layerwise=False
|
|
93
|
+
self.full_denominator = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def _update(self, state, ascent):
|
|
98
|
+
params = self.get_params()
|
|
99
|
+
sigmas = self.get_group_key('sigma')
|
|
100
|
+
|
|
101
|
+
# layerwise laplacian smoothing
|
|
102
|
+
if self.layerwise:
|
|
103
|
+
|
|
104
|
+
# precompute the denominator for each layer and store it in each parameters state
|
|
105
|
+
denominators = TensorList()
|
|
106
|
+
for p, σ in zip(params, sigmas):
|
|
107
|
+
if p.numel() > self.min_numel:
|
|
108
|
+
den = self.state[p]
|
|
109
|
+
if 'denominator' not in den: den['denominator'] = _precompute_denominator(p, σ)
|
|
110
|
+
denominators.append(den['denominator'])
|
|
111
|
+
|
|
112
|
+
# apply the smoothing
|
|
113
|
+
smoothed_direction = TensorList()
|
|
114
|
+
for g, σ, den in zip(ascent, sigmas, denominators):
|
|
115
|
+
smoothed_direction.append(torch.fft.ifft(torch.fft.fft(g.view(-1)) / den).real.reshape(g.shape)) # pylint: disable = not-callable
|
|
116
|
+
return smoothed_direction
|
|
117
|
+
|
|
118
|
+
# else
|
|
119
|
+
# full laplacian smoothing
|
|
120
|
+
# precompute full denominator
|
|
121
|
+
if self.full_denominator is None:
|
|
122
|
+
self.full_denominator = _precompute_denominator(ascent.to_vec(), self.sigma)
|
|
123
|
+
|
|
124
|
+
# apply the smoothing
|
|
125
|
+
vec = ascent.to_vec()
|
|
126
|
+
return ascent.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.full_denominator).real) # pylint: disable = not-callable
|
|
127
|
+
|
|
128
|
+
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from ...core import OptimizerModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _reset_stats_hook(optimizer, state):
|
|
6
|
+
for module in optimizer.unrolled_modules:
|
|
7
|
+
module: OptimizerModule
|
|
8
|
+
module.reset_stats()
|
|
9
|
+
|
|
10
|
+
# the reason why this needs to be at the end is ??? I NEED TO REMEMBER
|
|
11
|
+
class SwitchEMA(OptimizerModule):
|
|
12
|
+
"""Switch-EMA. Every n steps switches params to an exponential moving average of past weights.
|
|
13
|
+
|
|
14
|
+
In the paper the switch happens after each epoch.
|
|
15
|
+
|
|
16
|
+
Please put this module at the end, after all other modules.
|
|
17
|
+
|
|
18
|
+
This can also function as EMA, set `update_every` to None and instead call `set_ema` and `unset_ema` on this module.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
update_every (int): number of steps (batches) between setting model parameters to EMA.
|
|
23
|
+
momentum (int): EMA momentum factor.
|
|
24
|
+
reset_stats (bool, optional):
|
|
25
|
+
if True, when setting model parameters to EMA, resets other modules stats such as momentum velocities.
|
|
26
|
+
It might be better to set this to False if `update_every` is very small. Defaults to True.
|
|
27
|
+
|
|
28
|
+
reference
|
|
29
|
+
https://arxiv.org/abs/2402.09240
|
|
30
|
+
"""
|
|
31
|
+
def __init__(self, update_every: int | None, momentum: float = 0.99, reset_stats: bool = True):
|
|
32
|
+
defaults = dict(momentum=momentum)
|
|
33
|
+
super().__init__(defaults)
|
|
34
|
+
self.update_every = update_every
|
|
35
|
+
self.cur_step = 0
|
|
36
|
+
self.update_every = update_every
|
|
37
|
+
self._reset_stats = reset_stats
|
|
38
|
+
self.orig_params = None
|
|
39
|
+
|
|
40
|
+
def set_ema(self):
|
|
41
|
+
"""sets module parameters to EMA, stores original parameters that can be restored by calling `unset_ema`"""
|
|
42
|
+
params = self.get_params()
|
|
43
|
+
self.orig_params = params.clone()
|
|
44
|
+
params.set_(self.get_state_key('ema', init = 'params', params=params))
|
|
45
|
+
|
|
46
|
+
def unset_ema(self):
|
|
47
|
+
"""Undoes `set_ema`."""
|
|
48
|
+
if self.orig_params is None: raise ValueError('call `set_ema` first, and then `unset_ema`.')
|
|
49
|
+
params = self.get_params()
|
|
50
|
+
params.set_(self.orig_params)
|
|
51
|
+
|
|
52
|
+
@torch.no_grad
|
|
53
|
+
def step(self, state):
|
|
54
|
+
# if self.next_module is not None:
|
|
55
|
+
# warn(f'EMA should usually be the last module, but {self.next_module.__class__.__name__} is after it.')
|
|
56
|
+
self.cur_step += 1
|
|
57
|
+
|
|
58
|
+
params = self.get_params()
|
|
59
|
+
# state.maybe_use_grad_(params)
|
|
60
|
+
# update params with the child. Averaging is always applied at the end.
|
|
61
|
+
ret = self._update_params_or_step_with_next(state, params)
|
|
62
|
+
|
|
63
|
+
ema = self.get_state_key('ema', init = 'params', params=params)
|
|
64
|
+
momentum = self.get_group_key('momentum')
|
|
65
|
+
|
|
66
|
+
ema.lerp_compat_(params, 1 - momentum)
|
|
67
|
+
|
|
68
|
+
if (self.update_every is not None) and (self.cur_step % self.update_every == 0):
|
|
69
|
+
params.set_(ema.clone())
|
|
70
|
+
if self._reset_stats: state.add_post_step_hook(_reset_stats_hook)
|
|
71
|
+
|
|
72
|
+
return ret
|