torchzero 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_identical.py +1 -1
  2. torchzero/__init__.py +3 -1
  3. torchzero/_minimize/__init__.py +0 -0
  4. torchzero/_minimize/methods.py +95 -0
  5. torchzero/_minimize/minimize.py +518 -0
  6. torchzero/core/__init__.py +5 -5
  7. torchzero/core/chain.py +2 -1
  8. torchzero/core/functional.py +2 -1
  9. torchzero/core/module.py +75 -4
  10. torchzero/core/transform.py +6 -5
  11. torchzero/linalg/eigh.py +116 -68
  12. torchzero/linalg/linear_operator.py +1 -0
  13. torchzero/linalg/orthogonalize.py +60 -5
  14. torchzero/linalg/sketch.py +39 -0
  15. torchzero/modules/__init__.py +1 -0
  16. torchzero/modules/adaptive/adagrad.py +2 -0
  17. torchzero/modules/adaptive/adam.py +5 -1
  18. torchzero/modules/adaptive/adan.py +3 -0
  19. torchzero/modules/adaptive/ggt.py +20 -18
  20. torchzero/modules/adaptive/lion.py +3 -1
  21. torchzero/modules/adaptive/mars.py +6 -5
  22. torchzero/modules/adaptive/msam.py +3 -0
  23. torchzero/modules/adaptive/rmsprop.py +2 -0
  24. torchzero/modules/adaptive/rprop.py +9 -7
  25. torchzero/modules/adaptive/shampoo.py +9 -1
  26. torchzero/modules/adaptive/soap.py +32 -29
  27. torchzero/modules/basis/__init__.py +2 -0
  28. torchzero/modules/basis/ggt_basis.py +199 -0
  29. torchzero/modules/basis/soap_basis.py +254 -0
  30. torchzero/modules/clipping/ema_clipping.py +32 -27
  31. torchzero/modules/clipping/growth_clipping.py +1 -0
  32. torchzero/modules/experimental/__init__.py +1 -6
  33. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  34. torchzero/modules/experimental/cubic_adam.py +4 -0
  35. torchzero/modules/grad_approximation/__init__.py +3 -2
  36. torchzero/modules/least_squares/gn.py +6 -0
  37. torchzero/modules/misc/gradient_accumulation.py +1 -0
  38. torchzero/modules/misc/misc.py +6 -0
  39. torchzero/modules/momentum/averaging.py +6 -0
  40. torchzero/modules/momentum/momentum.py +13 -9
  41. torchzero/modules/ops/__init__.py +0 -1
  42. torchzero/modules/ops/accumulate.py +4 -0
  43. torchzero/modules/ops/higher_level.py +6 -1
  44. torchzero/modules/second_order/inm.py +4 -0
  45. torchzero/modules/second_order/newton.py +11 -3
  46. torchzero/modules/second_order/newton_cg.py +7 -3
  47. torchzero/modules/second_order/nystrom.py +14 -19
  48. torchzero/modules/second_order/rsn.py +37 -6
  49. torchzero/modules/trust_region/trust_region.py +2 -1
  50. torchzero/utils/benchmarks/logistic.py +33 -18
  51. torchzero/utils/optuna_tools.py +1 -1
  52. torchzero/utils/params.py +13 -1
  53. torchzero/utils/tensorlist.py +2 -2
  54. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/METADATA +1 -1
  55. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/RECORD +58 -55
  56. torchzero/modules/experimental/adanystrom.py +0 -258
  57. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  58. torchzero/modules/experimental/eigen_sr1.py +0 -182
  59. torchzero/modules/experimental/eigengrad.py +0 -207
  60. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  61. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/WHEEL +0 -0
  62. {torchzero-0.4.1.dist-info → torchzero-0.4.3.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class EMASquared(TensorTransform):
30
30
  def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
31
31
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
32
32
  super().__init__(defaults)
33
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
33
34
 
34
35
  @torch.no_grad
35
36
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -57,7 +58,7 @@ class SqrtEMASquared(TensorTransform):
57
58
  def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
58
59
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
59
60
  super().__init__(defaults)
60
-
61
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
61
62
 
62
63
  @torch.no_grad
63
64
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -141,6 +142,8 @@ class CenteredEMASquared(TensorTransform):
141
142
  def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
142
143
  defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
143
144
  super().__init__(defaults, uses_grad=False)
145
+ self.add_projected_keys("grad", "exp_avg")
146
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
144
147
 
145
148
  @torch.no_grad
146
149
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -175,6 +178,8 @@ class CenteredSqrtEMASquared(TensorTransform):
175
178
  def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
176
179
  defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
177
180
  super().__init__(defaults, uses_grad=False)
181
+ self.add_projected_keys("grad", "exp_avg")
182
+ self.add_projected_keys("grad_sq", "exp_avg_sq", "max_exp_avg_sq")
178
183
 
179
184
  @torch.no_grad
180
185
  def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
@@ -35,6 +35,8 @@ class ImprovedNewton(Transform):
35
35
  self,
36
36
  damping: float = 0,
37
37
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
38
+ eigv_tol: float | None = None,
39
+ truncate: int | None = None,
38
40
  update_freq: int = 1,
39
41
  precompute_inverse: bool | None = None,
40
42
  use_lstsq: bool = False,
@@ -89,6 +91,8 @@ class ImprovedNewton(Transform):
89
91
  state = self.global_state,
90
92
  damping = fs["damping"],
91
93
  eigval_fn = fs["eigval_fn"],
94
+ eigv_tol = fs["eigv_tol"],
95
+ truncate = fs["truncate"],
92
96
  precompute_inverse = precompute_inverse,
93
97
  use_lstsq = fs["use_lstsq"]
94
98
  )
@@ -7,6 +7,7 @@ from ...core import Chainable, Transform, Objective, HessianMethod
7
7
  from ...utils import vec_to_tensors_
8
8
  from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
9
9
  from ...linalg import torch_linalg
10
+ from ...linalg.eigh import regularize_eigh
10
11
 
11
12
  def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
12
13
  try:
@@ -30,6 +31,8 @@ def _newton_update_state_(
30
31
  H: torch.Tensor,
31
32
  damping: float,
32
33
  eigval_fn: Callable | None,
34
+ eigv_tol: float | None,
35
+ truncate: int | None,
33
36
  precompute_inverse: bool,
34
37
  use_lstsq: bool,
35
38
  ):
@@ -39,10 +42,11 @@ def _newton_update_state_(
39
42
  reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
40
43
  H += reg
41
44
 
42
- # if eigval_fn is given, we don't need H or H_inv, we store factors
43
- if eigval_fn is not None:
45
+ # if any args require eigendecomp, we don't need H or H_inv, we store factors
46
+ if any(i is not None for i in [eigval_fn, eigv_tol, truncate]):
44
47
  L, Q = torch_linalg.eigh(H, retry_float64=True)
45
- L = eigval_fn(L)
48
+ if eigval_fn is not None: L = eigval_fn(L)
49
+ L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eigv_tol)
46
50
  state["L"] = L
47
51
  state["Q"] = Q
48
52
  return
@@ -216,6 +220,8 @@ class Newton(Transform):
216
220
  self,
217
221
  damping: float = 0,
218
222
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
223
+ eigv_tol: float | None = None,
224
+ truncate: int | None = None,
219
225
  update_freq: int = 1,
220
226
  precompute_inverse: bool | None = None,
221
227
  use_lstsq: bool = False,
@@ -242,6 +248,8 @@ class Newton(Transform):
242
248
  H=H,
243
249
  damping = fs["damping"],
244
250
  eigval_fn = fs["eigval_fn"],
251
+ eigv_tol = fs["eigv_tol"],
252
+ truncate = fs["truncate"],
245
253
  precompute_inverse = precompute_inverse,
246
254
  use_lstsq = fs["use_lstsq"]
247
255
  )
@@ -226,7 +226,8 @@ class NewtonCGSteihaug(Transform):
226
226
  tol: float = 1e-8,
227
227
  reg: float = 1e-8,
228
228
  solver: Literal['cg', "minres"] = 'cg',
229
- adapt_tol: bool = True,
229
+ adapt_tol: bool = False,
230
+ terminate_on_tr: bool = True,
230
231
  npc_terminate: bool = False,
231
232
 
232
233
  # hvp settings
@@ -272,7 +273,6 @@ class NewtonCGSteihaug(Transform):
272
273
  npc_terminate=fs["npc_terminate"]
273
274
  miniter=fs["miniter"]
274
275
  max_history=fs["max_history"]
275
- adapt_tol=fs["adapt_tol"]
276
276
 
277
277
 
278
278
  # ------------------------------- trust region ------------------------------- #
@@ -294,9 +294,13 @@ class NewtonCGSteihaug(Transform):
294
294
  finfo = torch.finfo(orig_params[0].dtype)
295
295
  if trust_radius < finfo.tiny * 2:
296
296
  trust_radius = self.global_state['trust_radius'] = init
297
- if adapt_tol:
297
+
298
+ if fs["adapt_tol"]:
298
299
  self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
299
300
 
301
+ if fs["terminate_on_tr"]:
302
+ objective.should_terminate = True
303
+
300
304
  elif trust_radius > finfo.max / 2:
301
305
  trust_radius = self.global_state['trust_radius'] = init
302
306
 
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable, Transform, HVPMethod
7
7
  from ...utils import TensorList, vec_to_tensors
8
- from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
8
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod, orthogonalize
9
9
  from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
10
10
 
11
11
  class NystromSketchAndSolve(Transform):
@@ -75,7 +75,7 @@ class NystromSketchAndSolve(Transform):
75
75
  """
76
76
  def __init__(
77
77
  self,
78
- rank: int,
78
+ rank: int = 100,
79
79
  reg: float | None = 1e-2,
80
80
  eigv_tol: float = 0,
81
81
  truncate: int | None = None,
@@ -109,17 +109,15 @@ class NystromSketchAndSolve(Transform):
109
109
 
110
110
  generator = self.get_generator(params[0].device, seed=fs['seed'])
111
111
  try:
112
+ Omega = torch.randn([ndim, min(fs["rank"], ndim)], device=device, dtype=dtype, generator=generator)
113
+ Omega = orthogonalize(Omega, fs["orthogonalize_method"])
114
+ HOmega = H_mm(Omega)
115
+
112
116
  # compute the approximation
113
117
  L, Q = nystrom_approximation(
114
- A_mv=H_mv,
115
- A_mm=H_mm,
116
- ndim=ndim,
117
- rank=min(fs["rank"], ndim),
118
+ Omega=Omega,
119
+ AOmega=HOmega,
118
120
  eigv_tol=fs["eigv_tol"],
119
- orthogonalize_method=fs["orthogonalize_method"],
120
- dtype=dtype,
121
- device=device,
122
- generator=generator,
123
121
  )
124
122
 
125
123
  # regularize
@@ -225,7 +223,7 @@ class NystromPCG(Transform):
225
223
  """
