torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_opts.py +54 -21
- tests/test_tensorlist.py +2 -2
- tests/test_vars.py +61 -61
- torchzero/core/__init__.py +2 -3
- torchzero/core/module.py +49 -49
- torchzero/core/transform.py +219 -158
- torchzero/modules/__init__.py +1 -0
- torchzero/modules/clipping/clipping.py +10 -10
- torchzero/modules/clipping/ema_clipping.py +14 -13
- torchzero/modules/clipping/growth_clipping.py +16 -18
- torchzero/modules/experimental/__init__.py +12 -3
- torchzero/modules/experimental/absoap.py +50 -156
- torchzero/modules/experimental/adadam.py +15 -14
- torchzero/modules/experimental/adamY.py +17 -27
- torchzero/modules/experimental/adasoap.py +19 -129
- torchzero/modules/experimental/curveball.py +12 -12
- torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
- torchzero/modules/experimental/eigendescent.py +117 -0
- torchzero/modules/experimental/etf.py +172 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/newton_solver.py +11 -11
- torchzero/modules/experimental/newtonnewton.py +88 -0
- torchzero/modules/experimental/reduce_outward_lr.py +8 -5
- torchzero/modules/experimental/soapy.py +19 -146
- torchzero/modules/experimental/spectral.py +79 -204
- torchzero/modules/experimental/structured_newton.py +12 -12
- torchzero/modules/experimental/subspace_preconditioners.py +13 -10
- torchzero/modules/experimental/tada.py +38 -0
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +5 -5
- torchzero/modules/grad_approximation/grad_approximator.py +21 -21
- torchzero/modules/grad_approximation/rfdm.py +28 -15
- torchzero/modules/higher_order/__init__.py +1 -0
- torchzero/modules/higher_order/higher_order_newton.py +256 -0
- torchzero/modules/line_search/backtracking.py +42 -23
- torchzero/modules/line_search/line_search.py +40 -40
- torchzero/modules/line_search/scipy.py +18 -3
- torchzero/modules/line_search/strong_wolfe.py +21 -32
- torchzero/modules/line_search/trust_region.py +18 -6
- torchzero/modules/lr/__init__.py +1 -1
- torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
- torchzero/modules/lr/lr.py +20 -16
- torchzero/modules/momentum/averaging.py +25 -10
- torchzero/modules/momentum/cautious.py +73 -35
- torchzero/modules/momentum/ema.py +92 -41
- torchzero/modules/momentum/experimental.py +21 -13
- torchzero/modules/momentum/matrix_momentum.py +96 -54
- torchzero/modules/momentum/momentum.py +24 -4
- torchzero/modules/ops/accumulate.py +51 -21
- torchzero/modules/ops/binary.py +36 -36
- torchzero/modules/ops/debug.py +7 -7
- torchzero/modules/ops/misc.py +128 -129
- torchzero/modules/ops/multi.py +19 -19
- torchzero/modules/ops/reduce.py +16 -16
- torchzero/modules/ops/split.py +26 -26
- torchzero/modules/ops/switch.py +4 -4
- torchzero/modules/ops/unary.py +20 -20
- torchzero/modules/ops/utility.py +37 -37
- torchzero/modules/optimizers/adagrad.py +33 -24
- torchzero/modules/optimizers/adam.py +31 -34
- torchzero/modules/optimizers/lion.py +4 -4
- torchzero/modules/optimizers/muon.py +6 -6
- torchzero/modules/optimizers/orthograd.py +4 -5
- torchzero/modules/optimizers/rmsprop.py +13 -16
- torchzero/modules/optimizers/rprop.py +52 -49
- torchzero/modules/optimizers/shampoo.py +17 -23
- torchzero/modules/optimizers/soap.py +12 -19
- torchzero/modules/optimizers/sophia_h.py +13 -13
- torchzero/modules/projections/dct.py +4 -4
- torchzero/modules/projections/fft.py +6 -6
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +57 -57
- torchzero/modules/projections/structural.py +17 -17
- torchzero/modules/quasi_newton/__init__.py +33 -4
- torchzero/modules/quasi_newton/cg.py +67 -17
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
- torchzero/modules/quasi_newton/lbfgs.py +12 -12
- torchzero/modules/quasi_newton/lsr1.py +11 -11
- torchzero/modules/quasi_newton/olbfgs.py +19 -19
- torchzero/modules/quasi_newton/quasi_newton.py +254 -47
- torchzero/modules/second_order/newton.py +32 -20
- torchzero/modules/second_order/newton_cg.py +13 -12
- torchzero/modules/second_order/nystrom.py +21 -21
- torchzero/modules/smoothing/gaussian.py +21 -21
- torchzero/modules/smoothing/laplacian.py +7 -9
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +43 -9
- torchzero/modules/wrappers/optim_wrapper.py +11 -11
- torchzero/optim/wrappers/directsearch.py +244 -0
- torchzero/optim/wrappers/fcmaes.py +97 -0
- torchzero/optim/wrappers/mads.py +90 -0
- torchzero/optim/wrappers/nevergrad.py +4 -4
- torchzero/optim/wrappers/nlopt.py +28 -14
- torchzero/optim/wrappers/optuna.py +70 -0
- torchzero/optim/wrappers/scipy.py +162 -13
- torchzero/utils/__init__.py +2 -6
- torchzero/utils/derivatives.py +2 -1
- torchzero/utils/optimizer.py +55 -74
- torchzero/utils/python_tools.py +17 -4
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
- torchzero-0.3.10.dist-info/RECORD +139 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
- torchzero/core/preconditioner.py +0 -138
- torchzero/modules/experimental/algebraic_newton.py +0 -145
- torchzero/modules/experimental/tropical_newton.py +0 -136
- torchzero-0.3.9.dist-info/RECORD +0 -131
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,145 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Literal
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
import torch
|
|
6
|
-
import torchalgebras as ta
|
|
7
|
-
|
|
8
|
-
from ...core import Chainable, apply, Module
|
|
9
|
-
from ...utils import vec_to_tensors, TensorList
|
|
10
|
-
from ...utils.derivatives import (
|
|
11
|
-
hessian_list_to_mat,
|
|
12
|
-
hessian_mat,
|
|
13
|
-
jacobian_and_hessian_wrt,
|
|
14
|
-
)
|
|
15
|
-
|
|
16
|
-
class MaxItersReached(Exception): pass
|
|
17
|
-
def tropical_lstsq(
|
|
18
|
-
H: torch.Tensor,
|
|
19
|
-
g: torch.Tensor,
|
|
20
|
-
solver,
|
|
21
|
-
maxiter,
|
|
22
|
-
tol,
|
|
23
|
-
algebra,
|
|
24
|
-
verbose,
|
|
25
|
-
):
|
|
26
|
-
"""it can run on any algebra with add despite it saying tropical"""
|
|
27
|
-
algebra = ta.get_algebra(algebra)
|
|
28
|
-
|
|
29
|
-
x = torch.zeros_like(g, requires_grad=True)
|
|
30
|
-
best_x = x.detach().clone()
|
|
31
|
-
best_loss = float('inf')
|
|
32
|
-
opt = solver([x])
|
|
33
|
-
|
|
34
|
-
niter = 0
|
|
35
|
-
def closure(backward=True):
|
|
36
|
-
nonlocal niter, best_x, best_loss
|
|
37
|
-
if niter == maxiter: raise MaxItersReached
|
|
38
|
-
niter += 1
|
|
39
|
-
|
|
40
|
-
g_hat = algebra.mm(H, x)
|
|
41
|
-
loss = torch.nn.functional.mse_loss(g_hat, g)
|
|
42
|
-
if loss < best_loss:
|
|
43
|
-
best_x = x.detach().clone()
|
|
44
|
-
best_loss = loss.detach()
|
|
45
|
-
|
|
46
|
-
if backward:
|
|
47
|
-
opt.zero_grad()
|
|
48
|
-
loss.backward()
|
|
49
|
-
return loss
|
|
50
|
-
|
|
51
|
-
loss = None
|
|
52
|
-
prev_loss = float('inf')
|
|
53
|
-
for i in range(maxiter):
|
|
54
|
-
try:
|
|
55
|
-
loss = opt.step(closure)
|
|
56
|
-
if loss == 0: break
|
|
57
|
-
if tol is not None and prev_loss - loss < tol: break
|
|
58
|
-
prev_loss = loss
|
|
59
|
-
except MaxItersReached:
|
|
60
|
-
break
|
|
61
|
-
|
|
62
|
-
if verbose: print(f'{best_loss = } after {niter} iters')
|
|
63
|
-
return best_x.detach()
|
|
64
|
-
|
|
65
|
-
def tikhonov(H: torch.Tensor, reg: float, algebra: ta.Algebra = ta.TropicalSemiring()):
|
|
66
|
-
if reg!=0:
|
|
67
|
-
I = ta.AlgebraicTensor(torch.eye(H.size(-1), dtype=H.dtype, device=H.device), algebra)
|
|
68
|
-
I = I * reg
|
|
69
|
-
H = algebra.add(H, I.data)
|
|
70
|
-
return H
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class AlgebraicNewton(Module):
|
|
74
|
-
"""newton in other algebras, not that it works."""
|
|
75
|
-
def __init__(
|
|
76
|
-
self,
|
|
77
|
-
reg: float | None = None,
|
|
78
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
79
|
-
vectorize: bool = True,
|
|
80
|
-
solver=lambda p: torch.optim.LBFGS(p, line_search_fn='strong_wolfe'),
|
|
81
|
-
maxiter=1000,
|
|
82
|
-
tol: float | None = 1e-10,
|
|
83
|
-
algebra: ta.Algebra | str = 'tropical max',
|
|
84
|
-
verbose: bool = False,
|
|
85
|
-
inner: Chainable | None = None,
|
|
86
|
-
):
|
|
87
|
-
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize)
|
|
88
|
-
super().__init__(defaults)
|
|
89
|
-
|
|
90
|
-
self.algebra = ta.get_algebra(algebra)
|
|
91
|
-
self.lstsq_args:dict = dict(solver=solver, maxiter=maxiter, tol=tol, algebra=algebra, verbose=verbose)
|
|
92
|
-
|
|
93
|
-
if inner is not None:
|
|
94
|
-
self.set_child('inner', inner)
|
|
95
|
-
|
|
96
|
-
@torch.no_grad
|
|
97
|
-
def step(self, vars):
|
|
98
|
-
params = TensorList(vars.params)
|
|
99
|
-
closure = vars.closure
|
|
100
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
101
|
-
|
|
102
|
-
settings = self.settings[params[0]]
|
|
103
|
-
reg = settings['reg']
|
|
104
|
-
hessian_method = settings['hessian_method']
|
|
105
|
-
vectorize = settings['vectorize']
|
|
106
|
-
|
|
107
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
108
|
-
if hessian_method == 'autograd':
|
|
109
|
-
with torch.enable_grad():
|
|
110
|
-
loss = vars.loss = vars.loss_approx = closure(False)
|
|
111
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
112
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
113
|
-
vars.grad = g_list
|
|
114
|
-
H = hessian_list_to_mat(H_list)
|
|
115
|
-
|
|
116
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
117
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
118
|
-
with torch.enable_grad():
|
|
119
|
-
g_list = vars.get_grad(retain_graph=True)
|
|
120
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
121
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
122
|
-
|
|
123
|
-
else:
|
|
124
|
-
raise ValueError(hessian_method)
|
|
125
|
-
|
|
126
|
-
# -------------------------------- inner step -------------------------------- #
|
|
127
|
-
if 'inner' in self.children:
|
|
128
|
-
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
129
|
-
g = torch.cat([t.view(-1) for t in g_list])
|
|
130
|
-
|
|
131
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
132
|
-
if reg is not None: H = tikhonov(H, reg)
|
|
133
|
-
|
|
134
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
135
|
-
tropical_update = tropical_lstsq(H, g, **self.lstsq_args)
|
|
136
|
-
# what now? w - u is not defined, it is defined for max version if u < w
|
|
137
|
-
# w = params.to_vec()
|
|
138
|
-
# w_hat = self.algebra.sub(w, tropical_update)
|
|
139
|
-
# update = w_hat - w
|
|
140
|
-
# no
|
|
141
|
-
# it makes sense to solve tropical system and sub normally
|
|
142
|
-
# the only thing is that tropical system can have no solutions
|
|
143
|
-
|
|
144
|
-
vars.update = vec_to_tensors(tropical_update, params)
|
|
145
|
-
return vars
|
|
@@ -1,136 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from functools import partial
|
|
3
|
-
from typing import Literal
|
|
4
|
-
from collections.abc import Callable
|
|
5
|
-
import torch
|
|
6
|
-
|
|
7
|
-
from ...core import Chainable, apply, Module
|
|
8
|
-
from ...utils import vec_to_tensors, TensorList
|
|
9
|
-
from ...utils.derivatives import (
|
|
10
|
-
hessian_list_to_mat,
|
|
11
|
-
hessian_mat,
|
|
12
|
-
jacobian_and_hessian_wrt,
|
|
13
|
-
)
|
|
14
|
-
from ..second_order.newton import lu_solve, cholesky_solve, least_squares_solve
|
|
15
|
-
|
|
16
|
-
def tropical_sum(x, dim): return torch.amax(x, dim=dim)
|
|
17
|
-
def tropical_mul(x, y): return x+y
|
|
18
|
-
|
|
19
|
-
def tropical_matmul(x: torch.Tensor, y: torch.Tensor):
|
|
20
|
-
# this imlements matmul by calling mul and sum
|
|
21
|
-
|
|
22
|
-
x_squeeze = False
|
|
23
|
-
y_squeeze = False
|
|
24
|
-
|
|
25
|
-
if x.ndim == 1:
|
|
26
|
-
x_squeeze = True
|
|
27
|
-
x = x.unsqueeze(0)
|
|
28
|
-
|
|
29
|
-
if y.ndim == 1:
|
|
30
|
-
y_squeeze = True
|
|
31
|
-
y = y.unsqueeze(1)
|
|
32
|
-
|
|
33
|
-
res = tropical_sum(tropical_mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)
|
|
34
|
-
|
|
35
|
-
if x_squeeze: res = res.squeeze(-2)
|
|
36
|
-
if y_squeeze: res = res.squeeze(-1)
|
|
37
|
-
|
|
38
|
-
return res
|
|
39
|
-
|
|
40
|
-
def tropical_dot(x:torch.Tensor, y:torch.Tensor):
|
|
41
|
-
assert x.ndim == 1 and y.ndim == 1
|
|
42
|
-
return tropical_matmul(x.unsqueeze(0), y.unsqueeze(1))
|
|
43
|
-
|
|
44
|
-
def tropical_outer(x:torch.Tensor, y:torch.Tensor):
|
|
45
|
-
assert x.ndim == 1 and y.ndim == 1
|
|
46
|
-
return tropical_matmul(x.unsqueeze(1), y.unsqueeze(0))
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def tropical_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
50
|
-
r = b.unsqueeze(1) - A
|
|
51
|
-
return r.amin(dim=-2)
|
|
52
|
-
|
|
53
|
-
def tropical_solve_and_reconstruct(A: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
54
|
-
r = b.unsqueeze(1) - A
|
|
55
|
-
x = r.amin(dim=-2)
|
|
56
|
-
b_hat = tropical_matmul(A, x)
|
|
57
|
-
return x, b_hat
|
|
58
|
-
|
|
59
|
-
def tikhonov(H: torch.Tensor, reg: float):
|
|
60
|
-
if reg!=0: H += torch.eye(H.size(-1), dtype=H.dtype, device=H.device) * reg
|
|
61
|
-
return H
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class TropicalNewton(Module):
|
|
65
|
-
"""suston"""
|
|
66
|
-
def __init__(
|
|
67
|
-
self,
|
|
68
|
-
reg: float | None = None,
|
|
69
|
-
hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
|
|
70
|
-
vectorize: bool = True,
|
|
71
|
-
interpolate:bool=False,
|
|
72
|
-
inner: Chainable | None = None,
|
|
73
|
-
):
|
|
74
|
-
defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, interpolate=interpolate)
|
|
75
|
-
super().__init__(defaults)
|
|
76
|
-
|
|
77
|
-
if inner is not None:
|
|
78
|
-
self.set_child('inner', inner)
|
|
79
|
-
|
|
80
|
-
@torch.no_grad
|
|
81
|
-
def step(self, vars):
|
|
82
|
-
params = TensorList(vars.params)
|
|
83
|
-
closure = vars.closure
|
|
84
|
-
if closure is None: raise RuntimeError('NewtonCG requires closure')
|
|
85
|
-
|
|
86
|
-
settings = self.settings[params[0]]
|
|
87
|
-
reg = settings['reg']
|
|
88
|
-
hessian_method = settings['hessian_method']
|
|
89
|
-
vectorize = settings['vectorize']
|
|
90
|
-
interpolate = settings['interpolate']
|
|
91
|
-
|
|
92
|
-
# ------------------------ calculate grad and hessian ------------------------ #
|
|
93
|
-
if hessian_method == 'autograd':
|
|
94
|
-
with torch.enable_grad():
|
|
95
|
-
loss = vars.loss = vars.loss_approx = closure(False)
|
|
96
|
-
g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
|
|
97
|
-
g_list = [t[0] for t in g_list] # remove leading dim from loss
|
|
98
|
-
vars.grad = g_list
|
|
99
|
-
H = hessian_list_to_mat(H_list)
|
|
100
|
-
|
|
101
|
-
elif hessian_method in ('func', 'autograd.functional'):
|
|
102
|
-
strat = 'forward-mode' if vectorize else 'reverse-mode'
|
|
103
|
-
with torch.enable_grad():
|
|
104
|
-
g_list = vars.get_grad(retain_graph=True)
|
|
105
|
-
H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
|
|
106
|
-
method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
|
|
107
|
-
|
|
108
|
-
else:
|
|
109
|
-
raise ValueError(hessian_method)
|
|
110
|
-
|
|
111
|
-
# -------------------------------- inner step -------------------------------- #
|
|
112
|
-
if 'inner' in self.children:
|
|
113
|
-
g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
|
|
114
|
-
g = torch.cat([t.view(-1) for t in g_list])
|
|
115
|
-
|
|
116
|
-
# ------------------------------- regulazition ------------------------------- #
|
|
117
|
-
if reg is not None: H = tikhonov(H, reg)
|
|
118
|
-
|
|
119
|
-
# ----------------------------------- solve ---------------------------------- #
|
|
120
|
-
tropical_update, g_hat = tropical_solve_and_reconstruct(H, g)
|
|
121
|
-
|
|
122
|
-
g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
|
|
123
|
-
abs_error = torch.linalg.vector_norm(g-g_hat) # pylint:disable=not-callable
|
|
124
|
-
rel_error = abs_error/g_norm.clip(min=1e-8)
|
|
125
|
-
|
|
126
|
-
if interpolate:
|
|
127
|
-
if rel_error > 1e-8:
|
|
128
|
-
|
|
129
|
-
update = cholesky_solve(H, g)
|
|
130
|
-
if update is None: update = lu_solve(H, g)
|
|
131
|
-
if update is None: update = least_squares_solve(H, g)
|
|
132
|
-
|
|
133
|
-
tropical_update.lerp_(update.ravel(), rel_error.clip(max=1))
|
|
134
|
-
|
|
135
|
-
vars.update = vec_to_tensors(tropical_update, params)
|
|
136
|
-
return vars
|
torchzero-0.3.9.dist-info/RECORD
DELETED
|
@@ -1,131 +0,0 @@
|
|
|
1
|
-
docs/source/conf.py,sha256=jd80ZT2IdCx7nlQrpOTJL8UhGBNm6KYyXlpp0jmRiAw,1849
|
|
2
|
-
tests/test_identical.py,sha256=NZ7A8Rm1U9Q16d-cG2G_wccpPtNALyoKYJt9qMownMc,11568
|
|
3
|
-
tests/test_module.py,sha256=qX3rjdSJsbA8JO17bPTUIDspe7bg2dogqxMw__KV7SU,2039
|
|
4
|
-
tests/test_opts.py,sha256=TZVaCv2ZLdHSkL6snTEkqhTMHqlcO55L-c56k6Hh4xc,40850
|
|
5
|
-
tests/test_tensorlist.py,sha256=Djpr5C0T5d_gz-j-P-bpo_X51DC4twbtT9c-xDSFbP0,72438
|
|
6
|
-
tests/test_utils_optimizer.py,sha256=bvC0Ehvs2L8fohpyIF5Vfr9OKTycpnODWLPflXilU1c,8414
|
|
7
|
-
tests/test_vars.py,sha256=3p9dsHk7SJpMd-WRD0ziBNq5FEHRBJGSxbMLD8ES4J0,6815
|
|
8
|
-
torchzero/__init__.py,sha256=L7IJ1qZ3o8E9oRwlJZBK2_2yII_eeGEk57Of6EfVbrk,112
|
|
9
|
-
torchzero/core/__init__.py,sha256=2JRyeGZprTexAeEPQOIl9fLFGBwzvya-AwKyt7XAmGQ,210
|
|
10
|
-
torchzero/core/module.py,sha256=Razw3c71Kfegznm0vQxsii1KuTUCPBC9UGyq2v-KX4M,27568
|
|
11
|
-
torchzero/core/preconditioner.py,sha256=n9oh7kZdt1kU3Wh472lnvLrsXwhR5Wqe6lIp7JuAJ_I,6336
|
|
12
|
-
torchzero/core/transform.py,sha256=ajNJcX45ds-_lc5CqxgLfEFGil6_BYLerB0WvoTi8rM,10303
|
|
13
|
-
torchzero/modules/__init__.py,sha256=BDeyuSd2s1WFUUXIo3tGTNp4aYp4A2B94cydpPW24nY,332
|
|
14
|
-
torchzero/modules/functional.py,sha256=HXNzmPe7LsPadryEm7zrcEKqGej16QDwSgBkbEvggFM,6492
|
|
15
|
-
torchzero/modules/clipping/__init__.py,sha256=ZaffMF7mIRK6hZSfuZadgjNTX6hF5ANiLBny2w3S7I8,250
|
|
16
|
-
torchzero/modules/clipping/clipping.py,sha256=I-5utyrqdKtF5yaH-9m2F3UqdfpPmA2bSSFUAZ_d60Q,12544
|
|
17
|
-
torchzero/modules/clipping/ema_clipping.py,sha256=pLeNuEBLpJ74io2sHn_ZVYaQ6ydEfhpVfVEX2bFttd0,5947
|
|
18
|
-
torchzero/modules/clipping/growth_clipping.py,sha256=OD-kdia2Rn-DvYlYV6EZlGPDVTh9tj-W9mpiZPc3hOQ,6772
|
|
19
|
-
torchzero/modules/experimental/__init__.py,sha256=fEPDYDl7qhaFoferDRmG3ehwuqSvx4Vt2uOz0Y7h4to,483
|
|
20
|
-
torchzero/modules/experimental/absoap.py,sha256=Z4MS4pDPSQ9IaTk8g57OfrsWcYVOT72x533KKtn2Zxk,13512
|
|
21
|
-
torchzero/modules/experimental/adadam.py,sha256=OAPF1-NUbg79V3QOTYzsQlRC97C7XHj5boOLDqLz3PE,4029
|
|
22
|
-
torchzero/modules/experimental/adamY.py,sha256=g1pAHwgdyDdKvObZ67lCSc36L99tl5jlQgOr4lMJCDo,4595
|
|
23
|
-
torchzero/modules/experimental/adasoap.py,sha256=JdV6rB9xfqL3vbHpZCLmkJZKRObZ1nVoEmabtIeVT3E,11195
|
|
24
|
-
torchzero/modules/experimental/algebraic_newton.py,sha256=sq5ZD_j_EtlxIjNnS0rKKwTSG_JuwsZOg9ZMMQTuQm0,5154
|
|
25
|
-
torchzero/modules/experimental/curveball.py,sha256=Uk30uLEztTHD5IUJLJm9Nn3x31DF9kQHmeLFhc065us,3262
|
|
26
|
-
torchzero/modules/experimental/gradmin.py,sha256=iJmEvDEdVdck0C-94pY3iGxnIoNv6Fu6vj3f7lS6aQM,3686
|
|
27
|
-
torchzero/modules/experimental/newton_solver.py,sha256=iGI2LHLaZd2ovpbq1Vogs76os0zWG7VwM7nUz8RzxVg,3071
|
|
28
|
-
torchzero/modules/experimental/reduce_outward_lr.py,sha256=kjtRwepBGBca77ToM-lw3b8ywptMtmSdC_jQfjJAwlY,1184
|
|
29
|
-
torchzero/modules/experimental/soapy.py,sha256=Ishd2Jj6BbhjrLyC48zf-cjMmA1kJb_uKXESQBIML_s,10990
|
|
30
|
-
torchzero/modules/experimental/spectral.py,sha256=8_n208V2yPY3z5pCym-FvwO7DGFhozNgWlpIBtQSdrI,12139
|
|
31
|
-
torchzero/modules/experimental/structured_newton.py,sha256=uWczR-uAXHaFwf0mlOThv2sLG0irH6Gz1hKlGHtPAj4,3386
|
|
32
|
-
torchzero/modules/experimental/subspace_preconditioners.py,sha256=WnHpga7Kx4-N2xU5vP3uUHRER70ymyNJCWbSx2zXWOk,4976
|
|
33
|
-
torchzero/modules/experimental/tropical_newton.py,sha256=uq66ouhgrgc8iYGozDQ3_rtbubj8rKRwb1jfcdnlpHg,4903
|
|
34
|
-
torchzero/modules/grad_approximation/__init__.py,sha256=DVFjf0cXuF70NA0nJ2WklpP01PQgrRZxUjUQjjQeSos,195
|
|
35
|
-
torchzero/modules/grad_approximation/fdm.py,sha256=2PNNBIMup1xlOwLFAwAS3xAVd-7GGVyerMeKH1ug9LQ,3591
|
|
36
|
-
torchzero/modules/grad_approximation/forward_gradient.py,sha256=Kb8RNGAIb2tKzgofnEn4pQjS7TPq824B_P14idyy8e0,3564
|
|
37
|
-
torchzero/modules/grad_approximation/grad_approximator.py,sha256=Pa1Lv52T7WawUJUUA3IHm7mVypBQXLbjc5_15FkVwnQ,2938
|
|
38
|
-
torchzero/modules/grad_approximation/rfdm.py,sha256=s7OSMFnIEr43WKCT0TXdgzz_6odOkRN0BcKWkFbbPAE,10189
|
|
39
|
-
torchzero/modules/line_search/__init__.py,sha256=nkOUPLe88wE91ICEhprl2pJsvaKtbI3KzYOdT83AGsg,253
|
|
40
|
-
torchzero/modules/line_search/backtracking.py,sha256=FG_-KAN9whvBNZyhDa5-ta46IQFm8hagVvaPTXCCV88,6307
|
|
41
|
-
torchzero/modules/line_search/line_search.py,sha256=4z0fHJAGAZT2IVAOUxZetAszPtNuXfXdFzs1_WUWT2c,7296
|
|
42
|
-
torchzero/modules/line_search/scipy.py,sha256=7tfxXT8RAIHpRv-e5w9C8RNvkvgwgxHZaWI25RjTYy0,1156
|
|
43
|
-
torchzero/modules/line_search/strong_wolfe.py,sha256=Y6UXd2Br30YWta1phZx1wiSsFQC6wbgmvOpVITcmJpw,7504
|
|
44
|
-
torchzero/modules/line_search/trust_region.py,sha256=_zOje00BLvIMi0d5H9qZavqf3MWeB48Q-WosgXu3Ef4,2349
|
|
45
|
-
torchzero/modules/lr/__init__.py,sha256=pNxbBUGzDp24O6g7pu1bRg1tzh4eh-mSxVbhOItKHpc,90
|
|
46
|
-
torchzero/modules/lr/lr.py,sha256=wlubixzgxnm4ucyiEtGWzQOskaLXLInvSaR0sGKxto8,2161
|
|
47
|
-
torchzero/modules/lr/step_size.py,sha256=0HWYAYhVqWCCYe_-guBnMaOpqLbsMm4-F6bRFjltBsc,4036
|
|
48
|
-
torchzero/modules/momentum/__init__.py,sha256=pSD7vxu8PySrYOSHQMi3C9heYdcQr8y6WC_rwMybZm0,544
|
|
49
|
-
torchzero/modules/momentum/averaging.py,sha256=hyH5jzvYTbB1Vcjx0j_v4dtPp54GUUDOZYVDADGjcfE,2672
|
|
50
|
-
torchzero/modules/momentum/cautious.py,sha256=QCoBXpYcIUOrgY6XXHA30m0-MVy7iGCGxZGFLyDwqkc,5841
|
|
51
|
-
torchzero/modules/momentum/ema.py,sha256=4ubPpq9TL0oQZ5_eXBwU5oRbxV3faHMEM1a_kv8vRqI,7733
|
|
52
|
-
torchzero/modules/momentum/experimental.py,sha256=ze9oxqxdmqRFQyVdG7iBA-hICft5mxeAM6GCTQ4ewes,6352
|
|
53
|
-
torchzero/modules/momentum/matrix_momentum.py,sha256=IQjCp2Kb53bCaReM7fHBil_pwH9oiH029YkWFq0OIDw,4894
|
|
54
|
-
torchzero/modules/momentum/momentum.py,sha256=hcmmYysGItb3b7MBBVhoODh7p4Fyit68cZzD0NUBmvA,1540
|
|
55
|
-
torchzero/modules/ops/__init__.py,sha256=hxMZFSXX7xvitXkuBiYykVGX3p03Xprm_QA2CMg4eW8,1601
|
|
56
|
-
torchzero/modules/ops/accumulate.py,sha256=YGI11YxgTWvIBq5maDRWiSA-v-FS-XoaSYPU2SSrBY8,2759
|
|
57
|
-
torchzero/modules/ops/binary.py,sha256=-b0yvKvfDx9-HcaaxLWzg5C6rUl24oP3OltSF-iXi6w,9731
|
|
58
|
-
torchzero/modules/ops/debug.py,sha256=9sJOHRMwTMaOgOi2QFwCH7g2WPF1o3oyouPJO-MQQg4,862
|
|
59
|
-
torchzero/modules/ops/misc.py,sha256=xdxnGbRArWBqzyufUdrCQH-mAI9utRF0zxcvWCkEfZc,16383
|
|
60
|
-
torchzero/modules/ops/multi.py,sha256=P7mSG0LnDMkuZNSgtpHRNgqglqksrdxITCzkhmEjqxU,5742
|
|
61
|
-
torchzero/modules/ops/reduce.py,sha256=xvFHZG5Wf7KxfFLkynFGBOK6xywyTXsbCasW6h2OYAU,5695
|
|
62
|
-
torchzero/modules/ops/split.py,sha256=fFcDnJZ-e46cx_fx_TkGlVsFYOL1Y8UAp_pUPJOOdm4,2303
|
|
63
|
-
torchzero/modules/ops/switch.py,sha256=5idKd9xBP-KbqZjWBcr6ZDjso8BRpTNQYJg4xKWwmng,2511
|
|
64
|
-
torchzero/modules/ops/unary.py,sha256=h3MXS6jydZjfFetjaBCWCUWTXdQcNKnxEC6uGS6yh3c,4794
|
|
65
|
-
torchzero/modules/ops/utility.py,sha256=p-mc2j1mQEMLxp4brnAnzgmK6VKbSnYd2U8vkAwTKd8,3117
|
|
66
|
-
torchzero/modules/optimizers/__init__.py,sha256=BbT2nhIt4p74t1cO8ziQgzqZHaLvyuleXQbccugd06M,554
|
|
67
|
-
torchzero/modules/optimizers/adagrad.py,sha256=1DIBJ_7gJ35qidXMK4IkHYF_37Bl9Ptl9mAgfOq6YAk,4834
|
|
68
|
-
torchzero/modules/optimizers/adam.py,sha256=xctnENJ9rcpv2sis4zAGPGoy-ccJC1iVl8SvBynaG50,4093
|
|
69
|
-
torchzero/modules/optimizers/lion.py,sha256=eceNfITCozqYob0thWbIV7AdY1yAIJMqb4GJfB8a1SA,1087
|
|
70
|
-
torchzero/modules/optimizers/muon.py,sha256=m3LpwD6AF7E-1v3VVPHAN8S_tPTTFKZ5RpkzKea4K4g,9598
|
|
71
|
-
torchzero/modules/optimizers/orthograd.py,sha256=5BLnNJTYuGUClHmlxaXZ1jNvBR4zSFDGG8nM20lZdhk,2046
|
|
72
|
-
torchzero/modules/optimizers/rmsprop.py,sha256=d10Y9Ck-391tVysO3xMHg3g2Pe0UEZplgebEyDYi3Z4,4333
|
|
73
|
-
torchzero/modules/optimizers/rprop.py,sha256=n4k5-9F3ppH0Xl-4l4vNXfqVf2r67vMPCkstUaQKPLw,10974
|
|
74
|
-
torchzero/modules/optimizers/shampoo.py,sha256=AHHV6d71DqKDPCg52ShWIPIRSGtWkMc1v1XwXgDG3qY,8606
|
|
75
|
-
torchzero/modules/optimizers/soap.py,sha256=Kf2BAtIf2QY1V2ZJcUjRLcp2WfIVLd3mNclnaT3Nmds,11520
|
|
76
|
-
torchzero/modules/optimizers/sophia_h.py,sha256=8pSlYVm66xWplzdP8MX3MCTzzIYHsxGzDEXJKA03Zgg,4279
|
|
77
|
-
torchzero/modules/projections/__init__.py,sha256=OCxlh_-Tx-xpl31X03CeFJH9XveH563oEsWc8rUvX0A,196
|
|
78
|
-
torchzero/modules/projections/dct.py,sha256=wxaEV6dTNiOqW_n2UHX0De6mMXTKDXK6UNcMNI4Rogk,2373
|
|
79
|
-
torchzero/modules/projections/fft.py,sha256=OpCcEM1-A2dgk1umwRsBsvK7ObiHtsBKlkkcw0IX83Q,2961
|
|
80
|
-
torchzero/modules/projections/galore.py,sha256=c9CZ0kHxpKEoyfc_lnmeHOkNp55jCppb7onN5YmWnN8,242
|
|
81
|
-
torchzero/modules/projections/projection.py,sha256=aYufSD3ftRUqVScPmqxwEFgP1P8ioxM8z9eyzaL7d10,10147
|
|
82
|
-
torchzero/modules/projections/structural.py,sha256=QaCGHmzHCXj46sM-XZ5XlYU9BnuRKI2ReR3LE8y2R4g,5740
|
|
83
|
-
torchzero/modules/quasi_newton/__init__.py,sha256=0iOlX73PHj9lQS3_2cJ5lyCdas904MnFfIvR8Popvzw,402
|
|
84
|
-
torchzero/modules/quasi_newton/cg.py,sha256=lIJvfWAZ08r0o4uqaJnRG6pvcE2kBkJUkZ1MK37KMTk,9602
|
|
85
|
-
torchzero/modules/quasi_newton/lbfgs.py,sha256=SMgesPMZ4ubVeG7R395SnAb5ffkyPHbzSQMqPlLGI7U,9211
|
|
86
|
-
torchzero/modules/quasi_newton/lsr1.py,sha256=XmYyYANzQgQuFtOMW59znQrS-mprGRXazicfB9JAup8,6059
|
|
87
|
-
torchzero/modules/quasi_newton/olbfgs.py,sha256=2YAOXlMnPGw22sNcIMH1hmggzAXQRbN59RSPUZNKUZY,8352
|
|
88
|
-
torchzero/modules/quasi_newton/quasi_newton.py,sha256=rUp4s3MbACcOjwpz00TAjl-olif50voTmC16vv5XrSE,17496
|
|
89
|
-
torchzero/modules/quasi_newton/experimental/__init__.py,sha256=3qpZGgdsx6wpoafWaNWx-eamRl1FuxVCWQZq8Y7Cl98,39
|
|
90
|
-
torchzero/modules/quasi_newton/experimental/modular_lbfgs.py,sha256=ec6JKYX89xA_UlY9VrMB3hBjDyNKwkalS_4JQGA1qOY,10762
|
|
91
|
-
torchzero/modules/second_order/__init__.py,sha256=jolCGaIVkID9hpxgx0Tc22wgjVlwuWekWjKTMe5jKXw,114
|
|
92
|
-
torchzero/modules/second_order/newton.py,sha256=xxkrhFK4i5I9oOX3AGGh_6bXNDUSFq4D0pw3c7qgEd8,5925
|
|
93
|
-
torchzero/modules/second_order/newton_cg.py,sha256=PILHRf2koop_cywE1RNGukT16alDO7prC4C3HlZcW30,2861
|
|
94
|
-
torchzero/modules/second_order/nystrom.py,sha256=zdLSTQ_S5VViUt2sAmFNoDCCHKmHP2A7112czkZNlUk,6051
|
|
95
|
-
torchzero/modules/smoothing/__init__.py,sha256=tUTGN0A-EQC7xuLV2AuHFWk-t7D6jIJlpV_3qyfRqLk,80
|
|
96
|
-
torchzero/modules/smoothing/gaussian.py,sha256=YlT_G4MqAVkiWG56RHAwgt5SSPISpvQZQbSLh8mhF3I,6153
|
|
97
|
-
torchzero/modules/smoothing/laplacian.py,sha256=Bfrs7D59SfdU7j-97UBKD1hs0obC-ZgjJvG7oKwaa0o,5065
|
|
98
|
-
torchzero/modules/weight_decay/__init__.py,sha256=VdJfEx3uk8wYGCpMjYSeudXyGX8ONqsQYoBCE3cdM1U,72
|
|
99
|
-
torchzero/modules/weight_decay/weight_decay.py,sha256=p6jGD3hgC_rmZXiWYr7_IZWHMdVJJaT_bcHHzcdXSxU,1912
|
|
100
|
-
torchzero/modules/wrappers/__init__.py,sha256=6b5Ac-8u18IVp_Jnw1T1xQExwpQhpQ0JwNV9GyC_Yj8,31
|
|
101
|
-
torchzero/modules/wrappers/optim_wrapper.py,sha256=mcoQCUJwpMJuCDv03nDa0jZIb3Y0CyaeE1kNcJQozfo,3582
|
|
102
|
-
torchzero/optim/__init__.py,sha256=aXf7EkywqYiR50I4QeeVXro9aBhKiqfbY_BCia59sgU,46
|
|
103
|
-
torchzero/optim/utility/__init__.py,sha256=pUacok4XmebfxofE-QWZLgViajsU-3JkXcWi9OS-Jrw,24
|
|
104
|
-
torchzero/optim/utility/split.py,sha256=ZbazNuMTYunm75V_5ard0A_LletGaYAg-Pm2rANJKrE,1610
|
|
105
|
-
torchzero/optim/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
106
|
-
torchzero/optim/wrappers/nevergrad.py,sha256=2jHWQiWGjaffAqhJotMwOt03OtW-L57p8OesD2gVVow,3949
|
|
107
|
-
torchzero/optim/wrappers/nlopt.py,sha256=ZoHBf51OhwgAaExxmoFtvP8GqO9uBHdEsc4HLm0wcic,7588
|
|
108
|
-
torchzero/optim/wrappers/scipy.py,sha256=0BNBlHCbeTslXkXhnKvhuvJfNO7_CHFa2AXruYySnzM,14561
|
|
109
|
-
torchzero/utils/__init__.py,sha256=By___ngB1bcnrSZiJanvtKk8QFrPmLRhTOrkFYP2MU4,929
|
|
110
|
-
torchzero/utils/compile.py,sha256=N8AWLv_7oBUHYornmvvx_L4uynjiD-x5Hj1tBwei3-w,5127
|
|
111
|
-
torchzero/utils/derivatives.py,sha256=S4Vh2cwE2h6yvhqu799AjR4GVHOEg7yApH3SataKxnA,16881
|
|
112
|
-
torchzero/utils/numberlist.py,sha256=cbG0UsSb9WCRxVhw8sd7Yf0bDy_gSqtghiJtkUxIO6U,6139
|
|
113
|
-
torchzero/utils/ops.py,sha256=n4Su1sbgTzlHczuPEHkuWenTtNBCa_MvlQ_hCZkIPnQ,314
|
|
114
|
-
torchzero/utils/optimizer.py,sha256=-vuOZNu4luSZA5YtwC_7s-G2FvHKnM2k5KqC6bC_hcM,13097
|
|
115
|
-
torchzero/utils/optuna_tools.py,sha256=F-1Xg0n_29MVEb6lqgUFFNIl9BNJ6MOdIJPduoNH4JU,1325
|
|
116
|
-
torchzero/utils/params.py,sha256=nQo270aOURU7rJ_D102y2pSXbzhJPK0Z_ehx4mZBMes,5784
|
|
117
|
-
torchzero/utils/python_tools.py,sha256=RFBqNj8w52dpJ983pUPPDbg2x1MX_-SsBnBMffWGGIk,2066
|
|
118
|
-
torchzero/utils/tensorlist.py,sha256=qSbiliVo1euFAksdHHHRbPUdYYxfkw1dvhpXj71wGy0,53162
|
|
119
|
-
torchzero/utils/torch_tools.py,sha256=ohqnnZRlqdfp5PAfMSbQDIEKygW0_ARjxSEBp3Zo9nU,4756
|
|
120
|
-
torchzero/utils/linalg/__init__.py,sha256=Dzbho3_z7JDdKzYD-QdLArg0ZEoC2BVGdlE3JoAnXHQ,272
|
|
121
|
-
torchzero/utils/linalg/benchmark.py,sha256=wiIMn-GY2xxWbHVf8CPbJddUPeUPq9OUDkvbp1iILYI,479
|
|
122
|
-
torchzero/utils/linalg/matrix_funcs.py,sha256=-LecWrPWbJvfeCgIzUhfWARa2aSZvJ12lHX7Jno38O4,3099
|
|
123
|
-
torchzero/utils/linalg/orthogonalize.py,sha256=mDCkET7qgDZqf_y6oPYAK3d2L5HrB8gzOFPl0YoONaY,399
|
|
124
|
-
torchzero/utils/linalg/qr.py,sha256=L-RXuYV-SIHI-Llq4y1rQ_Tz-yamds0_QNZeHapbjNE,2507
|
|
125
|
-
torchzero/utils/linalg/solve.py,sha256=P0PMi0zro3G3Rd0X-JeoLk7tqYDB0js0aB4bpQ0OABU,5235
|
|
126
|
-
torchzero/utils/linalg/svd.py,sha256=wBxl-JSciINV-N6zvM4SGdveqMr6idq51h68LyQQRYg,660
|
|
127
|
-
torchzero-0.3.9.dist-info/licenses/LICENSE,sha256=r9ZciAoZoqKC_FNADE0ORukj1p1XhLXEbegdsAyqhJs,1087
|
|
128
|
-
torchzero-0.3.9.dist-info/METADATA,sha256=aENIaMgy94tD6nakRWfApleVSy6bxW8-q3-mQeVSeGA,13941
|
|
129
|
-
torchzero-0.3.9.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
|
130
|
-
torchzero-0.3.9.dist-info/top_level.txt,sha256=YDdpIOb7HyKV9THOtOYsFFMTbxvCO0kiol4-83tDj-A,21
|
|
131
|
-
torchzero-0.3.9.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|