torchzero 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_opts.py +199 -198
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +1 -1
  5. torchzero/core/functional.py +1 -1
  6. torchzero/core/modular.py +5 -5
  7. torchzero/core/module.py +2 -2
  8. torchzero/core/objective.py +10 -10
  9. torchzero/core/transform.py +1 -1
  10. torchzero/linalg/__init__.py +3 -2
  11. torchzero/linalg/eigh.py +223 -4
  12. torchzero/linalg/orthogonalize.py +2 -4
  13. torchzero/linalg/qr.py +12 -0
  14. torchzero/linalg/solve.py +1 -3
  15. torchzero/linalg/svd.py +47 -20
  16. torchzero/modules/__init__.py +4 -3
  17. torchzero/modules/adaptive/__init__.py +11 -3
  18. torchzero/modules/adaptive/adagrad.py +10 -10
  19. torchzero/modules/adaptive/adahessian.py +2 -2
  20. torchzero/modules/adaptive/adam.py +1 -1
  21. torchzero/modules/adaptive/adan.py +1 -1
  22. torchzero/modules/adaptive/adaptive_heavyball.py +1 -1
  23. torchzero/modules/adaptive/esgd.py +2 -2
  24. torchzero/modules/adaptive/ggt.py +186 -0
  25. torchzero/modules/adaptive/lion.py +2 -1
  26. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  27. torchzero/modules/adaptive/mars.py +2 -2
  28. torchzero/modules/adaptive/matrix_momentum.py +1 -1
  29. torchzero/modules/adaptive/msam.py +4 -4
  30. torchzero/modules/adaptive/muon.py +9 -6
  31. torchzero/modules/adaptive/natural_gradient.py +32 -15
  32. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  33. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  34. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  35. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  36. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  37. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  38. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  39. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  40. torchzero/modules/adaptive/rprop.py +2 -2
  41. torchzero/modules/adaptive/sam.py +4 -4
  42. torchzero/modules/adaptive/shampoo.py +28 -3
  43. torchzero/modules/adaptive/soap.py +3 -3
  44. torchzero/modules/adaptive/sophia_h.py +2 -2
  45. torchzero/modules/clipping/clipping.py +7 -7
  46. torchzero/modules/conjugate_gradient/cg.py +2 -2
  47. torchzero/modules/experimental/__init__.py +5 -0
  48. torchzero/modules/experimental/adanystrom.py +258 -0
  49. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  50. torchzero/modules/experimental/cubic_adam.py +160 -0
  51. torchzero/modules/experimental/eigen_sr1.py +182 -0
  52. torchzero/modules/experimental/eigengrad.py +207 -0
  53. torchzero/modules/experimental/l_infinity.py +1 -1
  54. torchzero/modules/experimental/matrix_nag.py +122 -0
  55. torchzero/modules/experimental/newton_solver.py +2 -2
  56. torchzero/modules/experimental/newtonnewton.py +34 -40
  57. torchzero/modules/grad_approximation/fdm.py +2 -2
  58. torchzero/modules/grad_approximation/rfdm.py +4 -4
  59. torchzero/modules/least_squares/gn.py +68 -45
  60. torchzero/modules/line_search/backtracking.py +2 -2
  61. torchzero/modules/line_search/line_search.py +1 -1
  62. torchzero/modules/line_search/strong_wolfe.py +2 -2
  63. torchzero/modules/misc/escape.py +1 -1
  64. torchzero/modules/misc/gradient_accumulation.py +1 -1
  65. torchzero/modules/misc/misc.py +1 -1
  66. torchzero/modules/misc/multistep.py +4 -7
  67. torchzero/modules/misc/regularization.py +2 -2
  68. torchzero/modules/misc/split.py +1 -1
  69. torchzero/modules/misc/switch.py +2 -2
  70. torchzero/modules/momentum/cautious.py +3 -3
  71. torchzero/modules/momentum/momentum.py +1 -1
  72. torchzero/modules/ops/higher_level.py +1 -1
  73. torchzero/modules/ops/multi.py +1 -1
  74. torchzero/modules/projections/projection.py +5 -2
  75. torchzero/modules/quasi_newton/__init__.py +1 -1
  76. torchzero/modules/quasi_newton/damping.py +1 -1
  77. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  78. torchzero/modules/quasi_newton/lbfgs.py +3 -3
  79. torchzero/modules/quasi_newton/lsr1.py +3 -3
  80. torchzero/modules/quasi_newton/quasi_newton.py +44 -29
  81. torchzero/modules/quasi_newton/sg2.py +69 -205
  82. torchzero/modules/restarts/restars.py +17 -17
  83. torchzero/modules/second_order/inm.py +33 -25
  84. torchzero/modules/second_order/newton.py +132 -130
  85. torchzero/modules/second_order/newton_cg.py +3 -3
  86. torchzero/modules/second_order/nystrom.py +83 -32
  87. torchzero/modules/second_order/rsn.py +41 -44
  88. torchzero/modules/smoothing/laplacian.py +1 -1
  89. torchzero/modules/smoothing/sampling.py +2 -3
  90. torchzero/modules/step_size/adaptive.py +6 -6
  91. torchzero/modules/step_size/lr.py +2 -2
  92. torchzero/modules/trust_region/cubic_regularization.py +1 -1
  93. torchzero/modules/trust_region/levenberg_marquardt.py +2 -2
  94. torchzero/modules/trust_region/trust_cg.py +1 -1
  95. torchzero/modules/variance_reduction/svrg.py +4 -5
  96. torchzero/modules/weight_decay/reinit.py +2 -2
  97. torchzero/modules/weight_decay/weight_decay.py +5 -5
  98. torchzero/modules/wrappers/optim_wrapper.py +4 -4
  99. torchzero/modules/zeroth_order/cd.py +1 -1
  100. torchzero/optim/mbs.py +291 -0
  101. torchzero/optim/wrappers/nevergrad.py +0 -9
  102. torchzero/optim/wrappers/optuna.py +2 -0
  103. torchzero/utils/benchmarks/__init__.py +0 -0
  104. torchzero/utils/benchmarks/logistic.py +122 -0
  105. torchzero/utils/derivatives.py +4 -4
  106. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  107. torchzero-0.4.1.dist-info/RECORD +209 -0
  108. torchzero/modules/adaptive/lmadagrad.py +0 -241
  109. torchzero-0.4.0.dist-info/RECORD +0 -191
  110. /torchzero/modules/{functional.py → opt_utils.py} +0 -0
  111. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  112. {torchzero-0.4.0.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/optim/mbs.py ADDED
@@ -0,0 +1,291 @@
1
+ from typing import NamedTuple
2
+ import math
3
+ from collections.abc import Iterable
4
+ from decimal import ROUND_HALF_UP, Decimal
5
+
6
+ import numpy as np
7
+
8
+
9
+ def format_number(number, n):
10
+ """Rounds to n significant digits after the decimal point."""
11
+ if number == 0: return 0
12
+ if math.isnan(number) or math.isinf(number) or (not math.isfinite(number)): return number
13
+ if n <= 0: raise ValueError("n must be positive")
14
+
15
+ dec = Decimal(str(number))
16
+ if dec.is_zero(): return 0
17
+ if number > 10**n or dec % 1 == 0: return int(dec)
18
+
19
+ if abs(dec) >= 1:
20
+ places = n
21
+ else:
22
+ frac_str = format(abs(dec), 'f').split('.')[1]
23
+ leading_zeros = len(frac_str) - len(frac_str.lstrip('0'))
24
+ places = leading_zeros + n
25
+
26
+ quantizer = Decimal('1e-' + str(places))
27
+ rounded_dec = dec.quantize(quantizer, rounding=ROUND_HALF_UP)
28
+
29
+ if rounded_dec % 1 == 0: return int(rounded_dec)
30
+ return float(rounded_dec)
31
+
32
+ def _nonfinite_to_inf(x):
33
+ if not math.isfinite(x): return math.inf
34
+ return x
35
+
36
+ def _tofloatlist(x) -> list[float]:
37
+ if isinstance(x, (int,float)): return [x]
38
+ if isinstance(x, np.ndarray) and x.size == 1: return [float(x.item())]
39
+ return [float(i) for i in x]
40
+
41
+ class Trial(NamedTuple):
42
+ x: float
43
+ f: tuple[float, ...]
44
+
45
+ class Solution(NamedTuple):
46
+ x: float
47
+ f: tuple[float, ...]
48
+ trials: list[Trial]
49
+
50
+ class MBS:
51
+ """Univariate minimization via grid search followed by refining, supports multi-objective functions.
52
+
53
+ This tends to outperform bayesian optimization for learning rate tuning, it is also good for plotting.
54
+
55
+ First it evaluates all points defined in ``grid``. The grid doesn't have to be dense and the solution doesn't
56
+ have to be between the endpoints.
57
+
58
+ Then it picks ``num_candidates`` best points per each objective. If any of those points are endpoints,
59
+ it expands the search space by ``step`` in that direction and evaluates the new endpoint.
60
+
61
+ Otherwise it keeps picking points between best points and evaluating them, until ``num_binary`` evaluations
62
+ have been performed.
63
+
64
+ Args:
65
+ grid (Iterable[float], optional): values for initial grid search. If ``log_scale=True``, should be in log10 scale.
66
+ step (float, optional): expansion step size. Defaults to 1.
67
+ num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
68
+ num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
69
+ num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
70
+ rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
71
+ lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
72
+ ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
73
+ log_scale (bool, optional):
74
+ whether to minimize in log10 scale. If true, it is assumed that
75
+ ``grid``, ``lb`` and ``ub`` are given in log10 scale.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ def objective(x: float):
81
+ x = x * 4
82
+ return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
83
+
84
+ mbs = MBS(grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
85
+
86
+ x, f, trials = mbs.run(objective)
87
+ # x - solution
88
+ # f - value at solution x
89
+ # trials - list of trials, each trial is a named tuple: Trial(x, f)
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ grid: Iterable[float],
95
+ step: float,
96
+ num_candidates: int = 3,
97
+ num_binary: int = 20,
98
+ num_expansions: int = 20,
99
+ rounding: int| None = 2,
100
+ lb = None,
101
+ ub = None,
102
+ log_scale: bool = False,
103
+ ):
104
+ self.objectives: dict[int, dict[float,float]] = {}
105
+ """dictionary of objectives, each maps point (x) to value (v)"""
106
+
107
+ self.evaluated: set[float] = set()
108
+ """set of evaluated points (x)"""
109
+
110
+ grid = tuple(grid)
111
+ if len(grid) == 0: raise ValueError("At least one grid search point must be specified")
112
+ self.grid = sorted(grid)
113
+
114
+ self.step = step
115
+ self.num_candidates = num_candidates
116
+ self.num_binary = num_binary
117
+ self.num_expansions = num_expansions
118
+ self.rounding = rounding
119
+ self.log_scale = log_scale
120
+ self.lb = lb
121
+ self.ub = ub
122
+
123
+ def _get_best_x(self, n: int, objective: int):
124
+ """n best points"""
125
+ obj = self.objectives[objective]
126
+ v_to_x = [(v,x) for x,v in obj.items()]
127
+ v_to_x.sort(key = lambda vx: vx[0])
128
+ xs = [x for v,x in v_to_x]
129
+ return xs[:n]
130
+
131
+ def _suggest_points_around(self, x: float, objective: int):
132
+ """suggests points around x"""
133
+ points = list(self.objectives[objective].keys())
134
+ points.sort()
135
+ if x not in points: raise RuntimeError(f"{x} not in {points}")
136
+
137
+ expansions = []
138
+ if x == points[0]:
139
+ expansions.append((x-self.step, 'expansion'))
140
+
141
+ if x == points[-1]:
142
+ expansions.append((x+self.step, 'expansion'))
143
+
144
+ if len(expansions) != 0: return expansions
145
+
146
+ idx = points.index(x)
147
+ xm = points[idx-1]
148
+ xp = points[idx+1]
149
+
150
+ x1 = (x - (x - xm)/2)
151
+ x2 = (x + (xp - x)/2)
152
+
153
+ return [(x1, 'binary'), (x2, 'binary')]
154
+
155
+ def _out_of_bounds(self, x):
156
+ if self.lb is not None and x < self.lb: return True
157
+ if self.ub is not None and x > self.ub: return True
158
+ return False
159
+
160
+ def _evaluate(self, fn, x):
161
+ """Evaluate a point, returns False if point is already in history"""
162
+ if self.rounding is not None: x = format_number(x, self.rounding)
163
+ if x in self.evaluated: return False
164
+ if self._out_of_bounds(x): return False
165
+
166
+ self.evaluated.add(x)
167
+
168
+ if self.log_scale: vals = _tofloatlist(fn(10 ** x))
169
+ else: vals = _tofloatlist(fn(x))
170
+ vals = [_nonfinite_to_inf(v) for v in vals]
171
+
172
+ for idx, v in enumerate(vals):
173
+ if idx not in self.objectives: self.objectives[idx] = {}
174
+ self.objectives[idx][x] = v
175
+
176
+ return True
177
+
178
+ def run(self, fn) -> Solution:
179
+ # step 1 - gr id search
180
+ for x in self.grid:
181
+ self._evaluate(fn, x)
182
+
183
+ # step 2 - binary search
184
+ while True:
185
+ if (self.num_candidates <= 0) or (self.num_expansions <= 0 and self.num_binary <= 0): break
186
+
187
+ # suggest candidates
188
+ candidates: list[tuple[float, str]] = []
189
+
190
+ # sample around best points
191
+ for objective in self.objectives:
192
+ best_points = self._get_best_x(self.num_candidates, objective)
193
+ for p in best_points:
194
+ candidates.extend(self._suggest_points_around(p, objective=objective))
195
+
196
+ # filter
197
+ if self.num_expansions <= 0:
198
+ candidates = [(x,t) for x,t in candidates if t != 'expansion']
199
+
200
+ if self.num_candidates <= 0:
201
+ candidates = [(x,t) for x,t in candidates if t != 'binary']
202
+
203
+ # if expansion was suggested, discard anything else
204
+ types = [t for x, t in candidates]
205
+ if any(t == 'expansion' for t in types):
206
+ candidates = [(x,t) for x,t in candidates if t == 'expansion']
207
+
208
+ # evaluate candidates
209
+ terminate = False
210
+ at_least_one_evaluated = False
211
+ for x, t in candidates:
212
+ evaluated = self._evaluate(fn, x)
213
+ if not evaluated: continue
214
+ at_least_one_evaluated = True
215
+
216
+ if t == 'expansion': self.num_expansions -= 1
217
+ elif t == 'binary': self.num_binary -= 1
218
+
219
+ if self.num_binary < 0:
220
+ terminate = True
221
+ break
222
+
223
+ if terminate: break
224
+ if not at_least_one_evaluated:
225
+ if self.rounding is None: break
226
+ self.rounding += 1
227
+ if self.rounding == 100: break
228
+
229
+ # create dict[float, tuple[float,...]]
230
+ ret = {}
231
+ for i, objective in enumerate(self.objectives.values()):
232
+ for x, v in objective.items():
233
+ if self.log_scale: x = 10 ** x
234
+ if x not in ret: ret[x] = [None for _ in self.objectives]
235
+ ret[x][i] = v
236
+
237
+ for v in ret.values():
238
+ assert len(v) == len(self.objectives), v
239
+ assert all(i is not None for i in v), v
240
+
241
+ # ret maps x to list of per-objective values, e.g. {1: [0.1, 0.3], ...}
242
+ # now make a list of trials as they are easier to work with
243
+ trials: list[Trial] = []
244
+ for x, values in ret.items():
245
+ trials.append(Trial(x=x, f=values))
246
+
247
+ # sort trials by sum of values
248
+ trials.sort(key = lambda trial: sum(trial.f))
249
+ return Solution(x=trials[0].x, f=trials[0].f, trials=trials)
250
+
251
+ def mbs_minimize(
252
+ fn,
253
+ grid: Iterable[float],
254
+ step: float,
255
+ num_candidates: int = 3,
256
+ num_binary: int = 20,
257
+ num_expansions: int = 20,
258
+ rounding=2,
259
+ lb:float | None = None,
260
+ ub:float | None = None,
261
+ log_scale=False,
262
+ ) -> Solution:
263
+ """minimize univariate function via MBS.
264
+
265
+ Args:
266
+ fn (function): objective function that accepts a float and returns a float or a sequence of floats to minimize.
267
+ step (float, optional): expansion step size. Defaults to 1.
268
+ num_candidates (int, optional): number of best points to sample new points around on each iteration. Defaults to 2.
269
+ num_binary (int, optional): maximum number of new points sampled via binary search. Defaults to 7.
270
+ num_expansions (int, optional): maximum number of expansions (not counted towards binary search points). Defaults to 7.
271
+ rounding (int, optional): rounding is to significant digits, avoids evaluating points that are too close.
272
+ lb (float | None, optional): lower bound. If ``log_scale=True``, should be in log10 scale.
273
+ ub (float | None, optional): upper bound. If ``log_scale=True``, should be in log10 scale.
274
+ log_scale (bool, optional):
275
+ whether to minimize in log10 scale. If true, it is assumed that
276
+ ``grid``, ``lb`` and ``ub`` are given in log10 scale.
277
+
278
+ Example:
279
+
280
+ ```python
281
+ def objective(x: float):
282
+ x = x * 4
283
+ return -(np.sin(x) * (x / 3) + np.cos(x*2.5) * 2 - 0.05 * (x-5)**2)
284
+
285
+ x, f, trials = mbs_minimize(objective, grid=[-1, 0, 1, 2, 3, 4], step=1, num_binary=10, num_expansions=10)
286
+ # x - solution
287
+ # f - value at solution x
288
+ # trials - list of trials, each trial is a named tuple: Trial(x, f)
289
+ """
290
+ mbs = MBS(grid, step=step, num_candidates=num_candidates, num_binary=num_binary, num_expansions=num_expansions, rounding=rounding, lb=lb, ub=ub, log_scale=log_scale)
291
+ return mbs.run(fn)
@@ -55,15 +55,6 @@ class NevergradWrapper(WrapperBase):
55
55
  mutable_sigma = False,
56
56
  use_init = True,
57
57
  ):
58
- """_summary_
59
-
60
- Args:
61
- params (_type_): _description_
62
- opt_cls (type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]): _description_
63
- budget (int | None, optional): _description_. Defaults to None.
64
- mutable_sigma (bool, optional): _description_. Defaults to False.
65
- use_init (bool, optional): _description_. Defaults to True.
66
- """
67
58
  defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
68
59
  super().__init__(params, defaults)
69
60
  self.opt_cls = opt_cls
@@ -45,6 +45,7 @@ class OptunaSampler(WrapperBase):
45
45
  self.study = optuna.create_study(sampler=self.sampler)
46
46
 
47
47
  # some optuna samplers use torch
48
+ # and require torch.enable_grad
48
49
  with torch.enable_grad():
49
50
  trial = self.study.ask()
50
51
 
@@ -58,6 +59,7 @@ class OptunaSampler(WrapperBase):
58
59
  params.from_vec_(vec)
59
60
 
60
61
  loss = closure()
62
+
61
63
  with torch.enable_grad(): self.study.tell(trial, tofloat(torch.nan_to_num(totensor(loss), 1e32)))
62
64
 
63
65
  return loss
File without changes
@@ -0,0 +1,122 @@
1
+ from functools import partial
2
+ from typing import Any, cast
3
+
4
+ import numpy as np
5
+ import torch
6
+ import tqdm
7
+
8
+
9
+ def generate_correlated_logistic_data(n_samples=2000, n_features=32, n_correlated_pairs=512, correlation=0.99, seed=0):
10
+ """Hard logistic regression dataset with correlated features"""
11
+ generator = np.random.default_rng(seed)
12
+
13
+ # ------------------------------------- X ------------------------------------ #
14
+ X = generator.standard_normal(size=(n_samples, n_features))
15
+ weights = generator.uniform(-2, 2, n_features)
16
+
17
+ used_pairs = []
18
+ for i in range(n_correlated_pairs):
19
+ idxs = None
20
+ while idxs is None or idxs in used_pairs:
21
+ idxs = tuple(generator.choice(n_features, size=2, replace=False).tolist())
22
+
23
+ used_pairs.append(idxs)
24
+ idx1, idx2 = idxs
25
+
26
+ noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
27
+ X[:, idx2] = correlation * X[:, idx1] + noise
28
+
29
+ w = generator.integers(1, 51)
30
+ weights[idx1] = w
31
+ weights[idx2] = -w
32
+
33
+ # ---------------------------------- logits ---------------------------------- #
34
+ logits = X @ weights
35
+ probabilities = 1 / (1 + np.exp(-logits))
36
+ y = generator.binomial(1, probabilities).astype(np.float32)
37
+
38
+ X = X - X.mean(0, keepdims=True)
39
+ X = X / X.std(0, keepdims=True)
40
+ return X, y
41
+
42
+
43
+ # if __name__ == '__main__':
44
+ # X, y = generate_correlated_logistic_data()
45
+
46
+ # plt.figure(figsize=(10, 8))
47
+ # sns.heatmap(pl.DataFrame(X).corr(), annot=True, cmap='coolwarm', fmt=".2f")
48
+ # plt.show()
49
+
50
+
51
+
52
+
53
+ def _tensorlist_equal(t1, t2):
54
+ return all(a == b for a, b in zip(t1, t2))
55
+
56
+ _placeholder = cast(Any, ...)
57
+
58
+ def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps: int, tol:float=0, l1:float=0, l2:float=0, pbar:bool=False, *, _assert_on_evaluated_same_params: bool = False):
59
+ # ------------------------------- verify inputs ------------------------------ #
60
+ n_samples, n_features = X.size()
61
+
62
+ if y.ndim != 1: raise ValueError(f"y should be 1d, got {y.shape}")
63
+ if y.size(0) != n_samples: raise ValueError(f"y should have {n_samples} elements, got {y.shape}")
64
+ if y.device != X.device: raise ValueError(f"X and y should be on same device, got {X.device = }, {y.device = }")
65
+ device = X.device
66
+ dtype = X.dtype
67
+
68
+ # ---------------------------- model and criterion --------------------------- #
69
+ n_targets = int(y.amax()) + 1
70
+ binary = n_targets == 2
71
+
72
+ if binary:
73
+ criterion = torch.nn.functional.binary_cross_entropy_with_logits
74
+ model = torch.nn.Linear(n_features, 1).to(device=device, dtype=dtype)
75
+ y = y.to(dtype=dtype)
76
+ else:
77
+ model = torch.nn.Linear(n_features, n_targets).to(device=device, dtype=dtype)
78
+ criterion = torch.nn.functional.cross_entropy
79
+ y = y.long()
80
+
81
+ optimizer = opt_fn(list(model.parameters()))
82
+
83
+ # ---------------------------------- closure --------------------------------- #
84
+ def _l1_penalty():
85
+ return sum(p.abs().sum() for p in model.parameters())
86
+ def _l2_penalty():
87
+ return sum(p.square().sum() for p in model.parameters())
88
+
89
+ def closure(backward=True, evaluated_params: list = _placeholder, epoch: int = _placeholder):
90
+ y_hat = model(X)
91
+ loss = criterion(y_hat.squeeze(), y)
92
+
93
+ if l1 > 0: loss += _l1_penalty() * l1
94
+ if l2 > 0: loss += _l2_penalty() * l2
95
+
96
+ if backward:
97
+ optimizer.zero_grad()
98
+ loss.backward()
99
+
100
+ # here I also test to make sure the optimizer doesn't evaluate same parameters twice per step
101
+ # this is for tests
102
+ if _assert_on_evaluated_same_params:
103
+ for p in evaluated_params:
104
+ assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
105
+
106
+ evaluated_params.append([p.clone() for p in model.parameters()])
107
+
108
+ return loss
109
+
110
+ # --------------------------------- optimize --------------------------------- #
111
+ losses = []
112
+ epochs = tqdm.trange(max_steps, disable=not pbar)
113
+ for epoch in epochs:
114
+ evaluated_params = []
115
+ loss = float(optimizer.step(partial(closure, evaluated_params=evaluated_params, epoch=epoch)))
116
+
117
+ losses.append(loss)
118
+ epochs.set_postfix_str(f"{loss:.5f}")
119
+ if loss <= tol:
120
+ break
121
+
122
+ return losses
@@ -7,7 +7,7 @@ from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
7
  from .tensorlist import TensorList
8
8
 
9
9
  def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
10
- flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
10
+ flat_outputs = torch.cat([i.ravel() for i in outputs])
11
11
  grad_ouputs = torch.eye(len(flat_outputs), device=outputs[0].device, dtype=outputs[0].dtype)
12
12
  jac = []
13
13
  for i in range(flat_outputs.numel()):
@@ -24,7 +24,7 @@ def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], crea
24
24
 
25
25
 
26
26
  def _jacobian_batched(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
27
- flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
27
+ flat_outputs = torch.cat([i.ravel() for i in outputs])
28
28
  return torch.autograd.grad(
29
29
  flat_outputs,
30
30
  wrt,
@@ -40,10 +40,10 @@ def flatten_jacobian(jacs: Sequence[torch.Tensor]) -> torch.Tensor:
40
40
 
41
41
  Args:
42
42
  jacs (Sequence[torch.Tensor]):
43
- output from jacobian_wrt where ach tensor has the shape `(*output.shape, *wrt[i].shape)`.
43
+ output from jacobian_wrt where ach tensor has the shape ``(*output.shape, *wrt[i].shape)``.
44
44
 
45
45
  Returns:
46
- torch.Tensor: has the shape `(output.ndim, wrt.ndim)`.
46
+ torch.Tensor: has the shape ``(output.ndim, wrt.ndim)``.
47
47
  """
48
48
  if not jacs:
49
49
  return torch.empty(0, 0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero