torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_opts.py +199 -198
- torchzero/__init__.py +1 -1
- torchzero/core/__init__.py +1 -1
- torchzero/core/functional.py +1 -1
- torchzero/core/modular.py +5 -5
- torchzero/core/module.py +2 -2
- torchzero/core/objective.py +10 -10
- torchzero/core/transform.py +1 -1
- torchzero/linalg/__init__.py +3 -2
- torchzero/linalg/eigh.py +223 -4
- torchzero/linalg/orthogonalize.py +2 -4
- torchzero/linalg/qr.py +12 -0
- torchzero/linalg/solve.py +1 -3
- torchzero/linalg/svd.py +47 -20
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +10 -10
- torchzero/modules/adaptive/adahessian.py +2 -2
- torchzero/modules/adaptive/adam.py +1 -1
- torchzero/modules/adaptive/adan.py +1 -1
- torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
- torchzero/modules/adaptive/esgd.py +2 -2
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +2 -1
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +2 -2
- torchzero/modules/adaptive/matrix_momentum.py +1 -1
- torchzero/modules/adaptive/msam.py +4 -4
- torchzero/modules/adaptive/muon.py +9 -6
- torchzero/modules/adaptive/natural_gradient.py +32 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rprop.py +2 -2
- torchzero/modules/adaptive/sam.py +4 -4
- torchzero/modules/adaptive/shampoo.py +28 -3
- torchzero/modules/adaptive/soap.py +3 -3
- torchzero/modules/adaptive/sophia_h.py +2 -2
- torchzero/modules/clipping/clipping.py +7 -7
- torchzero/modules/conjugate_gradient/cg.py +2 -2
- torchzero/modules/experimental/__init__.py +5 -0
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +2 -2
- torchzero/modules/experimental/newtonnewton.py +34 -40
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/rfdm.py +4 -4
- torchzero/modules/least_squares/gn.py +68 -45
- torchzero/modules/line_search/backtracking.py +2 -2
- torchzero/modules/line_search/line_search.py +1 -1
- torchzero/modules/line_search/strong_wolfe.py +2 -2
- torchzero/modules/misc/escape.py +1 -1
- torchzero/modules/misc/gradient_accumulation.py +1 -1
- torchzero/modules/misc/misc.py +1 -1
- torchzero/modules/misc/multistep.py +4 -7
- torchzero/modules/misc/regularization.py +2 -2
- torchzero/modules/misc/split.py +1 -1
- torchzero/modules/misc/switch.py +2 -2
- torchzero/modules/momentum/cautious.py +3 -3
- torchzero/modules/momentum/momentum.py +1 -1
- torchzero/modules/ops/higher_level.py +1 -1
- torchzero/modules/ops/multi.py +1 -1
- torchzero/modules/projections/projection.py +5 -2
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +1 -1
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +3 -3
- torchzero/modules/quasi_newton/lsr1.py +3 -3
- torchzero/modules/quasi_newton/quasi_newton.py +44 -29
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +17 -17
- torchzero/modules/second_order/inm.py +33 -25
- torchzero/modules/second_order/newton.py +132 -130
- torchzero/modules/second_order/newton_cg.py +3 -3
- torchzero/modules/second_order/nystrom.py +83 -32
- torchzero/modules/second_order/rsn.py +41 -44
- torchzero/modules/smoothing/laplacian.py +1 -1
- torchzero/modules/smoothing/sampling.py +2 -3
- torchzero/modules/step_size/adaptive.py +6 -6
- torchzero/modules/step_size/lr.py +2 -2
- torchzero/modules/trust_region/cubic_regularization.py +1 -1
- torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
- torchzero/modules/trust_region/trust_cg.py +1 -1
- torchzero/modules/variance_reduction/svrg.py +4 -5
- torchzero/modules/weight_decay/reinit.py +2 -2
- torchzero/modules/weight_decay/weight_decay.py +5 -5
- torchzero/modules/wrappers/optim_wrapper.py +4 -4
- torchzero/modules/zeroth_order/cd.py +1 -1
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/wrappers/nevergrad.py +0 -9
- torchzero/optim/wrappers/optuna.py +2 -0
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/derivatives.py +4 -4
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- torchzero/modules/adaptive/lmadagrad.py +0 -241
- torchzero-0.4.0.dist-info/RECORD +0 -191
- /torchzero/modules/{functional.py → opt_utils.py} +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/optim/mbs.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
import math
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from decimal import ROUND_HALF_UP, Decimal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def format_number(number, n):
|
|
10
|
+
"""Rounds to n significant digits after the decimal point."""
|
|
11
|
+
if number == 0: return 0
|
|
12
|
+
if math.isnan(number) or math.isinf(number) or (not math.isfinite(number)): return number
|
|
13
|
+
if n <= 0: raise ValueError("n must be positive")
|
|
14
|
+
|
|
15
|
+
dec = Decimal(str(number))
|
|
16
|
+
if dec.is_zero(): return 0
|
|
17
|
+
if number > 10**n or dec % 1 == 0: return int(dec)
|
|
18
|
+
|
|
19
|
+
if abs(dec) >= 1:
|
|
20
|
+
places = n
|
|
21
|
+
else:
|
|
22
|
+
frac_str = format(abs(dec), 'f').split('.')[1]
|
|
23
|
+
leading_zeros = len(frac_str) - len(frac_str.lstrip('0'))
|
|
24
|
+
places = leading_zeros + n
|
|
25
|
+
|
|
26
|
+
quantizer = Decimal('1e-' + str(places))
|
|
27
|
+
rounded_dec = dec.quantize(quantizer, rounding=ROUND_HALF_UP)
|
|
28
|
+
|
|
29
|
+
if rounded_dec % 1 == 0: return int(rounded_dec)
|
|
30
|
+
return float(rounded_dec)
|
|
31
|
+
|
|
32
|
+
def _nonfinite_to_inf(x):
|
|
33
|
+
if not math.isfinite(x): return math.inf
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
def _tofloatlist(x) -> list[float]:
|
|
37
|
+
if isinstance(x, (int,float)): return [x]
|
|
38
|
+
if isinstance(x, np.ndarray) and x.size == 1: return [float(x.item())]
|
|
39
|
+
return [float(i) for i in x]
|
|
40
|
+
|
|
41
|
+
class Trial(NamedTuple):
|
|
42
|
+
x: float
|
|
43
|
+
f: tuple[float, ...]
|
|
44
|
+
|
|
45
|
+
class Solution(NamedTuple):
|
|
46
|
+
x: float
|
|
47
|
+
f: tuple[float, ...]
|
|
48
|
+
trials: list[Trial]
|
|
49
|
+
|
|
50
|
+
class MBS:
|
|
51
|
+
"""Univariate minimization via grid search followed by refining, supports multi-objective functions.
|
|
52
|
+
|
|
53
|
+
This tends to outperform bayesian optimization for learning rate tuning, it is also good for plotting.
|
|
54
|
+
|
|
55
|
+
First it evaluates all points defined in ``grid``. The grid doesn't have to be dense and the solution doesn't
|
|
56
|
+
have to be between the endpoints.
|
|
57
|
+
|
|
58
|
+
Then it picks ``num_candidates`` best points per each objective. If any of those points are endpoints,
|
|
59
|
+
it expands the search space by ``step`` in that direction and evaluates the new endpoint.
|
|
60
|
+
|
|
61
|
+
Otherwise it keeps picking points between best points and evaluating them, until ``num_binary`` evaluations
|
|
62
|
+
have been performed.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
grid (Iterable[float], optional): values for initial grid search. If ``log_scale=True``, should be in log10 scale.
|
|
66
|
+
step (float, optional): expansion step size. Defaults to 1.
|
|
67
|
+
num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
|
|
68
|
+
num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
|
|
69
|
+
num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
|
|
70
|
+
rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
|
|
71
|
+
lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
|
|
72
|
+
ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
|
|
73
|
+
log_scale (bool, optional):
|
|
74
|
+
whether to minimize in log10 scale. If true, it is assumed that
|
|
75
|
+
``grid``, ``lb`` and ``ub`` are given in log10 scale.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
|
|
79
|
+
```python
|
|
80
|
+
def objective(x: float):
|
|
81
|
+
x = x * 4
|
|
82
|
+
return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
|
|
83
|
+
|
|
84
|
+
mbs = MBS(grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
|
|
85
|
+
|
|
86
|
+
x, f, trials = mbs.run(objective)
|
|
87
|
+
# x - solution
|
|
88
|
+
# f - value at solution x
|
|
89
|
+
# trials - list of trials, each trial is a named tuple: Trial(x, f)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
grid: Iterable[float],
|
|
95
|
+
step: float,
|
|
96
|
+
num_candidates: int = 3,
|
|
97
|
+
num_binary: int = 20,
|
|
98
|
+
num_expansions: int = 20,
|
|
99
|
+
rounding: int| None = 2,
|
|
100
|
+
lb = None,
|
|
101
|
+
ub = None,
|
|
102
|
+
log_scale: bool = False,
|
|
103
|
+
):
|
|
104
|
+
self.objectives: dict[int, dict[float,float]] = {}
|
|
105
|
+
"""dictionary of objectives, each maps point (x) to value (v)"""
|
|
106
|
+
|
|
107
|
+
self.evaluated: set[float] = set()
|
|
108
|
+
"""set of evaluated points (x)"""
|
|
109
|
+
|
|
110
|
+
grid = tuple(grid)
|
|
111
|
+
if len(grid) == 0: raise ValueError("At least one grid search point must be specified")
|
|
112
|
+
self.grid = sorted(grid)
|
|
113
|
+
|
|
114
|
+
self.step = step
|
|
115
|
+
self.num_candidates = num_candidates
|
|
116
|
+
self.num_binary = num_binary
|
|
117
|
+
self.num_expansions = num_expansions
|
|
118
|
+
self.rounding = rounding
|
|
119
|
+
self.log_scale = log_scale
|
|
120
|
+
self.lb = lb
|
|
121
|
+
self.ub = ub
|
|
122
|
+
|
|
123
|
+
def _get_best_x(self, n: int, objective: int):
|
|
124
|
+
"""n best points"""
|
|
125
|
+
obj = self.objectives[objective]
|
|
126
|
+
v_to_x = [(v,x) for x,v in obj.items()]
|
|
127
|
+
v_to_x.sort(key = lambda vx: vx[0])
|
|
128
|
+
xs = [x for v,x in v_to_x]
|
|
129
|
+
return xs[:n]
|
|
130
|
+
|
|
131
|
+
def _suggest_points_around(self, x: float, objective: int):
|
|
132
|
+
"""suggests points around x"""
|
|
133
|
+
points = list(self.objectives[objective].keys())
|
|
134
|
+
points.sort()
|
|
135
|
+
if x not in points: raise RuntimeError(f"{x} not in {points}")
|
|
136
|
+
|
|
137
|
+
expansions = []
|
|
138
|
+
if x == points[0]:
|
|
139
|
+
expansions.append((x-self.step, 'expansion'))
|
|
140
|
+
|
|
141
|
+
if x == points[-1]:
|
|
142
|
+
expansions.append((x+self.step, 'expansion'))
|
|
143
|
+
|
|
144
|
+
if len(expansions) != 0: return expansions
|
|
145
|
+
|
|
146
|
+
idx = points.index(x)
|
|
147
|
+
xm = points[idx-1]
|
|
148
|
+
xp = points[idx+1]
|
|
149
|
+
|
|
150
|
+
x1 = (x - (x - xm)/2)
|
|
151
|
+
x2 = (x + (xp - x)/2)
|
|
152
|
+
|
|
153
|
+
return [(x1, 'binary'), (x2, 'binary')]
|
|
154
|
+
|
|
155
|
+
def _out_of_bounds(self, x):
|
|
156
|
+
if self.lb is not None and x < self.lb: return True
|
|
157
|
+
if self.ub is not None and x > self.ub: return True
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
def _evaluate(self, fn, x):
|
|
161
|
+
"""Evaluate a point, returns False if point is already in history"""
|
|
162
|
+
if self.rounding is not None: x = format_number(x, self.rounding)
|
|
163
|
+
if x in self.evaluated: return False
|
|
164
|
+
if self._out_of_bounds(x): return False
|
|
165
|
+
|
|
166
|
+
self.evaluated.add(x)
|
|
167
|
+
|
|
168
|
+
if self.log_scale: vals = _tofloatlist(fn(10 ** x))
|
|
169
|
+
else: vals = _tofloatlist(fn(x))
|
|
170
|
+
vals = [_nonfinite_to_inf(v) for v in vals]
|
|
171
|
+
|
|
172
|
+
for idx, v in enumerate(vals):
|
|
173
|
+
if idx not in self.objectives: self.objectives[idx] = {}
|
|
174
|
+
self.objectives[idx][x] = v
|
|
175
|
+
|
|
176
|
+
return True
|
|
177
|
+
|
|
178
|
+
def run(self, fn) -> Solution:
|
|
179
|
+
# step 1 - gr id search
|
|
180
|
+
for x in self.grid:
|
|
181
|
+
self._evaluate(fn, x)
|
|
182
|
+
|
|
183
|
+
# step 2 - binary search
|
|
184
|
+
while True:
|
|
185
|
+
if (self.num_candidates <= 0) or (self.num_expansions <= 0 and self.num_binary <= 0): break
|
|
186
|
+
|
|
187
|
+
# suggest candidates
|
|
188
|
+
candidates: list[tuple[float, str]] = []
|
|
189
|
+
|
|
190
|
+
# sample around best points
|
|
191
|
+
for objective in self.objectives:
|
|
192
|
+
best_points = self._get_best_x(self.num_candidates, objective)
|
|
193
|
+
for p in best_points:
|
|
194
|
+
candidates.extend(self._suggest_points_around(p, objective=objective))
|
|
195
|
+
|
|
196
|
+
# filter
|
|
197
|
+
if self.num_expansions <= 0:
|
|
198
|
+
candidates = [(x,t) for x,t in candidates if t != 'expansion']
|
|
199
|
+
|
|
200
|
+
if self.num_candidates <= 0:
|
|
201
|
+
candidates = [(x,t) for x,t in candidates if t != 'binary']
|
|
202
|
+
|
|
203
|
+
# if expansion was suggested, discard anything else
|
|
204
|
+
types = [t for x, t in candidates]
|
|
205
|
+
if any(t == 'expansion' for t in types):
|
|
206
|
+
candidates = [(x,t) for x,t in candidates if t == 'expansion']
|
|
207
|
+
|
|
208
|
+
# evaluate candidates
|
|
209
|
+
terminate = False
|
|
210
|
+
at_least_one_evaluated = False
|
|
211
|
+
for x, t in candidates:
|
|
212
|
+
evaluated = self._evaluate(fn, x)
|
|
213
|
+
if not evaluated: continue
|
|
214
|
+
at_least_one_evaluated = True
|
|
215
|
+
|
|
216
|
+
if t == 'expansion': self.num_expansions -= 1
|
|
217
|
+
elif t == 'binary': self.num_binary -= 1
|
|
218
|
+
|
|
219
|
+
if self.num_binary < 0:
|
|
220
|
+
terminate = True
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
if terminate: break
|
|
224
|
+
if not at_least_one_evaluated:
|
|
225
|
+
if self.rounding is None: break
|
|
226
|
+
self.rounding += 1
|
|
227
|
+
if self.rounding == 100: break
|
|
228
|
+
|
|
229
|
+
# create dict[float, tuple[float,...]]
|
|
230
|
+
ret = {}
|
|
231
|
+
for i, objective in enumerate(self.objectives.values()):
|
|
232
|
+
for x, v in objective.items():
|
|
233
|
+
if self.log_scale: x = 10 ** x
|
|
234
|
+
if x not in ret: ret[x] = [None for _ in self.objectives]
|
|
235
|
+
ret[x][i] = v
|
|
236
|
+
|
|
237
|
+
for v in ret.values():
|
|
238
|
+
assert len(v) == len(self.objectives), v
|
|
239
|
+
assert all(i is not None for i in v), v
|
|
240
|
+
|
|
241
|
+
# ret maps x to list of per-objective values, e.g. {1: [0.1, 0.3], ...}
|
|
242
|
+
# now make a list of trials as they are easier to work with
|
|
243
|
+
trials: list[Trial] = []
|
|
244
|
+
for x, values in ret.items():
|
|
245
|
+
trials.append(Trial(x=x, f=values))
|
|
246
|
+
|
|
247
|
+
# sort trials by sum of values
|
|
248
|
+
trials.sort(key = lambda trial: sum(trial.f))
|
|
249
|
+
return Solution(x=trials[0].x, f=trials[0].f, trials=trials)
|
|
250
|
+
|
|
251
|
+
def mbs_minimize(
|
|
252
|
+
fn,
|
|
253
|
+
grid: Iterable[float],
|
|
254
|
+
step: float,
|
|
255
|
+
num_candidates: int = 3,
|
|
256
|
+
num_binary: int = 20,
|
|
257
|
+
num_expansions: int = 20,
|
|
258
|
+
rounding=2,
|
|
259
|
+
lb:float | None = None,
|
|
260
|
+
ub:float | None = None,
|
|
261
|
+
log_scale=False,
|
|
262
|
+
) -> Solution:
|
|
263
|
+
"""minimize univariate function via MBS.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
fn (function): objective function that accepts a float and returns a float or a sequence of floats to minimize.
|
|
267
|
+
step (float, optional): expansion step size. Defaults to 1.
|
|
268
|
+
num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
|
|
269
|
+
num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
|
|
270
|
+
num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
|
|
271
|
+
rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
|
|
272
|
+
lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
|
|
273
|
+
ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
|
|
274
|
+
log_scale (bool, optional):
|
|
275
|
+
whether to minimize in log10 scale. If true, it is assumed that
|
|
276
|
+
``grid``, ``lb`` and ``ub`` are given in log10 scale.
|
|
277
|
+
|
|
278
|
+
Example:
|
|
279
|
+
|
|
280
|
+
```python
|
|
281
|
+
def objective(x: float):
|
|
282
|
+
x = x * 4
|
|
283
|
+
return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
|
|
284
|
+
|
|
285
|
+
x, f, trials = mbs_minimize(objective, grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
|
|
286
|
+
# x - solution
|
|
287
|
+
# f - value at solution x
|
|
288
|
+
# trials - list of trials, each trial is a named tuple: Trial(x, f)
|
|
289
|
+
"""
|
|
290
|
+
mbs = MBS(grid, step=step, num_candidates=num_candidates, num_binary=num_binary, num_expansions=num_expansions, rounding=rounding, lb=lb, ub=ub, log_scale=log_scale)
|
|
291
|
+
return mbs.run(fn)
|
|
@@ -55,15 +55,6 @@ class NevergradWrapper(WrapperBase):
|
|
|
55
55
|
mutable_sigma = False,
|
|
56
56
|
use_init = True,
|
|
57
57
|
):
|
|
58
|
-
"""_summary_
|
|
59
|
-
|
|
60
|
-
Args:
|
|
61
|
-
params (_type_): _description_
|
|
62
|
-
opt_cls (type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]): _description_
|
|
63
|
-
budget (int | None, optional): _description_. Defaults to None.
|
|
64
|
-
mutable_sigma (bool, optional): _description_. Defaults to False.
|
|
65
|
-
use_init (bool, optional): _description_. Defaults to True.
|
|
66
|
-
"""
|
|
67
58
|
defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
|
|
68
59
|
super().__init__(params, defaults)
|
|
69
60
|
self.opt_cls = opt_cls
|
|
@@ -45,6 +45,7 @@ class OptunaSampler(WrapperBase):
|
|
|
45
45
|
self.study = optuna.create_study(sampler=self.sampler)
|
|
46
46
|
|
|
47
47
|
# some optuna samplers use torch
|
|
48
|
+
# and require torch.enable_grad
|
|
48
49
|
with torch.enable_grad():
|
|
49
50
|
trial = self.study.ask()
|
|
50
51
|
|
|
@@ -58,6 +59,7 @@ class OptunaSampler(WrapperBase):
|
|
|
58
59
|
params.from_vec_(vec)
|
|
59
60
|
|
|
60
61
|
loss = closure()
|
|
62
|
+
|
|
61
63
|
with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
|
|
62
64
|
|
|
63
65
|
return loss
|
|
File without changes
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Any, cast
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_correlated_logistic_data(n_samples=2000, n_features=32, n_correlated_pairs=512, correlation=0.99, seed=0):
|
|
10
|
+
"""Hard logistic regression dataset with correlated features"""
|
|
11
|
+
generator = np.random.default_rng(seed)
|
|
12
|
+
|
|
13
|
+
# ------------------------------------- X ------------------------------------ #
|
|
14
|
+
X = generator.standard_normal(size=(n_samples, n_features))
|
|
15
|
+
weights = generator.uniform(-2, 2, n_features)
|
|
16
|
+
|
|
17
|
+
used_pairs = []
|
|
18
|
+
for i in range(n_correlated_pairs):
|
|
19
|
+
idxs = None
|
|
20
|
+
while idxs is None or idxs in used_pairs:
|
|
21
|
+
idxs = tuple(generator.choice(n_features, size=2, replace=False).tolist())
|
|
22
|
+
|
|
23
|
+
used_pairs.append(idxs)
|
|
24
|
+
idx1, idx2 = idxs
|
|
25
|
+
|
|
26
|
+
noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
|
|
27
|
+
X[:, idx2] = correlation * X[:, idx1] + noise
|
|
28
|
+
|
|
29
|
+
w = generator.integers(1, 51)
|
|
30
|
+
weights[idx1] = w
|
|
31
|
+
weights[idx2] = -w
|
|
32
|
+
|
|
33
|
+
# ---------------------------------- logits ---------------------------------- #
|
|
34
|
+
logits = X @ weights
|
|
35
|
+
probabilities = 1 / (1 + np.exp(-logits))
|
|
36
|
+
y = generator.binomial(1, probabilities).astype(np.float32)
|
|
37
|
+
|
|
38
|
+
X = X - X.mean(0, keepdims=True)
|
|
39
|
+
X = X / X.std(0, keepdims=True)
|
|
40
|
+
return X, y
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# if __name__ == '__main__':
|
|
44
|
+
# X, y = generate_correlated_logistic_data()
|
|
45
|
+
|
|
46
|
+
# plt.figure(figsize=(10, 8))
|
|
47
|
+
# sns.heatmap(pl.DataFrame(X).corr(), annot=True, cmap='coolwarm', fmt=".2f")
|
|
48
|
+
# plt.show()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _tensorlist_equal(t1, t2):
|
|
54
|
+
return all(a == b for a, b in zip(t1, t2))
|
|
55
|
+
|
|
56
|
+
_placeholder = cast(Any, ...)
|
|
57
|
+
|
|
58
|
+
def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps: int, tol:float=0, l1:float=0, l2:float=0, pbar:bool=False, *, _assert_on_evaluated_same_params: bool = False):
|
|
59
|
+
# ------------------------------- verify inputs ------------------------------ #
|
|
60
|
+
n_samples, n_features = X.size()
|
|
61
|
+
|
|
62
|
+
if y.ndim != 1: raise ValueError(f"y should be 1d, got {y.shape}")
|
|
63
|
+
if y.size(0) != n_samples: raise ValueError(f"y should have {n_samples} elements, got {y.shape}")
|
|
64
|
+
if y.device != X.device: raise ValueError(f"X and y should be on same device, got {X.device = }, {y.device = }")
|
|
65
|
+
device = X.device
|
|
66
|
+
dtype = X.dtype
|
|
67
|
+
|
|
68
|
+
# ---------------------------- model and criterion --------------------------- #
|
|
69
|
+
n_targets = int(y.amax()) + 1
|
|
70
|
+
binary = n_targets == 2
|
|
71
|
+
|
|
72
|
+
if binary:
|
|
73
|
+
criterion = torch.nn.functional.binary_cross_entropy_with_logits
|
|
74
|
+
model = torch.nn.Linear(n_features, 1).to(device=device, dtype=dtype)
|
|
75
|
+
y = y.to(dtype=dtype)
|
|
76
|
+
else:
|
|
77
|
+
model = torch.nn.Linear(n_features, n_targets).to(device=device, dtype=dtype)
|
|
78
|
+
criterion = torch.nn.functional.cross_entropy
|
|
79
|
+
y = y.long()
|
|
80
|
+
|
|
81
|
+
optimizer = opt_fn(list(model.parameters()))
|
|
82
|
+
|
|
83
|
+
# ---------------------------------- closure --------------------------------- #
|
|
84
|
+
def _l1_penalty():
|
|
85
|
+
return sum(p.abs().sum() for p in model.parameters())
|
|
86
|
+
def _l2_penalty():
|
|
87
|
+
return sum(p.square().sum() for p in model.parameters())
|
|
88
|
+
|
|
89
|
+
def closure(backward=True, evaluated_params: list = _placeholder, epoch: int = _placeholder):
|
|
90
|
+
y_hat = model(X)
|
|
91
|
+
loss = criterion(y_hat.squeeze(), y)
|
|
92
|
+
|
|
93
|
+
if l1 > 0: loss += _l1_penalty() * l1
|
|
94
|
+
if l2 > 0: loss += _l2_penalty() * l2
|
|
95
|
+
|
|
96
|
+
if backward:
|
|
97
|
+
optimizer.zero_grad()
|
|
98
|
+
loss.backward()
|
|
99
|
+
|
|
100
|
+
# here I also test to make sure the optimizer doesn't evaluate same parameters twice per step
|
|
101
|
+
# this is for tests
|
|
102
|
+
if _assert_on_evaluated_same_params:
|
|
103
|
+
for p in evaluated_params:
|
|
104
|
+
assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
|
|
105
|
+
|
|
106
|
+
evaluated_params.append([p.clone() for p in model.parameters()])
|
|
107
|
+
|
|
108
|
+
return loss
|
|
109
|
+
|
|
110
|
+
# --------------------------------- optimize --------------------------------- #
|
|
111
|
+
losses = []
|
|
112
|
+
epochs = tqdm.trange(max_steps, disable=not pbar)
|
|
113
|
+
for epoch in epochs:
|
|
114
|
+
evaluated_params = []
|
|
115
|
+
loss = float(optimizer.step(partial(closure, evaluated_params=evaluated_params, epoch=epoch)))
|
|
116
|
+
|
|
117
|
+
losses.append(loss)
|
|
118
|
+
epochs.set_postfix_str(f"{loss:.5f}")
|
|
119
|
+
if loss <= tol:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
return losses
|
torchzero/utils/derivatives.py
CHANGED
|
@@ -7,7 +7,7 @@ from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
|
|
|
7
7
|
from .tensorlist import TensorList
|
|
8
8
|
|
|
9
9
|
def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
10
|
-
flat_outputs = torch.cat([i.
|
|
10
|
+
flat_outputs = torch.cat([i.ravel() for i in outputs])
|
|
11
11
|
grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
|
|
12
12
|
jac = []
|
|
13
13
|
for i in range(flat_outputs.numel()):
|
|
@@ -24,7 +24,7 @@ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], crea
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
|
|
27
|
-
flat_outputs = torch.cat([i.
|
|
27
|
+
flat_outputs = torch.cat([i.ravel() for i in outputs])
|
|
28
28
|
return torch.autograd.grad(
|
|
29
29
|
flat_outputs,
|
|
30
30
|
wrt,
|
|
@@ -40,10 +40,10 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
|
|
|
40
40
|
|
|
41
41
|
Args:
|
|
42
42
|
jacs (Sequence[torch.Tensor]):
|
|
43
|
-
output from jacobian_wrt where ach tensor has the shape
|
|
43
|
+
output from jacobian_wrt where ach tensor has the shape ``(*output.shape, *wrt[i].shape)``.
|
|
44
44
|
|
|
45
45
|
Returns:
|
|
46
|
-
torch.Tensor: has the shape
|
|
46
|
+
torch.Tensor: has the shape ``(output.ndim, wrt.ndim)``.
|
|
47
47
|
"""
|
|
48
48
|
if not jacs:
|
|
49
49
|
return torch.empty(0, 0)
|