torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ from collections.abc import Callable
2
+ from functools import partial
3
+ from typing import Any, overload, TYPE_CHECKING
4
+ import random
5
+
6
+ import torch
7
+ from ...core import OptimizerModule
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ from ...optim import Modular
12
+
13
+
14
+ # LR SCHEDULING MOVED TO LR MODULE
15
+
16
+ # def _set_momentum_hook(optimizer, state, momentum):
17
+ # for module in optimizer.unrolled_modules:
18
+ # if 'momentum' in module.defaults:
19
+ # for g in module.param_groups:
20
+ # g['momentum'] = momentum
21
+ # if 'beta1' in module.defaults:
22
+ # for g in module.param_groups:
23
+ # g['beta1'] = momentum
24
+
25
+ # def _add_scheduler_hook(opt: "Modular", scheduler_cls, id):
26
+ # """post-init hook that sets `scheduler_step_fn` to the scheduler step."""
27
+ # # get LR module
28
+ # lr_module = opt.get_lr_module()
29
+
30
+ # # get current LRScheduler module
31
+ # scheds = [i for i in opt.unrolled_modules if isinstance(i, LRScheduler)]
32
+ # scheds = [i for i in scheds if i.id == id]
33
+ # if len(scheds) != 1:
34
+ # raise RuntimeError(f"more than 1 module with id {id}: {scheds}")
35
+
36
+ # sch_module = scheds[0]
37
+
38
+ # # make a scheduler and save the step function
39
+ # scheduler = scheduler_cls(lr_module)
40
+ # sch_module.scheduler_step_fn = scheduler.step
41
+
42
+
43
+ # class LRScheduler(OptimizerModule):
44
+ # """Use any pytorch lr scheduler.
45
+
46
+ # Important - the lr is applied multiplicatively and multiplies with learning rate of other modules,
47
+ # so usually base learning rate of the lr scheduler, such as `max_lr` for OneCycleLR, should be set to 1.
48
+
49
+ # Args:
50
+ # lr_scheduler (Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any]):
51
+ # something like:
52
+ # .. code:: py
53
+ # lambda opt: OneCycleLR(opt, max_lr = 1, total_steps = 60000)
54
+ # update_every (int, optional):
55
+ # call `step` every n steps, useful for schedulers that only step once per epoch. Defaults to 1.
56
+ # cycle_momentum (bool, optional):
57
+ # enables support for cycling momentum with schedulers that support it, such as `OneCycleLR`.
58
+ # Unlike lr, momentum is not applied multiplicatively, but set to all other modules with
59
+ # `momentum` or `beta` settings. Has no effect if there are no modules that support momentum. Defaults to False.
60
+ # init_lr (float, optional):
61
+ # initial lr, I believe most lr schedulers ignore this. Defaults to 1.
62
+ # init_momentum (float, optional):
63
+ # initial init_momentum, I believe most lr schedulers ignore this.
64
+ # Has no effect if `cycle_momentum` is False or there are no modules that support momentum. Defaults to 0.
65
+ # """
66
+ # def __init__(
67
+ # self,
68
+ # lr_scheduler: Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler | Any],
69
+ # step_every: int = 1,
70
+ # cycle_momentum: bool = True,
71
+ # ):
72
+ # super().__init__({})
73
+ # scheduler = lr_scheduler(self.dummy_opt)
74
+ # self.update_every = step_every
75
+ # self.cycle_momentum = cycle_momentum
76
+
77
+ # self.scheduler_step_fn = scheduler.step
78
+ # self.cur = 0
79
+ # self.cur_lr = init_lr
80
+ # self.cur_momentum = init_momentum
81
+
82
+ # self.id = random.random()
83
+
84
+ # def step(self, state):
85
+ # if self.cur % self.update_every == 0:
86
+ # self.scheduler_step_fn()
87
+ # self.cur_lr = self.dummy_opt.first_param_group['lr']
88
+ # self.cur_momentum = self.dummy_opt.first_param_group['momentum']
89
+
90
+ # params = self.get_params()
91
+ # ascent = state.maybe_use_grad_(params)
92
+ # ascent *= self.cur_lr
93
+
94
+ # if self.cycle_momentum:
95
+ # state.add_post_step_hook(partial(_set_momentum_hook, momentum = self.cur_momentum))
96
+
97
+ class LRWarmup(OptimizerModule):
98
+ """Linear learning rate warmup.
99
+
100
+ Args:
101
+ n_steps (int): number of warmup steps.
102
+ start_lr (float, optional): initial lr. Defaults to 1e-8.
103
+ end_lr (float, optional): final lr. Defaults to 1.
104
+ delay_steps (int, optional): number of `start_lr` steps before starting the warmup. Defaults to 0.
105
+ """
106
+ def __init__(self, n_steps: int, start_lr: float = 1e-8, end_lr: float = 1, delay_steps: int = 0):
107
+
108
+ super().__init__({})
109
+ self.n_steps = n_steps
110
+ self.start_lr = start_lr
111
+ self.end_lr = end_lr
112
+ self.delay_steps = delay_steps
113
+
114
+ self.cur = 0
115
+
116
+ def _update(self, state, ascent):
117
+ if self.cur < self.delay_steps:
118
+ if self.start_lr != 1: ascent *= self.start_lr
119
+
120
+ elif self.cur >= self.n_steps + self.delay_steps:
121
+ if self.end_lr != 1: ascent *= self.end_lr
122
+
123
+ else:
124
+ remaining = (self.n_steps - (self.cur-self.delay_steps)) / self.n_steps
125
+ lr = (self.start_lr * remaining) + self.end_lr * (1 - remaining)
126
+ ascent *= lr
127
+
128
+ self.cur += 1
129
+ return ascent
130
+
131
+
@@ -0,0 +1,80 @@
1
+ import random
2
+ from typing import Any
3
+
4
+ from ...core import OptimizerModule
5
+ from ...tensorlist import TensorList
6
+
7
+
8
+ class PolyakStepSize(OptimizerModule):
9
+ """Polyak step-size. Meant to be used at the beginning when ascent is the gradient but other placements may work.
10
+ This can also work with SGD as SPS (Stochastic Polyak Step-Size) seems to use the same formula.
11
+
12
+ Args:
13
+ max (float | None, optional): maximum possible step size. Defaults to None.
14
+ min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
15
+ use_grad (bool, optional):
16
+ if True, uses dot product of update and gradient to compute the step size.
17
+ Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
18
+ Defaults to True.
19
+ parameterwise (bool, optional):
20
+ if True, calculate Polyak step-size for each parameter separately,
21
+ if False calculate one global step size for all parameters. Defaults to False.
22
+ alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
23
+ """
24
+ def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
25
+
26
+ defaults = dict(alpha = alpha)
27
+ super().__init__(defaults)
28
+ self.max = max
29
+ self.min_obj_value = min_obj_value
30
+ self.use_grad = use_grad
31
+ self.parameterwise = parameterwise
32
+
33
+ def _update(self, state, ascent):
34
+ if state.closure is None: raise ValueError("PolyakStepSize requires closure")
35
+ if state.fx0 is None: state.fx0 = state.closure(False) # can only happen when placed after SPSA
36
+
37
+ alpha = self.get_group_key('alpha')
38
+
39
+ if self.parameterwise:
40
+ if self.use_grad: denom = (ascent*state.maybe_compute_grad_(self.get_params())).mean()
41
+ else: denom = ascent.pow(2).mean()
42
+ polyak_step_size: TensorList | Any = (state.fx0 - self.min_obj_value) / denom.where(denom!=0, 1) # type:ignore
43
+ polyak_step_size = polyak_step_size.where(denom != 0, 0)
44
+ if self.max is not None: polyak_step_size = polyak_step_size.clamp_max(self.max)
45
+
46
+ else:
47
+ if self.use_grad: denom = (ascent*state.maybe_compute_grad_(self.get_params())).total_mean()
48
+ else: denom = ascent.pow(2).total_mean()
49
+ if denom == 0: polyak_step_size = 0 # we converged
50
+ else: polyak_step_size = (state.fx0 - self.min_obj_value) / denom
51
+
52
+ if self.max is not None:
53
+ if polyak_step_size > self.max: polyak_step_size = self.max
54
+
55
+ ascent.mul_(alpha * polyak_step_size)
56
+ return ascent
57
+
58
+
59
+
60
+ class RandomStepSize(OptimizerModule):
61
+ """Uses random global step size from `low` to `high`.
62
+
63
+ Args:
64
+ low (float, optional): minimum learning rate. Defaults to 0.
65
+ high (float, optional): maximum learning rate. Defaults to 1.
66
+ parameterwise (bool, optional):
67
+ if True, generate random step size for each parameter separately,
68
+ if False generate one global random step size. Defaults to False.
69
+ """
70
+ def __init__(self, low: float = 0, high: float = 1, parameterwise=False):
71
+ super().__init__({})
72
+ self.low = low; self.high = high
73
+ self.parameterwise = parameterwise
74
+
75
+ def _update(self, state, ascent):
76
+ if self.parameterwise:
77
+ lr = [random.uniform(self.low, self.high) for _ in range(len(ascent))]
78
+ else:
79
+ lr = random.uniform(self.low, self.high)
80
+ return ascent.mul_(lr) # type:ignore
@@ -0,0 +1,4 @@
1
+ r"""
2
+ This includes modules that use the hessian computed via autograd.
3
+ """
4
+ from .newton import ExactNewton, LinearSystemSolvers, FallbackLinearSystemSolvers, LINEAR_SYSTEM_SOLVERS
@@ -0,0 +1,165 @@
1
+ from typing import Literal
2
+ from collections import abc
3
+
4
+ import torch
5
+
6
+ from ...utils.derivatives import hessian_list_to_mat, jacobian_and_hessian
7
+ from ...tensorlist import TensorList
8
+ from ...core import OptimizerModule
9
+
10
+
11
+ def _cholesky_solve(hessian: torch.Tensor, grad: torch.Tensor):
12
+ cholesky, info = torch.linalg.cholesky_ex(hessian) # pylint:disable=not-callable
13
+ if info == 0:
14
+ grad.unsqueeze_(1)
15
+ return torch.cholesky_solve(grad, cholesky), True
16
+ return None, False
17
+
18
+ def _lu_solve(hessian: torch.Tensor, grad: torch.Tensor):
19
+ try:
20
+ newton_step, info = torch.linalg.solve_ex(hessian, grad) # pylint:disable=not-callable
21
+ if info == 0: return newton_step, True
22
+ return None, False
23
+ except torch.linalg.LinAlgError:
24
+ return None, False
25
+
26
+
27
+ def _cholesky_fallback_lu(hessian: torch.Tensor, grad: torch.Tensor):
28
+ step, success = _cholesky_solve(hessian, grad)
29
+ if not success:
30
+ step, success = _lu_solve(hessian, grad)
31
+ return step, success
32
+
33
+ def _least_squares_solve(hessian: torch.Tensor, grad: torch.Tensor):
34
+ return torch.linalg.lstsq(hessian, grad)[0], True # pylint:disable=not-callable
35
+
36
+
37
+ def _fallback_gd(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
38
+ return grad.mul_(1e-2), True
39
+
40
+ def _fallback_safe_diag(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
41
+ diag = hessian.diag().reciprocal_().nan_to_num_(1,1,1)
42
+ if torch.all(diag == 1): # fallback to gd
43
+ return _fallback_gd(hessian, grad, lr)
44
+ return grad.mul_(diag * lr), True
45
+
46
+
47
+ def regularize_hessian_(hessian: torch.Tensor, value: float | Literal['eig']):
48
+ """regularize hessian matrix in-place"""
49
+ if value == 'eig':
50
+ value = torch.linalg.eigvalsh(hessian).min().clamp_(max=0).neg_() # pylint:disable=not-callable
51
+ elif value != 0:
52
+ hessian.add_(torch.eye(hessian.shape[0], device=hessian.device,dtype=hessian.dtype), alpha = value)
53
+
54
+ LinearSystemSolvers = Literal['cholesky', 'lu', 'cholesky_lu', 'lstsq']
55
+ FallbackLinearSystemSolvers = Literal['lstsq', 'safe_diag', 'gd']
56
+
57
+ LINEAR_SYSTEM_SOLVERS = {
58
+ "cholesky": _cholesky_solve,
59
+ "lu": _lu_solve,
60
+ "cholesky_lu": _cholesky_fallback_lu,
61
+ "lstsq": _least_squares_solve,
62
+ "safe_diag": _fallback_safe_diag,
63
+ "gd": _fallback_gd
64
+ }
65
+
66
+ class ExactNewton(OptimizerModule):
67
+ """Peforms an exact Newton step using batched autograd.
68
+
69
+ Note that this doesn't support per-group settings.
70
+
71
+ Args:
72
+ tikhonov (float, optional):
73
+ tikhonov regularization (constant value added to the diagonal of the hessian).
74
+ Also known as Levenberg-Marquardt regularization. Can be set to 'eig', so it will be set
75
+ to the smallest eigenvalue of the hessian if that value is negative. Defaults to 0.
76
+ solver (Solvers, optional):
77
+ solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
78
+ fallback (Solvers, optional):
79
+ what to do if solver fails. Defaults to "safe_diag"
80
+ (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
81
+ validate (bool, optional):
82
+ validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
83
+ If not, undo the step and perform a gradient descent step.
84
+ tol (float, optional):
85
+ only has effect if `validate` is enabled.
86
+ If loss increased by `loss * tol`, perform gradient descent step.
87
+ Set this to 0 to guarantee that loss always decreases. Defaults to 1.
88
+ gd_lr (float, optional):
89
+ only has effect if `validate` is enabled.
90
+ Gradient descent step learning rate. Defaults to 1e-2.
91
+ batched_hessian (bool, optional):
92
+ whether to use experimental pytorch vmap-vectorized hessian calculation. As per pytorch docs,
93
+ should be faster, but this feature being experimental, there may be performance cliffs.
94
+ Defaults to True.
95
+ diag (False, optional):
96
+ only use the diagonal of the hessian. This will still calculate the full hessian!
97
+ This is mainly useful for benchmarking.
98
+ """
99
+ def __init__(
100
+ self,
101
+ tikhonov: float | Literal['eig'] = 0.0,
102
+ solver: LinearSystemSolvers = "cholesky_lu",
103
+ fallback: FallbackLinearSystemSolvers = "safe_diag",
104
+ validate=False,
105
+ tol: float = 1,
106
+ gd_lr = 1e-2,
107
+ batched_hessian=True,
108
+ diag: bool = False,
109
+ ):
110
+ super().__init__({})
111
+ self.tikhonov: float | Literal['eig'] = tikhonov
112
+ self.batched_hessian = batched_hessian
113
+
114
+ self.solver: abc.Callable = LINEAR_SYSTEM_SOLVERS[solver]
115
+ self.fallback: abc.Callable = LINEAR_SYSTEM_SOLVERS[fallback]
116
+
117
+ self.validate = validate
118
+ self.gd_lr = gd_lr
119
+ self.tol = tol
120
+
121
+ self.diag = diag
122
+
123
+ @torch.no_grad
124
+ def step(self, state):
125
+ if state.closure is None: raise ValueError("Newton requires a closure to compute the gradient.")
126
+
127
+ params = self.get_params()
128
+
129
+ # exact hessian via autograd
130
+ with torch.enable_grad():
131
+ state.fx0 = state.closure(False)
132
+ grads, hessian = jacobian_and_hessian([state.fx0], params) # type:ignore
133
+ state.grad = grads = TensorList(grads).squeeze_(0)
134
+ gvec = grads.to_vec()
135
+ hessian = hessian_list_to_mat(hessian)
136
+
137
+ # tikhonov regularization
138
+ regularize_hessian_(hessian, self.tikhonov)
139
+
140
+ # calculate newton step
141
+ if self.diag:
142
+ newton_step = gvec / hessian.diag()
143
+ else:
144
+ newton_step, success = self.solver(hessian, gvec)
145
+ if not success:
146
+ newton_step, success = self.fallback(hessian, gvec)
147
+ if not success:
148
+ newton_step, success = _fallback_gd(hessian, gvec)
149
+
150
+ # apply the `_update` method
151
+ state.ascent = grads.from_vec(newton_step.squeeze_().nan_to_num_(0,0,0))
152
+
153
+ # validate if newton step decreased loss
154
+ if self.validate:
155
+
156
+ params.sub_(state.ascent)
157
+ fx1 = state.closure(False)
158
+ params.add_(state.ascent)
159
+
160
+ # if loss increases, set ascent direction to grad times lr
161
+ if (not fx1.isfinite()) or fx1 - state.fx0 > state.fx0 * self.tol: # type:ignore
162
+ state.ascent = grads.div_(grads.total_vector_norm(2) / self.gd_lr)
163
+
164
+ # peform an update with the ascent direction, or pass it to the child.
165
+ return self._update_params_or_step_with_next(state, params=params)
@@ -0,0 +1,5 @@
1
+ r"""
2
+ Gradient smoothing and orthogonalization methods.
3
+ """
4
+ from .laplacian_smoothing import LaplacianSmoothing, gradient_laplacian_smoothing_
5
+ from .gaussian_smoothing import GaussianSmoothing
@@ -0,0 +1,90 @@
1
+ from contextlib import nullcontext
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from ...tensorlist import TensorList, Distributions, mean as tlmean
7
+ from ...utils.python_tools import _ScalarLoss
8
+ from ...core import _ClosureType, OptimizationState, OptimizerModule, _maybe_pass_backward
9
+
10
+
11
+ def _numpy_or_torch_mean(losses: list):
12
+ """Returns the mean of a list of losses, which can be either numpy arrays or torch tensors."""
13
+ if isinstance(losses[0], torch.Tensor):
14
+ return torch.mean(torch.stack(losses))
15
+ return np.mean(losses).item()
16
+
17
+ class GaussianSmoothing(OptimizerModule):
18
+ """Samples and averages value and gradients in multiple random points around current position.
19
+ This effectively applies smoothing to the function.
20
+
21
+ Args:
22
+ n_samples (int, optional): number of gradient samples from around current position. Defaults to 4.
23
+ sigma (float, optional): how far from current position to sample from. Defaults to 0.1.
24
+ distribution (tl.Distributions, optional): distribution for random positions. Defaults to "normal".
25
+ sample_x0 (bool, optional): 1st sample will be x0. Defaults to False.
26
+ randomize_every (int | None, optional): randomizes the points every n steps. Defaults to 1.
27
+ """
28
+ def __init__(
29
+ self,
30
+ n_samples: int = 4,
31
+ sigma: float = 0.1,
32
+ distribution: Distributions = "normal",
33
+ sample_x0 = False,
34
+ randomize_every: int | None = 1,
35
+ ):
36
+ defaults = dict(sigma = sigma)
37
+ super().__init__(defaults)
38
+ self.n_samples = n_samples
39
+ self.distribution: Distributions = distribution
40
+ self.randomize_every = randomize_every
41
+ self.current_step = 0
42
+ self.perturbations = None
43
+ self.sample_x0 = sample_x0
44
+
45
+ @torch.no_grad()
46
+ def step(self, state: OptimizationState):
47
+ if state.closure is None: raise ValueError('GaussianSmoothing requires closure.')
48
+ closure = state.closure
49
+ params = self.get_params()
50
+ sigmas = self.get_group_key('sigma')
51
+
52
+ # generate random perturbations
53
+ if self.perturbations is None or (self.randomize_every is not None and self.current_step % self.randomize_every == 0):
54
+ if self.sample_x0:
55
+ self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples-1)]
56
+ else:
57
+ self.perturbations = [params.sample_like(sigmas, distribution=self.distribution) for _ in range(self.n_samples)]
58
+
59
+ @torch.no_grad
60
+ def smooth_closure(backward = True):
61
+ losses = []
62
+ grads = []
63
+
64
+ # sample gradient and loss at x0
65
+ if self.sample_x0:
66
+ with torch.enable_grad() if backward else nullcontext():
67
+ losses.append(closure())
68
+ if backward: grads.append(params.grad.clone())
69
+
70
+ # sample gradients from points around current params
71
+ # and average them
72
+ if self.perturbations is None: raise ValueError('who set perturbations to None???')
73
+ for p in self.perturbations:
74
+ params.add_(p)
75
+ with torch.enable_grad() if backward else nullcontext():
76
+ losses.append(_maybe_pass_backward(closure, backward))
77
+ if backward: grads.append(params.grad.clone())
78
+ params.sub_(p)
79
+
80
+ # set the new averaged grads and return average loss
81
+ if backward: params.set_grad_(tlmean(grads))
82
+ return _numpy_or_torch_mean(losses)
83
+
84
+
85
+ self.current_step += 1
86
+ state.closure = smooth_closure
87
+ return self._update_params_or_step_with_next(state)
88
+
89
+
90
+ # todo single loop gaussian homotopy?
@@ -0,0 +1,128 @@
1
+ from typing import Literal
2
+ from collections.abc import Iterable
3
+
4
+ import torch
5
+
6
+ from ...tensorlist import TensorList
7
+ from ...core import OptimizerModule
8
+
9
+
10
+ def vector_laplacian_smoothing(input: torch.Tensor, sigma: float = 1) -> torch.Tensor:
11
+ """Returns a new vector with laplacian smoothing applied to it. This flattens the input!"""
12
+ vec = input.view(-1)
13
+ v = torch.zeros_like(vec)
14
+ v[0] = -2
15
+ v[1] = 1
16
+ v[-1] = 1
17
+ numerator = torch.fft.fft(vec) # pylint: disable = not-callable
18
+ denominator = 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
19
+ return torch.fft.ifft(numerator / denominator).real # pylint: disable = not-callable
20
+
21
+ def gradient_laplacian_smoothing_(params: Iterable[torch.Tensor], sigma: float = 1, layerwise=True, min_numel = 4):
22
+ """Applies laplacian smoothing to gradients of an iterable of parameters.
23
+
24
+ This updates gradients in-place.
25
+
26
+ Args:
27
+ params (abc.Iterable[torch.Tensor]): an iterable of Tensors that will have gradients smoothed.
28
+ sigma (float, optional): controls the amount of smoothing. Defaults to 1.
29
+ layerwise (bool, optional):
30
+ If True, applies smoothing to each parameter's gradient separately,
31
+ Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
32
+ min_numel (int, optional):
33
+ minimum number of elements in a parameter to apply laplacian smoothing to.
34
+ Only has effect if `layerwise` is True. Defaults to 4.
35
+
36
+ Reference:
37
+ *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
38
+ Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
39
+ """
40
+ grads = TensorList(params).get_existing_grads()
41
+ if layerwise:
42
+ for g in grads:
43
+ if g.numel() >= min_numel:
44
+ g.set_(vector_laplacian_smoothing(g, sigma).reshape(g.shape)) # type:ignore
45
+ else:
46
+ vec = grads.to_vec()
47
+ grads.from_vec_(vector_laplacian_smoothing(vec, sigma))
48
+
49
+
50
+ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
51
+ """Denominator will always be the same and depends on the size of the vector and the sigma."""
52
+ v = torch.zeros_like(tensor.view(-1))
53
+ v[0] = -2
54
+ v[1] = 1
55
+ v[-1] = 1
56
+ return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
+
58
+ class LaplacianSmoothing(OptimizerModule):
59
+ """Applies laplacian smoothing via a fast Fourier transform solver.
60
+
61
+ Args:
62
+ sigma (float, optional): controls the amount of smoothing. Defaults to 1.
63
+ layerwise (bool, optional):
64
+ If True, applies smoothing to each parameter's gradient separately,
65
+ Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
66
+ min_numel (int, optional):
67
+ minimum number of elements in a parameter to apply laplacian smoothing to.
68
+ Only has effect if `layerwise` is True. Defaults to 4.
69
+ target (str, optional):
70
+ determines what this module updates.
71
+
72
+ "ascent" - it updates the ascent (default).
73
+
74
+ "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
75
+
76
+ "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
77
+
78
+ Reference:
79
+ *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
80
+ Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
81
+
82
+ """
83
+ def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Literal['ascent', 'grad', 'closure',] = 'ascent'):
84
+ # sigma from defaults is used in layerwise case
85
+ # otherwise self.sigma is used
86
+ defaults = dict(sigma = sigma)
87
+ self.sigma = 1
88
+ super().__init__(defaults, target=target)
89
+ self.layerwise = layerwise
90
+ self.min_numel = min_numel
91
+
92
+ # precomputed denominator for when layerwise=False
93
+ self.full_denominator = None
94
+
95
+
96
+ @torch.no_grad
97
+ def _update(self, state, ascent):
98
+ params = self.get_params()
99
+ sigmas = self.get_group_key('sigma')
100
+
101
+ # layerwise laplacian smoothing
102
+ if self.layerwise:
103
+
104
+ # precompute the denominator for each layer and store it in each parameters state
105
+ denominators = TensorList()
106
+ for p, σ in zip(params, sigmas):
107
+ if p.numel() > self.min_numel:
108
+ den = self.state[p]
109
+ if 'denominator' not in den: den['denominator'] = _precompute_denominator(p, σ)
110
+ denominators.append(den['denominator'])
111
+
112
+ # apply the smoothing
113
+ smoothed_direction = TensorList()
114
+ for g, σ, den in zip(ascent, sigmas, denominators):
115
+ smoothed_direction.append(torch.fft.ifft(torch.fft.fft(g.view(-1)) / den).real.reshape(g.shape)) # pylint: disable = not-callable
116
+ return smoothed_direction
117
+
118
+ # else
119
+ # full laplacian smoothing
120
+ # precompute full denominator
121
+ if self.full_denominator is None:
122
+ self.full_denominator = _precompute_denominator(ascent.to_vec(), self.sigma)
123
+
124
+ # apply the smoothing
125
+ vec = ascent.to_vec()
126
+ return ascent.from_vec(torch.fft.ifft(torch.fft.fft(vec) / self.full_denominator).real) # pylint: disable = not-callable
127
+
128
+
@@ -0,0 +1,2 @@
1
+ from .ema import SwitchEMA
2
+ from .swa import PeriodicSWA, CyclicSWA
@@ -0,0 +1,72 @@
1
+ import torch
2
+ from ...core import OptimizerModule
3
+
4
+
5
+ def _reset_stats_hook(optimizer, state):
6
+ for module in optimizer.unrolled_modules:
7
+ module: OptimizerModule
8
+ module.reset_stats()
9
+
10
+ # the reason why this needs to be at the end is ??? I NEED TO REMEMBER
11
+ class SwitchEMA(OptimizerModule):
12
+ """Switch-EMA. Every n steps switches params to an exponential moving average of past weights.
13
+
14
+ In the paper the switch happens after each epoch.
15
+
16
+ Please put this module at the end, after all other modules.
17
+
18
+ This can also function as EMA, set `update_every` to None and instead call `set_ema` and `unset_ema` on this module.
19
+
20
+
21
+ Args:
22
+ update_every (int): number of steps (batches) between setting model parameters to EMA.
23
+ momentum (int): EMA momentum factor.
24
+ reset_stats (bool, optional):
25
+ if True, when setting model parameters to EMA, resets other modules stats such as momentum velocities.
26
+ It might be better to set this to False if `update_every` is very small. Defaults to True.
27
+
28
+ reference
29
+ https://arxiv.org/abs/2402.09240
30
+ """
31
+ def __init__(self, update_every: int | None, momentum: float = 0.99, reset_stats: bool = True):
32
+ defaults = dict(momentum=momentum)
33
+ super().__init__(defaults)
34
+ self.update_every = update_every
35
+ self.cur_step = 0
36
+ self.update_every = update_every
37
+ self._reset_stats = reset_stats
38
+ self.orig_params = None
39
+
40
+ def set_ema(self):
41
+ """sets module parameters to EMA, stores original parameters that can be restored by calling `unset_ema`"""
42
+ params = self.get_params()
43
+ self.orig_params = params.clone()
44
+ params.set_(self.get_state_key('ema', init = 'params', params=params))
45
+
46
+ def unset_ema(self):
47
+ """Undoes `set_ema`."""
48
+ if self.orig_params is None: raise ValueError('call `set_ema` first, and then `unset_ema`.')
49
+ params = self.get_params()
50
+ params.set_(self.orig_params)
51
+
52
+ @torch.no_grad
53
+ def step(self, state):
54
+ # if self.next_module is not None:
55
+ # warn(f'EMA should usually be the last module, but {self.next_module.__class__.__name__} is after it.')
56
+ self.cur_step += 1
57
+
58
+ params = self.get_params()
59
+ # state.maybe_use_grad_(params)
60
+ # update params with the child. Averaging is always applied at the end.
61
+ ret = self._update_params_or_step_with_next(state, params)
62
+
63
+ ema = self.get_state_key('ema', init = 'params', params=params)
64
+ momentum = self.get_group_key('momentum')
65
+
66
+ ema.lerp_compat_(params, 1 - momentum)
67
+
68
+ if (self.update_every is not None) and (self.cur_step % self.update_every == 0):
69
+ params.set_(ema.clone())
70
+ if self._reset_stats: state.add_post_step_hook(_reset_stats_hook)
71
+
72
+ return ret