torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ import torch
4
+
5
+ from ...core import Transform, Chainable, Module, Var, apply_transform
6
+ from ...utils import TensorList, as_tensorlist, NumberList
7
+ from ...modules.quasi_newton.lbfgs import _adaptive_damping, lbfgs, _lerp_params_update_
8
+
9
+ class ExpandedLBFGS(Module):
10
+ """L-BFGS but uses differences between more pairs than just consequtive. Window size controls how far away the pairs are allowed to be.
11
+ """
12
+ def __init__(
13
+ self,
14
+ history_size=10,
15
+ window_size:int=3,
16
+ tol: float | None = 1e-10,
17
+ damping: bool = False,
18
+ init_damping=0.9,
19
+ eigval_bounds=(0.5, 50),
20
+ params_beta: float | None = None,
21
+ grads_beta: float | None = None,
22
+ update_freq = 1,
23
+ z_beta: float | None = None,
24
+ tol_reset: bool = False,
25
+ inner: Chainable | None = None,
26
+ ):
27
+ defaults = dict(history_size=history_size, window_size=window_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
28
+ super().__init__(defaults)
29
+
30
+ self.global_state['s_history'] = deque(maxlen=history_size)
31
+ self.global_state['y_history'] = deque(maxlen=history_size)
32
+ self.global_state['sy_history'] = deque(maxlen=history_size)
33
+ self.global_state['p_history'] = deque(maxlen=window_size)
34
+ self.global_state['g_history'] = deque(maxlen=window_size)
35
+
36
+ if inner is not None:
37
+ self.set_child('inner', inner)
38
+
39
+ def reset(self):
40
+ self.state.clear()
41
+ self.global_state['step'] = 0
42
+ self.global_state['s_history'].clear()
43
+ self.global_state['y_history'].clear()
44
+ self.global_state['sy_history'].clear()
45
+ self.global_state['p_history'].clear()
46
+ self.global_state['g_history'].clear()
47
+
48
+ @torch.no_grad
49
+ def step(self, var):
50
+ params = as_tensorlist(var.params)
51
+ update = as_tensorlist(var.get_update())
52
+ step = self.global_state.get('step', 0)
53
+ self.global_state['step'] = step + 1
54
+
55
+ # history of s and k
56
+ s_history: deque[TensorList] = self.global_state['s_history']
57
+ y_history: deque[TensorList] = self.global_state['y_history']
58
+ sy_history: deque[torch.Tensor] = self.global_state['sy_history']
59
+ p_history: deque[TensorList] = self.global_state['p_history']
60
+ g_history: deque[TensorList] = self.global_state['g_history']
61
+
62
+ tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
63
+ 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
64
+ params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
65
+
66
+ l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
67
+ prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
68
+
69
+ # 1st step - there are no previous params and grads, lbfgs will do normalized GD step
70
+ if step == 0:
71
+ s = None; y = None; ys = None
72
+ else:
73
+ s = l_params - prev_l_params
74
+ y = l_update - prev_l_grad
75
+ ys = s.dot(y)
76
+
77
+ if damping:
78
+ s, y, ys = _adaptive_damping(s, y, ys, init_damping=init_damping, eigval_bounds=eigval_bounds)
79
+
80
+ prev_l_params.copy_(l_params)
81
+ prev_l_grad.copy_(l_update)
82
+
83
+ # update effective preconditioning state
84
+ if step % update_freq == 0:
85
+ if ys is not None and ys > 1e-10:
86
+ assert s is not None and y is not None
87
+ s_history.append(s)
88
+ y_history.append(y)
89
+ sy_history.append(ys)
90
+
91
+ if len(p_history) > 1:
92
+ for p_i, g_i in zip(list(p_history)[:-1], list(g_history)[:-1]):
93
+ s_i = l_params - p_i
94
+ y_i = l_update - g_i
95
+ ys_i = s_i.dot(y_i)
96
+
97
+ if ys_i > 1e-10:
98
+ if damping:
99
+ s_i, y_i, ys_i = _adaptive_damping(s_i, y_i, ys_i, init_damping=init_damping, eigval_bounds=eigval_bounds)
100
+
101
+ s_history.append(s_i)
102
+ y_history.append(y_i)
103
+ sy_history.append(ys_i)
104
+
105
+ p_history.append(l_params.clone())
106
+ g_history.append(l_update.clone())
107
+
108
+
109
+ # step with inner module before applying preconditioner
110
+ if self.children:
111
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
112
+
113
+ # tolerance on gradient difference to avoid exploding after converging
114
+ if tol is not None:
115
+ if y is not None and y.abs().global_max() <= tol:
116
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
117
+ if tol_reset: self.reset()
118
+ return var
119
+
120
+ # lerp initial H^-1 @ q guess
121
+ z_ema = None
122
+ if z_beta is not None:
123
+ z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
124
+
125
+ # precondition
126
+ dir = lbfgs(
127
+ tensors_=as_tensorlist(update),
128
+ s_history=s_history,
129
+ y_history=y_history,
130
+ sy_history=sy_history,
131
+ y=y,
132
+ sy=ys,
133
+ z_beta = z_beta,
134
+ z_ema = z_ema,
135
+ step=step
136
+ )
137
+
138
+ var.update = dir
139
+
140
+ return var
141
+
@@ -2,12 +2,12 @@ import torch
2
2
 
3
3
  from ...core import Chainable
4
4
  from ...utils import vec_to_tensors
5
- from .projection import Projection
5
+ from ..projections import ProjectionBase
6
6
 
7
7
 
8
- class FFTProjection(Projection):
8
+ class FFTProjection(ProjectionBase):
9
9
  # norm description copied from pytorch docstring
10
- """Project update into Fourrier space of real-valued inputs.
10
+ """Project update into Fourier space of real-valued inputs.
11
11
 
12
12
  Args:
13
13
  modules (Chainable): modules that will optimize the projected update.
@@ -45,8 +45,8 @@ class FFTProjection(Projection):
45
45
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
46
46
 
47
47
  @torch.no_grad
48
- def project(self, tensors, var, current):
49
- settings = self.settings[var.params[0]]
48
+ def project(self, tensors, params, grads, loss, states, settings, current):
49
+ settings = settings[0]
50
50
  one_d = settings['one_d']
51
51
  norm = settings['norm']
52
52
 
@@ -60,14 +60,14 @@ class FFTProjection(Projection):
60
60
  return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
61
61
 
62
62
  @torch.no_grad
63
- def unproject(self, tensors, var, current):
64
- settings = self.settings[var.params[0]]
63
+ def unproject(self, projected_tensors, params, grads, loss, projected_states, projected_settings, current):
64
+ settings = projected_settings[0]
65
65
  one_d = settings['one_d']
66
66
  norm = settings['norm']
67
67
 
68
68
  if one_d:
69
- vec = torch.view_as_complex(tensors[0])
69
+ vec = torch.view_as_complex(projected_tensors[0])
70
70
  unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
71
- return vec_to_tensors(unprojected_vec, reference=var.params)
71
+ return vec_to_tensors(unprojected_vec, reference=params)
72
72
 
73
- return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, var.params)] # pylint:disable=not-callable
73
+ return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(projected_tensors, params)] # pylint:disable=not-callable
@@ -0,0 +1,85 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import TensorwiseTransform
6
+
7
+
8
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor):
9
+ try:
10
+ L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
11
+ return Q @ ((Q.mH @ g) / L)
12
+ except torch.linalg.LinAlgError:
13
+ return None
14
+
15
+
16
+ class HNewton(TensorwiseTransform):
17
+ """This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
18
+
19
+ .. warning::
20
+ Experimental.
21
+
22
+ """
23
+ def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
24
+ defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
25
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
26
+
27
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
28
+
29
+ history_size = setting['history_size']
30
+
31
+ if 'param_history' not in state:
32
+ state['param_history'] = deque(maxlen=history_size)
33
+ state['grad_history'] = deque(maxlen=history_size)
34
+
35
+ param_history: deque = state['param_history']
36
+ grad_history: deque = state['grad_history']
37
+ param_history.append(param.ravel())
38
+ grad_history.append(tensor.ravel())
39
+
40
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
41
+ window_size = setting['window_size']
42
+ reg = setting['reg']
43
+ tol = setting['tol']
44
+
45
+ param_history: deque = state['param_history']
46
+ grad_history: deque = state['grad_history']
47
+ g = tensor.ravel()
48
+
49
+ n = len(param_history)
50
+ s_list = []
51
+ y_list = []
52
+
53
+ for i in range(n):
54
+ for j in range(i):
55
+ if i - j <= window_size:
56
+ p_i, g_i = param_history[i], grad_history[i]
57
+ p_j, g_j = param_history[j], grad_history[j]
58
+ s = p_i - p_j # vec in hvp
59
+ y = g_i - g_j # hvp
60
+ if s.dot(y) > tol:
61
+ s_list.append(s)
62
+ y_list.append(y)
63
+
64
+ if len(s_list) < 1:
65
+ scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
66
+ tensor.mul_(scale)
67
+ return tensor
68
+
69
+ S = torch.stack(s_list, 1)
70
+ Y = torch.stack(y_list, 1)
71
+
72
+ B = S.T @ Y
73
+ if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
74
+ g_proj = g @ S
75
+
76
+ newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
77
+ if info != 0:
78
+ newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
79
+ newton = S @ newton_proj
80
+ return newton.view_as(tensor)
81
+
82
+
83
+ # scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
84
+ # tensor.mul_(scale)
85
+ # return tensor
@@ -4,8 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
- from ....core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
- from ....utils import NumberList, TensorList, as_tensorlist
7
+ from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
9
 
