torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,21 +3,14 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import (
7
- Chainable,
8
- Modular,
9
- Module,
10
- Transform,
11
- Var,
12
- apply_transform,
13
- )
14
- from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
15
- from ..line_search import LineSearchBase
6
+ from ...core import Chainable, TensorTransform
7
+
8
+ from ...utils import TensorList, safe_dict_update_, unpack_dicts, unpack_states
16
9
  from ..quasi_newton.quasi_newton import HessianUpdateStrategy
17
10
  from ..functional import safe_clip
18
11
 
19
12
 
20
- class ConguateGradientBase(Transform, ABC):
13
+ class ConguateGradientBase(TensorTransform, ABC):
21
14
  """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
22
15
 
23
16
  This is an abstract class, to use it, subclass it and override `get_beta`.
@@ -52,13 +45,8 @@ class ConguateGradientBase(Transform, ABC):
52
45
  """
53
46
  def __init__(self, defaults, clip_beta: bool, restart_interval: int | None | Literal['auto'], inner: Chainable | None = None):
54
47
  if defaults is None: defaults = {}
55
- defaults['restart_interval'] = restart_interval
56
- defaults['clip_beta'] = clip_beta
57
- super().__init__(defaults, uses_grad=False)
58
-
59
- if inner is not None:
60
- self.set_child('inner', inner)
61
-
48
+ safe_dict_update_(defaults, dict(restart_interval=restart_interval, clip_beta=clip_beta))
49
+ super().__init__(defaults, inner=inner)
62
50
 
63
51
  def reset_for_online(self):
64
52
  super().reset_for_online()
@@ -74,40 +62,38 @@ class ConguateGradientBase(Transform, ABC):
74
62
  """returns beta"""
75
63
 
76
64
  @torch.no_grad
77
- def update_tensors(self, tensors, params, grads, loss, states, settings):
78
- tensors = as_tensorlist(tensors)
79
- params = as_tensorlist(params)
80
-
81
- step = self.global_state.get('step', 0) + 1
82
- self.global_state['step'] = step
65
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
66
+ tensors = TensorList(tensors)
67
+ params = TensorList(params)
68
+ self.increment_counter("step", start=0)
83
69
 
84
70
  # initialize on first step
85
- if self.global_state.get('stage', 0) == 0:
71
+ if self.global_state.get('stage', "first step") == "first update":
86
72
  g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
87
73
  d_prev.copy_(tensors)
88
74
  g_prev.copy_(tensors)
89
75
  self.initialize(params, tensors)
90
- self.global_state['stage'] = 1
76
+ self.global_state['stage'] = "first apply"
91
77
 
92
78
  else:
93
79
  # if `update_tensors` was called multiple times before `apply_tensors`,
94
80
  # stage becomes 2
95
- self.global_state['stage'] = 2
81
+ self.global_state['stage'] = "initialized"
96
82
 
97
83
  @torch.no_grad
98
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
99
- tensors = as_tensorlist(tensors)
84
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
85
+ tensors = TensorList(tensors)
100
86
  step = self.global_state['step']
101
87
 
102
- if 'inner' in self.children:
103
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
88
+ assert self.global_state['stage'] != "first update"
104
89
 
105
- assert self.global_state['stage'] != 0
106
- if self.global_state['stage'] == 1:
107
- self.global_state['stage'] = 2
90
+ # on 1st apply we don't have previous gradients
91
+ # so just return tensors
92
+ if self.global_state['stage'] == "first apply":
93
+ self.global_state['stage'] = "initialized"
108
94
  return tensors
109
95
 
110
- params = as_tensorlist(params)
96
+ params = TensorList(params)
111
97
  g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
112
98
 
113
99
  # get beta
@@ -119,10 +105,13 @@ class ConguateGradientBase(Transform, ABC):
119
105
  dir = tensors.add_(d_prev.mul_(beta))
120
106
  d_prev.copy_(dir)
121
107
 
122
- # resetting
108
+ # resetting every `reset_interval` steps, use step+1 to not reset on 1st step
109
+ # so if reset_interval=2, then 1st step collects g_prev and d_prev, then
110
+ # two steps will happen until reset.
123
111
  restart_interval = settings[0]['restart_interval']
124
112
  if restart_interval == 'auto': restart_interval = tensors.global_numel() + 1