226
224
  def __init__(
227
225
  self,
228
- rank: int,
226
+ rank: int = 100,
229
227
  maxiter=None,
230
228
  tol=1e-8,
231
229
  reg: float = 1e-6,
@@ -260,16 +258,13 @@ class NystromPCG(Transform):
260
258
  generator = self.get_generator(device, seed=fs['seed'])
261
259
 
262
260
  try:
261
+ Omega = torch.randn(ndim, min(fs["rank"], ndim), device=device, dtype=dtype, generator=generator)
262
+ HOmega = H_mm(orthogonalize(Omega, fs["orthogonalize_method"]))
263
+ # compute the approximation
263
264
  L, Q = nystrom_approximation(
264
- A_mv=None,
265
- A_mm=H_mm,
266
- ndim=ndim,
267
- rank=min(fs["rank"], ndim),
265
+ Omega=Omega,
266
+ AOmega=HOmega,
268
267
  eigv_tol=fs["eigv_tol"],
269
- orthogonalize_method=fs["orthogonalize_method"],
270
- dtype=dtype,
271
- device=device,
272
- generator=generator,
273
268
  )
274
269
 
275
270
  self.global_state["L"] = L
@@ -25,9 +25,23 @@ def _orthonormal_sketch(m, n, dtype, device, generator):
25
25
  return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
26
26
 
27
27
  def _rademacher_sketch(m, n, dtype, device, generator):
28
- rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
28
+ rademacher = torch.bernoulli(torch.full((m,n), 0.5, device=device, dtype=dtype), generator = generator).mul_(2).sub_(1)
29
29
  return rademacher.mul_(1 / math.sqrt(m))
30
30
 
31
+ def _row_sketch(m, n, dtype, device, generator):
32
+ weights = torch.ones(m, dtype=dtype, device=device)
33
+ indices = torch.multinomial(weights, n, replacement=False, generator=generator)
34
+
35
+ P = torch.zeros(m, n, dtype=dtype, device=device)
36
+ P[indices, range(n)] = 1
37
+ return P
38
+
39
+ def _topk_rows(grad, m, n, dtype, device, generator):
40
+ _, indices = torch.topk(grad.abs(), n)
41
+ P = torch.zeros(m, n, dtype=dtype, device=device)
42
+ P[indices, range(n)] = 1
43
+ return P
44
+
31
45
  class SubspaceNewton(Transform):
32
46
  """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
33
47
 
@@ -37,7 +51,9 @@ class SubspaceNewton(Transform):
37
51
  sketch_type (str, optional):
38
52
  - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
39
53
  - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
40
- - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
54
+ - "rows" - samples random rows.
55
+ - "topk" - samples top-rank rows with largest gradient magnitude.
56
+ - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
41
57
  - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
42
58
  damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
43
59
  hvp_method (str, optional):
@@ -93,13 +109,15 @@ class SubspaceNewton(Transform):
93
109
 
94
110
  def __init__(
95
111
  self,
96
- sketch_size: int,
97
- sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
112
+ sketch_size: int = 100,
113
+ sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher", "rows", "topk"] = "common_directions",
98
114
  damping:float=0,
99
115
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
116
+ eigv_tol: float | None = None,
117
+ truncate: int | None = None,
100
118
  update_freq: int = 1,
101
119
  precompute_inverse: bool = False,
102
- use_lstsq: bool = True,
120
+ use_lstsq: bool = False,
103
121
  hvp_method: HVPMethod = "batched_autograd",
104
122
  h: float = 1e-2,
105
123
  seed: int | None = None,
@@ -131,6 +149,14 @@ class SubspaceNewton(Transform):
131
149
  elif sketch_type == 'orthonormal':
132
150
  S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
151
 
152
+ elif sketch_type == "rows":
153
+ S = _row_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
154
+
155
+ elif sketch_type == "topk":
156
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
157
+ g = torch.cat([t.ravel() for t in g_list])
158
+ S = _topk_rows(g, ndim, sketch_size, device=device, dtype=dtype, generator=generator)
159
+
134
160
  elif sketch_type == 'common_directions':
135
161
  # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
136
162
  g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
@@ -189,6 +215,10 @@ class SubspaceNewton(Transform):
189
215
  else:
190
216
  raise ValueError(f'Unknown sketch_type {sketch_type}')
191
217
 
218
+ # print(f'{S.shape = }')
219
+ # I = torch.eye(S.size(1), device=S.device, dtype=S.dtype)
220
+ # print(f'{torch.nn.functional.mse_loss(S.T @ S, I) = }')
221
+
192
222
  # form sketched hessian
193
223
  HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
194
224
  hvp_method=fs["hvp_method"], h=fs["h"])
@@ -200,9 +230,10 @@ class SubspaceNewton(Transform):
200
230
  H = H_sketched,
201
231
  damping = fs["damping"],
202
232
  eigval_fn = fs["eigval_fn"],
233
+ eigv_tol = fs["eigv_tol"],
234
+ truncate = fs["truncate"],
203
235
  precompute_inverse = fs["precompute_inverse"],
204
236
  use_lstsq = fs["use_lstsq"]
205
-
206
237
  )
207
238
 
208
239
  self.global_state["S"] = S
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Callable, Mapping, Sequence
4
+ from collections.abc import Callable, Mapping, Sequence, MutableMapping
5
5
  from functools import partial
6
6
  from typing import Any, Literal, Protocol, cast, final, overload
7
7
 
@@ -203,6 +203,7 @@ def fixed_radius(
203
203
  ) -> tuple[float, bool]:
204
204
  return init, True
205
205
 
206
+
206
207
  _RADIUS_KEYS = Literal['default', 'fixed']
207
208
  _RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
208
209
  "default": default_radius,
@@ -5,39 +5,54 @@ import numpy as np
5
5
  import torch
6
6
  import tqdm
7
7
 
8
-
9
- def generate_correlated_logistic_data(n_samples=2000, n_features=32, n_correlated_pairs=512, correlation=0.99, seed=0):
10
- """Hard logistic regression dataset with correlated features"""
8
+ def generate_correlated_logistic_data(
9
+ n_samples=100_000,
10
+ n_features=32,
11
+ n_classes=10,
12
+ n_correlated=768,
13
+ correlation=0.99,
14
+ seed=0
15
+ ) -> tuple[np.ndarray, np.ndarray]:
16
+ assert n_classes >= 2
11
17
  generator = np.random.default_rng(seed)
12
18
 
13
- # ------------------------------------- X ------------------------------------ #
14
19
  X = generator.standard_normal(size=(n_samples, n_features))
15
- weights = generator.uniform(-2, 2, n_features)
20
+ weights = generator.uniform(-2, 2, size=(n_features, n_classes))
21
+
22
+ used_pairs = set()
23
+ n_correlated = min(n_correlated, n_features * (n_features - 1) // 2)
16
24
 
17
- used_pairs = []
18
- for i in range(n_correlated_pairs):
25
+ for _ in range(n_correlated):
19
26
  idxs = None
20
27
  while idxs is None or idxs in used_pairs:
21
- idxs = tuple(generator.choice(n_features, size=2, replace=False).tolist())
28
+ pair = generator.choice(n_features, size=2, replace=False)
29
+ pair.sort()
30
+ idxs = tuple(pair)
22
31
 
23
- used_pairs.append(idxs)
32
+ used_pairs.add(idxs)
24
33
  idx1, idx2 = idxs
25
34
 
26
35
  noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
27
36
  X[:, idx2] = correlation * X[:, idx1] + noise
28
37
 
29
38
  w = generator.integers(1, 51)
30
- weights[idx1] = w
31
- weights[idx2] = -w
39
+ cls = generator.integers(0, n_classes)
40
+ weights[idx1, cls] = w
41
+ weights[idx2, cls] = -w
32
42
 
33
- # ---------------------------------- logits ---------------------------------- #
34
43
  logits = X @ weights
35
- probabilities = 1 / (1 + np.exp(-logits))
36
- y = generator.binomial(1, probabilities).astype(np.float32)
37
44
 
38
- X = X - X.mean(0, keepdims=True)
39
- X = X / X.std(0, keepdims=True)
40
- return X, y
45
+ logits -= logits.max(axis=1, keepdims=True)
46
+ exp_logits = np.exp(logits)
47
+ probabilities = exp_logits / exp_logits.sum(axis=1, keepdims=True)
48
+
49
+ y_one_hot = generator.multinomial(1, pvals=probabilities)
50
+ y = np.argmax(y_one_hot, axis=1)
51
+
52
+ X -= X.mean(0, keepdims=True)
53
+ X /= X.std(0, keepdims=True)
54
+
55
+ return X, y.astype(np.int64)
41
56
 
42
57
 
43
58
  # if __name__ == '__main__':
@@ -101,7 +116,7 @@ def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps:
101
116
  # this is for tests
102
117
  if _assert_on_evaluated_same_params:
103
118
  for p in evaluated_params:
104
- assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
119
+ assert not _tensorlist_equal(p, model.parameters()), f"{optimizer} evaluated same parameters on epoch {epoch}"
105
120
 
106
121
  evaluated_params.append([p.clone() for p in model.parameters()])
107
122
 
@@ -27,7 +27,7 @@ def get_momentum(trial: optuna.Trial, prefix: str, conditional: bool=True) -> li
27
27
  m = NAG(beta, dampening, lerp)
28
28
  if debiased: m = Chain(m, Debias(beta1=beta))
29
29
  else:
30
- m = EMA(beta, dampening, debiased=debiased, lerp=lerp)
30
+ m = EMA(beta, dampening, debias=debiased, lerp=lerp)
31
31
  return [m]
32
32
  return []
33
33
 
torchzero/utils/params.py CHANGED
@@ -3,7 +3,7 @@ from collections.abc import Sequence, Iterable, Mapping
3
3
  import warnings
4
4
  import torch, numpy as np
5
5
 
6
-
6
+ from .torch_tools import set_storage_
7
7
 
8
8
  Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Mapping[str, Any]]
9
9
 
@@ -147,3 +147,15 @@ def _set_update_and_grad_(
147
147
 
148
148
  return param_groups
149
149
 
150
+
151
+ def _set_fake_params_(fake_params: Iterable[torch.Tensor], storage: Iterable[torch.Tensor]):
152
+ """sets ``fake_params`` storage to ``storage`` while they remain the same python object"""
153
+ for fake_p, s in zip(fake_params, storage):
154
+ fake_p.set_(s.view_as(s).requires_grad_()) # pyright: ignore[reportArgumentType]
155
+
156
+ def _empty_fake_param_storage_(fake_params: Iterable[torch.Tensor]):
157
+ """sets ``fake_params`` storage to empty while they remain the same python object"""
158
+ for p in fake_params:
159
+ set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
160
+
161
+
@@ -330,10 +330,10 @@ class TensorList(list[torch.Tensor | Any]):
330
330
 
331
331
  def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
332
332
  # return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
333
- if ord == 1: return self.global_sum()
334
- if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
335
333
  if ord == torch.inf: return self.abs().global_max()
336
334
  if ord == -torch.inf: return self.abs().global_min()
335
+ if ord == 1: return self.abs().global_sum()
336
+ if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
337
337
  if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
338
338
 
339
339
  return self.abs().pow_(ord).global_sum().pow(1/ord)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero