torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +57 -0
- tests/test_identical.py +230 -0
- tests/test_module.py +50 -0
- tests/test_opts.py +884 -0
- tests/test_tensorlist.py +1787 -0
- tests/test_utils_optimizer.py +170 -0
- tests/test_vars.py +184 -0
- torchzero/__init__.py +4 -4
- torchzero/core/__init__.py +3 -13
- torchzero/core/module.py +629 -510
- torchzero/core/preconditioner.py +137 -0
- torchzero/core/transform.py +252 -0
- torchzero/modules/__init__.py +13 -21
- torchzero/modules/clipping/__init__.py +3 -0
- torchzero/modules/clipping/clipping.py +320 -0
- torchzero/modules/clipping/ema_clipping.py +135 -0
- torchzero/modules/clipping/growth_clipping.py +187 -0
- torchzero/modules/experimental/__init__.py +13 -18
- torchzero/modules/experimental/absoap.py +350 -0
- torchzero/modules/experimental/adadam.py +111 -0
- torchzero/modules/experimental/adamY.py +135 -0
- torchzero/modules/experimental/adasoap.py +282 -0
- torchzero/modules/experimental/algebraic_newton.py +145 -0
- torchzero/modules/experimental/curveball.py +89 -0
- torchzero/modules/experimental/dsoap.py +290 -0
- torchzero/modules/experimental/gradmin.py +85 -0
- torchzero/modules/experimental/reduce_outward_lr.py +35 -0
- torchzero/modules/experimental/spectral.py +286 -0
- torchzero/modules/experimental/subspace_preconditioners.py +128 -0
- torchzero/modules/experimental/tropical_newton.py +136 -0
- torchzero/modules/functional.py +209 -0
- torchzero/modules/grad_approximation/__init__.py +4 -0
- torchzero/modules/grad_approximation/fdm.py +120 -0
- torchzero/modules/grad_approximation/forward_gradient.py +81 -0
- torchzero/modules/grad_approximation/grad_approximator.py +66 -0
- torchzero/modules/grad_approximation/rfdm.py +259 -0
- torchzero/modules/line_search/__init__.py +5 -30
- torchzero/modules/line_search/backtracking.py +186 -0
- torchzero/modules/line_search/line_search.py +181 -0
- torchzero/modules/line_search/scipy.py +37 -0
- torchzero/modules/line_search/strong_wolfe.py +260 -0
- torchzero/modules/line_search/trust_region.py +61 -0
- torchzero/modules/lr/__init__.py +2 -0
- torchzero/modules/lr/lr.py +59 -0
- torchzero/modules/lr/step_size.py +97 -0
- torchzero/modules/momentum/__init__.py +14 -4
- torchzero/modules/momentum/averaging.py +78 -0
- torchzero/modules/momentum/cautious.py +181 -0
- torchzero/modules/momentum/ema.py +173 -0
- torchzero/modules/momentum/experimental.py +189 -0
- torchzero/modules/momentum/matrix_momentum.py +124 -0
- torchzero/modules/momentum/momentum.py +43 -106
- torchzero/modules/ops/__init__.py +103 -0
- torchzero/modules/ops/accumulate.py +65 -0
- torchzero/modules/ops/binary.py +240 -0
- torchzero/modules/ops/debug.py +25 -0
- torchzero/modules/ops/misc.py +419 -0
- torchzero/modules/ops/multi.py +137 -0
- torchzero/modules/ops/reduce.py +149 -0
- torchzero/modules/ops/split.py +75 -0
- torchzero/modules/ops/switch.py +68 -0
- torchzero/modules/ops/unary.py +115 -0
- torchzero/modules/ops/utility.py +112 -0
- torchzero/modules/optimizers/__init__.py +18 -10
- torchzero/modules/optimizers/adagrad.py +146 -49
- torchzero/modules/optimizers/adam.py +112 -118
- torchzero/modules/optimizers/lion.py +18 -11
- torchzero/modules/optimizers/muon.py +222 -0
- torchzero/modules/optimizers/orthograd.py +55 -0
- torchzero/modules/optimizers/rmsprop.py +103 -51
- torchzero/modules/optimizers/rprop.py +342 -99
- torchzero/modules/optimizers/shampoo.py +197 -0
- torchzero/modules/optimizers/soap.py +286 -0
- torchzero/modules/optimizers/sophia_h.py +129 -0
- torchzero/modules/projections/__init__.py +5 -0
- torchzero/modules/projections/dct.py +73 -0
- torchzero/modules/projections/fft.py +73 -0
- torchzero/modules/projections/galore.py +10 -0
- torchzero/modules/projections/projection.py +218 -0
- torchzero/modules/projections/structural.py +151 -0
- torchzero/modules/quasi_newton/__init__.py +7 -4
- torchzero/modules/quasi_newton/cg.py +218 -0
- torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
- torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
- torchzero/modules/quasi_newton/lbfgs.py +228 -0
- torchzero/modules/quasi_newton/lsr1.py +170 -0
- torchzero/modules/quasi_newton/olbfgs.py +196 -0
- torchzero/modules/quasi_newton/quasi_newton.py +475 -0
- torchzero/modules/second_order/__init__.py +3 -4
- torchzero/modules/second_order/newton.py +142 -165
- torchzero/modules/second_order/newton_cg.py +84 -0
- torchzero/modules/second_order/nystrom.py +168 -0
- torchzero/modules/smoothing/__init__.py +2 -5
- torchzero/modules/smoothing/gaussian.py +164 -0
- torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
- torchzero/modules/weight_decay/__init__.py +1 -0
- torchzero/modules/weight_decay/weight_decay.py +52 -0
- torchzero/modules/wrappers/__init__.py +1 -0
- torchzero/modules/wrappers/optim_wrapper.py +91 -0
- torchzero/optim/__init__.py +2 -10
- torchzero/optim/utility/__init__.py +1 -0
- torchzero/optim/utility/split.py +45 -0
- torchzero/optim/wrappers/nevergrad.py +2 -28
- torchzero/optim/wrappers/nlopt.py +31 -16
- torchzero/optim/wrappers/scipy.py +79 -156
- torchzero/utils/__init__.py +27 -0
- torchzero/utils/compile.py +175 -37
- torchzero/utils/derivatives.py +513 -99
- torchzero/utils/linalg/__init__.py +5 -0
- torchzero/utils/linalg/matrix_funcs.py +87 -0
- torchzero/utils/linalg/orthogonalize.py +11 -0
- torchzero/utils/linalg/qr.py +71 -0
- torchzero/utils/linalg/solve.py +168 -0
- torchzero/utils/linalg/svd.py +20 -0
- torchzero/utils/numberlist.py +132 -0
- torchzero/utils/ops.py +10 -0
- torchzero/utils/optimizer.py +284 -0
- torchzero/utils/optuna_tools.py +40 -0
- torchzero/utils/params.py +149 -0
- torchzero/utils/python_tools.py +40 -25
- torchzero/utils/tensorlist.py +1081 -0
- torchzero/utils/torch_tools.py +48 -12
- torchzero-0.3.2.dist-info/METADATA +379 -0
- torchzero-0.3.2.dist-info/RECORD +128 -0
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
- {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
- torchzero-0.3.2.dist-info/top_level.txt +3 -0
- torchzero/core/tensorlist_optimizer.py +0 -219
- torchzero/modules/adaptive/__init__.py +0 -4
- torchzero/modules/adaptive/adaptive.py +0 -192
- torchzero/modules/experimental/experimental.py +0 -294
- torchzero/modules/experimental/quad_interp.py +0 -104
- torchzero/modules/experimental/subspace.py +0 -259
- torchzero/modules/gradient_approximation/__init__.py +0 -7
- torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
- torchzero/modules/gradient_approximation/base_approximator.py +0 -105
- torchzero/modules/gradient_approximation/fdm.py +0 -125
- torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
- torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
- torchzero/modules/gradient_approximation/rfdm.py +0 -125
- torchzero/modules/line_search/armijo.py +0 -56
- torchzero/modules/line_search/base_ls.py +0 -139
- torchzero/modules/line_search/directional_newton.py +0 -217
- torchzero/modules/line_search/grid_ls.py +0 -158
- torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
- torchzero/modules/meta/__init__.py +0 -12
- torchzero/modules/meta/alternate.py +0 -65
- torchzero/modules/meta/grafting.py +0 -195
- torchzero/modules/meta/optimizer_wrapper.py +0 -173
- torchzero/modules/meta/return_overrides.py +0 -46
- torchzero/modules/misc/__init__.py +0 -10
- torchzero/modules/misc/accumulate.py +0 -43
- torchzero/modules/misc/basic.py +0 -115
- torchzero/modules/misc/lr.py +0 -96
- torchzero/modules/misc/multistep.py +0 -51
- torchzero/modules/misc/on_increase.py +0 -53
- torchzero/modules/operations/__init__.py +0 -29
- torchzero/modules/operations/multi.py +0 -298
- torchzero/modules/operations/reduction.py +0 -134
- torchzero/modules/operations/singular.py +0 -113
- torchzero/modules/optimizers/sgd.py +0 -54
- torchzero/modules/orthogonalization/__init__.py +0 -2
- torchzero/modules/orthogonalization/newtonschulz.py +0 -159
- torchzero/modules/orthogonalization/svd.py +0 -86
- torchzero/modules/regularization/__init__.py +0 -22
- torchzero/modules/regularization/dropout.py +0 -34
- torchzero/modules/regularization/noise.py +0 -77
- torchzero/modules/regularization/normalization.py +0 -328
- torchzero/modules/regularization/ortho_grad.py +0 -78
- torchzero/modules/regularization/weight_decay.py +0 -92
- torchzero/modules/scheduling/__init__.py +0 -2
- torchzero/modules/scheduling/lr_schedulers.py +0 -131
- torchzero/modules/scheduling/step_size.py +0 -80
- torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
- torchzero/modules/weight_averaging/__init__.py +0 -2
- torchzero/modules/weight_averaging/ema.py +0 -72
- torchzero/modules/weight_averaging/swa.py +0 -171
- torchzero/optim/experimental/__init__.py +0 -20
- torchzero/optim/experimental/experimental.py +0 -343
- torchzero/optim/experimental/ray_search.py +0 -83
- torchzero/optim/first_order/__init__.py +0 -18
- torchzero/optim/first_order/cautious.py +0 -158
- torchzero/optim/first_order/forward_gradient.py +0 -70
- torchzero/optim/first_order/optimizers.py +0 -570
- torchzero/optim/modular.py +0 -148
- torchzero/optim/quasi_newton/__init__.py +0 -1
- torchzero/optim/quasi_newton/directional_newton.py +0 -58
- torchzero/optim/second_order/__init__.py +0 -1
- torchzero/optim/second_order/newton.py +0 -94
- torchzero/optim/zeroth_order/__init__.py +0 -4
- torchzero/optim/zeroth_order/fdm.py +0 -87
- torchzero/optim/zeroth_order/newton_fdm.py +0 -146
- torchzero/optim/zeroth_order/rfdm.py +0 -217
- torchzero/optim/zeroth_order/rs.py +0 -85
- torchzero/random/__init__.py +0 -1
- torchzero/random/random.py +0 -46
- torchzero/tensorlist.py +0 -826
- torchzero-0.1.8.dist-info/METADATA +0 -130
- torchzero-0.1.8.dist-info/RECORD +0 -104
- torchzero-0.1.8.dist-info/top_level.txt +0 -1
torchzero/utils/compile.py
CHANGED
|
@@ -1,39 +1,177 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import functools
|
|
1
|
+
import time
|
|
2
|
+
|
|
4
3
|
import torch
|
|
4
|
+
import torch.utils.benchmark
|
|
5
|
+
|
|
6
|
+
class _OptionalCompiler:
|
|
7
|
+
"""this holds .enable attribute, set to True to enable compiling library wise"""
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.enable = False
|
|
10
|
+
|
|
11
|
+
def enable_compilation(
|
|
12
|
+
self,
|
|
13
|
+
x,
|
|
14
|
+
fullgraph: bool = False,
|
|
15
|
+
dynamic: bool | None = None,
|
|
16
|
+
backend="inductor",
|
|
17
|
+
mode: str | None = "max-autotune-no-cudagraphs",
|
|
18
|
+
options: dict[str, str | int | bool] | None = None,
|
|
19
|
+
disable: bool = False,
|
|
20
|
+
):
|
|
21
|
+
"""compiles if self.compile is True otherwise returns uncompiled `x`"""
|
|
22
|
+
return _MaybeCompiledFunc(x, self, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode, options=options, disable=disable)
|
|
23
|
+
|
|
24
|
+
class _MaybeCompiledFunc:
|
|
25
|
+
def __init__(self, func, compiler: _OptionalCompiler, **kwargs):
|
|
26
|
+
self.func = func
|
|
27
|
+
self.kwargs = kwargs
|
|
28
|
+
self.compiled = False
|
|
29
|
+
self.compiler = compiler
|
|
30
|
+
|
|
31
|
+
def __call__(self, *args, **kwargs):
|
|
32
|
+
if self.compiler.enable and not self.compiled:
|
|
33
|
+
self.func = torch.compile(self.func, **self.kwargs)
|
|
34
|
+
self.compiled = True
|
|
35
|
+
return self.func(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
_optional_compiler = _OptionalCompiler()
|
|
39
|
+
"""this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
|
|
40
|
+
|
|
41
|
+
def set_compilation(enable: bool):
|
|
42
|
+
"""`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
|
|
43
|
+
_optional_compiler.enable = enable
|
|
44
|
+
|
|
45
|
+
def enable_compilation(fn): return _optional_compiler.enable_compilation(fn)
|
|
46
|
+
|
|
47
|
+
def benchmark_compile_cuda(fn, n: int, **kwargs):
|
|
48
|
+
# warmup
|
|
49
|
+
for _ in range(n):
|
|
50
|
+
fn(**kwargs)
|
|
51
|
+
|
|
52
|
+
compiled = torch.compile(fn, mode = 'max-autotune-no-cudagraphs')
|
|
53
|
+
|
|
54
|
+
# compiled warmup
|
|
55
|
+
for _ in range(n):
|
|
56
|
+
if _ == 0:
|
|
57
|
+
start = time.perf_counter()
|
|
58
|
+
compiled(**kwargs)
|
|
59
|
+
print(f'Compiling took {time.perf_counter() - start} s.')
|
|
60
|
+
else:
|
|
61
|
+
compiled(**kwargs)
|
|
62
|
+
|
|
63
|
+
# UNCOMPILED
|
|
64
|
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
65
|
+
torch.cuda.synchronize()
|
|
66
|
+
starter.record() # type:ignore
|
|
67
|
+
start = time.perf_counter()
|
|
68
|
+
|
|
69
|
+
for _ in range(n):
|
|
70
|
+
fn(**kwargs)
|
|
71
|
+
|
|
72
|
+
ender.record() # type:ignore
|
|
73
|
+
torch.cuda.synchronize()
|
|
74
|
+
sec = 1e-3 * starter.elapsed_time(ender)
|
|
75
|
+
|
|
76
|
+
print(f'Uncompiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
|
|
77
|
+
|
|
78
|
+
# COMPILED
|
|
79
|
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
80
|
+
torch.cuda.synchronize()
|
|
81
|
+
starter.record() # type:ignore
|
|
82
|
+
start = time.perf_counter()
|
|
83
|
+
|
|
84
|
+
for _ in range(n):
|
|
85
|
+
compiled(**kwargs)
|
|
86
|
+
|
|
87
|
+
ender.record() # type:ignore
|
|
88
|
+
torch.cuda.synchronize()
|
|
89
|
+
sec = 1e-3 * starter.elapsed_time(ender)
|
|
90
|
+
|
|
91
|
+
print(f'Compiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
|
|
92
|
+
|
|
93
|
+
# UNCOMPILED
|
|
94
|
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
95
|
+
torch.cuda.synchronize()
|
|
96
|
+
starter.record() # type:ignore
|
|
97
|
+
start = time.perf_counter()
|
|
98
|
+
|
|
99
|
+
for _ in range(n):
|
|
100
|
+
fn(**kwargs)
|
|
101
|
+
|
|
102
|
+
ender.record() # type:ignore
|
|
103
|
+
torch.cuda.synchronize()
|
|
104
|
+
sec = 1e-3 * starter.elapsed_time(ender)
|
|
105
|
+
|
|
106
|
+
print(f'Uncompiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
|
|
107
|
+
|
|
108
|
+
# COMPILED
|
|
109
|
+
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
|
|
110
|
+
torch.cuda.synchronize()
|
|
111
|
+
starter.record() # type:ignore
|
|
112
|
+
start = time.perf_counter()
|
|
113
|
+
|
|
114
|
+
for _ in range(n):
|
|
115
|
+
compiled(**kwargs)
|
|
116
|
+
|
|
117
|
+
ender.record() # type:ignore
|
|
118
|
+
torch.cuda.synchronize()
|
|
119
|
+
sec = 1e-3 * starter.elapsed_time(ender)
|
|
120
|
+
|
|
121
|
+
print(f'Compiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
|
|
122
|
+
|
|
123
|
+
def benchmark_compile_cpu(fn, n: int, **kwargs):
|
|
124
|
+
# warmup
|
|
125
|
+
for _ in range(n):
|
|
126
|
+
fn(**kwargs)
|
|
127
|
+
|
|
128
|
+
compiled = torch.compile(fn, mode = 'max-autotune-no-cudagraphs')
|
|
129
|
+
|
|
130
|
+
# compiled warmup
|
|
131
|
+
for _ in range(n):
|
|
132
|
+
if _ == 0:
|
|
133
|
+
start = time.perf_counter()
|
|
134
|
+
compiled(**kwargs)
|
|
135
|
+
print(f'Compiling took {time.perf_counter() - start} s.')
|
|
136
|
+
else:
|
|
137
|
+
compiled(**kwargs)
|
|
138
|
+
|
|
139
|
+
# UNCOMPILED
|
|
140
|
+
start = time.perf_counter()
|
|
141
|
+
|
|
142
|
+
for _ in range(n):
|
|
143
|
+
fn(**kwargs)
|
|
144
|
+
|
|
145
|
+
sec = time.perf_counter() - start
|
|
146
|
+
|
|
147
|
+
print(f'Uncompiled took {sec} s., {sec/n} per call')
|
|
148
|
+
|
|
149
|
+
# COMPILED
|
|
150
|
+
start = time.perf_counter()
|
|
151
|
+
|
|
152
|
+
for _ in range(n):
|
|
153
|
+
compiled(**kwargs)
|
|
154
|
+
|
|
155
|
+
sec = time.perf_counter() - start
|
|
156
|
+
|
|
157
|
+
print(f'Compiled took {sec} s., {sec/n} per call')
|
|
158
|
+
|
|
159
|
+
# UNCOMPILED
|
|
160
|
+
start = time.perf_counter()
|
|
161
|
+
|
|
162
|
+
for _ in range(n):
|
|
163
|
+
fn(**kwargs)
|
|
164
|
+
|
|
165
|
+
sec = time.perf_counter() - start
|
|
166
|
+
|
|
167
|
+
print(f'Uncompiled took {sec} s., {sec/n} per call')
|
|
168
|
+
|
|
169
|
+
# COMPILED
|
|
170
|
+
start = time.perf_counter()
|
|
171
|
+
|
|
172
|
+
for _ in range(n):
|
|
173
|
+
compiled(**kwargs)
|
|
174
|
+
|
|
175
|
+
sec = time.perf_counter() - start
|
|
5
176
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def _try_compiling(warn=False):
|
|
9
|
-
def add(x,y): return x + y
|
|
10
|
-
compled_add = torch.compile(add)
|
|
11
|
-
try:
|
|
12
|
-
res = compled_add(torch.tensor(1.), torch.tensor(2.))
|
|
13
|
-
except Exception as e:
|
|
14
|
-
if warn: warnings.warn(f'Compiling failed so no further functions will be compiled:\n{e}')
|
|
15
|
-
return False
|
|
16
|
-
if res == 3: return True
|
|
17
|
-
return False
|
|
18
|
-
|
|
19
|
-
class _Compiler:
|
|
20
|
-
def __init__(self, warn=False):
|
|
21
|
-
self.can_compile = None
|
|
22
|
-
self.warn = warn
|
|
23
|
-
|
|
24
|
-
def maybe_compile(self, fn, **kwargs):
|
|
25
|
-
if self.can_compile is None: self.can_compile = _try_compiling(self.warn)
|
|
26
|
-
if self.can_compile: return torch.compile(fn, **kwargs)
|
|
27
|
-
return fn
|
|
28
|
-
|
|
29
|
-
_COMPILER = _Compiler(False)
|
|
30
|
-
|
|
31
|
-
@functools.wraps(torch.compile)
|
|
32
|
-
def maybe_compile(*args, **kwargs):
|
|
33
|
-
"""Compiles a function if possible. Same usage as `torch.compile`.
|
|
34
|
-
|
|
35
|
-
On first try this will attempt to compile a simple test function. If that fails, all subsequent functions will not be compiled.
|
|
36
|
-
I need to actually test this on windows.
|
|
37
|
-
"""
|
|
38
|
-
if ENABLE_COMPILING: return _COMPILER.maybe_compile(*args, **kwargs)
|
|
39
|
-
return args[0]
|
|
177
|
+
print(f'Compiled took {sec} s., {sec/n} per call')
|