125
- if restart_interval is not None and step % restart_interval == 0:
113
+
114
+ if restart_interval is not None and (step + 1) % restart_interval == 0:
126
115
  self.state.clear()
127
116
  self.global_state.clear()
128
117
 
@@ -1,4 +1,5 @@
1
1
  """Those are various ideas of mine plus some other modules that I decided not to move to other sub-packages for whatever reason. This is generally less tested and shouldn't be used."""
2
+ from .coordinate_momentum import CoordinateMomentum
2
3
  from .curveball import CurveBall
3
4
 
4
5
  # from dct import DCTProjection
@@ -6,14 +7,9 @@ from .fft import FFTProjection
6
7
  from .gradmin import GradMin
7
8
  from .higher_order_newton import HigherOrderNewton
8
9
  from .l_infinity import InfinityNormTrustRegion
9
- from .momentum import (
10
- CoordinateMomentum,
11
- NesterovEMASquared,
12
- PrecenteredEMASquared,
13
- SqrtNesterovEMASquared,
14
- )
15
10
  from .newton_solver import NewtonSolver
16
11
  from .newtonnewton import NewtonNewton
17
12
  from .reduce_outward_lr import ReduceOutwardLR
18
13
  from .scipy_newton_cg import ScipyNewtonCG
14
+ from .spsa1 import SPSA1
19
15
  from .structural_projections import BlockPartition, TensorizeProjection
@@ -0,0 +1,36 @@
1
+ import torch
2
+
3
+ from ...core import TensorTransform
4
+ from ...utils import NumberList, TensorList, unpack_states
5
+
6
+
7
+ def coordinate_momentum_(
8
+ tensors: TensorList,
9
+ velocity_: TensorList,
10
+ p: float | NumberList,
11
+ ):
12
+ """
13
+ sets `velocity_` to p% random values from `tensors`.
14
+
15
+ Returns `velocity_`
16
+ """
17
+ mask = tensors.bernoulli_like(p).as_bool()
18
+ velocity_.masked_set_(mask, tensors)
19
+ return velocity_
20
+
21
+
22
+ class CoordinateMomentum(TensorTransform):
23
+ """Maintains a momentum buffer, on each step each value in the buffer has ``p`` chance to be updated with the new value.
24
+
25
+ Args:
26
+ p (float, optional): _description_. Defaults to 0.1.
27
+ """
28
+ def __init__(self, p: float = 0.1):
29
+ defaults = dict(p=p)
30
+ super().__init__(defaults)
31
+
32
+ @torch.no_grad
33
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
34
+ p = NumberList(s['p'] for s in settings)
35
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
36
+ return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
@@ -1,25 +1,25 @@
1
1
  from typing import Literal
2
- from collections.abc import Callable
2
+
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform, Chainable, apply_transform
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
5
+ from ...core import Chainable, Transform, step, HVPMethod
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
+
8
8
 
9
9
  def curveball(
10
10
  tensors: TensorList,
11
11
  z_: TensorList,
12
- Hz: TensorList,
12
+ Hzz: TensorList,
13
13
  momentum: float | NumberList,
14
14
  precond_lr: float | NumberList,
15
15
  ):
16
16
  """returns z_, clone it!!! (no just negate it)"""
17
- delta = Hz + tensors
17
+ delta = Hzz + tensors
18
18
  z_.mul_(momentum).sub_(delta.mul_(precond_lr)) # z ← ρz − βΔ
19
19
  return z_
20
20
 
21
21
 
