torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +2 -2
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +47 -36
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +8 -2
- torchzero/core/chain.py +47 -0
- torchzero/core/functional.py +103 -0
- torchzero/core/modular.py +233 -0
- torchzero/core/module.py +132 -643
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +56 -23
- torchzero/core/transform.py +261 -365
- torchzero/linalg/__init__.py +10 -0
- torchzero/linalg/eigh.py +34 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +95 -0
- torchzero/{utils/linalg → linalg}/qr.py +4 -2
- torchzero/{utils/linalg → linalg}/solve.py +76 -88
- torchzero/linalg/svd.py +20 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +0 -1
- torchzero/modules/adaptive/__init__.py +1 -1
- torchzero/modules/adaptive/adagrad.py +163 -213
- torchzero/modules/adaptive/adahessian.py +74 -103
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +49 -30
- torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/lion.py +5 -10
- torchzero/modules/adaptive/lmadagrad.py +87 -32
- torchzero/modules/adaptive/mars.py +5 -5
- torchzero/modules/adaptive/matrix_momentum.py +47 -51
- torchzero/modules/adaptive/msam.py +70 -52
- torchzero/modules/adaptive/muon.py +59 -124
- torchzero/modules/adaptive/natural_gradient.py +33 -28
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +123 -129
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +15 -18
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +26 -37
- torchzero/modules/experimental/__init__.py +3 -6
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
- torchzero/modules/experimental/newton_solver.py +22 -53
- torchzero/modules/experimental/newtonnewton.py +20 -17
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +5 -5
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/functional.py +8 -1
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +20 -17
- torchzero/modules/least_squares/gn.py +90 -42
- torchzero/modules/line_search/__init__.py +1 -1
- torchzero/modules/line_search/_polyinterp.py +3 -1
- torchzero/modules/line_search/adaptive.py +3 -3
- torchzero/modules/line_search/backtracking.py +3 -3
- torchzero/modules/line_search/interpolation.py +160 -0
- torchzero/modules/line_search/line_search.py +42 -51
- torchzero/modules/line_search/strong_wolfe.py +5 -5
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +10 -78
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +120 -122
- torchzero/modules/misc/multistep.py +63 -61
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +30 -28
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +34 -28
- torchzero/modules/momentum/momentum.py +11 -11
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +19 -19
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +43 -43
- torchzero/modules/quasi_newton/__init__.py +2 -0
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +7 -7
- torchzero/modules/quasi_newton/lsr1.py +7 -7
- torchzero/modules/quasi_newton/quasi_newton.py +25 -16
- torchzero/modules/quasi_newton/sg2.py +292 -0
- torchzero/modules/restarts/restars.py +26 -24
- torchzero/modules/second_order/__init__.py +6 -3
- torchzero/modules/second_order/ifn.py +58 -0
- torchzero/modules/second_order/inm.py +101 -0
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +105 -228
- torchzero/modules/second_order/newton_cg.py +102 -154
- torchzero/modules/second_order/nystrom.py +158 -178
- torchzero/modules/second_order/rsn.py +237 -0
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +11 -10
- torchzero/modules/step_size/adaptive.py +23 -23
- torchzero/modules/step_size/lr.py +15 -15
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +2 -2
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +21 -18
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +12 -13
- torchzero/modules/wrappers/optim_wrapper.py +57 -50
- torchzero/modules/zeroth_order/cd.py +9 -6
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +6 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +112 -88
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
- torchzero-0.4.0.dist-info/RECORD +191 -0
- tests/test_vars.py +0 -185
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/modules/higher_order/__init__.py +0 -1
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.14.dist-info/RECORD +0 -167
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
- {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -117,7 +117,7 @@ class Backtracking(LineSearchBase):
|
|
|
117
117
|
|
|
118
118
|
# # directional derivative
|
|
119
119
|
if c == 0: d = 0
|
|
120
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
120
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), var.get_updates()))
|
|
121
121
|
|
|
122
122
|
# scale init
|
|
123
123
|
init_scale = self.global_state.get('init_scale', 1)
|
|
@@ -136,7 +136,7 @@ class Backtracking(LineSearchBase):
|
|
|
136
136
|
if adaptive:
|
|
137
137
|
finfo = torch.finfo(var.params[0].dtype)
|
|
138
138
|
if init_scale <= finfo.tiny * 2:
|
|
139
|
-
self.global_state["init_scale"] =
|
|
139
|
+
self.global_state["init_scale"] = init * 2
|
|
140
140
|
else:
|
|
141
141
|
self.global_state['init_scale'] = init_scale * beta**maxiter
|
|
142
142
|
return 0
|
|
@@ -199,7 +199,7 @@ class AdaptiveBacktracking(LineSearchBase):
|
|
|
199
199
|
|
|
200
200
|
# directional derivative (0 if c = 0 because it is not needed)
|
|
201
201
|
if c == 0: d = 0
|
|
202
|
-
else: d = -sum(t.sum() for t in torch._foreach_mul(var.
|
|
202
|
+
else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grads(), update))
|
|
203
203
|
|
|
204
204
|
# scale beta
|
|
205
205
|
beta = beta * self.global_state['beta_scale']
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from bisect import insort
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.polynomial import Polynomial
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# we have a list of points in ascending order of their `y` value
|
|
9
|
+
class Point:
|
|
10
|
+
__slots__ = ("x", "y", "d")
|
|
11
|
+
def __init__(self, x, y, d):
|
|
12
|
+
self.x = x
|
|
13
|
+
self.y = y
|
|
14
|
+
self.d = d
|
|
15
|
+
|
|
16
|
+
def __lt__(self, other):
|
|
17
|
+
return self.y < other.y
|
|
18
|
+
|
|
19
|
+
def _get_dpoint(points: list[Point]):
|
|
20
|
+
"""returns lowest point with derivative and list of other points"""
|
|
21
|
+
for i,p in enumerate(points):
|
|
22
|
+
if p.d is not None:
|
|
23
|
+
cpoints = points.copy()
|
|
24
|
+
del cpoints[i]
|
|
25
|
+
return p, cpoints
|
|
26
|
+
return None, points
|
|
27
|
+
|
|
28
|
+
# -------------------------------- quadratic2 -------------------------------- #
|
|
29
|
+
def _fitmin_quadratic2(x1, y1, d1, x2, y2):
|
|
30
|
+
|
|
31
|
+
a = (y2 - y1 - d1*(x2 - x1)) / (x2 - x1)**2
|
|
32
|
+
if a <= 0: return None
|
|
33
|
+
|
|
34
|
+
b = d1 - 2*a*x1
|
|
35
|
+
# c = y_1 - d_1*x_1 + a*x_1**2
|
|
36
|
+
|
|
37
|
+
return -b / (2*a)
|
|
38
|
+
|
|
39
|
+
def quadratic2(points:list[Point]):
|
|
40
|
+
pd, points = _get_dpoint(points)
|
|
41
|
+
if pd is None: return None
|
|
42
|
+
if len(points) == 0: return None
|
|
43
|
+
|
|
44
|
+
pn = points[0]
|
|
45
|
+
return _fitmin_quadratic2(pd.x, pd.y, pd.d, pn.x, pn.y)
|
|
46
|
+
|
|
47
|
+
# -------------------------------- quadratic3 -------------------------------- #
|
|
48
|
+
def _fitmin_quadratic3(x1, y1, x2, y2, x3, y3):
|
|
49
|
+
quad = Polynomial.fit([x1,x2,x3], [y1,y2,y3], deg=2)
|
|
50
|
+
a,b,c = quad.coef
|
|
51
|
+
if a <= 0: return None
|
|
52
|
+
return -b / (2*a)
|
|
53
|
+
|
|
54
|
+
def quadratic3(points:list[Point]):
|
|
55
|
+
if len(points) < 3: return None
|
|
56
|
+
|
|
57
|
+
p1,p2,p3 = points[:3]
|
|
58
|
+
return _fitmin_quadratic3(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y)
|
|
59
|
+
|
|
60
|
+
# ---------------------------------- cubic3 ---------------------------------- #
|
|
61
|
+
def _minimize_polynomial(poly: Polynomial):
|
|
62
|
+
roots = poly.deriv().roots()
|
|
63
|
+
vals = poly(roots)
|
|
64
|
+
argmin = np.argmin(vals)
|
|
65
|
+
return roots[argmin], vals[argmin]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _fitmin_cubic3(x1,y1,x2,y2,x3,y3,x4,d4):
|
|
69
|
+
"""x4 is allowed to be equal to x1"""
|
|
70
|
+
|
|
71
|
+
A = np.array([
|
|
72
|
+
[x1**3, x1**2, x1, 1],
|
|
73
|
+
[x2**3, x2**2, x2, 1],
|
|
74
|
+
[x3**3, x3**2, x3, 1],
|
|
75
|
+
[3*x4**2, 2*x4, 1, 0]
|
|
76
|
+
])
|
|
77
|
+
|
|
78
|
+
B = np.array([y1, y2, y3, d4])
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
coeffs = np.linalg.solve(A, B)
|
|
82
|
+
except np.linalg.LinAlgError:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
cubic = Polynomial(coeffs)
|
|
86
|
+
x_min, y_min = _minimize_polynomial(cubic)
|
|
87
|
+
if y_min < min(y1,y2,y3): return x_min
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
def cubic3(points: list[Point]):
|
|
91
|
+
pd, points = _get_dpoint(points)
|
|
92
|
+
if pd is None: return None
|
|
93
|
+
if len(points) < 2: return None
|
|
94
|
+
p1, p2 = points[:2]
|
|
95
|
+
return _fitmin_cubic3(pd.x, pd.y, p1.x, p1.y, p2.x, p2.y, pd.x, pd.d)
|
|
96
|
+
|
|
97
|
+
# ---------------------------------- cubic4 ---------------------------------- #
|
|
98
|
+
def _fitmin_cubic4(x1, y1, x2, y2, x3, y3, x4, y4):
|
|
99
|
+
cubic = Polynomial.fit([x1,x2,x3,x4], [y1,y2,y3,y4], deg=3)
|
|
100
|
+
x_min, y_min = _minimize_polynomial(cubic)
|
|
101
|
+
if y_min < min(y1,y2,y3,y4): return x_min
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
def cubic4(points:list[Point]):
|
|
105
|
+
if len(points) < 4: return None
|
|
106
|
+
|
|
107
|
+
p1,p2,p3,p4 = points[:4]
|
|
108
|
+
return _fitmin_cubic4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
|
|
109
|
+
|
|
110
|
+
# ---------------------------------- linear3 --------------------------------- #
|
|
111
|
+
def _linear_intersection(x1,y1,s1,x2,y2,s2):
|
|
112
|
+
if s1 == 0 or s2 == 0 or s1 == s2: return None
|
|
113
|
+
return (y1 - s1*x1 - y2 + s2*x2) / (s2 - s1)
|
|
114
|
+
|
|
115
|
+
def _fitmin_linear3(x1, y1, d1, x2, y2, x3, y3):
|
|
116
|
+
# we have that
|
|
117
|
+
# s2 = (y2 - y3) / (x2 - x3) # slope origin in x2 y2
|
|
118
|
+
# f1(x) = y1 + d1 * (x - x1)
|
|
119
|
+
# f2(x) = y2 + s2 * (x - x2)
|
|
120
|
+
# y1 + d1 * (x - x1) = y2 + s2 * (x - x2)
|
|
121
|
+
# y1 + d1 x - d1 x1 - y2 - s2 x + s2 x2 = 0
|
|
122
|
+
# s2 x - d1 x = y1 - d1 x1 - y2 + s2 x2
|
|
123
|
+
# x = (y1 - d1 x1 - y2 + s2 x2) / (s2 - d1)
|
|
124
|
+
|
|
125
|
+
if x2 < x1 < x3 or x3 < x1 < x2: # point with derivative in between
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
if d1 > 0:
|
|
129
|
+
if x2 > x1 or x3 > x1: return None # intersection is above to the right
|
|
130
|
+
if x2 > x3: x2,y2,x3,y3 = x3,y3,x2,y2
|
|
131
|
+
if d1 < 0:
|
|
132
|
+
if x2 < x1 or x3 < x1: return None # intersection is above to the left
|
|
133
|
+
if x2 < x3: x2,y2,x3,y3 = x3,y3,x2,y2
|
|
134
|
+
|
|
135
|
+
s2 = (y2 - y3) / (x2 - x3)
|
|
136
|
+
return _linear_intersection(x1,y1,d1,x2,y2,s2)
|
|
137
|
+
|
|
138
|
+
def linear3(points:list[Point]):
|
|
139
|
+
pd, points = _get_dpoint(points)
|
|
140
|
+
if pd is None: return None
|
|
141
|
+
if len(points) < 2: return None
|
|
142
|
+
p1, p2 = points[:2]
|
|
143
|
+
return _fitmin_linear3(pd.x, pd.y, pd.d, p1.x, p1.y, p2.x, p2.y)
|
|
144
|
+
|
|
145
|
+
# ---------------------------------- linear4 --------------------------------- #
|
|
146
|
+
def _fitmin_linear4(x1, y1, x2, y2, x3, y3, x4, y4):
|
|
147
|
+
# sort by x
|
|
148
|
+
points = ((x1,y1), (x2,y2), (x3,y3), (x4,y4))
|
|
149
|
+
points = sorted(points, key=lambda x: x[0])
|
|
150
|
+
|
|
151
|
+
(x1,y1), (x2,y2), (x3,y3), (x4,y4) = points
|
|
152
|
+
s1 = (y1 - y2) / (x1 - x2)
|
|
153
|
+
s3 = (y3 - y4) / (x3 - x4)
|
|
154
|
+
|
|
155
|
+
return _linear_intersection(x1,y1,s1,x3,y3,s3)
|
|
156
|
+
|
|
157
|
+
def linear4(points:list[Point]):
|
|
158
|
+
if len(points) < 4: return None
|
|
159
|
+
p1,p2,p3,p4 = points[:4]
|
|
160
|
+
return _fitmin_linear4(p1.x, p1.y, p2.x, p2.y, p3.x, p3.y, p4.x, p4.y)
|
|
@@ -8,8 +8,9 @@ from typing import Any, Literal
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from ...core import Module,
|
|
11
|
+
from ...core import Module, Objective
|
|
12
12
|
from ...utils import tofloat, set_storage_
|
|
13
|
+
from ..functional import clip_by_finfo
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class MaxLineSearchItersReached(Exception): pass
|
|
@@ -103,23 +104,18 @@ class LineSearchBase(Module, ABC):
|
|
|
103
104
|
):
|
|
104
105
|
if not math.isfinite(step_size): return
|
|
105
106
|
|
|
106
|
-
#
|
|
107
|
-
step_size =
|
|
107
|
+
# avoid overflow error
|
|
108
|
+
step_size = clip_by_finfo(tofloat(step_size), torch.finfo(update[0].dtype))
|
|
108
109
|
|
|
109
110
|
# skip is parameters are already at suggested step size
|
|
110
111
|
if self._current_step_size == step_size: return
|
|
111
112
|
|
|
112
|
-
# this was basically causing floating point imprecision to build up
|
|
113
|
-
#if False:
|
|
114
|
-
# if abs(alpha) < abs(step_size) and step_size != 0:
|
|
115
|
-
# torch._foreach_add_(params, update, alpha=alpha)
|
|
116
|
-
|
|
117
|
-
# else:
|
|
118
113
|
assert self._initial_params is not None
|
|
119
114
|
if step_size == 0:
|
|
120
115
|
new_params = [p.clone() for p in self._initial_params]
|
|
121
116
|
else:
|
|
122
117
|
new_params = torch._foreach_sub(self._initial_params, update, alpha=step_size)
|
|
118
|
+
|
|
123
119
|
for c, n in zip(params, new_params):
|
|
124
120
|
set_storage_(c, n)
|
|
125
121
|
|
|
@@ -131,10 +127,7 @@ class LineSearchBase(Module, ABC):
|
|
|
131
127
|
params: list[torch.Tensor],
|
|
132
128
|
update: list[torch.Tensor],
|
|
133
129
|
):
|
|
134
|
-
|
|
135
|
-
# alpha = [self._current_step_size - s for s in step_size]
|
|
136
|
-
# if any(a!=0 for a in alpha):
|
|
137
|
-
# torch._foreach_add_(params, torch._foreach_mul(update, alpha))
|
|
130
|
+
|
|
138
131
|
assert self._initial_params is not None
|
|
139
132
|
if not np.isfinite(step_size).all(): step_size = [0 for _ in step_size]
|
|
140
133
|
|
|
@@ -146,7 +139,7 @@ class LineSearchBase(Module, ABC):
|
|
|
146
139
|
for c, n in zip(params, new_params):
|
|
147
140
|
set_storage_(c, n)
|
|
148
141
|
|
|
149
|
-
def _loss(self, step_size: float, var:
|
|
142
|
+
def _loss(self, step_size: float, var: Objective, closure, params: list[torch.Tensor],
|
|
150
143
|
update: list[torch.Tensor], backward:bool=False) -> float:
|
|
151
144
|
|
|
152
145
|
# if step_size is 0, we might already know the loss
|
|
@@ -172,16 +165,16 @@ class LineSearchBase(Module, ABC):
|
|
|
172
165
|
# if evaluated loss at step size 0, set it to var.loss
|
|
173
166
|
if step_size == 0:
|
|
174
167
|
var.loss = loss
|
|
175
|
-
if backward: var.
|
|
168
|
+
if backward: var.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
|
|
176
169
|
|
|
177
170
|
return tofloat(loss)
|
|
178
171
|
|
|
179
|
-
def _loss_derivative_gradient(self, step_size: float, var:
|
|
172
|
+
def _loss_derivative_gradient(self, step_size: float, var: Objective, closure,
|
|
180
173
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
181
174
|
# if step_size is 0, we might already know the derivative
|
|
182
|
-
if (var.
|
|
175
|
+
if (var.grads is not None) and (step_size == 0):
|
|
183
176
|
loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
|
|
184
|
-
derivative = - sum(t.sum() for t in torch._foreach_mul(var.
|
|
177
|
+
derivative = - sum(t.sum() for t in torch._foreach_mul(var.grads, update))
|
|
185
178
|
|
|
186
179
|
else:
|
|
187
180
|
# loss with a backward pass sets params.grad
|
|
@@ -191,81 +184,79 @@ class LineSearchBase(Module, ABC):
|
|
|
191
184
|
derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
|
|
192
185
|
else torch.zeros_like(p) for p in params], update))
|
|
193
186
|
|
|
194
|
-
assert var.
|
|
195
|
-
return loss, tofloat(derivative), var.
|
|
187
|
+
assert var.grads is not None
|
|
188
|
+
return loss, tofloat(derivative), var.grads
|
|
196
189
|
|
|
197
|
-
def _loss_derivative(self, step_size: float, var:
|
|
190
|
+
def _loss_derivative(self, step_size: float, var: Objective, closure,
|
|
198
191
|
params: list[torch.Tensor], update: list[torch.Tensor]):
|
|
199
192
|
return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
|
|
200
193
|
|
|
201
|
-
def evaluate_f(self, step_size: float, var:
|
|
194
|
+
def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
|
|
202
195
|
"""evaluate function value at alpha `step_size`."""
|
|
203
196
|
closure = var.closure
|
|
204
197
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
205
|
-
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
198
|
+
return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)
|
|
206
199
|
|
|
207
|
-
def evaluate_f_d(self, step_size: float, var:
|
|
200
|
+
def evaluate_f_d(self, step_size: float, var: Objective):
|
|
208
201
|
"""evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
|
|
209
202
|
closure = var.closure
|
|
210
203
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
211
|
-
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
204
|
+
return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
|
|
212
205
|
|
|
213
|
-
def evaluate_f_d_g(self, step_size: float, var:
|
|
206
|
+
def evaluate_f_d_g(self, step_size: float, var: Objective):
|
|
214
207
|
"""evaluate function value, directional derivative, and gradient list at step size `step_size`."""
|
|
215
208
|
closure = var.closure
|
|
216
209
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
217
|
-
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.
|
|
210
|
+
return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
|
|
218
211
|
|
|
219
|
-
def make_objective(self, var:
|
|
212
|
+
def make_objective(self, var: Objective, backward:bool=False):
|
|
220
213
|
closure = var.closure
|
|
221
214
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
222
|
-
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.
|
|
215
|
+
return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_updates(), backward=backward)
|
|
223
216
|
|
|
224
|
-
def make_objective_with_derivative(self, var:
|
|
217
|
+
def make_objective_with_derivative(self, var: Objective):
|
|
225
218
|
closure = var.closure
|
|
226
219
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
227
|
-
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.
|
|
220
|
+
return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_updates())
|
|
228
221
|
|
|
229
|
-
def make_objective_with_derivative_and_gradient(self, var:
|
|
222
|
+
def make_objective_with_derivative_and_gradient(self, var: Objective):
|
|
230
223
|
closure = var.closure
|
|
231
224
|
if closure is None: raise RuntimeError('line search requires closure')
|
|
232
|
-
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.
|
|
225
|
+
return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_updates())
|
|
233
226
|
|
|
234
227
|
@abstractmethod
|
|
235
|
-
def search(self, update: list[torch.Tensor], var:
|
|
228
|
+
def search(self, update: list[torch.Tensor], var: Objective) -> float:
|
|
236
229
|
"""Finds the step size to use"""
|
|
237
230
|
|
|
238
231
|
@torch.no_grad
|
|
239
|
-
def
|
|
232
|
+
def apply(self, objective: Objective) -> Objective:
|
|
240
233
|
self._reset()
|
|
241
234
|
|
|
242
|
-
params =
|
|
235
|
+
params = objective.params
|
|
243
236
|
self._initial_params = [p.clone() for p in params]
|
|
244
|
-
update =
|
|
237
|
+
update = objective.get_updates()
|
|
245
238
|
|
|
246
239
|
try:
|
|
247
|
-
step_size = self.search(update=update, var=
|
|
240
|
+
step_size = self.search(update=update, var=objective)
|
|
248
241
|
except MaxLineSearchItersReached:
|
|
249
242
|
step_size = self._best_step_size
|
|
250
243
|
|
|
251
|
-
|
|
252
|
-
if var.loss_approx is None: var.loss_approx = self._lowest_loss
|
|
244
|
+
step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
|
|
253
245
|
|
|
254
|
-
#
|
|
255
|
-
if
|
|
256
|
-
if var.last_module_lrs is None:
|
|
257
|
-
self.set_step_size_(step_size, params=params, update=update)
|
|
246
|
+
# set loss_approx
|
|
247
|
+
if objective.loss_approx is None: objective.loss_approx = self._lowest_loss
|
|
258
248
|
|
|
259
|
-
|
|
260
|
-
|
|
249
|
+
# if this is last module, directly update parameters to avoid redundant operations
|
|
250
|
+
if objective.modular is not None and self is objective.modular.modules[-1]:
|
|
251
|
+
self.set_step_size_(step_size, params=params, update=update)
|
|
261
252
|
|
|
262
|
-
|
|
263
|
-
return
|
|
253
|
+
objective.stop = True; objective.skip_update = True
|
|
254
|
+
return objective
|
|
264
255
|
|
|
265
256
|
# revert parameters and multiply update by step size
|
|
266
257
|
self.set_step_size_(0, params=params, update=update)
|
|
267
|
-
torch._foreach_mul_(
|
|
268
|
-
return
|
|
258
|
+
torch._foreach_mul_(objective.updates, step_size)
|
|
259
|
+
return objective
|
|
269
260
|
|
|
270
261
|
|
|
271
262
|
|
|
@@ -277,7 +268,7 @@ class GridLineSearch(LineSearchBase):
|
|
|
277
268
|
|
|
278
269
|
@torch.no_grad
|
|
279
270
|
def search(self, update, var):
|
|
280
|
-
start,end,num=itemgetter('start','end','num')(self.defaults)
|
|
271
|
+
start, end, num = itemgetter('start', 'end', 'num')(self.defaults)
|
|
281
272
|
|
|
282
273
|
for lr in torch.linspace(start,end,num):
|
|
283
274
|
self.evaluate_f(lr.item(), var=var, backward=False)
|
|
@@ -7,7 +7,7 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.optim.lbfgs import _cubic_interpolate
|
|
9
9
|
|
|
10
|
-
from ...utils import as_tensorlist, totensor
|
|
10
|
+
from ...utils import as_tensorlist, totensor, tofloat
|
|
11
11
|
from ._polyinterp import polyinterp, polyinterp2
|
|
12
12
|
from .line_search import LineSearchBase, TerminationCondition, termination_condition
|
|
13
13
|
from ..step_size.adaptive import _bb_geom
|
|
@@ -92,7 +92,7 @@ class _StrongWolfe:
|
|
|
92
92
|
return _apply_bounds(a_lo + 0.5 * (a_hi - a_lo), bounds)
|
|
93
93
|
|
|
94
94
|
if self.interpolation in ('polynomial', 'polynomial2'):
|
|
95
|
-
finite_history = [(a, f, g) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
95
|
+
finite_history = [(tofloat(a), tofloat(f), tofloat(g)) for a, (f,g) in self.history.items() if math.isfinite(a) and math.isfinite(f) and math.isfinite(g)]
|
|
96
96
|
if bounds is None: bounds = (None, None)
|
|
97
97
|
polyinterp_fn = polyinterp if self.interpolation == 'polynomial' else polyinterp2
|
|
98
98
|
try:
|
|
@@ -284,8 +284,8 @@ class StrongWolfe(LineSearchBase):
|
|
|
284
284
|
'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
|
|
285
285
|
'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
|
|
286
286
|
|
|
287
|
-
dir = as_tensorlist(var.
|
|
288
|
-
grad_list = var.
|
|
287
|
+
dir = as_tensorlist(var.get_updates())
|
|
288
|
+
grad_list = var.get_grads()
|
|
289
289
|
|
|
290
290
|
g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
|
|
291
291
|
f_0 = var.get_loss(False)
|
|
@@ -370,6 +370,6 @@ class StrongWolfe(LineSearchBase):
|
|
|
370
370
|
self.global_state['initial_scale'] = self.global_state.get('initial_scale', 1) * 0.5
|
|
371
371
|
finfo = torch.finfo(dir[0].dtype)
|
|
372
372
|
if self.global_state['initial_scale'] < finfo.tiny * 2:
|
|
373
|
-
self.global_state['initial_scale'] =
|
|
373
|
+
self.global_state['initial_scale'] = init_value * 2
|
|
374
374
|
|
|
375
375
|
return 0
|
torchzero/modules/misc/debug.py
CHANGED
|
@@ -11,9 +11,9 @@ class PrintUpdate(Module):
|
|
|
11
11
|
defaults = dict(text=text, print_fn=print_fn)
|
|
12
12
|
super().__init__(defaults)
|
|
13
13
|
|
|
14
|
-
def
|
|
15
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
16
|
-
return
|
|
14
|
+
def apply(self, objective):
|
|
15
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.updates}')
|
|
16
|
+
return objective
|
|
17
17
|
|
|
18
18
|
class PrintShape(Module):
|
|
19
19
|
"""Prints shapes of the update."""
|
|
@@ -21,10 +21,10 @@ class PrintShape(Module):
|
|
|
21
21
|
defaults = dict(text=text, print_fn=print_fn)
|
|
22
22
|
super().__init__(defaults)
|
|
23
23
|
|
|
24
|
-
def
|
|
25
|
-
shapes = [u.shape for u in
|
|
24
|
+
def apply(self, objective):
|
|
25
|
+
shapes = [u.shape for u in objective.updates] if objective.updates is not None else None
|
|
26
26
|
self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
|
|
27
|
-
return
|
|
27
|
+
return objective
|
|
28
28
|
|
|
29
29
|
class PrintParams(Module):
|
|
30
30
|
"""Prints current update."""
|
|
@@ -32,9 +32,9 @@ class PrintParams(Module):
|
|
|
32
32
|
defaults = dict(text=text, print_fn=print_fn)
|
|
33
33
|
super().__init__(defaults)
|
|
34
34
|
|
|
35
|
-
def
|
|
36
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
37
|
-
return
|
|
35
|
+
def apply(self, objective):
|
|
36
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.params}')
|
|
37
|
+
return objective
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
class PrintLoss(Module):
|
|
@@ -43,6 +43,6 @@ class PrintLoss(Module):
|
|
|
43
43
|
defaults = dict(text=text, print_fn=print_fn)
|
|
44
44
|
super().__init__(defaults)
|
|
45
45
|
|
|
46
|
-
def
|
|
47
|
-
self.defaults["print_fn"](f'{self.defaults["text"]}{
|
|
48
|
-
return
|
|
46
|
+
def apply(self, objective):
|
|
47
|
+
self.defaults["print_fn"](f'{self.defaults["text"]}{objective.get_loss(False)}')
|
|
48
|
+
return objective
|
torchzero/modules/misc/escape.py
CHANGED
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
from typing import Literal
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
-
from ...core import Modular, Module,
|
|
6
|
+
from ...core import Modular, Module, Objective, Chainable
|
|
7
7
|
from ...utils import NumberList, TensorList
|
|
8
8
|
|
|
9
9
|
|
|
@@ -15,11 +15,11 @@ class EscapeAnnealing(Module):
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
@torch.no_grad
|
|
18
|
-
def
|
|
19
|
-
closure =
|
|
18
|
+
def apply(self, objective):
|
|
19
|
+
closure = objective.closure
|
|
20
20
|
if closure is None: raise RuntimeError("Escape requries closure")
|
|
21
21
|
|
|
22
|
-
params = TensorList(
|
|
22
|
+
params = TensorList(objective.params)
|
|
23
23
|
settings = self.settings[params[0]]
|
|
24
24
|
max_region = self.get_settings(params, 'max_region', cls=NumberList)
|
|
25
25
|
max_iter = settings['max_iter']
|
|
@@ -41,7 +41,7 @@ class EscapeAnnealing(Module):
|
|
|
41
41
|
self.global_state['n_bad'] = n_bad
|
|
42
42
|
|
|
43
43
|
# no progress
|
|
44
|
-
f_0 =
|
|
44
|
+
f_0 = objective.get_loss(False)
|
|
45
45
|
if n_bad >= n_tol:
|
|
46
46
|
for i in range(1, max_iter+1):
|
|
47
47
|
alpha = max_region * (i / max_iter)
|
|
@@ -51,12 +51,12 @@ class EscapeAnnealing(Module):
|
|
|
51
51
|
f_star = closure(False)
|
|
52
52
|
|
|
53
53
|
if math.isfinite(f_star) and f_star < f_0-1e-12:
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
return
|
|
54
|
+
objective.updates = None
|
|
55
|
+
objective.stop = True
|
|
56
|
+
objective.skip_update = True
|
|
57
|
+
return objective
|
|
58
58
|
|
|
59
59
|
params.sub_(pert)
|
|
60
60
|
|
|
61
61
|
self.global_state['n_bad'] = 0
|
|
62
|
-
return
|
|
62
|
+
return objective
|
|
@@ -3,74 +3,6 @@ import torch
|
|
|
3
3
|
from ...core import Chainable, Module
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
# class GradientAccumulation(Module):
|
|
7
|
-
# """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
8
|
-
|
|
9
|
-
# Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
|
|
10
|
-
# is more computationally efficient, but sometimes it is not feasible due to memory constraints.
|
|
11
|
-
|
|
12
|
-
# .. note::
|
|
13
|
-
# Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
|
|
14
|
-
|
|
15
|
-
# Args:
|
|
16
|
-
# modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
|
|
17
|
-
# n (int): number of gradients to accumulate.
|
|
18
|
-
# mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
|
|
19
|
-
# stop (bool, optional):
|
|
20
|
-
# this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
|
|
21
|
-
|
|
22
|
-
# Examples:
|
|
23
|
-
# Adam with gradients accumulated for 16 batches.
|
|
24
|
-
|
|
25
|
-
# .. code-block:: python
|
|
26
|
-
|
|
27
|
-
# opt = tz.Modular(
|
|
28
|
-
# model.parameters(),
|
|
29
|
-
# tz.m.GradientAccumulation(
|
|
30
|
-
# [tz.m.Adam(), tz.m.LR(1e-2)],
|
|
31
|
-
# n=16
|
|
32
|
-
# )
|
|
33
|
-
# )
|
|
34
|
-
|
|
35
|
-
# """
|
|
36
|
-
# def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
|
|
37
|
-
# defaults = dict(n=n, mean=mean, stop=stop)
|
|
38
|
-
# super().__init__(defaults)
|
|
39
|
-
# self.set_child('modules', modules)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# @torch.no_grad
|
|
43
|
-
# def step(self, var):
|
|
44
|
-
# accumulator = self.get_state(var.params, 'accumulator')
|
|
45
|
-
# settings = self.defaults
|
|
46
|
-
# n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
47
|
-
# step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
48
|
-
|
|
49
|
-
# # add update to accumulator
|
|
50
|
-
# torch._foreach_add_(accumulator, var.get_update())
|
|
51
|
-
|
|
52
|
-
# # step with accumulated updates
|
|
53
|
-
# if step % n == 0:
|
|
54
|
-
# if mean:
|
|
55
|
-
# torch._foreach_div_(accumulator, n)
|
|
56
|
-
|
|
57
|
-
# var.update = [a.clone() for a in accumulator]
|
|
58
|
-
# var = self.children['modules'].step(var)
|
|
59
|
-
|
|
60
|
-
# # zero accumulator
|
|
61
|
-
# torch._foreach_zero_(accumulator)
|
|
62
|
-
|
|
63
|
-
# else:
|
|
64
|
-
# # prevent update
|
|
65
|
-
# if stop:
|
|
66
|
-
# var.update = None
|
|
67
|
-
# var.stop=True
|
|
68
|
-
# var.skip_update=True
|
|
69
|
-
|
|
70
|
-
# return var
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
6
|
|
|
75
7
|
class GradientAccumulation(Module):
|
|
76
8
|
"""Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
|
|
@@ -106,21 +38,21 @@ class GradientAccumulation(Module):
|
|
|
106
38
|
|
|
107
39
|
|
|
108
40
|
@torch.no_grad
|
|
109
|
-
def
|
|
110
|
-
accumulator = self.get_state(
|
|
41
|
+
def apply(self, objective):
|
|
42
|
+
accumulator = self.get_state(objective.params, 'accumulator')
|
|
111
43
|
settings = self.defaults
|
|
112
44
|
n = settings['n']; mean = settings['mean']; stop = settings['stop']
|
|
113
|
-
step = self.
|
|
45
|
+
step = self.increment_counter("step", 0)
|
|
114
46
|
|
|
115
47
|
# add update to accumulator
|
|
116
|
-
torch._foreach_add_(accumulator,
|
|
48
|
+
torch._foreach_add_(accumulator, objective.get_updates())
|
|
117
49
|
|
|
118
50
|
# step with accumulated updates
|
|
119
|
-
if step % n == 0:
|
|
51
|
+
if (step + 1) % n == 0:
|
|
120
52
|
if mean:
|
|
121
53
|
torch._foreach_div_(accumulator, n)
|
|
122
54
|
|
|
123
|
-
|
|
55
|
+
objective.updates = accumulator
|
|
124
56
|
|
|
125
57
|
# zero accumulator
|
|
126
58
|
self.clear_state_keys('accumulator')
|
|
@@ -128,9 +60,9 @@ class GradientAccumulation(Module):
|
|
|
128
60
|
else:
|
|
129
61
|
# prevent update
|
|
130
62
|
if stop:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
63
|
+
objective.updates = None
|
|
64
|
+
objective.stop=True
|
|
65
|
+
objective.skip_update=True
|
|
134
66
|
|
|
135
|
-
return
|
|
67
|
+
return objective
|
|
136
68
|
|