torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
117
117
 
118
118
  # # directional derivative
119
119
  if c == 0: d = 0
120
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), var.get_update()))
120
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
121
121
 
122
122
  # scale init
123
123
  init_scale = self.global_state.get('init_scale', 1)
@@ -136,7 +136,7 @@ class Backtracking(LineSearchBase):
136
136
  if adaptive:
137
137
  finfo = torch.finfo(var.params[0].dtype)
138
138
  if init_scale <= finfo.tiny * 2:
139
- self.global_state["init_scale"] = finfo.max / 2
139
+ self.global_state["init_scale"] = init * 2
140
140
  else:
141
141
  self.global_state['init_scale'] = init_scale * beta**maxiter
142
142
  return 0
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
199
199
 
200
200
  # directional derivative (0 if c = 0 because it is not needed)
201
201
  if c == 0: d = 0
202
- else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
202
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
203
203
 
204
204
  # scale beta
205
205
  beta = beta * self.global_state['beta_scale']
@@ -0,0 +1,160 @@
1
+ import math
2
+ from bisect import insort
3
+
4
+ import numpy as np
5
+ from numpy.polynomial import Polynomial
6
+
7
+
8
+ # we have a list of points in ascending order of their `y` value
9
+ class Point:
10
+ __slots__ = ("x", "y", "d")
11
+ def __init__(self, x, y, d):
12
+ self.x = x
13
+ self.y = y
14
+ self.d = d
15
+
16
+ def __lt__(self, other):
17
+ return self.y < other.y
18
+
19
+ def _get_dpoint(points: list[Point]):
20
+ """returns lowest point with derivative and list of other points"""
21
+ for i,p in enumerate(points):
22
+ if p.d is not None:
23
+ cpoints = points.copy()
24
+ del cpoints[i]
25
+ return p, cpoints
26
+ return None, points
27
+
28
+ # -------------------------------- quadratic2 -------------------------------- #
29
+ def _fitmin_quadratic2(x1, y1, d1, x2, y2):
30
+
31
+ a = (y2 - y1 - d1*(x2 - x1)) / (x2 - x1)**2
32
+ if a <= 0: return None
33
+
34
+ b = d1 - 2*a*x1
35
+ # c = y_1 - d_1*x_1 + a*x_1**2
36
+
37
+ return -b / (2*a)
38
+
39
+ def quadratic2(points:list[Point]):
40
+ pd, points = _get_dpoint(points)
41
+ if pd is None: return None
42
+ if len(points) == 0: return None
43
+
44
+ pn = points[0]
45
+ return _fitmin_quadratic2(pd.x, pd.y, pd.d, pn.x, pn.y)
46
+
47
+ # -------------------------------- quadratic3 -------------------------------- #
48
+ def _fitmin_quadratic3(x1, y1, x2, y2, x3, y3):
49
+ quad = Polynomial.fit([x1,x2,x3], [y1,y2,y3], deg=2)
50
+ a,b,c = quad.coef
51
+ if a <= 0: return None
52
+ return -b / (2*a)
53
+
54
+ def quadratic3(points:list[Point]):
55
+ if len(points) < 3: return None
56
+
57
+ p1,p2,p3 = points[:3]
58
+ return _fitmin_quadratic3(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y)
59
+
60
+ # ---------------------------------- cubic3 ---------------------------------- #
61
+ def _minimize_polynomial(poly: Polynomial):
62
+ roots = poly.deriv().roots()
63
+ vals = poly(roots)
64
+ argmin = np.argmin(vals)
65
+ return roots[argmin], vals[argmin]
66
+
67
+
68
+ def _fitmin_cubic3(x1,y1,x2,y2,x3,y3,x4,d4):
69
+ """x4 is allowed to be equal to x1"""
70
+
71
+ A = np.array([
72
+ [x1**3, x1**2, x1, 1],
73
+ [x2**3, x2**2, x2, 1],
74
+ [x3**3, x3**2, x3, 1],
75
+ [3*x4**2, 2*x4, 1, 0]
76
+ ])
77
+
78
+ B = np.array([y1, y2, y3, d4])
79
+
80
+ try:
81
+ coeffs = np.linalg.solve(A, B)
82
+ except np.linalg.LinAlgError:
83
+ return None
84
+
85
+ cubic = Polynomial(coeffs)
86
+ x_min, y_min = _minimize_polynomial(cubic)
87
+ if y_min < min(y1,y2,y3): return x_min
88
+ return None
89
+
90
+ def cubic3(points: list[Point]):
91
+ pd, points = _get_dpoint(points)
92
+ if pd is None: return None
93
+ if len(points) < 2: return None
94
+ p1, p2 = points[:2]
95
+ return _fitmin_cubic3(pd.x, pd.y, p1.x, p1.y, p2.x, p2.y, pd.x, pd.d)
96
+
97
+ # ---------------------------------- cubic4 ---------------------------------- #
98
+ def _fitmin_cubic4(x1, y1, x2, y2, x3, y3, x4, y4):
99
+ cubic = Polynomial.fit([x1,x2,x3,x4], [y1,y2,y3,y4], deg=3)
100
+ x_min, y_min = _minimize_polynomial(cubic)
101
+ if y_min < min(y1,y2,y3,y4): return x_min
102
+ return None
103
+
104
+ def cubic4(points:list[Point]):
105
+ if len(points) < 4: return None
106
+
107
+ p1,p2,p3,p4 = points[:4]
108
+ return _fitmin_cubic4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
109
+
110
+ # ---------------------------------- linear3 --------------------------------- #
111
+ def _linear_intersection(x1,y1,s1,x2,y2,s2):
112
+ if s1 == 0 or s2 == 0 or s1 == s2: return None
113
+ return (y1 - s1*x1 - y2 + s2*x2) / (s2 - s1)
114
+
115
+ def _fitmin_linear3(x1, y1, d1, x2, y2, x3, y3):
116
+ # we have that
117
+ # s2 = (y2 - y3) / (x2 - x3) # slope origin in x2 y2
118
+ # f1(x) = y1 + d1 * (x - x1)
119
+ # f2(x) = y2 + s2 * (x - x2)
120
+ # y1 + d1 * (x - x1) = y2 + s2 * (x - x2)
121
+ # y1 + d1 x - d1 x1 - y2 - s2 x + s2 x2 = 0
122
+ # s2 x - d1 x = y1 - d1 x1 - y2 + s2 x2
123
+ # x = (y1 - d1 x1 - y2 + s2 x2) / (s2 - d1)
124
+
125
+ if x2 < x1 < x3 or x3 < x1 < x2: # point with derivative in between
126
+ return None
127
+
128
+ if d1 > 0:
129
+ if x2 > x1 or x3 > x1: return None # intersection is above to the right
130
+ if x2 > x3: x2,y2,x3,y3 = x3,y3,x2,y2
131
+ if d1 < 0:
132
+ if x2 < x1 or x3 < x1: return None # intersection is above to the left
133
+ if x2 < x3: x2,y2,x3,y3 = x3,y3,x2,y2
134
+
135
+ s2 = (y2 - y3) / (x2 - x3)
136
+ return _linear_intersection(x1,y1,d1,x2,y2,s2)
137
+
138
+ def linear3(points:list[Point]):
139
+ pd, points = _get_dpoint(points)
140
+ if pd is None: return None
141
+ if len(points) < 2: return None
142
+ p1, p2 = points[:2]
143
+ return _fitmin_linear3(pd.x, pd.y, pd.d, p1.x, p1.y, p2.x, p2.y)
144
+
145
+ # ---------------------------------- linear4 --------------------------------- #
146
+ def _fitmin_linear4(x1, y1, x2, y2, x3, y3, x4, y4):
147
+ # sort by x
148
+ points = ((x1,y1), (x2,y2), (x3,y3), (x4,y4))
149
+ points = sorted(points, key=lambda x: x[0])
150
+
151
+ (x1,y1), (x2,y2), (x3,y3), (x4,y4) = points
152
+ s1 = (y1 - y2) / (x1 - x2)
153
+ s3 = (y3 - y4) / (x3 - x4)
154
+
155
+ return _linear_intersection(x1,y1,s1,x3,y3,s3)
156
+
157
+ def linear4(points:list[Point]):
158
+ if len(points) < 4: return None
159
+ p1,p2,p3,p4 = points[:4]
160
+ return _fitmin_linear4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
@@ -8,8 +8,9 @@ from typing import Any, Literal
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
- from ...core import Module, Target, Var
11
+ from ...core import Module, Objective
12
12
  from ...utils import tofloat, set_storage_