10
10
 
11
11
  def _adaptive_damping(
@@ -43,32 +43,31 @@ def lbfgs(
43
43
  if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
44
  return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
45
45
 
46
- else:
47
- # 1st loop
48
- alpha_list = []
49
- q = tensors_.clone()
50
- for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
51
- p_i = 1 / ys_i # this is also denoted as ρ (rho)
52
- alpha = p_i * s_i.dot(q)
53
- alpha_list.append(alpha)
54
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
55
-
56
- # calculate z
57
- # s.y/y.y is also this weird y-looking symbol I couldn't find
58
- # z is it times q
59
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
60
- z = q * (ys_k / (y_k.dot(y_k)))
61
-
62
- if z_tfm is not None:
63
- z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
64
-
65
- # 2nd loop
66
- for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
67
- p_i = 1 / ys_i
68
- beta_i = p_i * y_i.dot(z)
69
- z.add_(s_i, alpha = alpha_i - beta_i)
70
-
71
- return z
46
+ # 1st loop
47
+ alpha_list = []
48
+ q = tensors_.clone()
49
+ for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
+ p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
+ alpha = p_i * s_i.dot(q)
52
+ alpha_list.append(alpha)
53
+ q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
+
55
+ # calculate z
56
+ # s.y/y.y is also this weird y-looking symbol I couldn't find
57
+ # z is it times q
58
+ # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
+ z = q * (ys_k / (y_k.dot(y_k)))
60
+
61
+ if z_tfm is not None:
62
+ z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
63
+
64
+ # 2nd loop
65
+ for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
66
+ p_i = 1 / ys_i
67
+ beta_i = p_i * y_i.dot(z)
68
+ z.add_(s_i, alpha = alpha_i - beta_i)
69
+
70
+ return z
72
71
 
73
72
  def _apply_tfms_into_history(
74
73
  self: Module,
@@ -22,8 +22,9 @@ from ..second_order.newton import (
22
22
 
23
23
 
24
24
  class NewtonNewton(Module):
25
- """
26
- Method that I thought of and then it worked.
25
+ """Applies Newton-like preconditioning to Newton step.
26
+
27
+ This is a method that I thought of and then it worked. Here is how it works:
27
28
 
28
29
  1. Calculate newton step by solving Hx=g
29
30
 
@@ -34,6 +35,9 @@ class NewtonNewton(Module):
34
35
  4. Optionally, repeat (if order is higher than 3.)
35
36
 
36
37
  Memory is n^order. It tends to converge faster on convex functions, but can be unstable on non-convex. Orders higher than 3 are usually too unsable and have little benefit.
38
+
39
+ 3rd order variant can minimize some convex functions with up to 100 variables in less time than Newton's method,
40
+ this is if pytorch can vectorize hessian computation efficiently.
37
41
  """
38
42
  def __init__(
39
43
  self,
@@ -83,6 +87,6 @@ class NewtonNewton(Module):
83
87
  if x is None: x = least_squares_solve(H, xp)
84
88
  xp = x.squeeze()
85
89
 
86
- var.update = vec_to_tensors(xp, params)
90
+ var.update = vec_to_tensors(xp.nan_to_num_(0,0,0), params)
87
91
  return var
88
92
 
@@ -0,0 +1,220 @@
1
+ import math
2
+ from collections.abc import Mapping
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+
7
+ from ...core import Module
8
+ from ...utils import TensorList
9
+
10
+
11
+
12
+ def adaptive_tracking(
13
+ f,
14
+ f_0,
15
+ f_1,
16
+ t_0,
17
+ maxiter: int
18
+ ):
19
+
20
+ t = t_0
21
+ f_t = f(t)
22
+
23
+ # backtrack
24
+ if f_t > f_0:
25
+ if f_1 > f_0: t = min(0.5, t_0/2)
26
+ while f_t > f_0:
27
+ maxiter -= 1
28
+ if maxiter < 0: return 0, f_0
29
+ t = t/2
30
+ f_t = f(t) if t!=1 else f_1
31
+ return t, f_t
32
+
33
+ # forwardtrack
34
+ f_prev = f_t
35
+ t *= 2
36
+ f_t = f(t)
37
+ if f_prev < f_t: return t/2, f_prev
38
+ while f_prev >= f_t:
39
+ maxiter -= 1
40
+ if maxiter < 0: return t, f_t
41
+ f_prev = f_t
42
+ t *= 2
43
+ f_t = f(t)
44
+ return t/2, f_prev
45
+
46
+
47
+
48
+ class ParabolaSearch(Module):
49
+ """"""
50
+ def __init__(
51
+ self,
52
+ step_size: float = 1e-2,
53
+ adaptive: bool=True,
54
+ normalize: bool=False,
55
+ # method: str | None = None,
56
+ maxiter: int | None = 10,
57
+ # bracket=None,
58
+ # bounds=None,
59
+ # tol: float | None = None,
60
+ # options=None,
61
+ ):
62
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
63
+ defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
64
+ super().__init__(defaults)
65
+
66
+ import scipy.optimize
67
+ self.scopt = scipy.optimize
68
+
69
+
70
+ @torch.no_grad
71
+ def step(self, var):
72
+ x_0 = TensorList(var.params)
73
+ closure = var.closure
74
+ assert closure is not None
75
+ settings = self.settings[x_0[0]]
76
+ step_size = settings['step_size']
77
+ adaptive = settings['adaptive']
78
+ normalize = settings['normalize']
79
+ maxiter = settings['maxiter']
80
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
81
+
82
+ grad = TensorList(var.get_grad())
83
+ f_0 = var.get_loss(False)
84
+
85
+ scale = 1
86
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
87
+ if adaptive: scale = grad.abs().mean().clip(min=1e-8)
88
+
89
+ # make step
90
+ v_0 = grad * (step_size/scale)
91
+ x_0 -= v_0
92
+ with torch.enable_grad():
93
+ f_1 = closure()
94
+ grad = x_0.grad
95
+
96
+ x_0 += v_0
97
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
98
+ v_1 = grad * (step_size/scale)
99
+ a = v_1 - v_0
100
+
101
+ def parabolic_objective(t: float):
102
+ nonlocal x_0
103
+
104
+ step = v_0*t + 0.5*a*t**2
105
+ x_0 -= step
106
+ value = closure(False)
107
+ x_0 += step
108
+ return value.detach().cpu()
109
+
110
+ prev_t = self.global_state.get('prev_t', 2)
111
+ t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
112
+ self.global_state['prev_t'] = t
113
+
114
+ # method, bracket, bounds, tol, options, maxiter = itemgetter(
115
+ # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
116
+
117
+ # if maxiter is not None:
118
+ # options = dict(options) if isinstance(options, Mapping) else {}
119
+ # options['maxiter'] = maxiter
120
+
121
+ # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
122
+ # t = res.x
123
+
124
+ var.update = v_0*t + 0.5*a*t**2
125
+ return var
126
+
127
+ class CubicParabolaSearch(Module):
128
+ """"""
129
+ def __init__(
130
+ self,
131
+ step_size: float = 1e-2,
132
+ adaptive: bool=True,
133
+ normalize: bool=False,
134
+ # method: str | None = None,
135
+ maxiter: int | None = 10,
136
+ # bracket=None,
137
+ # bounds=None,
138
+ # tol: float | None = None,
139
+ # options=None,
140
+ ):
141
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
142
+ defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
143
+ super().__init__(defaults)
144
+
145
+ import scipy.optimize
146
+ self.scopt = scipy.optimize
147
+
148
+
149
+ @torch.no_grad
150
+ def step(self, var):
151
+ x_0 = TensorList(var.params)
152
+ closure = var.closure
153
+ assert closure is not None
154
+ settings = self.settings[x_0[0]]
155
+ step_size = settings['step_size']
156
+ adaptive = settings['adaptive']
157
+ maxiter = settings['maxiter']
158
+ normalize = settings['normalize']
159
+ if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
160
+
161
+ grad = TensorList(var.get_grad())
162
+ f_0 = var.get_loss(False)
163
+
164
+ scale = 1
165
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
166
+ if adaptive: scale = grad.abs().mean().clip(min=1e-8)
167
+
168
+ # make step
169
+ v_0 = grad * (step_size/scale)
170
+ x_0 -= v_0
171
+ with torch.enable_grad():
172
+ f_1 = closure()
173
+ grad = x_0.grad
174
+
175
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
176
+ v_1 = grad * (step_size/scale)
177
+ a_0 = v_1 - v_0
178
+
179
+ # make another step
180
+ x_0 -= v_1
181
+ with torch.enable_grad():
182
+ f_2 = closure()
183
+ grad = x_0.grad
184
+
185
+ if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
186
+ v_2 = grad * (step_size/scale)
187
+ a_1 = v_2 - v_1
188
+
189
+ j = a_1 - a_0
190
+
191
+ x_0 += v_0
192
+ x_0 += v_1
193
+
194
+ def parabolic_objective(t: float):
195
+ nonlocal x_0
196
+
197
+ step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
198
+ x_0 -= step
199
+ value = closure(False)
200
+ x_0 += step
201
+ return value
202
+
203
+
204
+ prev_t = self.global_state.get('prev_t', 2)
205
+ t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
206
+ self.global_state['prev_t'] = t
207
+
208
+ # method, bracket, bounds, tol, options, maxiter = itemgetter(
209
+ # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
210
+
211
+ # if maxiter is not None:
212
+ # options = dict(options) if isinstance(options, Mapping) else {}
213
+ # options['maxiter'] = maxiter
214
+
215
+ # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
216
+ # t = res.x
217
+
218
+ var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
219
+ return var
220
+
@@ -4,19 +4,19 @@ from ...core import Target, Transform
4
4
  from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
6
  class ReduceOutwardLR(Transform):
7
- """
8
- When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
7
+ """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
9
8
 
10
9
  This means updates that move weights towards zero have higher learning rates.
11
10
 
12
- A note on this is that it sounded good but its really bad in practice.
11
+ .. warning::
12
+ This sounded good but after testing turns out it sucks.
13
13
  """
14
14
  def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
15
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
16
16
  super().__init__(defaults, uses_grad=use_grad, target=target)
17
17
 
18
18
  @torch.no_grad
19
- def apply(self, tensors, params, grads, loss, states, settings):
19
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
20
20
  params = TensorList(params)
21
21
  tensors = TensorList(tensors)
22
22