torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_identical.py +22 -22
- tests/test_module_autograd.py +586 -0
- tests/test_objective.py +188 -0
- tests/test_opts.py +225 -214
- tests/test_tensorlist.py +0 -8
- tests/test_utils_optimizer.py +0 -1
- torchzero/__init__.py +2 -2
- torchzero/core/__init__.py +7 -4
- torchzero/core/chain.py +20 -23
- torchzero/core/functional.py +90 -24
- torchzero/core/modular.py +53 -57
- torchzero/core/module.py +132 -52
- torchzero/core/objective.py +948 -0
- torchzero/core/reformulation.py +55 -24
- torchzero/core/transform.py +261 -367
- torchzero/linalg/__init__.py +11 -0
- torchzero/linalg/eigh.py +253 -0
- torchzero/linalg/linalg_utils.py +14 -0
- torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
- torchzero/linalg/matrix_power.py +28 -0
- torchzero/linalg/orthogonalize.py +93 -0
- torchzero/{utils/linalg → linalg}/qr.py +16 -2
- torchzero/{utils/linalg → linalg}/solve.py +74 -88
- torchzero/linalg/svd.py +47 -0
- torchzero/linalg/torch_linalg.py +168 -0
- torchzero/modules/__init__.py +4 -3
- torchzero/modules/adaptive/__init__.py +11 -3
- torchzero/modules/adaptive/adagrad.py +167 -217
- torchzero/modules/adaptive/adahessian.py +76 -105
- torchzero/modules/adaptive/adam.py +53 -76
- torchzero/modules/adaptive/adan.py +50 -31
- torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
- torchzero/modules/adaptive/aegd.py +12 -12
- torchzero/modules/adaptive/esgd.py +98 -119
- torchzero/modules/adaptive/ggt.py +186 -0
- torchzero/modules/adaptive/lion.py +7 -11
- torchzero/modules/adaptive/lre_optimizers.py +299 -0
- torchzero/modules/adaptive/mars.py +7 -7
- torchzero/modules/adaptive/matrix_momentum.py +48 -52
- torchzero/modules/adaptive/msam.py +71 -53
- torchzero/modules/adaptive/muon.py +67 -129
- torchzero/modules/adaptive/natural_gradient.py +63 -41
- torchzero/modules/adaptive/orthograd.py +11 -15
- torchzero/modules/adaptive/psgd/__init__.py +5 -0
- torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
- torchzero/modules/adaptive/psgd/psgd.py +1390 -0
- torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
- torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
- torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
- torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
- torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
- torchzero/modules/adaptive/rmsprop.py +83 -75
- torchzero/modules/adaptive/rprop.py +48 -47
- torchzero/modules/adaptive/sam.py +55 -45
- torchzero/modules/adaptive/shampoo.py +149 -130
- torchzero/modules/adaptive/soap.py +207 -143
- torchzero/modules/adaptive/sophia_h.py +106 -130
- torchzero/modules/clipping/clipping.py +22 -25
- torchzero/modules/clipping/ema_clipping.py +31 -25
- torchzero/modules/clipping/growth_clipping.py +14 -17
- torchzero/modules/conjugate_gradient/cg.py +27 -38
- torchzero/modules/experimental/__init__.py +7 -6
- torchzero/modules/experimental/adanystrom.py +258 -0
- torchzero/modules/experimental/common_directions_whiten.py +142 -0
- torchzero/modules/experimental/coordinate_momentum.py +36 -0
- torchzero/modules/experimental/cubic_adam.py +160 -0
- torchzero/modules/experimental/curveball.py +25 -41
- torchzero/modules/experimental/eigen_sr1.py +182 -0
- torchzero/modules/experimental/eigengrad.py +207 -0
- torchzero/modules/experimental/gradmin.py +2 -2
- torchzero/modules/experimental/higher_order_newton.py +14 -40
- torchzero/modules/experimental/l_infinity.py +1 -1
- torchzero/modules/experimental/matrix_nag.py +122 -0
- torchzero/modules/experimental/newton_solver.py +23 -54
- torchzero/modules/experimental/newtonnewton.py +45 -48
- torchzero/modules/experimental/reduce_outward_lr.py +7 -7
- torchzero/modules/experimental/scipy_newton_cg.py +21 -24
- torchzero/modules/experimental/spsa1.py +3 -3
- torchzero/modules/experimental/structural_projections.py +1 -4
- torchzero/modules/grad_approximation/fdm.py +2 -2
- torchzero/modules/grad_approximation/forward_gradient.py +7 -7
- torchzero/modules/grad_approximation/grad_approximator.py +23 -16
- torchzero/modules/grad_approximation/rfdm.py +24 -21
- torchzero/modules/least_squares/gn.py +121 -50
- torchzero/modules/line_search/backtracking.py +4 -4
- torchzero/modules/line_search/line_search.py +33 -33
- torchzero/modules/line_search/strong_wolfe.py +4 -4
- torchzero/modules/misc/debug.py +12 -12
- torchzero/modules/misc/escape.py +10 -10
- torchzero/modules/misc/gradient_accumulation.py +11 -79
- torchzero/modules/misc/homotopy.py +16 -8
- torchzero/modules/misc/misc.py +121 -123
- torchzero/modules/misc/multistep.py +52 -53
- torchzero/modules/misc/regularization.py +49 -44
- torchzero/modules/misc/split.py +31 -29
- torchzero/modules/misc/switch.py +37 -32
- torchzero/modules/momentum/averaging.py +14 -14
- torchzero/modules/momentum/cautious.py +37 -31
- torchzero/modules/momentum/momentum.py +12 -12
- torchzero/modules/ops/__init__.py +4 -4
- torchzero/modules/ops/accumulate.py +21 -21
- torchzero/modules/ops/binary.py +67 -66
- torchzero/modules/ops/higher_level.py +20 -20
- torchzero/modules/ops/multi.py +44 -41
- torchzero/modules/ops/reduce.py +26 -23
- torchzero/modules/ops/unary.py +53 -53
- torchzero/modules/ops/utility.py +47 -46
- torchzero/modules/{functional.py → opt_utils.py} +1 -1
- torchzero/modules/projections/galore.py +1 -1
- torchzero/modules/projections/projection.py +46 -43
- torchzero/modules/quasi_newton/__init__.py +1 -1
- torchzero/modules/quasi_newton/damping.py +2 -2
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
- torchzero/modules/quasi_newton/lbfgs.py +10 -10
- torchzero/modules/quasi_newton/lsr1.py +10 -10
- torchzero/modules/quasi_newton/quasi_newton.py +54 -39
- torchzero/modules/quasi_newton/sg2.py +69 -205
- torchzero/modules/restarts/restars.py +39 -37
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/ifn.py +31 -62
- torchzero/modules/second_order/inm.py +57 -53
- torchzero/modules/second_order/multipoint.py +40 -80
- torchzero/modules/second_order/newton.py +165 -196
- torchzero/modules/second_order/newton_cg.py +105 -157
- torchzero/modules/second_order/nystrom.py +216 -185
- torchzero/modules/second_order/rsn.py +132 -125
- torchzero/modules/smoothing/laplacian.py +13 -12
- torchzero/modules/smoothing/sampling.py +10 -10
- torchzero/modules/step_size/adaptive.py +24 -24
- torchzero/modules/step_size/lr.py +17 -17
- torchzero/modules/termination/termination.py +32 -30
- torchzero/modules/trust_region/cubic_regularization.py +3 -3
- torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
- torchzero/modules/trust_region/trust_cg.py +2 -2
- torchzero/modules/trust_region/trust_region.py +27 -22
- torchzero/modules/variance_reduction/svrg.py +23 -21
- torchzero/modules/weight_decay/__init__.py +2 -1
- torchzero/modules/weight_decay/reinit.py +83 -0
- torchzero/modules/weight_decay/weight_decay.py +17 -18
- torchzero/modules/wrappers/optim_wrapper.py +14 -14
- torchzero/modules/zeroth_order/cd.py +10 -7
- torchzero/optim/mbs.py +291 -0
- torchzero/optim/root.py +3 -3
- torchzero/optim/utility/split.py +2 -1
- torchzero/optim/wrappers/directsearch.py +27 -63
- torchzero/optim/wrappers/fcmaes.py +14 -35
- torchzero/optim/wrappers/mads.py +11 -31
- torchzero/optim/wrappers/moors.py +66 -0
- torchzero/optim/wrappers/nevergrad.py +4 -13
- torchzero/optim/wrappers/nlopt.py +31 -25
- torchzero/optim/wrappers/optuna.py +8 -13
- torchzero/optim/wrappers/pybobyqa.py +124 -0
- torchzero/optim/wrappers/scipy/__init__.py +7 -0
- torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
- torchzero/optim/wrappers/scipy/brute.py +48 -0
- torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
- torchzero/optim/wrappers/scipy/direct.py +69 -0
- torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
- torchzero/optim/wrappers/scipy/experimental.py +141 -0
- torchzero/optim/wrappers/scipy/minimize.py +151 -0
- torchzero/optim/wrappers/scipy/sgho.py +111 -0
- torchzero/optim/wrappers/wrapper.py +121 -0
- torchzero/utils/__init__.py +7 -25
- torchzero/utils/benchmarks/__init__.py +0 -0
- torchzero/utils/benchmarks/logistic.py +122 -0
- torchzero/utils/compile.py +2 -2
- torchzero/utils/derivatives.py +97 -73
- torchzero/utils/optimizer.py +4 -77
- torchzero/utils/python_tools.py +31 -0
- torchzero/utils/tensorlist.py +11 -5
- torchzero/utils/thoad_tools.py +68 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
- torchzero-0.4.1.dist-info/RECORD +209 -0
- tests/test_vars.py +0 -185
- torchzero/core/var.py +0 -376
- torchzero/modules/adaptive/lmadagrad.py +0 -186
- torchzero/modules/experimental/momentum.py +0 -160
- torchzero/optim/wrappers/scipy.py +0 -572
- torchzero/utils/linalg/__init__.py +0 -12
- torchzero/utils/linalg/matrix_funcs.py +0 -87
- torchzero/utils/linalg/orthogonalize.py +0 -12
- torchzero/utils/linalg/svd.py +0 -20
- torchzero/utils/ops.py +0 -10
- torchzero-0.3.15.dist-info/RECORD +0 -175
- /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
- {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,160 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
from functools import partial
|
|
3
|
-
from operator import itemgetter
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
|
-
from ...core import Target, Transform
|
|
9
|
-
from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
-
from ..functional import ema_, ema_sq_, sqrt_ema_sq_
|
|
11
|
-
from ..momentum.momentum import nag_
|
|
12
|
-
from ..ops.higher_level import EMASquared, SqrtEMASquared
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def precentered_ema_sq_(
|
|
16
|
-
tensors: TensorList,
|
|
17
|
-
exp_avg_: TensorList,
|
|
18
|
-
exp_avg_sq_: TensorList,
|
|
19
|
-
beta1: float | NumberList,
|
|
20
|
-
beta2: float | NumberList,
|
|
21
|
-
step: int,
|
|
22
|
-
min_step: int,
|
|
23
|
-
pow: float,
|
|
24
|
-
max_exp_avg_sq_: TensorList | None,
|
|
25
|
-
):
|
|
26
|
-
"""
|
|
27
|
-
Squared EMA of (update - 1st EMA). Starts taking effect after `min_step` to avoid division by epsilon.
|
|
28
|
-
|
|
29
|
-
returns `exp_avg_sq_` or `max_exp_avg_sq_`.
|
|
30
|
-
"""
|
|
31
|
-
exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0, lerp=False)
|
|
32
|
-
|
|
33
|
-
if step < min_step: centered_update = tensors
|
|
34
|
-
else: centered_update = tensors - exp_avg_
|
|
35
|
-
|
|
36
|
-
exp_avg_sq_=ema_sq_(
|
|
37
|
-
centered_update,
|
|
38
|
-
exp_avg_sq_=exp_avg_sq_,
|
|
39
|
-
beta=beta2,
|
|
40
|
-
pow=pow,
|
|
41
|
-
max_exp_avg_sq_=max_exp_avg_sq_,
|
|
42
|
-
)
|
|
43
|
-
return exp_avg_sq_
|
|
44
|
-
|
|
45
|
-
class PrecenteredEMASquared(Transform):
|
|
46
|
-
"""Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
|
|
47
|
-
def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
|
|
48
|
-
defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
|
|
49
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
50
|
-
|
|
51
|
-
@torch.no_grad
|
|
52
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
53
|
-
step = self.global_state['step'] = self.global_state.get('step', 0) + 1
|
|
54
|
-
|
|
55
|
-
beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
|
|
56
|
-
amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
|
|
57
|
-
|
|
58
|
-
if amsgrad:
|
|
59
|
-
exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
|
|
60
|
-
else:
|
|
61
|
-
exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
|
|
62
|
-
max_exp_avg_sq = None
|
|
63
|
-
|
|
64
|
-
return precentered_ema_sq_(
|
|
65
|
-
TensorList(tensors),
|
|
66
|
-
exp_avg_ = exp_avg,
|
|
67
|
-
exp_avg_sq_=exp_avg_sq,
|
|
68
|
-
beta1=beta1,
|
|
69
|
-
beta2=beta2,
|
|
70
|
-
step = step,
|
|
71
|
-
min_step=min_step,
|
|
72
|
-
pow=pow,
|
|
73
|
-
max_exp_avg_sq_=max_exp_avg_sq,
|
|
74
|
-
).clone()
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def nag_ema_sq_(
|
|
78
|
-
tensors: TensorList,
|
|
79
|
-
exp_avg_sq_: TensorList,
|
|
80
|
-
beta: float | NumberList,
|
|
81
|
-
max_exp_avg_sq_: TensorList | None,
|
|
82
|
-
pow: float,
|
|
83
|
-
lerp:bool=True,
|
|
84
|
-
):
|
|
85
|
-
"""
|
|
86
|
-
Nesterov EMA of squared tensors.
|
|
87
|
-
|
|
88
|
-
Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
|
|
89
|
-
"""
|
|
90
|
-
if pow == 1: tensors = tensors.abs()
|
|
91
|
-
elif pow%2 == 0: tensors = tensors.pow(pow)
|
|
92
|
-
else: tensors = tensors.pow(pow).abs()
|
|
93
|
-
|
|
94
|
-
exp_avg_sq_=nag_(tensors,velocity_=exp_avg_sq_,momentum=beta,dampening=0,lerp=lerp,)
|
|
95
|
-
|
|
96
|
-
# AMSGrad
|
|
97
|
-
if max_exp_avg_sq_ is not None:
|
|
98
|
-
max_exp_avg_sq_.maximum_(exp_avg_sq_)
|
|
99
|
-
exp_avg_sq_ = max_exp_avg_sq_
|
|
100
|
-
|
|
101
|
-
return exp_avg_sq_
|
|
102
|
-
|
|
103
|
-
def sqrt_nag_ema_sq_(
|
|
104
|
-
tensors: TensorList,
|
|
105
|
-
exp_avg_sq_: TensorList,
|
|
106
|
-
beta: float | NumberList,
|
|
107
|
-
max_exp_avg_sq_: TensorList | None,
|
|
108
|
-
debiased: bool,
|
|
109
|
-
step: int,
|
|
110
|
-
pow: float,
|
|
111
|
-
lerp:bool=False,
|
|
112
|
-
):
|
|
113
|
-
"""
|
|
114
|
-
Square root of nesterov EMA of squared tensors.
|
|
115
|
-
|
|
116
|
-
Returns new tensors.
|
|
117
|
-
"""
|
|
118
|
-
return sqrt_ema_sq_(tensors=tensors,exp_avg_sq_=exp_avg_sq_,beta=beta,max_exp_avg_sq_=max_exp_avg_sq_,
|
|
119
|
-
pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
|
|
120
|
-
|
|
121
|
-
class NesterovEMASquared(EMASquared):
|
|
122
|
-
"""squared momentum with nesterov momentum rule"""
|
|
123
|
-
EMA_SQ_FN = staticmethod(nag_ema_sq_)
|
|
124
|
-
|
|
125
|
-
class SqrtNesterovEMASquared(SqrtEMASquared):
|
|
126
|
-
"""square root of squared momentum with nesterov momentum rule"""
|
|
127
|
-
SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
def coordinate_momentum_(
|
|
131
|
-
tensors: TensorList,
|
|
132
|
-
velocity_: TensorList,
|
|
133
|
-
p: float | NumberList,
|
|
134
|
-
):
|
|
135
|
-
"""
|
|
136
|
-
sets `velocity_` to p% random values from `tensors`.
|
|
137
|
-
|
|
138
|
-
Returns `velocity_`
|
|
139
|
-
"""
|
|
140
|
-
mask = tensors.bernoulli_like(p).as_bool()
|
|
141
|
-
velocity_.masked_set_(mask, tensors)
|
|
142
|
-
return velocity_
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
class CoordinateMomentum(Transform):
|
|
146
|
-
"""Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
p (float, optional): _description_. Defaults to 0.1.
|
|
150
|
-
target (Target, optional): _description_. Defaults to 'update'.
|
|
151
|
-
"""
|
|
152
|
-
def __init__(self, p: float = 0.1, target: Target = 'update'):
|
|
153
|
-
defaults = dict(p=p)
|
|
154
|
-
super().__init__(defaults, uses_grad=False, target=target)
|
|
155
|
-
|
|
156
|
-
@torch.no_grad
|
|
157
|
-
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
158
|
-
p = NumberList(s['p'] for s in settings)
|
|
159
|
-
velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
|
|
160
|
-
return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
|