22
- class CurveBall(Module):
22
+ class CurveBall(Transform):
23
23
  """CurveBall method from https://arxiv.org/pdf/1805.08095#page=4.09.
24
24
 
25
25
  For now this implementation does not include automatic ρ, α and β hyper-parameters in closed form, therefore it is expected to underperform compared to official implementation (https://github.com/jotaf98/pytorch-curveball/tree/master) so I moved this to experimental.
@@ -36,7 +36,7 @@ class CurveBall(Module):
36
36
  self,
37
37
  precond_lr: float=1e-3,
38
38
  momentum: float=0.9,
39
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
39
+ hvp_method: HVPMethod = "autograd",
40
40
  h: float = 1e-3,
41
41
  reg: float = 1,
42
42
  inner: Chainable | None = None,
@@ -44,46 +44,30 @@ class CurveBall(Module):
44
44
  defaults = dict(precond_lr=precond_lr, momentum=momentum, hvp_method=hvp_method, h=h, reg=reg)
45
45
  super().__init__(defaults)
46
46
 
47
- if inner is not None: self.set_child('inner', inner)
47
+ self.set_child('inner', inner)
48
48
 
49
49
  @torch.no_grad
50
- def step(self, var):
51
-
52
- params = var.params
53
- settings = self.settings[params[0]]
54
- hvp_method = settings['hvp_method']
55
- h = settings['h']
50
+ def apply_states(self, objective, states, settings):
51
+ params = objective.params
52
+ fs = settings[0]
53
+ hvp_method = fs['hvp_method']
54
+ h = fs['h']
56
55
 
57
- precond_lr, momentum, reg = self.get_settings(params, 'precond_lr', 'momentum', 'reg', cls=NumberList)
56
+ precond_lr, momentum, reg = unpack_dicts(settings, 'precond_lr', 'momentum', 'reg', cls=NumberList)
58
57
 
59
-
60
- closure = var.closure
58
+ closure = objective.closure
61
59
  assert closure is not None
62
60
 
63
- z, Hz = self.get_state(params, 'z', 'Hz', cls=TensorList)
64
-
65
- if hvp_method == 'autograd':
66
- grad = var.get_grad(create_graph=True)
67
- Hvp = hvp(params, grad, z)
68
-
69
- elif hvp_method == 'forward':
70
- loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
71
-
72
- elif hvp_method == 'central':
73
- loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
74
-
75
- else:
76
- raise ValueError(hvp_method)
77
-
78
-
79
- Hz.set_(Hvp + z*reg)
61
+ z, Hz = unpack_states(states, params, 'z', 'Hz', cls=TensorList)
62
+ Hz, _ = objective.hessian_vector_product(z, rgrad=None, at_x0=True, hvp_method=hvp_method, h=h)
80
63
 
64
+ Hz = TensorList(Hz)
65
+ Hzz = Hz.add_(z * reg)
81
66
 
82
- update = var.get_update()
83
- if 'inner' in self.children:
84
- update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
67
+ objective = self.inner_step("inner", objective, must_exist=False)
68
+ updates = objective.get_updates()
85
69
 
86
- z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
87
- var.update = z.neg()
70
+ z = curveball(TensorList(updates), z, Hzz, momentum=momentum, precond_lr=precond_lr)
71
+ objective.updates = z.neg()
88
72
 
89
- return var
73
+ return objective
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Var, Chainable
8
+ from ...core import Module, Objective, Chainable
9
9
  from ...utils import NumberList, TensorList
10
10
  from ...utils.derivatives import jacobian_wrt
11
11
  from ..grad_approximation import GradApproximator, GradTarget
@@ -43,7 +43,7 @@ class GradMin(Reformulation):
43
43
  super().__init__(defaults, modules=modules)
44
44
 
45
45
  @torch.no_grad
46
- def closure(self, backward, closure, params, var):
46
+ def closure(self, backward, closure, params, objective):
47
47
  settings = self.settings[params[0]]
48
48
  loss_term = settings['loss_term']
49
49
  relative = settings['relative']
@@ -1,21 +1,12 @@
1
- import itertools
2
1
  import math
3
- import warnings
4
- from collections.abc import Callable
5
- from contextlib import nullcontext
6
- from functools import partial
7
2
  from typing import Any, Literal
8
3
 
9
4
  import numpy as np
10
5
  import scipy.optimize
11
6
  import torch
12
7
 
13
- from ...core import Chainable, Module, apply_transform
8
+ from ...core import DerivativesMethod, Module
14
9
  from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
- from ...utils.derivatives import (
16
- flatten_jacobian,
17
- jacobian_wrt,
18
- )
19
10
 
20
11
  _LETTERS = 'abcdefghijklmnopqrstuvwxyz'
21
12
  def _poly_eval(s: np.ndarray, c, derivatives):
@@ -195,22 +186,22 @@ class HigherOrderNewton(Module):
195
186
  max_attempts = 10,
196
187
  boundary_tol: float = 1e-2,
197
188
  de_iters: int | None = None,
198
- vectorize: bool = True,
189
+ derivatives_method: DerivativesMethod = "batched_autograd",
199
190
  ):
200
191
  if init is None:
201
192
  if trust_method == 'bounds': init = 1
202
193
  else: init = 0.1
203
194
 
204
- defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
195
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad, derivatives_method=derivatives_method)
205
196
  super().__init__(defaults)
206
197
 
207
198
  @torch.no_grad
208
- def step(self, var):
209
- params = TensorList(var.params)
210
- closure = var.closure
199
+ def apply(self, objective):
200
+ params = TensorList(objective.params)
201
+ closure = objective.closure
211
202
  if closure is None: raise RuntimeError('HigherOrderNewton requires closure')
212
203
 
213
- settings = self.settings[params[0]]
204
+ settings = self.defaults
214
205
  order = settings['order']
215
206
  nplus = settings['nplus']
216
207
  nminus = settings['nminus']
@@ -219,31 +210,12 @@ class HigherOrderNewton(Module):
219
210
  trust_method = settings['trust_method']
220
211
  de_iters = settings['de_iters']
221
212
  max_attempts = settings['max_attempts']
222
- vectorize = settings['vectorize']
223
213
  boundary_tol = settings['boundary_tol']
224
214
  rho_good = settings['rho_good']
225
215
  rho_bad = settings['rho_bad']
226
216
 
227
217
  # ------------------------ calculate grad and hessian ------------------------ #
228
- with torch.enable_grad():
229
- loss = var.loss = var.loss_approx = closure(False)
230
-
231
- g_list = torch.autograd.grad(loss, params, create_graph=True)
232
- var.grad = list(g_list)
233
-
234
- g = torch.cat([t.ravel() for t in g_list])
235
- n = g.numel()
236
- derivatives = [g]
237
- T = g # current derivatives tensor
238
-
239
- # get all derivative up to order
240
- for o in range(2, order + 1):
241
- is_last = o == order
242
- T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
243
- with torch.no_grad() if is_last else nullcontext():
244
- # the shape is (ndim, ) * order
245
- T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
246
- derivatives.append(T)
218
+ loss, *derivatives = objective.derivatives(order=order, at_x0=True, method=self.defaults["derivatives_method"])
247
219
 
248
220
  x0 = torch.cat([p.ravel() for p in params])
249
221
 
@@ -301,7 +273,8 @@ class HigherOrderNewton(Module):
301
273
  vec_to_tensors_(x0, params)
302
274
  reduction = loss - loss_star
303
275
 
304
- rho = reduction / (max(pred_reduction, 1e-8))
276
+ rho = reduction / (max(pred_reduction, finfo.tiny * 2)) # pyright:ignore[reportArgumentType]
277
+
305
278
  # failed step
306
279
  if rho < rho_bad:
307
280
  self.global_state['trust_region'] = trust_value * nminus
@@ -320,8 +293,9 @@ class HigherOrderNewton(Module):
320
293
  assert x_star is not None
321
294
  if success:
322
295
  difference = vec_to_tensors(x0 - x_star, params)
323
- var.update = list(difference)
296
+ objective.updates = list(difference)
324
297
  else:
325
- var.update = params.zeros_like()
326
- return var
298
+ objective.updates = params.zeros_like()
299
+
300
+ return objective
327
301
 
@@ -1,11 +1,10 @@
1
- from collections.abc import Callable, Iterable
2
- from typing import Any, Literal, overload
1
+ from collections.abc import Callable
2
+ from typing import Any
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Modular, Module, apply_transform
7
- from ...utils import TensorList, as_tensorlist
8
- from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
6
+ from ...core import Chainable, Modular, Module, step, HVPMethod
7
+ from ...utils import TensorList
9
8
  from ..quasi_newton import LBFGS
10
9
 
11
10
 
@@ -19,24 +18,26 @@ class NewtonSolver(Module):
19
18
  tol:float | None=1e-3,
20
19
  reg: float = 0,
21
20
  warm_start=True,
22
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
21
+ hvp_method: HVPMethod = "autograd",
23
22
  reset_solver: bool = False,
24
23
  h: float= 1e-3,
24
+
25
25
  inner: Chainable | None = None,
26
26
  ):
27
- defaults = dict(tol=tol, h=h,reset_solver=reset_solver, maxiter=maxiter, maxiter1=maxiter1, reg=reg, warm_start=warm_start, solver=solver, hvp_method=hvp_method)
28
- super().__init__(defaults,)
27
+ defaults = locals().copy()
28
+ del defaults['self'], defaults['inner']
29
+ super().__init__(defaults)
29
30
 
30
- if inner is not None:
31
- self.set_child('inner', inner)
31
+ self.set_child("inner", inner)
32
32
 
33
33
  self._num_hvps = 0
34
34
  self._num_hvps_last_step = 0
35
35
 
36
36
  @torch.no_grad
37
- def step(self, var):
38
- params = TensorList(var.params)
39
- closure = var.closure
37
+ def apply(self, objective):
38
+
39
+ params = TensorList(objective.params)
40
+ closure = objective.closure
40
41
  if closure is None: raise RuntimeError('NewtonCG requires closure')
41
42
 
42
43
  settings = self.settings[params[0]]
@@ -44,51 +45,19 @@ class NewtonSolver(Module):
44
45
  maxiter = settings['maxiter']
45
46
  maxiter1 = settings['maxiter1']
46
47
  tol = settings['tol']
47
- reg = settings['reg']
48
48
  hvp_method = settings['hvp_method']
49
49
  warm_start = settings['warm_start']
50
50
  h = settings['h']
51
51
  reset_solver = settings['reset_solver']
52
52
 
53
53
  self._num_hvps_last_step = 0
54
- # ---------------------- Hessian vector product function --------------------- #
55
- if hvp_method == 'autograd':
56
- grad = var.get_grad(create_graph=True)
57
-
58
- def H_mm(x):
59
- self._num_hvps_last_step += 1
60
- with torch.enable_grad():
61
- Hvp = TensorList(hvp(params, grad, x, retain_graph=True))
62
- if reg != 0: Hvp = Hvp + (x*reg)
63
- return Hvp
64
-
65
- else:
66
-
67
- with torch.enable_grad():
68
- grad = var.get_grad()
69
-
70
- if hvp_method == 'forward':
71
- def H_mm(x):
72
- self._num_hvps_last_step += 1
73
- Hvp = TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
74
- if reg != 0: Hvp = Hvp + (x*reg)
75
- return Hvp
76
-
77
- elif hvp_method == 'central':
78
- def H_mm(x):
79
- self._num_hvps_last_step += 1
80
- Hvp = TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
81
- if reg != 0: Hvp = Hvp + (x*reg)
82
- return Hvp
83
-
84
- else:
85
- raise ValueError(hvp_method)
86
54
 
55
+ # ---------------------- Hessian vector product function --------------------- #
56
+ _, H_mv = objective.list_Hvp_function(hvp_method=hvp_method, h=h, at_x0=True)
87
57
 
88
58
  # -------------------------------- inner step -------------------------------- #
89
- b = as_tensorlist(grad)
90
- if 'inner' in self.children:
91
- b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
59
+ objective = self.inner_step("inner", objective, must_exist=False)
60
+ b = TensorList(objective.get_updates())
92
61
 
93
62
  # ---------------------------------- run cg ---------------------------------- #
94
63
  x0 = None
@@ -112,7 +81,7 @@ class NewtonSolver(Module):
112
81
  solver = self.global_state['solver']
113
82
 
114
83
  def lstsq_closure(backward=True):
115
- Hx = H_mm(x).detach()
84
+ Hx = H_mv(x).detach()
116
85
  # loss = (Hx-b).pow(2).global_mean()
117
86
  # if backward:
118
87
  # solver.zero_grad()
@@ -122,7 +91,7 @@ class NewtonSolver(Module):
122
91
  loss = residual.pow(2).global_mean()
123
92
  if backward:
124
93
  with torch.no_grad():
125
- H_residual = H_mm(residual)
94
+ H_residual = H_mv(residual)
126
95
  n = residual.global_numel()
127
96
  x.set_grad_((2.0 / n) * H_residual)
128
97
 
@@ -143,8 +112,8 @@ class NewtonSolver(Module):
143
112
  assert x0 is not None
144
113
  x0.copy_(x)
145
114
 
146
- var.update = x.detach()
115
+ objective.updates = x.detach()
147
116
  self._num_hvps += self._num_hvps_last_step
148
- return var
117
+ return objective
149
118
 
150
119
 
@@ -7,7 +7,8 @@ from typing import Literal
7
7
 
8
8
  import torch
9
9
 
10
- from ...core import Chainable, Module, apply_transform
10
+ from ...core import Chainable, Module, step
11
+ from ...linalg.linear_operator import Dense
11
12
  from ...utils import TensorList, vec_to_tensors
12
13
  from ...utils.derivatives import (
13
14
  flatten_jacobian,
@@ -19,7 +20,7 @@ from ..second_order.newton import (
19
20
  _least_squares_solve,
20
21
  _lu_solve,
21
22
  )
22
- from ...utils.linalg.linear_operator import Dense
23
+
23
24
 
24
25
  class NewtonNewton(Module):
25
26
  """Applies Newton-like preconditioning to Newton step.
@@ -51,9 +52,10 @@ class NewtonNewton(Module):
51
52
  super().__init__(defaults)
52
53
 
53
54
  @torch.no_grad
54
- def update(self, var):
55
- params = TensorList(var.params)
56
- closure = var.closure
55
+ def update(self, objective):
56
+
57
+ params = TensorList(objective.params)
58
+ closure = objective.closure
57
59
  if closure is None: raise RuntimeError('NewtonNewton requires closure')
58
60
 
59
61
  settings = self.settings[params[0]]
@@ -66,9 +68,9 @@ class NewtonNewton(Module):
66
68
  # ------------------------ calculate grad and hessian ------------------------ #
67
69
  Hs = []
68
70
  with torch.enable_grad():
69
- loss = var.loss = var.loss_approx = closure(False)
71
+ loss = objective.loss = objective.loss_approx = closure(False)
70
72
  g_list = torch.autograd.grad(loss, params, create_graph=True)
71
- var.grad = list(g_list)
73
+ objective.grads = list(g_list)
72
74
 
73
75
  xp = torch.cat([t.ravel() for t in g_list])
74
76
  I = torch.eye(xp.numel(), dtype=xp.dtype, device=xp.device)
@@ -93,13 +95,14 @@ class NewtonNewton(Module):
93
95
  self.global_state['xp'] = xp.nan_to_num_(0,0,0)
94
96
 
95
97
  @torch.no_grad
96
- def apply(self, var):
97
- params = var.params
98
+ def apply(self, objective):
99
+ params = objective.params
98
100
  xp = self.global_state['xp']
99
- var.update = vec_to_tensors(xp, params)
100
- return var
101
+ objective.updates = vec_to_tensors(xp, params)
102
+ return objective
101
103
 
102
- def get_H(self, var):
104
+ @torch.no_grad
105
+ def get_H(self, objective=...):
103
106
  Hs = self.global_state["Hs"]
104
107
  if len(Hs) == 1: return Dense(Hs[0])
105
108
  return Dense(torch.linalg.multi_dot(self.global_state["Hs"])) # pylint:disable=not-callable
@@ -1,28 +1,28 @@
1
1
  import torch
2
2
 
3
- from ...core import Target, Transform
3
+ from ...core import TensorTransform
4
4
  from ...utils import TensorList, unpack_states, unpack_dicts
5
5
 
6
- class ReduceOutwardLR(Transform):
6
+ class ReduceOutwardLR(TensorTransform):
7
7
  """When update sign matches weight sign, the learning rate for that weight is multiplied by `mul`.
8
8
 
9
9
  This means updates that move weights towards zero have higher learning rates.
10
10
 
11
- .. warning::
11
+ Warning:
12
12
  This sounded good but after testing turns out it sucks.
13
13
  """
14
- def __init__(self, mul = 0.5, use_grad=False, invert=False, target: Target = 'update'):
14
+ def __init__(self, mul = 0.5, use_grad=False, invert=False):
15
15
  defaults = dict(mul=mul, use_grad=use_grad, invert=invert)
16
- super().__init__(defaults, uses_grad=use_grad, target=target)
16
+ super().__init__(defaults, uses_grad=use_grad)
17
17
 
18
18
  @torch.no_grad
19
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
19
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
20
20
  params = TensorList(params)
21
21
  tensors = TensorList(tensors)
22
22
 
23
23
  mul = [s['mul'] for s in settings]
24
24
  s = settings[0]
25
- use_grad = s['use_grad']
25
+ use_grad = self._uses_grad
26
26
  invert = s['invert']
27
27
 
28
28
  if use_grad: cur = grads