torchzero 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. torchzero/__init__.py +3 -1
  2. torchzero/_minimize/__init__.py +0 -0
  3. torchzero/_minimize/methods.py +95 -0
  4. torchzero/_minimize/minimize.py +518 -0
  5. torchzero/core/__init__.py +5 -5
  6. torchzero/core/chain.py +2 -1
  7. torchzero/core/functional.py +2 -1
  8. torchzero/core/module.py +75 -4
  9. torchzero/core/transform.py +6 -5
  10. torchzero/linalg/eigh.py +116 -68
  11. torchzero/linalg/linear_operator.py +1 -0
  12. torchzero/linalg/orthogonalize.py +60 -5
  13. torchzero/linalg/sketch.py +39 -0
  14. torchzero/modules/__init__.py +1 -0
  15. torchzero/modules/adaptive/adagrad.py +2 -0
  16. torchzero/modules/adaptive/adam.py +5 -1
  17. torchzero/modules/adaptive/adan.py +3 -0
  18. torchzero/modules/adaptive/ggt.py +20 -18
  19. torchzero/modules/adaptive/lion.py +3 -1
  20. torchzero/modules/adaptive/mars.py +6 -5
  21. torchzero/modules/adaptive/msam.py +3 -0
  22. torchzero/modules/adaptive/rmsprop.py +2 -0
  23. torchzero/modules/adaptive/rprop.py +9 -7
  24. torchzero/modules/adaptive/shampoo.py +9 -1
  25. torchzero/modules/adaptive/soap.py +32 -29
  26. torchzero/modules/basis/__init__.py +2 -0
  27. torchzero/modules/basis/ggt_basis.py +199 -0
  28. torchzero/modules/basis/soap_basis.py +254 -0
  29. torchzero/modules/clipping/ema_clipping.py +32 -27
  30. torchzero/modules/clipping/growth_clipping.py +1 -0
  31. torchzero/modules/experimental/__init__.py +1 -6
  32. torchzero/modules/experimental/coordinate_momentum.py +2 -0
  33. torchzero/modules/experimental/cubic_adam.py +4 -0
  34. torchzero/modules/grad_approximation/__init__.py +3 -2
  35. torchzero/modules/least_squares/gn.py +6 -0
  36. torchzero/modules/misc/gradient_accumulation.py +1 -0
  37. torchzero/modules/misc/misc.py +6 -0
  38. torchzero/modules/momentum/averaging.py +6 -0
  39. torchzero/modules/momentum/momentum.py +4 -0
  40. torchzero/modules/ops/__init__.py +0 -1
  41. torchzero/modules/ops/accumulate.py +4 -0
  42. torchzero/modules/ops/higher_level.py +6 -1
  43. torchzero/modules/second_order/inm.py +4 -0
  44. torchzero/modules/second_order/newton.py +11 -3
  45. torchzero/modules/second_order/newton_cg.py +7 -3
  46. torchzero/modules/second_order/nystrom.py +14 -19
  47. torchzero/modules/second_order/rsn.py +37 -6
  48. torchzero/modules/trust_region/trust_region.py +2 -1
  49. torchzero/utils/benchmarks/logistic.py +33 -18
  50. torchzero/utils/params.py +13 -1
  51. torchzero/utils/tensorlist.py +2 -2
  52. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/METADATA +1 -1
  53. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/RECORD +56 -53
  54. torchzero/modules/experimental/adanystrom.py +0 -258
  55. torchzero/modules/experimental/common_directions_whiten.py +0 -142
  56. torchzero/modules/experimental/eigen_sr1.py +0 -182
  57. torchzero/modules/experimental/eigengrad.py +0 -207
  58. /torchzero/modules/{experimental → grad_approximation}/spsa1.py +0 -0
  59. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/WHEEL +0 -0
  60. {torchzero-0.4.1.dist-info → torchzero-0.4.2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from ...core import Chainable, Transform, Objective, HessianMethod
7
7
  from ...utils import vec_to_tensors_
8
8
  from ...linalg.linear_operator import Dense, DenseWithInverse, Eigendecomposition
9
9
  from ...linalg import torch_linalg
10
+ from ...linalg.eigh import regularize_eigh
10
11
 
11
12
  def _try_lu_solve(H: torch.Tensor, g: torch.Tensor):
12
13
  try:
@@ -30,6 +31,8 @@ def _newton_update_state_(
30
31
  H: torch.Tensor,
31
32
  damping: float,
32
33
  eigval_fn: Callable | None,
34
+ eigv_tol: float | None,
35
+ truncate: int | None,
33
36
  precompute_inverse: bool,
34
37
  use_lstsq: bool,
35
38
  ):
@@ -39,10 +42,11 @@ def _newton_update_state_(
39
42
  reg = torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(damping)
40
43
  H += reg
41
44
 
42
- # if eigval_fn is given, we don't need H or H_inv, we store factors
43
- if eigval_fn is not None:
45
+ # if any args require eigendecomp, we don't need H or H_inv, we store factors
46
+ if any(i is not None for i in [eigval_fn, eigv_tol, truncate]):
44
47
  L, Q = torch_linalg.eigh(H, retry_float64=True)
45
- L = eigval_fn(L)
48
+ if eigval_fn is not None: L = eigval_fn(L)
49
+ L, Q = regularize_eigh(L, Q, truncate=truncate, tol=eigv_tol)
46
50
  state["L"] = L
47
51
  state["Q"] = Q
48
52
  return
@@ -216,6 +220,8 @@ class Newton(Transform):
216
220
  self,
217
221
  damping: float = 0,
218
222
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
223
+ eigv_tol: float | None = None,
224
+ truncate: int | None = None,
219
225
  update_freq: int = 1,
220
226
  precompute_inverse: bool | None = None,
221
227
  use_lstsq: bool = False,
@@ -242,6 +248,8 @@ class Newton(Transform):
242
248
  H=H,
243
249
  damping = fs["damping"],
244
250
  eigval_fn = fs["eigval_fn"],
251
+ eigv_tol = fs["eigv_tol"],
252
+ truncate = fs["truncate"],
245
253
  precompute_inverse = precompute_inverse,
246
254
  use_lstsq = fs["use_lstsq"]
247
255
  )
@@ -226,7 +226,8 @@ class NewtonCGSteihaug(Transform):
226
226
  tol: float = 1e-8,
227
227
  reg: float = 1e-8,
228
228
  solver: Literal['cg', "minres"] = 'cg',
229
- adapt_tol: bool = True,
229
+ adapt_tol: bool = False,
230
+ terminate_on_tr: bool = True,
230
231
  npc_terminate: bool = False,
231
232
 
232
233
  # hvp settings
@@ -272,7 +273,6 @@ class NewtonCGSteihaug(Transform):
272
273
  npc_terminate=fs["npc_terminate"]
273
274
  miniter=fs["miniter"]
274
275
  max_history=fs["max_history"]
275
- adapt_tol=fs["adapt_tol"]
276
276
 
277
277
 
278
278
  # ------------------------------- trust region ------------------------------- #
@@ -294,9 +294,13 @@ class NewtonCGSteihaug(Transform):
294
294
  finfo = torch.finfo(orig_params[0].dtype)
295
295
  if trust_radius < finfo.tiny * 2:
296
296
  trust_radius = self.global_state['trust_radius'] = init
297
- if adapt_tol:
297
+
298
+ if fs["adapt_tol"]:
298
299
  self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
299
300
 
301
+ if fs["terminate_on_tr"]:
302
+ objective.should_terminate = True
303
+
300
304
  elif trust_radius > finfo.max / 2:
301
305
  trust_radius = self.global_state['trust_radius'] = init
302
306
 
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable, Transform, HVPMethod
7
7
  from ...utils import TensorList, vec_to_tensors
8
- from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod
8
+ from ...linalg import nystrom_pcg, nystrom_sketch_and_solve, nystrom_approximation, cg, regularize_eigh, OrthogonalizeMethod, orthogonalize
9
9
  from ...linalg.linear_operator import Eigendecomposition, ScaledIdentity
10
10
 
11
11
  class NystromSketchAndSolve(Transform):
@@ -75,7 +75,7 @@ class NystromSketchAndSolve(Transform):
75
75
  """
76
76
  def __init__(
77
77
  self,
78
- rank: int,
78
+ rank: int = 100,
79
79
  reg: float | None = 1e-2,
80
80
  eigv_tol: float = 0,
81
81
  truncate: int | None = None,
@@ -109,17 +109,15 @@ class NystromSketchAndSolve(Transform):
109
109
 
110
110
  generator = self.get_generator(params[0].device, seed=fs['seed'])
111
111
  try:
112
+ Omega = torch.randn([ndim, min(fs["rank"], ndim)], device=device, dtype=dtype, generator=generator)
113
+ Omega = orthogonalize(Omega, fs["orthogonalize_method"])
114
+ HOmega = H_mm(Omega)
115
+
112
116
  # compute the approximation
113
117
  L, Q = nystrom_approximation(
114
- A_mv=H_mv,
115
- A_mm=H_mm,
116
- ndim=ndim,
117
- rank=min(fs["rank"], ndim),
118
+ Omega=Omega,
119
+ AOmega=HOmega,
118
120
  eigv_tol=fs["eigv_tol"],
119
- orthogonalize_method=fs["orthogonalize_method"],
120
- dtype=dtype,
121
- device=device,
122
- generator=generator,
123
121
  )
124
122
 
125
123
  # regularize
@@ -225,7 +223,7 @@ class NystromPCG(Transform):
225
223
  """
226
224
  def __init__(
227
225
  self,
228
- rank: int,
226
+ rank: int = 100,
229
227
  maxiter=None,
230
228
  tol=1e-8,
231
229
  reg: float = 1e-6,
@@ -260,16 +258,13 @@ class NystromPCG(Transform):
260
258
  generator = self.get_generator(device, seed=fs['seed'])
261
259
 
262
260
  try:
261
+ Omega = torch.randn(ndim, min(fs["rank"], ndim), device=device, dtype=dtype, generator=generator)
262
+ HOmega = H_mm(orthogonalize(Omega, fs["orthogonalize_method"]))
263
+ # compute the approximation
263
264
  L, Q = nystrom_approximation(
264
- A_mv=None,
265
- A_mm=H_mm,
266
- ndim=ndim,
267
- rank=min(fs["rank"], ndim),
265
+ Omega=Omega,
266
+ AOmega=HOmega,
268
267
  eigv_tol=fs["eigv_tol"],
269
- orthogonalize_method=fs["orthogonalize_method"],
270
- dtype=dtype,
271
- device=device,
272
- generator=generator,
273
268
  )
274
269
 
275
270
  self.global_state["L"] = L
@@ -25,9 +25,23 @@ def _orthonormal_sketch(m, n, dtype, device, generator):
25
25
  return _qr_orthonormalize(torch.randn(m, n, dtype=dtype, device=device, generator=generator))
26
26
 
27
27
  def _rademacher_sketch(m, n, dtype, device, generator):
28
- rademacher = torch.bernoulli(torch.full((m,n), 0.5), generator = generator).mul_(2).sub_(1)
28
+ rademacher = torch.bernoulli(torch.full((m,n), 0.5, device=device, dtype=dtype), generator = generator).mul_(2).sub_(1)
29
29
  return rademacher.mul_(1 / math.sqrt(m))
30
30
 
31
+ def _row_sketch(m, n, dtype, device, generator):
32
+ weights = torch.ones(m, dtype=dtype, device=device)
33
+ indices = torch.multinomial(weights, n, replacement=False, generator=generator)
34
+
35
+ P = torch.zeros(m, n, dtype=dtype, device=device)
36
+ P[indices, range(n)] = 1
37
+ return P
38
+
39
+ def _topk_rows(grad, m, n, dtype, device, generator):
40
+ _, indices = torch.topk(grad.abs(), n)
41
+ P = torch.zeros(m, n, dtype=dtype, device=device)
42
+ P[indices, range(n)] = 1
43
+ return P
44
+
31
45
  class SubspaceNewton(Transform):
32
46
  """Subspace Newton. Performs a Newton step in a subspace (random or spanned by past gradients).
33
47
 
@@ -37,7 +51,9 @@ class SubspaceNewton(Transform):
37
51
  sketch_type (str, optional):
38
52
  - "common_directions" - uses history steepest descent directions as the basis[2]. It is orthonormalized on-line using Gram-Schmidt (default).
39
53
  - "orthonormal" - random orthonormal basis. Orthonormality is necessary to use linear operator based modules such as trust region, but it can be slower to compute.
40
- - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis. It is recommended to use at least "orthonormal" - it requires QR but it is still very cheap.
54
+ - "rows" - samples random rows.
55
+ - "topk" - samples top-rank rows with largest gradient magnitude.
56
+ - "rademacher" - approximately orthonormal (if dimension is large) scaled random rademacher basis.
41
57
  - "mixed" - random orthonormal basis but with four directions set to gradient, slow and fast gradient EMAs, and previous update direction.
42
58
  damping (float, optional): hessian damping (scale of identity matrix added to hessian). Defaults to 0.
43
59
  hvp_method (str, optional):
@@ -93,13 +109,15 @@ class SubspaceNewton(Transform):
93
109
 
94
110
  def __init__(
95
111
  self,
96
- sketch_size: int,
97
- sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher"] = "common_directions",
112
+ sketch_size: int = 100,
113
+ sketch_type: Literal["orthonormal", "common_directions", "mixed", "rademacher", "rows", "topk"] = "common_directions",
98
114
  damping:float=0,
99
115
  eigval_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
116
+ eigv_tol: float | None = None,
117
+ truncate: int | None = None,
100
118
  update_freq: int = 1,
101
119
  precompute_inverse: bool = False,
102
- use_lstsq: bool = True,
120
+ use_lstsq: bool = False,
103
121
  hvp_method: HVPMethod = "batched_autograd",
104
122
  h: float = 1e-2,
105
123
  seed: int | None = None,
@@ -131,6 +149,14 @@ class SubspaceNewton(Transform):
131
149
  elif sketch_type == 'orthonormal':
132
150
  S = _orthonormal_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
133
151
 
152
+ elif sketch_type == "rows":
153
+ S = _row_sketch(ndim, sketch_size, device=device, dtype=dtype, generator=generator)
154
+
155
+ elif sketch_type == "topk":
156
+ g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
157
+ g = torch.cat([t.ravel() for t in g_list])
158
+ S = _topk_rows(g, ndim, sketch_size, device=device, dtype=dtype, generator=generator)
159
+
134
160
  elif sketch_type == 'common_directions':
135
161
  # Wang, Po-Wei, Ching-pei Lee, and Chih-Jen Lin. "The common-directions method for regularized empirical risk minimization." Journal of Machine Learning Research 20.58 (2019): 1-49.
136
162
  g_list = objective.get_grads(create_graph=hvp_method in ("batched_autograd", "autograd"))
@@ -189,6 +215,10 @@ class SubspaceNewton(Transform):
189
215
  else:
190
216
  raise ValueError(f'Unknown sketch_type {sketch_type}')
191
217
 
218
+ # print(f'{S.shape = }')
219
+ # I = torch.eye(S.size(1), device=S.device, dtype=S.dtype)
220
+ # print(f'{torch.nn.functional.mse_loss(S.T @ S, I) = }')
221
+
192
222
  # form sketched hessian
193
223
  HS, _ = objective.hessian_matrix_product(S, rgrad=None, at_x0=True,
194
224
  hvp_method=fs["hvp_method"], h=fs["h"])
@@ -200,9 +230,10 @@ class SubspaceNewton(Transform):
200
230
  H = H_sketched,
201
231
  damping = fs["damping"],
202
232
  eigval_fn = fs["eigval_fn"],
233
+ eigv_tol = fs["eigv_tol"],
234
+ truncate = fs["truncate"],
203
235
  precompute_inverse = fs["precompute_inverse"],
204
236
  use_lstsq = fs["use_lstsq"]
205
-
206
237
  )
207
238
 
208
239
  self.global_state["S"] = S
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
- from collections.abc import Callable, Mapping, Sequence
4
+ from collections.abc import Callable, Mapping, Sequence, MutableMapping
5
5
  from functools import partial
6
6
  from typing import Any, Literal, Protocol, cast, final, overload
7
7
 
@@ -203,6 +203,7 @@ def fixed_radius(
203
203
  ) -> tuple[float, bool]:
204
204
  return init, True
205
205
 
206
+
206
207
  _RADIUS_KEYS = Literal['default', 'fixed']
207
208
  _RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
208
209
  "default": default_radius,
@@ -5,39 +5,54 @@ import numpy as np
5
5
  import torch
6
6
  import tqdm
7
7
 
8
-
9
- def generate_correlated_logistic_data(n_samples=2000, n_features=32, n_correlated_pairs=512, correlation=0.99, seed=0):
10
- """Hard logistic regression dataset with correlated features"""
8
+ def generate_correlated_logistic_data(
9
+ n_samples=100_000,
10
+ n_features=32,
11
+ n_classes=10,
12
+ n_correlated=768,
13
+ correlation=0.99,
14
+ seed=0
15
+ ) -> tuple[np.ndarray, np.ndarray]:
16
+ assert n_classes >= 2
11
17
  generator = np.random.default_rng(seed)
12
18
 
13
- # ------------------------------------- X ------------------------------------ #
14
19
  X = generator.standard_normal(size=(n_samples, n_features))
15
- weights = generator.uniform(-2, 2, n_features)
20
+ weights = generator.uniform(-2, 2, size=(n_features, n_classes))
21
+
22
+ used_pairs = set()
23
+ n_correlated = min(n_correlated, n_features * (n_features - 1) // 2)
16
24
 
17
- used_pairs = []
18
- for i in range(n_correlated_pairs):
25
+ for _ in range(n_correlated):
19
26
  idxs = None
20
27
  while idxs is None or idxs in used_pairs:
21
- idxs = tuple(generator.choice(n_features, size=2, replace=False).tolist())
28
+ pair = generator.choice(n_features, size=2, replace=False)
29
+ pair.sort()
30
+ idxs = tuple(pair)
22
31
 
23
- used_pairs.append(idxs)
32
+ used_pairs.add(idxs)
24
33
  idx1, idx2 = idxs
25
34
 
26
35
  noise = generator.standard_normal(n_samples) * np.sqrt(1 - correlation**2)
27
36
  X[:, idx2] = correlation * X[:, idx1] + noise
28
37
 
29
38
  w = generator.integers(1, 51)
30
- weights[idx1] = w
31
- weights[idx2] = -w
39
+ cls = generator.integers(0, n_classes)
40
+ weights[idx1, cls] = w
41
+ weights[idx2, cls] = -w
32
42
 
33
- # ---------------------------------- logits ---------------------------------- #
34
43
  logits = X @ weights
35
- probabilities = 1 / (1 + np.exp(-logits))
36
- y = generator.binomial(1, probabilities).astype(np.float32)
37
44
 
38
- X = X - X.mean(0, keepdims=True)
39
- X = X / X.std(0, keepdims=True)
40
- return X, y
45
+ logits -= logits.max(axis=1, keepdims=True)
46
+ exp_logits = np.exp(logits)
47
+ probabilities = exp_logits / exp_logits.sum(axis=1, keepdims=True)
48
+
49
+ y_one_hot = generator.multinomial(1, pvals=probabilities)
50
+ y = np.argmax(y_one_hot, axis=1)
51
+
52
+ X -= X.mean(0, keepdims=True)
53
+ X /= X.std(0, keepdims=True)
54
+
55
+ return X, y.astype(np.int64)
41
56
 
42
57
 
43
58
  # if __name__ == '__main__':
@@ -101,7 +116,7 @@ def run_logistic_regression(X: torch.Tensor, y: torch.Tensor, opt_fn, max_steps:
101
116
  # this is for tests
102
117
  if _assert_on_evaluated_same_params:
103
118
  for p in evaluated_params:
104
- assert not _tensorlist_equal(p, model.parameters()), f"evaluated same parameters on epoch {epoch}"
119
+ assert not _tensorlist_equal(p, model.parameters()), f"{optimizer} evaluated same parameters on epoch {epoch}"
105
120
 
106
121
  evaluated_params.append([p.clone() for p in model.parameters()])
107
122
 
torchzero/utils/params.py CHANGED
@@ -3,7 +3,7 @@ from collections.abc import Sequence, Iterable, Mapping
3
3
  import warnings
4
4
  import torch, numpy as np
5
5
 
6
-
6
+ from .torch_tools import set_storage_
7
7
 
8
8
  Params = Iterable[torch.Tensor | tuple[str, torch.Tensor] | Mapping[str, Any]]
9
9
 
@@ -147,3 +147,15 @@ def _set_update_and_grad_(
147
147
 
148
148
  return param_groups
149
149
 
150
+
151
+ def _set_fake_params_(fake_params: Iterable[torch.Tensor], storage: Iterable[torch.Tensor]):
152
+ """sets ``fake_params`` storage to ``storage`` while they remain the same python object"""
153
+ for fake_p, s in zip(fake_params, storage):
154
+ fake_p.set_(s.view_as(s).requires_grad_()) # pyright: ignore[reportArgumentType]
155
+
156
+ def _empty_fake_param_storage_(fake_params: Iterable[torch.Tensor]):
157
+ """sets ``fake_params`` storage to empty while they remain the same python object"""
158
+ for p in fake_params:
159
+ set_storage_(p, torch.empty(0, device=p.device, dtype=p.dtype))
160
+
161
+
@@ -330,10 +330,10 @@ class TensorList(list[torch.Tensor | Any]):
330
330
 
331
331
  def global_vector_norm(self, ord:float = 2) -> torch.Tensor:
332
332
  # return torch.linalg.vector_norm(self.to_vec(), ord = ord) # pylint:disable = not-callable
333
- if ord == 1: return self.global_sum()
334
- if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
335
333
  if ord == torch.inf: return self.abs().global_max()
336
334
  if ord == -torch.inf: return self.abs().global_min()
335
+ if ord == 1: return self.abs().global_sum()
336
+ if ord % 2 == 0: return self.pow(ord).global_sum().pow(1/ord)
337
337
  if ord == 0: return (self != 0).global_sum().to(self[0].dtype)
338
338
 
339
339
  return self.abs().pow_(ord).global_sum().pow(1/ord)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero
@@ -5,50 +5,54 @@ tests/test_objective.py,sha256=HY0rK0z6PpiXvEsCu4mLgTlSVKusnT69S2GbuVcwMRo,7119
5
5
  tests/test_opts.py,sha256=hw7CCw7FD_RJSdiSacyXUSM7DI-_RfP8wJlsz079SNw,44263
6
6
  tests/test_tensorlist.py,sha256=B0Tq4_r-1DOYpS360X7IsLQiWn5fukhIMDKZM6zVO2Y,72164
7
7
  tests/test_utils_optimizer.py,sha256=_JoMqvXXZ6TxugS_CmfmP55Vvp0XrSPCjSz2nJJmaoI,8399
8
- torchzero/__init__.py,sha256=nit4KxrRoW6hJDGOy0jkphuawY5gAvPqrYY11Yct6fA,133
9
- torchzero/core/__init__.py,sha256=h9Ck7XX2XuJUTojU2IMa_2TprXZHbgo748txa3z7-2o,341
10
- torchzero/core/chain.py,sha256=dtFpxnw8vcbi3EeAANXyPtUmyPyv_VuZrTiPlLRmh7c,1899
11
- torchzero/core/functional.py,sha256=TSygtyQHDhqf998--hF48yIFr-y3Ycz8arjjR8x1ILU,3156
8
+ torchzero/__init__.py,sha256=SZLJgf_sjHyqtTzz0f70AtHP_V_WloX1KQF8mm34zdg,175
9
+ torchzero/_minimize/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ torchzero/_minimize/methods.py,sha256=1oktoSdWtiA0JEF34yTkY3_nPB5Q5ODHl18C0mcglNw,2445
11
+ torchzero/_minimize/minimize.py,sha256=JJBmREQvhDxyqGM62xharsuebyefxRADkd6Bg_TE-DQ,17236
12
+ torchzero/core/__init__.py,sha256=lufcll5r98gTjVfQSvz6-wfI0qMAgZtLLSByHuHTats,358
13
+ torchzero/core/chain.py,sha256=-6vW-L5pzg2Rwpq3LKIAoqJGPvCkHKjt_B1boGikQmM,1900
14
+ torchzero/core/functional.py,sha256=D125Hso8fHMSKlyhkir3GGJzXxuIitXmVhKn2Y9x-Ck,3272
12
15
  torchzero/core/modular.py,sha256=Xpp6jfiKArC3Q42G63I9qj3eWcYt-l7d-EIm-59ADcI,9584
13
- torchzero/core/module.py,sha256=HfbPfxXxgyBf9wQl7Fpw6B6Ux6UYfvPEmITC64ozb_Q,18012
16
+ torchzero/core/module.py,sha256=DKGLwLWm9LkOBYZHW9QBoXo9eBgnYz7nmoCXJ0gl0e0,21210
14
17
  torchzero/core/objective.py,sha256=kEIlry7Bxf_zDUoqAIKUTRvvJmCEpn0Ad2crNt18GCc,40005
15
18
  torchzero/core/reformulation.py,sha256=UyAS_xq5sy_mMpmkvtwpHrZHd6Y2RgyPwN0zZlyxFTI,3857
16
- torchzero/core/transform.py,sha256=aJRBtvYjKqD-Ic_AkzeSINYDsTaBAErA-kocEl3PHZw,12244
19
+ torchzero/core/transform.py,sha256=WlHoc5cCY1vXQrwMsIG0g3Kle93kBSbrBfxGz5X9_0Q,12251
17
20
  torchzero/linalg/__init__.py,sha256=wlry3dbncdsySKk6sSdiRefTcc8dIh4DcA0wFyU1MC8,407
18
21
  torchzero/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
19
- torchzero/linalg/eigh.py,sha256=YC8x5NEWWsnc3suCebnTfeb4lVMhy-H8LGOZbGnwd8A,7902
22
+ torchzero/linalg/eigh.py,sha256=l1fX_7hL-DFk8gu20-NuSKDJcRpz58KxUKQHeBhCcHE,9035
20
23
  torchzero/linalg/linalg_utils.py,sha256=1RBpQevb25_hbRieONw6CgnoWOJFXXv8oWTMugEivEk,385
21
- torchzero/linalg/linear_operator.py,sha256=mVEOvu6yY7TYhUdmZm1IAc6_pWnTaykKDgZu_-J-atk,16653
24
+ torchzero/linalg/linear_operator.py,sha256=MWTY7DS8B8IkR28kVA9nmoM-OU-1eBsP22iYXkDrj9A,16654
22
25
  torchzero/linalg/matrix_power.py,sha256=gEWGvh3atc7745dwNcxNg0RtUrVgeKD6KxyRckKkkdQ,1255
23
- torchzero/linalg/orthogonalize.py,sha256=Fv6zv1JvS9AVwjiMVed55J8-pEbVZv7vqoEo5g0Zrv0,3270
26
+ torchzero/linalg/orthogonalize.py,sha256=GSvDZA9evTpu3obqCkEocgpDp_91sRexoAwH2q0zTEY,5345
24
27
  torchzero/linalg/qr.py,sha256=KykXhSlye0vhyP5JjX6pkPnheHKLLbAKmDff8Hogxyo,2857
28
+ torchzero/linalg/sketch.py,sha256=dKD9t7I7stv089cCvZyPAOZ0D9wzVG1TmV3297w0tk4,1261
25
29
  torchzero/linalg/solve.py,sha256=kING1WCioof8_EKgHeyr53dlft_9KtlJnwOWega3DnA,14355
26
30
  torchzero/linalg/svd.py,sha256=jmunSxM-twR5VCUGI_gmV3j7QxMJIe1aBoBlJf5i2fo,1432
27
31
  torchzero/linalg/torch_linalg.py,sha256=brhMXUZyYuxuEV-FyQerep1iL7auutW5kmgJpOzUROw,6001
28
- torchzero/modules/__init__.py,sha256=dsOalCw-OVkD8rhpQdcODc3Hsd_sQ2_2xVC-J8mlSuk,632
32
+ torchzero/modules/__init__.py,sha256=ZN20E2ES6zDf5DuFbZpuCKFinFc5eGR1h00iYZ_XBGU,652
29
33
  torchzero/modules/opt_utils.py,sha256=aj7xqHmeze4izxG9k3L6ziG-K_yj8n8fkFpIv-X8V78,8141
30
34
  torchzero/modules/adaptive/__init__.py,sha256=X8w2Dal3k0WpLQN-WolnWBBgUyIiZF5RnqBlN0dcAYw,1081
31
- torchzero/modules/adaptive/adagrad.py,sha256=hMT-Al-vtD6tzPUpQ79LCNko97D7rJN5ji9JOfBqR3k,12015
35
+ torchzero/modules/adaptive/adagrad.py,sha256=NDwmUZaEk0lWnbgYxN23yTWK5A5dQ9BtoKzRTFSKozY,12131
32
36
  torchzero/modules/adaptive/adahessian.py,sha256=ucf8loS_lU9VjCb_M42WwXESjPJ_KFChLGkIMFWXO5o,8734
33
- torchzero/modules/adaptive/adam.py,sha256=Okm7Sc9fMArQAZ7Ph4Etq68uL-IXKY4YNqHWpTzPoTY,3767
34
- torchzero/modules/adaptive/adan.py,sha256=965tBUwKy6uDiY2la6fVcGcsvGMs90Zg-ZHPtozJGe4,4110
37
+ torchzero/modules/adaptive/adam.py,sha256=RDHYyIAJdi1Pxny8HOHiCFgvPztNwlJlCtzE_ZE-138,3896
38
+ torchzero/modules/adaptive/adan.py,sha256=tmQHiJ5MNwOGP3fp479goHh0xXlhnzULhHxKcVZOkvM,4219
35
39
  torchzero/modules/adaptive/adaptive_heavyball.py,sha256=iDiZqke6z6FOR9mhoHMLMm7jvxjzHIQANTe0FBwNj1Q,2230
36
40
  torchzero/modules/adaptive/aegd.py,sha256=WLN6vvbSRhQ1P753M3bx_becSF-3cTbu37nhz3NvdGM,1903
37
41
  torchzero/modules/adaptive/esgd.py,sha256=gnah-7zk_fMsn7yIWivqDgnaaSdDFXpxg33ywF6TMZg,6173
38
- torchzero/modules/adaptive/ggt.py,sha256=eYCeV3GArdLv9WuWeim0V3CHJYl3FVKtrtsGshkqwWg,6608
39
- torchzero/modules/adaptive/lion.py,sha256=H3aI2qnrMtmkvXcoddzjjxdkoD5cq_QwIkLmd_bVPso,1085
42
+ torchzero/modules/adaptive/ggt.py,sha256=7G0Hh8lWy4o73VYVHcZ1JJyDqeKcXi2Y6Qp3qIHosOY,6512
43
+ torchzero/modules/adaptive/lion.py,sha256=yeKUt3WIITtWx97IQzudgbdai77MCfnL_cu90vRkTmA,1141
40
44
  torchzero/modules/adaptive/lre_optimizers.py,sha256=AwWUIwnBrozR2HFYLfJnMCBHAWWMKzkS63xFKstRgc0,9760
41
- torchzero/modules/adaptive/mars.py,sha256=w-cK-1tFuR74SY01xS5jsg1b9qs3l8eOptGrUyQ2m80,2261
45
+ torchzero/modules/adaptive/mars.py,sha256=WquKzTnCZcxzslcvSBMFJVz_kjuCuAzlesw1bHnKqOg,2325
42
46
  torchzero/modules/adaptive/matrix_momentum.py,sha256=YefF2k746ke7qiiabdhCPCUFB1_fRddAfGCyIOwV3Ok,6789
43
- torchzero/modules/adaptive/msam.py,sha256=nqwjuhBMX2UO-omUIeOcD5ti6PIKfKs-RVCn7ourkKA,6946
47
+ torchzero/modules/adaptive/msam.py,sha256=cHfdNkk3Joy2aENwUZXGf3N0P7zcxYGKuySf699OTfM,7051
44
48
  torchzero/modules/adaptive/muon.py,sha256=jQ6jlfM4vVRidGJ7FrLtgPnZeuIfW_zU72o7LvOKqh8,8023
45
49
  torchzero/modules/adaptive/natural_gradient.py,sha256=8UzacvvIMbYVVE2q0HQ9DLLHYlm1eu6cAiRsOv5XRzQ,7078
46
50
  torchzero/modules/adaptive/orthograd.py,sha256=0u2sfGZJjlJItLX2WRP5fLAD8Wd9SgJzQYAUpARJ64A,1813
47
- torchzero/modules/adaptive/rmsprop.py,sha256=qWVkRmUQ3dui9yBVYtAEll7OlXZDKNT_m70FakTOrTY,4529
48
- torchzero/modules/adaptive/rprop.py,sha256=a4_UkWse5u2JFAEIlxQqDBUwvUfxh1kNs2ZIhtccnWE,11540
51
+ torchzero/modules/adaptive/rmsprop.py,sha256=sb709Smpkm8H3vYOsh7BzWni5hAf3nBQevhagyOt7mo,4655
52
+ torchzero/modules/adaptive/rprop.py,sha256=vw-Rufa-gpHgq1gDarmNQexrFr13lPLq_mj3c3pNB_Q,11593
49
53
  torchzero/modules/adaptive/sam.py,sha256=CTMCqaH9s5EmKQyj1GpqSeTO1weyfsNWPYFN1xaSm_o,5709
50
- torchzero/modules/adaptive/shampoo.py,sha256=C_Mo7UFQtDxW4McWJjT731FNAp3g9MqF0Hka54Yi3xQ,9847
51
- torchzero/modules/adaptive/soap.py,sha256=hz2N6-jUSWU93RNViIS1c-Ue2uKmQx6BxyYg6mEa2fo,12408
54
+ torchzero/modules/adaptive/shampoo.py,sha256=1WpjroFS37HmDLV51iK4d8vtnJWFrGCsDkoQav0p47E,10048
55
+ torchzero/modules/adaptive/soap.py,sha256=jyS6F2o4bMKzMU8H2dDggFQEqMqw4W1rX78u8p3uaV4,12619
52
56
  torchzero/modules/adaptive/sophia_h.py,sha256=O_izgGlUgUlpH3Oi5PdCKTyxus4yO1PaJUFhGXuGG9k,7063
53
57
  torchzero/modules/adaptive/psgd/__init__.py,sha256=g73mAkWEutwU6jzjiwdbYk5Yxgs4i6QVWefFKkm8cDw,223
54
58
  torchzero/modules/adaptive/psgd/_psgd_utils.py,sha256=YtwbUKyVWITZPmpwCBJBC42XQP9HcxNx_znEaIv3hsI,1096
@@ -58,21 +62,20 @@ torchzero/modules/adaptive/psgd/psgd_kron_newton.py,sha256=oH-oI1pvbR-z6H6ma1O2G
58
62
  torchzero/modules/adaptive/psgd/psgd_kron_whiten.py,sha256=vmhkY6cKaRE5qzy_4tUkIJp6qC3L6ESZMuiU_ih5tR4,7299
59
63
  torchzero/modules/adaptive/psgd/psgd_lra_newton.py,sha256=JL8JmqHgcFqfkX7VeD3sRvNj0xeCuDTHxjNyQ_HigBw,4709
60
64
  torchzero/modules/adaptive/psgd/psgd_lra_whiten.py,sha256=SaNYtE4_2tV29CbVaTHi8A6RxmhoMaucF5NoMRg6QaA,4197
65
+ torchzero/modules/basis/__init__.py,sha256=MeXoykwqqmWt-Gx8YWMycVL7m5N4j7Ob_L0GbcwLOfM,65
66
+ torchzero/modules/basis/ggt_basis.py,sha256=wVNFN-9a0xGszudMDi_04mqPSschACF7kiftLkMyqYA,7749
67
+ torchzero/modules/basis/soap_basis.py,sha256=pwlxIa9lW9V1NcLPmhm--LVbyq7ALSfkV_4b6ki1hO8,10479
61
68
  torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
62
69
  torchzero/modules/clipping/clipping.py,sha256=C2dMt0rpuiLMsKq2EWi8qhISSxfCU0nKKGgjWEk2Yxc,14198
63
- torchzero/modules/clipping/ema_clipping.py,sha256=D4NgXzXYMjK_SKQU3rVoOKzaCd9igGQg_7sXiGMgMqI,6750
64
- torchzero/modules/clipping/growth_clipping.py,sha256=I1nk5xXBjk0BzWYzMC58LZHouY44myZNIUjM-duv7zc,6508
70
+ torchzero/modules/clipping/ema_clipping.py,sha256=7lFkQWVkchxlZynYXS4JDjhxB8T5tbE0qsP3GXK6mrA,6916
71
+ torchzero/modules/clipping/growth_clipping.py,sha256=VAmUUeIsSGWrGmZiFAngWUBBsxj4d0QAMf36oAMZL8A,6556
65
72
  torchzero/modules/conjugate_gradient/__init__.py,sha256=G5WcVoiQYupRBeqjI4lCraGeXNSvWT-_-ynpcE6NQS8,184
66
73
  torchzero/modules/conjugate_gradient/cg.py,sha256=fcmP77_v_RPpb0sDV2B_90FvFY67FdJt54KHdccY5YU,14540
67
- torchzero/modules/experimental/__init__.py,sha256=YbBrWu2vkXHiBcDXmim-Yte4ZxfmQCs_0fCeIArvtnM,942
68
- torchzero/modules/experimental/adanystrom.py,sha256=fUWPxxi1aJhWme_d31dBG0XxEZY1hJr6AEiFHdFDxCQ,8970
69
- torchzero/modules/experimental/common_directions_whiten.py,sha256=R_1fQKlvMD99oFrflJLgxl6ObV8jyPc7-NxAUFQeoYA,4941
70
- torchzero/modules/experimental/coordinate_momentum.py,sha256=HzKy8X5qEvud-xKHJYHpzH6ObxzvYcMcdgodsCw4Bbk,1099
71
- torchzero/modules/experimental/cubic_adam.py,sha256=RhcHajUfUAcXZDks0X0doR18YtMItQYPmxuEihud4bo,5137
74
+ torchzero/modules/experimental/__init__.py,sha256=najUDh01Av6gEeMYRV9X9lWAr4ZrC6ZgJcPtNpon7ZQ,734
75
+ torchzero/modules/experimental/coordinate_momentum.py,sha256=4BMmgooPysYlX7QOaTUjBn6MNfBAMujM5TCm72vSexw,1152
76
+ torchzero/modules/experimental/cubic_adam.py,sha256=97sgbtkqG1ziXOMxlCor-L-UzzqgSumz8shVOgYL4oQ,5303
72
77
  torchzero/modules/experimental/curveball.py,sha256=beHGD1Wh9GxYqMBh1k9Ru6TG3U9eZR6_l8ZUQcZzYxw,2765
73
78
  torchzero/modules/experimental/dct.py,sha256=CW-Y2gcjlHlxtIx7SekUOfw2EzujA6v0LcjDYGAfh6M,2433
74
- torchzero/modules/experimental/eigen_sr1.py,sha256=rCcWVplTWQh91xpgDap35CGEex41C19irUfDlq9lviU,6865
75
- torchzero/modules/experimental/eigengrad.py,sha256=UPuyo-OmCmu3XLAPclIfsnMN4qcNwX83m7S_55syukA,8455
76
79
  torchzero/modules/experimental/fft.py,sha256=s95EzvK4-ZJdwZbVhtqwirY9eVy7v6mFDRMgoLY9wjo,3020
77
80
  torchzero/modules/experimental/gradmin.py,sha256=LajM0GU1fB6PsGDg8k0KjKI73RvyZYqPvzcdoVYDq-c,3752
78
81
  torchzero/modules/experimental/higher_order_newton.py,sha256=qLSCbkmd7dw0lAhOJGpvvOesZfCMNt2Vz_mc7HknCMQ,12131
@@ -82,15 +85,15 @@ torchzero/modules/experimental/newton_solver.py,sha256=aHZh8EA-QQop3iGz7Ge37KTNg
82
85
  torchzero/modules/experimental/newtonnewton.py,sha256=TYUuQwHu8bom08czU9lP7MQq5qFBq_JYZTH_Wmm4g-o,3269
83
86
  torchzero/modules/experimental/reduce_outward_lr.py,sha256=ehctg5zLEOHPfiQQUq5ShMj3pDhtxqdNUEneMR9l7Bs,1275
84
87
  torchzero/modules/experimental/scipy_newton_cg.py,sha256=psllNtDwUbkVAXBDKwWEueatOmDNPFy-pMwBkqF3_r0,3902
85
- torchzero/modules/experimental/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
86
88
  torchzero/modules/experimental/structural_projections.py,sha256=IwpgibNDO0slzMyi6djQXRhQO6IagUgUUCr_-7US1IE,4104
87
- torchzero/modules/grad_approximation/__init__.py,sha256=_mQ2sWvnMfqc3RQcVmZuBlphtLZCO7z819abGY6kYuM,196
89
+ torchzero/modules/grad_approximation/__init__.py,sha256=BAFXc73_ORySVDyXiyZxpusXWn7K66KFT9LZEMwVKes,221
88
90
  torchzero/modules/grad_approximation/fdm.py,sha256=hq7U8UkzCfc7z0J1ZmZo9xOLzHHY0uRjebcwZQrBCzA,4376
89
91
  torchzero/modules/grad_approximation/forward_gradient.py,sha256=7fKZoKetYzgD85L3W0x1oG56SdWHj5MDWwmWpV7bpr4,3949
90
92
  torchzero/modules/grad_approximation/grad_approximator.py,sha256=hX4nqa0yw1OkA2UKmzZ3HhvMfL0Wwv1yQePxrgAueS8,4782
91
93
  torchzero/modules/grad_approximation/rfdm.py,sha256=-5zqMB98YNNa1aQXXtf6UNGSJxySO7mn1NksWyPzp3o,19607
94
+ torchzero/modules/grad_approximation/spsa1.py,sha256=DiQ_nHAC8gnqoNNK7oe6djOiwpwvI5aPtpKA43F7jrQ,3607
92
95
  torchzero/modules/least_squares/__init__.py,sha256=mJwE2IXVB3mn_7BzsmDNKhfyViCV8GOrqHJJjz04HR4,41
93
- torchzero/modules/least_squares/gn.py,sha256=3RQ_7e35Ql9uVUUPi34nef9eQNeZ09fldi964V61Tgg,7889
96
+ torchzero/modules/least_squares/gn.py,sha256=hufsWNq_UdEPFDFKNGgCiM4R9739Xu8JqYWSwKkdSZ8,8087
94
97
  torchzero/modules/line_search/__init__.py,sha256=_QjxUJmNC8OqtUuyTJp9wDfHNFKZBZqj6lttWKhG-cI,217
95
98
  torchzero/modules/line_search/_polyinterp.py,sha256=i3sNl6SFAUJi4oxhhjBlcxJY9KRunIZjJ8sGdaJOVjc,10990
96
99
  torchzero/modules/line_search/adaptive.py,sha256=YNabP6-01dhAUDAOuHRPZCwiV5xTRdHmkN667HQ6V3w,3798
@@ -102,21 +105,21 @@ torchzero/modules/line_search/strong_wolfe.py,sha256=9jGjxebuXHbl8wEFpvV0s4mMX4J
102
105
  torchzero/modules/misc/__init__.py,sha256=UYY9CeNepnC8H1LnFa829ux5MEjtGZ9zql624IbCFX8,825
103
106
  torchzero/modules/misc/debug.py,sha256=wFt9wB6IdRSsOGLhQjdjmGt4KdB0V5IT0iBFMj97R3Y,1617
104
107
  torchzero/modules/misc/escape.py,sha256=c_OMf2jQ7MbxkrXWNmgIpZrBe28N9f89tnzuCQ3fu3A,1930
105
- torchzero/modules/misc/gradient_accumulation.py,sha256=Xzjt_ulm6Z3mpmtagoUqoefhoeSDVnmX__tVbcI_RQE,2271
108
+ torchzero/modules/misc/gradient_accumulation.py,sha256=1BVqGXwv1YPg7DRJWP0XY6s-vzxrvyXLdruM1Y5KJ5s,2326
106
109
  torchzero/modules/misc/homotopy.py,sha256=oa0YFYfv8kkg9v7nukdjTwinuyQa4Nt7kTpddUVCSKg,2257
107
- torchzero/modules/misc/misc.py,sha256=f-3qxBq1KYI3iGYJXzv1cHEJHc0ScEp-vCLCgiaEgJQ,15002
110
+ torchzero/modules/misc/misc.py,sha256=eWVyYSYiQxcS7G7aVM4nFYiF0csE9qcztTaP4id5CbE,15306
108
111
  torchzero/modules/misc/multistep.py,sha256=twdE-lU9Wa0b_uquH9kZ-1OwP0gqWfFMJkdjVWJRwe4,6599
109
112
  torchzero/modules/misc/regularization.py,sha256=MCd_tnBYfFnx0b3sM1vHNQ_WbTVfo7l8pxmxGVgWcc0,5935
110
113
  torchzero/modules/misc/split.py,sha256=rmi9PgMgiqddrr8fY8Dbdcl2dgwTn9YBAve_bg5Zd08,4288
111
114
  torchzero/modules/misc/switch.py,sha256=_ycuD23gR0ZvIUmX3feYBr0_WTX22Pfhu3whpiSCMv4,3678
112
115
  torchzero/modules/momentum/__init__.py,sha256=AKWC4HIkN9ZJwN38dJvVJkFEhiP9r93G-kMDokBfsj8,281
113
- torchzero/modules/momentum/averaging.py,sha256=Q6WLwCJwgNY96YIfQXWpsX-2kDR7n0IOMDfZMvNVc9U,3035
116
+ torchzero/modules/momentum/averaging.py,sha256=OTO_LRNiAhbcKTXrWI-uENqIOH_3DX5_1uYJ3eMVcJY,3202
114
117
  torchzero/modules/momentum/cautious.py,sha256=1hD2H08OQaNZG52sheRADBsuf9uJsaoLV4n-UVGUH3Y,8379
115
- torchzero/modules/momentum/momentum.py,sha256=MPHd4TU1bSlEKLGfueNdmaZ13V5J1suW6agBc3SvrTs,4389
116
- torchzero/modules/ops/__init__.py,sha256=xUYzWWLlSwaT8sw3dWywkALqI6YGCZgptWQJVy83HhM,1249
117
- torchzero/modules/ops/accumulate.py,sha256=f-Uutg7gNFRobTc5YI9JlfFiSacXmg0gDhIwQNwZSZg,3439
118
+ torchzero/modules/momentum/momentum.py,sha256=D6Rfy_Ha5jd9uEk3cwCXfGH1dMiP4k4w08SHiE-hChc,4494
119
+ torchzero/modules/ops/__init__.py,sha256=p5hwECuODOv6E4H0lETQHweSsUtMlsGE0d8bfTv2Rwc,1225
120
+ torchzero/modules/ops/accumulate.py,sha256=mbJFwykU2fa6IIfsHVXdhmRp7QX1czpCWjw6AYkNn1k,3636
118
121
  torchzero/modules/ops/binary.py,sha256=eB6zwz5ZSSyeWvwVfuOFMjem93oMB7hCo4kNF705jn8,12219
119
- torchzero/modules/ops/higher_level.py,sha256=cUh-908S0GWVGekmUN5c_Vx0HP3P2tQoKN3COQM5TaQ,8965
122
+ torchzero/modules/ops/higher_level.py,sha256=f9DFNI9rnxc-rShAJOfsiwvyGsWu8FsJwJf5yg_V4eg,9366
120
123
  torchzero/modules/ops/multi.py,sha256=WzNK07_wL7z0Gb2pmv5a15Oss6tW9IG79x1c4ZPmOqQ,8643
121
124
  torchzero/modules/ops/reduce.py,sha256=SzpkNV5NTsVFp-61a1m8lDKJ1ivJmfQofolFWxbbAe4,6526
122
125
  torchzero/modules/ops/unary.py,sha256=vXvWfDFo2CBFwb1ej_WV-fGg61lQRbwN4HklAik8tJY,4844
@@ -136,12 +139,12 @@ torchzero/modules/restarts/__init__.py,sha256=7282ePwN_I0vSeLPYS4TTclE9ZU7pL6Upy
136
139
  torchzero/modules/restarts/restars.py,sha256=gcRZ8VHGg60cFVzsk0TWa6-EXoqEFbEeP1p7fs2Av0Q,9348
137
140
  torchzero/modules/second_order/__init__.py,sha256=42HeVA3Azl_tXV0_injU-q4QOu7lXzt6AVUcwnPy4Ag,313
138
141
  torchzero/modules/second_order/ifn.py,sha256=oAjfFVjLzG6L4n_ELXAWGZSicWizilQy_hQf4hmOoL0,2019
139
- torchzero/modules/second_order/inm.py,sha256=OddoZHQfSuFnlx_7Zj2qiVcC2A_9yMVn_0Gy1A7hNAg,3420
142
+ torchzero/modules/second_order/inm.py,sha256=_FnaUHKLl46AtI_XYwF52wtOUbAaO5EMUNRJspX5FEM,3574
140
143
  torchzero/modules/second_order/multipoint.py,sha256=mHG1SFLsILELIspxZ8U_hxJBlkGwzvUWg96bOIrQsIY,7500
141
- torchzero/modules/second_order/newton.py,sha256=QcLXsglvf4zJEwR4cldsGVZCABQtxb6U5qVmU3spN_A,11061
142
- torchzero/modules/second_order/newton_cg.py,sha256=k8G8CSmeIQZObkWVURFnbF_4g2UvJiwh3xToxn7sFJE,14816
143
- torchzero/modules/second_order/nystrom.py,sha256=WQFfJj0DOfWXyyx36C54m0WqZPIvTTK7n8U7khLhGLg,13359
144
- torchzero/modules/second_order/rsn.py,sha256=9s-JyJNNeDlIFv8YVGn7y8DGPnP93WJEjpUQXehX3uY,9980
144
+ torchzero/modules/second_order/newton.py,sha256=W37_ePdAB1wnlRrNRd2ovNgkbodK1JV8J4SJytVuF_M,11456
145
+ torchzero/modules/second_order/newton_cg.py,sha256=gHmpLRQ2FRr0750gYkFQ7XweJVZmYI6yG9H2vrKvAdA,14925
146
+ torchzero/modules/second_order/nystrom.py,sha256=lGLjtzq2WAWcaT3E6Say82ySZ1yp9I2ASuOqyNTUmiQ,13361
147
+ torchzero/modules/second_order/rsn.py,sha256=13t42cUvY8JQMC4zf4UsqKvpnTXuXZUZJDECCxRYWjg,11286
145
148
  torchzero/modules/smoothing/__init__.py,sha256=RYxCLLfG2onBbMUToaoedsr20rXaayyBt7Ov8OxULrU,80
146
149
  torchzero/modules/smoothing/laplacian.py,sha256=1cewdvnneKn51bbIBqKij0bkveKE7wOYCZ-aGlqzK5M,5201
147
150
  torchzero/modules/smoothing/sampling.py,sha256=bCH7wlTYZ_vtKUKSkI6znORxQ5Z6DGcpo10F-GYvFlE,12880
@@ -155,7 +158,7 @@ torchzero/modules/trust_region/cubic_regularization.py,sha256=QJjLRkfERvOzV5dTdy
155
158
  torchzero/modules/trust_region/dogleg.py,sha256=zwFR49gghxztVGEETF2D4AkeGgHkQRbHGGelav3GuFg,3619
156
159
  torchzero/modules/trust_region/levenberg_marquardt.py,sha256=-qbeEW3qRKou48bBdZ-u4Nv43TMt475XV6P_aWfxtqE,5039
157
160
  torchzero/modules/trust_region/trust_cg.py,sha256=X9rCJQWvptjZVH2H16iekvAYmleKQAYZKRKC3V0JjFY,4455
158
- torchzero/modules/trust_region/trust_region.py,sha256=oXMNIvboz0R_1J0Gfd4IvbnwZFl32csNVv-lTYGB0zk,12913
161
+ torchzero/modules/trust_region/trust_region.py,sha256=ax1pJDr3NPLfojUXRMb-hsxD4MpQL1bPAOwozAVTCJI,12930
159
162
  torchzero/modules/variance_reduction/__init__.py,sha256=3pwPWZpjgz1btfLJ3rEaK7Wl8B1pDh0HIf0kvD_NJH8,22
160
163
  torchzero/modules/variance_reduction/svrg.py,sha256=hXEJ0PUYSksHV0ws3t3cE_4MUTTEn1Htu37iZdDdJCs,8746
161
164
  torchzero/modules/weight_decay/__init__.py,sha256=zQrjSujD0c-rKfKjUpuutfAODljsz1hS3zUNJW7zbh4,132
@@ -196,14 +199,14 @@ torchzero/utils/metrics.py,sha256=XPpOvY257tb4mN3Sje1AVNlQkOXiW24_lXXdtd0JYok,31
196
199
  torchzero/utils/numberlist.py,sha256=iMoqz4IzXy-aE9bqVYJ21GV6pl0z-NeTsXR-LaI8C24,6229
197
200
  torchzero/utils/optimizer.py,sha256=G741IvE57RaVYowr9FEqfRm_opPAeu4UWKU5iPKDMFA,8415
198
201
  torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
199
- torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
202
+ torchzero/utils/params.py,sha256=-amJs518rpI0zzYavTlWrl60JNrgsk1xxdGvIrSw1ZI,6406
200
203
  torchzero/utils/python_tools.py,sha256=HATghTNijlQxmw8rzJfZPPGj1CjcnRxEwogmrgqnARU,4577
201
- torchzero/utils/tensorlist.py,sha256=4rN8gm967pPmtO5kotXqIX7Mal0ps-IHkGBybfeWY4M,56357
204
+ torchzero/utils/tensorlist.py,sha256=wpzBJvIAmw9VXsg1UF8gZtq-eh7GlvdM6WL_7NyPYlY,56363
202
205
  torchzero/utils/thoad_tools.py,sha256=G8k-z0vireEUtI3A_YAR6dtwYjSnN49e_GadcHwwQKc,2319
203
206
  torchzero/utils/torch_tools.py,sha256=DsHaSRGZ3-IuySZJTrkojTbaMMlttJFe0hFvB2xnl2U,5069
204
207
  torchzero/utils/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
205
- torchzero/utils/benchmarks/logistic.py,sha256=RHsjHEWkPqaag0kt3wfmdddh4DhftcyW9r70tj9OGp4,4382
206
- torchzero-0.4.1.dist-info/METADATA,sha256=hB0rFqXnaRbwVkFRwTwjXpKnIFLi8MBvLXbgXTuUGWk,564
207
- torchzero-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
208
- torchzero-0.4.1.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
209
- torchzero-0.4.1.dist-info/RECORD,,
208
+ torchzero/utils/benchmarks/logistic.py,sha256=1c9kB6tDaKsSNlQn44_Lso2_g-85fQK45RvwLZOcJOo,4587
209
+ torchzero-0.4.2.dist-info/METADATA,sha256=nApA6WdQrTYR0c5TXCORxOktKgVwxlyMqgnfkKNPHLk,564
210
+ torchzero-0.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
211
+ torchzero-0.4.2.dist-info/top_level.txt,sha256=ETW_iE2ubg0oMyef_h-ayB5i1OOZZd4SNdR3ltIbHe0,16
212
+ torchzero-0.4.2.dist-info/RECORD,,