13
+ from ..functional import clip_by_finfo
13
14
 
14
15
 
15
16
  class MaxLineSearchItersReached(Exception): pass
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
103
104
  ):
104
105
  if not math.isfinite(step_size): return
105
106
 
106
- # fixes overflow when backtracking keeps increasing alpha after converging
107
- step_size = max(min(tofloat(step_size), 1e36), -1e36)
107
+ # avoid overflow error
108
+ step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
108
109
 
109
110
  # skip is parameters are already at suggested step size
110
111
  if self._current_step_size == step_size: return
111
112
 
112
- # this was basically causing floating point imprecision to build up
113
- #if False:
114
- # if abs(alpha) < abs(step_size) and step_size != 0:
115
- # torch._foreach_add_(params, update, alpha=alpha)
116
-
117
- # else:
118
113
  assert self._initial_params is not None
119
114
  if step_size == 0:
120
115
  new_params = [p.clone() for p in self._initial_params]
121
116
  else:
122
117
  new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
118
+
123
119
  for c, n in zip(params, new_params):
124
120
  set_storage_(c, n)
125
121
 
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
131
127
  params: list[torch.Tensor],
132
128
  update: list[torch.Tensor],
133
129
  ):
134
- # if not np.isfinite(step_size): step_size = [0 for _ in step_size]
135
- # alpha = [self._current_step_size - s for s in step_size]
136
- # if any(a!=0 for a in alpha):
137
- # torch._foreach_add_(params, torch._foreach_mul(update, alpha))
130
+
138
131
  assert self._initial_params is not None
139
132
  if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
140
133
 
@@ -146,7 +139,7 @@ class LineSearchBase(Module, ABC):
146
139
  for c, n in zip(params, new_params):
147
140
  set_storage_(c, n)
148
141
 
149
- def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
142
+ def _loss(self, step_size: float, var: Objective, closure, params: list[torch.Tensor],
150
143
  update: list[torch.Tensor], backward:bool=False) -> float:
151
144
 
152
145
  # if step_size is 0, we might already know the loss
@@ -172,16 +165,16 @@ class LineSearchBase(Module, ABC):
172
165
  # if evaluated loss at step size 0, set it to var.loss
173
166
  if step_size == 0:
174
167
  var.loss = loss
175
- if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
168
+ if backward: var.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
176
169
 
177
170
  return tofloat(loss)
178
171
 
179
- def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
172
+ def _loss_derivative_gradient(self, step_size: float, var: Objective, closure,
180
173
  params: list[torch.Tensor], update: list[torch.Tensor]):
181
174
  # if step_size is 0, we might already know the derivative
182
- if (var.grad is not None) and (step_size == 0):
175
+ if (var.grads is not None) and (step_size == 0):
183
176
  loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
184
- derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
177
+ derivative = - sum(t.sum() for t in torch._foreach_mul(var.grads, update))
185
178
 
186
179
  else:
187
180
  # loss with a backward pass sets params.grad
@@ -191,81 +184,79 @@ class LineSearchBase(Module, ABC):
191
184
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
192
185
  else torch.zeros_like(p) for p in params], update))
193
186
 
194
- assert var.grad is not None
195
- return loss, tofloat(derivative), var.grad
187
+ assert var.grads is not None
188
+ return loss, tofloat(derivative), var.grads
196
189
 
197
- def _loss_derivative(self, step_size: float, var: Var, closure,
190
+ def _loss_derivative(self, step_size: float, var: Objective, closure,
198
191
  params: list[torch.Tensor], update: list[torch.Tensor]):
199
192
  return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
200
193
 
201
- def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
194
+ def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
202
195
  """evaluate function value at alpha `step_size`."""
203
196
  closure = var.closure
204
197
  if closure is None: raise RuntimeError('line search requires closure')
205
- return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
198
+ return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)
206
199
 
207
- def evaluate_f_d(self, step_size: float, var: Var):
200
+ def evaluate_f_d(self, step_size: float, var: Objective):
208
201
  """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
209
202
  closure = var.closure
210
203
  if closure is None: raise RuntimeError('line search requires closure')
211
- return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
204
+ return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
212
205
 
213
- def evaluate_f_d_g(self, step_size: float, var: Var):
206
+ def evaluate_f_d_g(self, step_size: float, var: Objective):
214
207
  """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
215
208
  closure = var.closure
216
209
  if closure is None: raise RuntimeError('line search requires closure')
217
- return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
210
+ return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
218
211
 
219
- def make_objective(self, var: Var, backward:bool=False):
212
+ def make_objective(self, var: Objective, backward:bool=False):
220
213
  closure = var.closure
221
214
  if closure is None: raise RuntimeError('line search requires closure')
222
- return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
215
+ return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_updates(), backward=backward)
223
216
 
224
- def make_objective_with_derivative(self, var: Var):
217
+ def make_objective_with_derivative(self, var: Objective):
225
218
  closure = var.closure
226
219
  if closure is None: raise RuntimeError('line search requires closure')
227
- return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
220
+ return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_updates())
228
221
 
229
- def make_objective_with_derivative_and_gradient(self, var: Var):
222
+ def make_objective_with_derivative_and_gradient(self, var: Objective):
230
223
  closure = var.closure
231
224
  if closure is None: raise RuntimeError('line search requires closure')
232
- return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
225
+ return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_updates())
233
226
 
234
227
  @abstractmethod
235
- def search(self, update: list[torch.Tensor], var: Var) -> float:
228
+ def search(self, update: list[torch.Tensor], var: Objective) -> float:
236
229
  """Finds the step size to use"""
237
230
 
238
231
  @torch.no_grad
239
- def step(self, var: Var) -> Var:
232
+ def apply(self, objective: Objective) -> Objective:
240
233
  self._reset()
241
234
 
242
- params = var.params
235
+ params = objective.params
243
236
  self._initial_params = [p.clone() for p in params]
244
- update = var.get_update()
237
+ update = objective.get_updates()
245
238
 
246
239
  try:
247
- step_size = self.search(update=update, var=var)
240
+ step_size = self.search(update=update, var=objective)
248
241
  except MaxLineSearchItersReached:
249
242
  step_size = self._best_step_size
250
243
 
251
- # set loss_approx
252
- if var.loss_approx is None: var.loss_approx = self._lowest_loss
244
+ step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
253
245
 
254
- # this is last module - set step size to found step_size times lr
255
- if var.is_last:
256
- if var.last_module_lrs is None:
257
- self.set_step_size_(step_size, params=params, update=update)
246
+ # set loss_approx
247
+ if objective.loss_approx is None: objective.loss_approx = self._lowest_loss
258
248
 
259
- else:
260
- self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
249
+ # if this is last module, directly update parameters to avoid redundant operations
250
+ if objective.modular is not None and self is objective.modular.modules[-1]:
251
+ self.set_step_size_(step_size, params=params, update=update)
261
252
 
262
- var.stop = True; var.skip_update = True
263
- return var
253
+ objective.stop = True; objective.skip_update = True
254
+ return objective
264
255
 
265
256
  # revert parameters and multiply update by step size
266
257
  self.set_step_size_(0, params=params, update=update)
267
- torch._foreach_mul_(var.update, step_size)
268
- return var
258
+ torch._foreach_mul_(objective.updates, step_size)
259
+ return objective
269
260
 
270
261
 
271
262
 
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
277
268
 
278
269
  @torch.no_grad
279
270
  def search(self, update, var):
280
- start,end,num=itemgetter('start','end','num')(self.defaults)
271
+ start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
281
272
 
282
273
  for lr in torch.linspace(start,end,num):
283
274
  self.evaluate_f(lr.item(), var=var, backward=False)
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import torch
8
8
  from torch.optim.lbfgs import _cubic_interpolate
9
9
 
10
- from ...utils import as_tensorlist, totensor
10
+ from ...utils import as_tensorlist, totensor, tofloat
11
11
  from ._polyinterp import polyinterp, polyinterp2
12
12
  from .line_search import LineSearchBase, TerminationCondition, termination_condition
13
13
  from ..step_size.adaptive import _bb_geom
@@ -92,7 +92,7 @@ class _StrongWolfe:
92
92
  return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
93
93
 
94
94
  if self.interpolation in ('polynomial', 'polynomial2'):
95
- finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
95
+ finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
96
96
  if bounds is None: bounds = (None, None)
97
97
  polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
98
98
  try:
@@ -284,8 +284,8 @@ class StrongWolfe(LineSearchBase):
284
284
  'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
285
285
  'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
286
286
 
287
- dir = as_tensorlist(var.get_update())
288
- grad_list = var.get_grad()
287
+ dir = as_tensorlist(var.get_updates())
288
+ grad_list = var.get_grads()
289
289
 
290
290
  g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
291
291
  f_0 = var.get_loss(False)
@@ -370,6 +370,6 @@ class StrongWolfe(LineSearchBase):
370
370
  self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
371
371
  finfo = torch.finfo(dir[0].dtype)
372
372
  if self.global_state['initial_scale'] < finfo.tiny * 2:
373
- self.global_state['initial_scale'] = finfo.max / 2
373
+ self.global_state['initial_scale'] = init_value * 2
374
374
 
375
375
  return 0
@@ -11,9 +11,9 @@ class PrintUpdate(Module):
11
11
  defaults = dict(text=text, print_fn=print_fn)
12
12
  super().__init__(defaults)
13
13
 
14
- def step(self, var):
15
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
16
- return var
14
+ def apply(self, objective):
15
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.updates}')
16
+ return objective
17
17
 
18
18
  class PrintShape(Module):
19
19
  """Prints shapes of the update."""
@@ -21,10 +21,10 @@ class PrintShape(Module):
21
21
  defaults = dict(text=text, print_fn=print_fn)
22
22
  super().__init__(defaults)
23
23
 
24
- def step(self, var):
25
- shapes = [u.shape for u in var.update] if var.update is not None else None
24
+ def apply(self, objective):
25
+ shapes = [u.shape for u in objective.updates] if objective.updates is not None else None
26
26
  self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
27
- return var
27
+ return objective
28
28
 
29
29
  class PrintParams(Module):
30
30
  """Prints current update."""
@@ -32,9 +32,9 @@ class PrintParams(Module):
32
32
  defaults = dict(text=text, print_fn=print_fn)
33
33
  super().__init__(defaults)
34
34
 
35
- def step(self, var):
36
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
37
- return var
35
+ def apply(self, objective):
36
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.params}')
37
+ return objective
38
38
 
39
39
 
40
40
  class PrintLoss(Module):
@@ -43,6 +43,6 @@ class PrintLoss(Module):
43
43
  defaults = dict(text=text, print_fn=print_fn)
44
44
  super().__init__(defaults)
45
45
 
46
- def step(self, var):
47
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
48
- return var
46
+ def apply(self, objective):
47
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.get_loss(False)}')
48
+ return objective
@@ -3,7 +3,7 @@ import math
3
3
  from typing import Literal
4
4
  import torch
5
5
 
6
- from ...core import Modular, Module, Var, Chainable
6
+ from ...core import Modular, Module, Objective, Chainable
7
7
  from ...utils import NumberList, TensorList
8
8
 
9
9
 
@@ -15,11 +15,11 @@ class EscapeAnnealing(Module):
15
15
 
16
16
 
17
17
  @torch.no_grad
18
- def step(self, var):
19
- closure = var.closure
18
+ def apply(self, objective):
19
+ closure = objective.closure
20
20
  if closure is None: raise RuntimeError("Escape requries closure")
21
21
 
22
- params = TensorList(var.params)
22
+ params = TensorList(objective.params)
23
23
  settings = self.settings[params[0]]
24
24
  max_region = self.get_settings(params, 'max_region', cls=NumberList)
25
25
  max_iter = settings['max_iter']
@@ -41,7 +41,7 @@ class EscapeAnnealing(Module):
41
41
  self.global_state['n_bad'] = n_bad
42
42
 
43
43
  # no progress
44
- f_0 = var.get_loss(False)
44
+ f_0 = objective.get_loss(False)
45
45
  if n_bad >= n_tol:
46
46
  for i in range(1, max_iter+1):
47
47
  alpha = max_region * (i / max_iter)
@@ -51,12 +51,12 @@ class EscapeAnnealing(Module):
51
51
  f_star = closure(False)
52
52
 
53
53
  if math.isfinite(f_star) and f_star < f_0-1e-12:
54
- var.update = None
55
- var.stop = True
56
- var.skip_update = True
57
- return var
54
+ objective.updates = None
55
+ objective.stop = True
56
+ objective.skip_update = True
57
+ return objective
58
58
 
59
59
  params.sub_(pert)
60
60
 
61
61
  self.global_state['n_bad'] = 0
62
- return var
62
+ return objective
@@ -3,74 +3,6 @@ import torch
3
3
  from ...core import Chainable, Module
4
4
 
5
5
 
6
- # class GradientAccumulation(Module):
7
- # """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
-
9
- # Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
- # is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
-
12
- # .. note::
13
- # Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
-
15
- # Args:
16
- # modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
- # n (int): number of gradients to accumulate.
18
- # mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
- # stop (bool, optional):
20
- # this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
-
22
- # Examples:
23
- # Adam with gradients accumulated for 16 batches.
24
-
25
- # .. code-block:: python
26
-
27
- # opt = tz.Modular(
28
- # model.parameters(),
29
- # tz.m.GradientAccumulation(
30
- # [tz.m.Adam(), tz.m.LR(1e-2)],
31
- # n=16
32
- # )
33
- # )
34
-
35
- # """
36
- # def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
- # defaults = dict(n=n, mean=mean, stop=stop)
38
- # super().__init__(defaults)
39
- # self.set_child('modules', modules)
40
-
41
-
42
- # @torch.no_grad
43
- # def step(self, var):
44
- # accumulator = self.get_state(var.params, 'accumulator')
45
- # settings = self.defaults
46
- # n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
- # step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
-
49
- # # add update to accumulator
50
- # torch._foreach_add_(accumulator, var.get_update())
51
-
52
- # # step with accumulated updates
53
- # if step % n == 0:
54
- # if mean:
55
- # torch._foreach_div_(accumulator, n)
56
-
57
- # var.update = [a.clone() for a in accumulator]
58
- # var = self.children['modules'].step(var)
59
-
60
- # # zero accumulator
61
- # torch._foreach_zero_(accumulator)
62
-
63
- # else:
64
- # # prevent update
65
- # if stop:
66
- # var.update = None
67
- # var.stop=True
68
- # var.skip_update=True
69
-
70
- # return var
71
-
72
-
73
-
74
6
 
75
7
  class GradientAccumulation(Module):
76
8
  """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
@@ -106,21 +38,21 @@ class GradientAccumulation(Module):
106
38
 
107
39
 
108
40
  @torch.no_grad
109
- def step(self, var):
110
- accumulator = self.get_state(var.params, 'accumulator')
41
+ def apply(self, objective):
42
+ accumulator = self.get_state(objective.params, 'accumulator')
111
43
  settings = self.defaults
112
44
  n = settings['n']; mean = settings['mean']; stop = settings['stop']
113
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
45
+ step = self.increment_counter("step", 0)
114
46
 
115
47
  # add update to accumulator
116
- torch._foreach_add_(accumulator, var.get_update())
48
+ torch._foreach_add_(accumulator, objective.get_updates())
117
49
 
118
50
  # step with accumulated updates
119
- if step % n == 0:
51
+ if (step + 1) % n == 0:
120
52
  if mean:
121
53
  torch._foreach_div_(accumulator, n)
122
54
 
123
- var.update = accumulator
55
+ objective.updates = accumulator
124
56
 
125
57
  # zero accumulator
126
58
  self.clear_state_keys('accumulator')
@@ -128,9 +60,9 @@ class GradientAccumulation(Module):
128
60
  else:
129
61
  # prevent update
130
62
  if stop:
131
- var.update = None
132
- var.stop=True
133
- var.skip_update=True
63
+ objective.updates = None
64
+ objective.stop=True
65
+ objective.skip_update=True
134
66
 
135
- return var
67
+ return objective
136
68