torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,48 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Module
6
+ from ...utils.tensorlist import Distributions
7
+
8
+ class PrintUpdate(Module):
9
+ """Prints current update."""
10
+ def __init__(self, text = 'update = ', print_fn = print):
11
+ defaults = dict(text=text, print_fn=print_fn)
12
+ super().__init__(defaults)
13
+
14
+ def step(self, var):
15
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
16
+ return var
17
+
18
+ class PrintShape(Module):
19
+ """Prints shapes of the update."""
20
+ def __init__(self, text = 'shapes = ', print_fn = print):
21
+ defaults = dict(text=text, print_fn=print_fn)
22
+ super().__init__(defaults)
23
+
24
+ def step(self, var):
25
+ shapes = [u.shape for u in var.update] if var.update is not None else None
26
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
27
+ return var
28
+
29
+ class PrintParams(Module):
30
+ """Prints current update."""
31
+ def __init__(self, text = 'params = ', print_fn = print):
32
+ defaults = dict(text=text, print_fn=print_fn)
33
+ super().__init__(defaults)
34
+
35
+ def step(self, var):
36
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.params}')
37
+ return var
38
+
39
+
40
+ class PrintLoss(Module):
41
+ """Prints var.get_loss()."""
42
+ def __init__(self, text = 'loss = ', print_fn = print):
43
+ defaults = dict(text=text, print_fn=print_fn)
44
+ super().__init__(defaults)
45
+
46
+ def step(self, var):
47
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.get_loss(False)}')
48
+ return var
@@ -0,0 +1,60 @@
1
+ import torch
2
+
3
+ from ...core import Module
4
+ from ...utils import TensorList, NumberList
5
+
6
+
7
+ class EscapeAnnealing(Module):
8
+ """If parameters stop changing, this runs a backward annealing random search"""
9
+ def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
10
+ defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
11
+ super().__init__(defaults)
12
+
13
+
14
+ @torch.no_grad
15
+ def step(self, var):
16
+ closure = var.closure
17
+ if closure is None: raise RuntimeError("Escape requries closure")
18
+
19
+ params = TensorList(var.params)
20
+ settings = self.settings[params[0]]
21
+ max_region = self.get_settings(params, 'max_region', cls=NumberList)
22
+ max_iter = settings['max_iter']
23
+ tol = settings['tol']
24
+ n_tol = settings['n_tol']
25
+
26
+ n_bad = self.global_state.get('n_bad', 0)
27
+
28
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
29
+ diff = params-prev_params
30
+ prev_params.copy_(params)
31
+
32
+ if diff.abs().global_max() <= tol:
33
+ n_bad += 1
34
+
35
+ else:
36
+ n_bad = 0
37
+
38
+ self.global_state['n_bad'] = n_bad
39
+
40
+ # no progress
41
+ f_0 = var.get_loss(False)
42
+ if n_bad >= n_tol:
43
+ for i in range(1, max_iter+1):
44
+ alpha = max_region * (i / max_iter)
45
+ pert = params.sample_like(distribution='sphere').mul_(alpha)
46
+
47
+ params.add_(pert)
48
+ f_star = closure(False)
49
+
50
+ if f_star < f_0-1e-10:
51
+ var.update = None
52
+ var.stop = True
53
+ var.skip_update = True
54
+ return var
55
+
56
+ else:
57
+ params.sub_(pert)
58
+
59
+ self.global_state['n_bad'] = 0
60
+ return var
@@ -0,0 +1,70 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module
4
+
5
+
6
+ class GradientAccumulation(Module):
7
+ """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
+
9
+ Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
+ is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
+
12
+ .. note::
13
+ Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
+
15
+ Args:
16
+ modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
+ n (int): number of gradients to accumulate.
18
+ mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
+ stop (bool, optional):
20
+ this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
+
22
+ Examples:
23
+ Adam with gradients accumulated for 16 batches.
24
+
25
+ .. code-block:: python
26
+
27
+ opt = tz.Modular(
28
+ model.parameters(),
29
+ tz.m.GradientAccumulation(
30
+ modules=[tz.m.Adam(), tz.m.LR(1e-2)],
31
+ n=16
32
+ )
33
+ )
34
+
35
+ """
36
+ def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
+ defaults = dict(n=n, mean=mean, stop=stop)
38
+ super().__init__(defaults)
39
+ self.set_child('modules', modules)
40
+
41
+
42
+ @torch.no_grad
43
+ def step(self, var):
44
+ accumulator = self.get_state(var.params, 'accumulator')
45
+ settings = self.settings[var.params[0]]
46
+ n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
+
49
+ # add update to accumulator
50
+ torch._foreach_add_(accumulator, var.get_update())
51
+
52
+ # step with accumulated updates
53
+ if step % n == 0:
54
+ if mean:
55
+ torch._foreach_div_(accumulator, n)
56
+
57
+ var.update = [a.clone() for a in accumulator]
58
+ var = self.children['modules'].step(var)
59
+
60
+ # zero accumulator
61
+ torch._foreach_zero_(accumulator)
62
+
63
+ else:
64
+ # prevent update
65
+ if stop:
66
+ var.stop=True
67
+ var.skip_update=True
68
+
69
+ return var
70
+
@@ -0,0 +1,316 @@
1
+ from collections import deque
2
+ from collections.abc import Iterable
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
+ from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
+
11
+
12
+ class Previous(TensorwiseTransform):
13
+ """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
+ def __init__(self, n=1, target: Target = 'update'):
15
+ defaults = dict(n=n)
16
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
17
+
18
+
19
+ @torch.no_grad
20
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
21
+ n = setting['n']
22
+
23
+ if 'history' not in state:
24
+ state['history'] = deque(maxlen=n+1)
25
+
26
+ state['history'].append(tensor)
27
+
28
+ return state['history'][0]
29
+
30
+
31
+ class LastDifference(Transform):
32
+ """Outputs difference between past two updates."""
33
+ def __init__(self,target: Target = 'update'):
34
+ super().__init__({}, target=target)
35
+
36
+ @torch.no_grad
37
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
38
+ prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
39
+ difference = torch._foreach_sub(tensors, prev_tensors)
40
+ for p, c in zip(prev_tensors, tensors): p.set_(c)
41
+ return difference
42
+
43
+ class LastGradDifference(Module):
44
+ """Outputs difference between past two gradients."""
45
+ def __init__(self):
46
+ super().__init__({})
47
+
48
+ @torch.no_grad
49
+ def step(self, var):
50
+ grad = var.get_grad()
51
+ prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
52
+ difference = torch._foreach_sub(grad, prev_grad)
53
+ for p, c in zip(prev_grad, grad): p.copy_(c)
54
+ var.update = list(difference)
55
+ return var
56
+
57
+ class LastParamDifference(Module):
58
+ """Outputs difference between past two parameters, which is the effective previous update."""
59
+ def __init__(self):
60
+ super().__init__({})
61
+
62
+ @torch.no_grad
63
+ def step(self, var):
64
+ params = var.params
65
+ prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
66
+ difference = torch._foreach_sub(params, prev_params)
67
+ for p, c in zip(prev_params, params): p.copy_(c)
68
+ var.update = list(difference)
69
+ return var
70
+
71
+
72
+
73
+ class LastProduct(Transform):
74
+ """Outputs difference between past two updates."""
75
+ def __init__(self,target: Target = 'update'):
76
+ super().__init__({}, uses_grad=False, target=target)
77
+
78
+ @torch.no_grad
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
80
+ prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
81
+ prod = torch._foreach_mul(tensors, prev)
82
+ for p, c in zip(prev, tensors): p.set_(c)
83
+ return prod
84
+
85
+ class LastRatio(Transform):
86
+ """Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
87
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
88
+ defaults = dict(numerator=numerator)
89
+ super().__init__(defaults, uses_grad=False, target=target)
90
+
91
+ @torch.no_grad
92
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
93
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
94
+ numerator = settings[0]['numerator']
95
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
96
+ else: ratio = torch._foreach_div(prev, tensors)
97
+ for p, c in zip(prev, tensors): p.set_(c)
98
+ return ratio
99
+
100
+ class LastAbsoluteRatio(Transform):
101
+ """Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
102
+ def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
103
+ defaults = dict(numerator=numerator, eps=eps)
104
+ super().__init__(defaults, uses_grad=False, target=target)
105
+
106
+ @torch.no_grad
107
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
108
+ prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
109
+ numerator = settings[0]['numerator']
110
+ eps = NumberList(s['eps'] for s in settings)
111
+
112
+ torch._foreach_abs_(tensors)
113
+ torch._foreach_clamp_min_(prev, eps)
114
+
115
+ if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
116
+ else: ratio = torch._foreach_div(prev, tensors)
117
+ for p, c in zip(prev, tensors): p.set_(c)
118
+ return ratio
119
+
120
+ class GradSign(Transform):
121
+ """Copies gradient sign to update."""
122
+ def __init__(self, target: Target = 'update'):
123
+ super().__init__({}, uses_grad=True, target=target)
124
+
125
+ @torch.no_grad
126
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
127
+ assert grads is not None
128
+ return [t.copysign_(g) for t,g in zip(tensors, grads)]
129
+
130
+ class UpdateSign(Transform):
131
+ """Outputs gradient with sign copied from the update."""
132
+ def __init__(self, target: Target = 'update'):
133
+ super().__init__({}, uses_grad=True, target=target)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ assert grads is not None
138
+ return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
139
+
140
+ class GraftToGrad(Transform):
141
+ """Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
142
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
143
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
144
+ super().__init__(defaults, uses_grad=True, target=target)
145
+
146
+ @torch.no_grad
147
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
148
+ assert grads is not None
149
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
150
+ return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
151
+
152
+ class GraftGradToUpdate(Transform):
153
+ """Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
154
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
155
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
156
+ super().__init__(defaults, uses_grad=True, target=target)
157
+
158
+ @torch.no_grad
159
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
160
+ assert grads is not None
161
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
162
+ return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
163
+
164
+
165
+ class GraftToParams(Transform):
166
+ """Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
167
+ def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
168
+ defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
169
+ super().__init__(defaults, uses_grad=False, target=target)
170
+
171
+ @torch.no_grad
172
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
173
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
174
+ return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
175
+
176
+ class Relative(Transform):
177
+ """Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
178
+ def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
179
+ defaults = dict(min_value=min_value)
180
+ super().__init__(defaults, uses_grad=False, target=target)
181
+
182
+ @torch.no_grad
183
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
184
+ mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
185
+ torch._foreach_mul_(tensors, mul)
186
+ return tensors
187
+
188
+ class FillLoss(Module):
189
+ """Outputs tensors filled with loss value times :code:`alpha`"""
190
+ def __init__(self, alpha: float = 1, backward: bool = True):
191
+ defaults = dict(alpha=alpha, backward=backward)
192
+ super().__init__(defaults)
193
+
194
+ @torch.no_grad
195
+ def step(self, var):
196
+ alpha = self.get_settings(var.params, 'alpha')
197
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
198
+ var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
199
+ return var
200
+
201
+ class MulByLoss(Module):
202
+ """Multiplies update by loss times :code:`alpha`"""
203
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
204
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
205
+ super().__init__(defaults)
206
+
207
+ @torch.no_grad
208
+ def step(self, var):
209
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
210
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
211
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
+ torch._foreach_mul_(var.update, mul)
213
+ return var
214
+
215
+ class DivByLoss(Module):
216
+ """Divides update by loss times :code:`alpha`"""
217
+ def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
218
+ defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
219
+ super().__init__(defaults)
220
+
221
+ @torch.no_grad
222
+ def step(self, var):
223
+ alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
224
+ loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
225
+ mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
226
+ torch._foreach_div_(var.update, mul)
227
+ return var
228
+
229
+
230
+ class NoiseSign(Transform):
231
+ """Outputs random tensors with sign copied from the update."""
232
+ def __init__(self, distribution:Distributions = 'normal', alpha = 1):
233
+ defaults = dict(distribution=distribution, alpha=alpha)
234
+ super().__init__(defaults, uses_grad=False)
235
+
236
+ @torch.no_grad
237
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
238
+ alpha = [s['alpha'] for s in settings]
239
+ distribution = self.settings[params[0]]['distribution']
240
+ return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
241
+
242
+ class HpuEstimate(Transform):
243
+ """returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
244
+ def __init__(self):
245
+ defaults = dict()
246
+ super().__init__(defaults, uses_grad=False)
247
+
248
+ def reset_for_online(self):
249
+ super().reset_for_online()
250
+ self.clear_state_keys('prev_params', 'prev_update')
251
+
252
+ @torch.no_grad
253
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
254
+ prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
255
+ s = torch._foreach_sub(params, prev_params)
256
+ y = torch._foreach_sub(tensors, prev_update)
257
+ for p, c in zip(prev_params, params): p.copy_(c)
258
+ for p, c in zip(prev_update, tensors): p.copy_(c)
259
+ torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
260
+ self.store(params, ['s', 'y'], [s, y])
261
+
262
+ @torch.no_grad
263
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
264
+ return [self.state[p]['y'] for p in params]
265
+
266
+ class RandomHvp(Module):
267
+ """Returns a hessian-vector product with a random vector"""
268
+
269
+ def __init__(
270
+ self,
271
+ n_samples: int = 1,
272
+ distribution: Distributions = "normal",
273
+ update_freq: int = 1,
274
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
275
+ h=1e-3,
276
+ ):
277
+ defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
278
+ super().__init__(defaults)
279
+
280
+ @torch.no_grad
281
+ def step(self, var):
282
+ params = TensorList(var.params)
283
+ settings = self.settings[params[0]]
284
+ n_samples = settings['n_samples']
285
+ distribution = settings['distribution']
286
+ hvp_method = settings['hvp_method']
287
+ h = settings['h']
288
+ update_freq = settings['update_freq']
289
+
290
+ step = self.global_state.get('step', 0)
291
+ self.global_state['step'] = step + 1
292
+
293
+ D = None
294
+ if step % update_freq == 0:
295
+
296
+ rgrad = None
297
+ for i in range(n_samples):
298
+ u = params.sample_like(distribution=distribution)
299
+
300
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
301
+ h=h, normalize=True, retain_grad=i < n_samples-1)
302
+
303
+ if D is None: D = Hvp
304
+ else: torch._foreach_add_(D, Hvp)
305
+
306
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
307
+ if update_freq != 1:
308
+ assert D is not None
309
+ D_buf = self.get_state(params, "D", cls=TensorList)
310
+ D_buf.set_(D)
311
+
312
+ if D is None:
313
+ D = self.get_state(params, "D", cls=TensorList)
314
+
315
+ var.update = list(D)
316
+ return var
@@ -0,0 +1,158 @@
1
+ from collections.abc import Iterable
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, Module, Var
6
+ from ...utils import TensorList
7
+
8
+ def _sequential_step(self: Module, var: Var, sequential: bool):
9
+ params = var.params
10
+ steps = self.settings[params[0]]['steps']
11
+
12
+ if sequential: modules = self.get_children_sequence() * steps
13
+ else: modules = [self.children['module']] * steps
14
+
15
+ if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
16
+
17
+ # store original params unless this is last module and can update params directly
18
+ params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
19
+
20
+ # first step - pass var as usual
21
+ var = modules[0].step(var)
22
+ new_var = var
23
+
24
+ # subsequent steps - update parameters and create new var
25
+ if len(modules) > 1:
26
+ for m in modules[1:]:
27
+
28
+ # update params
29
+ if (not new_var.skip_update):
30
+ if new_var.last_module_lrs is not None:
31
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
32
+
33
+ torch._foreach_sub_(params, new_var.get_update())
34
+
35
+ # create new var since we are at a new point, that means grad, update and loss will be None
36
+ new_var = Var(params=new_var.params, closure=new_var.closure,
37
+ model=new_var.model, current_step=new_var.current_step + 1)
38
+
39
+ # step
40
+ new_var = m.step(new_var)
41
+
42
+ # final parameter update
43
+ if (not new_var.skip_update):
44
+ if new_var.last_module_lrs is not None:
45
+ torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
46
+
47
+ torch._foreach_sub_(params, new_var.get_update())
48
+
49
+ # if last module, update is applied so return new var
50
+ if params_before_steps is None:
51
+ new_var.stop = True
52
+ new_var.skip_update = True
53
+ return new_var
54
+
55
+ # otherwise use parameter difference as update
56
+ var.update = list(torch._foreach_sub(params_before_steps, params))
57
+ for p, bef in zip(params, params_before_steps):
58
+ p.set_(bef) # pyright:ignore[reportArgumentType]
59
+ return var
60
+
61
+ class Multistep(Module):
62
+ """Performs :code:`steps` inner steps with :code:`module` per each step.
63
+
64
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
65
+ def __init__(self, module: Chainable, steps: int):
66
+ defaults = dict(steps=steps)
67
+ super().__init__(defaults)
68
+ self.set_child('module', module)
69
+
70
+ @torch.no_grad
71
+ def step(self, var):
72
+ return _sequential_step(self, var, sequential=False)
73
+
74
+ class Sequential(Module):
75
+ """On each step, this sequentially steps with :code:`modules` :code:`steps` times.
76
+
77
+ The update is taken to be the parameter difference between parameters before and after the inner loop."""
78
+ def __init__(self, modules: Iterable[Chainable], steps: int=1):
79
+ defaults = dict(steps=steps)
80
+ super().__init__(defaults)
81
+ self.set_children_sequence(modules)
82
+
83
+ @torch.no_grad
84
+ def step(self, var):
85
+ return _sequential_step(self, var, sequential=True)
86
+
87
+
88
+ class NegateOnLossIncrease(Module):
89
+ """Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
90
+ if loss is larger than at :code:`parameters`,
91
+ the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
92
+ def __init__(self, backtrack=False):
93
+ defaults = dict(backtrack=backtrack)
94
+ super().__init__(defaults=defaults)
95
+
96
+ @torch.no_grad
97
+ def step(self, var):
98
+ closure = var.closure
99
+ if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
100
+ backtrack = self.settings[var.params[0]]['backtrack']
101
+
102
+ update = var.get_update()
103
+ f_0 = var.get_loss(backward=False)
104
+
105
+ torch._foreach_sub_(var.params, update)
106
+ f_1 = closure(False)
107
+
108
+ if f_1 <= f_0:
109
+ if var.is_last and var.last_module_lrs is None:
110
+ var.stop = True
111
+ var.skip_update = True
112
+ return var
113
+
114
+ torch._foreach_add_(var.params, update)
115
+ return var
116
+
117
+ torch._foreach_add_(var.params, update)
118
+ if backtrack:
119
+ torch._foreach_neg_(var.update)
120
+ else:
121
+ torch._foreach_zero_(var.update)
122
+ return var
123
+
124
+
125
+ class Online(Module):
126
+ """Allows certain modules to be used for mini-batch optimization."""
127
+ def __init__(self, module: Chainable,):
128
+ super().__init__()
129
+
130
+ self.set_child('module', module)
131
+
132
+ @torch.no_grad
133
+ def step(self, var):
134
+ closure = var.closure
135
+ if closure is None: raise ValueError("Closure must be passed for Online")
136
+ step = self.global_state.get('step', 0) + 1
137
+ self.global_state['step'] = step
138
+ params = TensorList(var.params)
139
+ p_cur = params.clone()
140
+ p_prev = self.get_state(params, 'p_prev', cls=TensorList)
141
+ module = self.children['module']
142
+
143
+ if step == 1:
144
+ var = module.step(var.clone(clone_update=False))
145
+
146
+ p_prev.copy_(params)
147
+ return var
148
+
149
+ # restore previous params
150
+ var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
151
+ params.set_(p_prev)
152
+ module.reset_for_online()
153
+ module.update(var_prev)
154
+
155
+ # restore current params
156
+ params.set_(p_cur)
157
+ p_prev.copy_(params)
158
+ return module.step(var.clone(clone_update=False))