torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,39 +1,177 @@
1
- """Experimental and I need to test this on Windows."""
2
- import warnings
3
- import functools
1
+ import time
2
+
4
3
  import torch
4
+ import torch.utils.benchmark
5
+
6
+ class _OptionalCompiler:
7
+ """this holds .enable attribute, set to True to enable compiling library wise"""
8
+ def __init__(self):
9
+ self.enable = False
10
+
11
+ def enable_compilation(
12
+ self,
13
+ x,
14
+ fullgraph: bool = False,
15
+ dynamic: bool | None = None,
16
+ backend="inductor",
17
+ mode: str | None = "max-autotune-no-cudagraphs",
18
+ options: dict[str, str | int | bool] | None = None,
19
+ disable: bool = False,
20
+ ):
21
+ """compiles if self.compile is True otherwise returns uncompiled `x`"""
22
+ return _MaybeCompiledFunc(x, self, fullgraph=fullgraph, dynamic=dynamic, backend=backend, mode=mode, options=options, disable=disable)
23
+
24
+ class _MaybeCompiledFunc:
25
+ def __init__(self, func, compiler: _OptionalCompiler, **kwargs):
26
+ self.func = func
27
+ self.kwargs = kwargs
28
+ self.compiled = False
29
+ self.compiler = compiler
30
+
31
+ def __call__(self, *args, **kwargs):
32
+ if self.compiler.enable and not self.compiled:
33
+ self.func = torch.compile(self.func, **self.kwargs)
34
+ self.compiled = True
35
+ return self.func(*args, **kwargs)
36
+
37
+
38
+ _optional_compiler = _OptionalCompiler()
39
+ """this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
40
+
41
+ def set_compilation(enable: bool):
42
+ """`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
43
+ _optional_compiler.enable = enable
44
+
45
+ def enable_compilation(fn): return _optional_compiler.enable_compilation(fn)
46
+
47
+ def benchmark_compile_cuda(fn, n: int, **kwargs):
48
+ # warmup
49
+ for _ in range(n):
50
+ fn(**kwargs)
51
+
52
+ compiled = torch.compile(fn, mode = 'max-autotune-no-cudagraphs')
53
+
54
+ # compiled warmup
55
+ for _ in range(n):
56
+ if _ == 0:
57
+ start = time.perf_counter()
58
+ compiled(**kwargs)
59
+ print(f'Compiling took {time.perf_counter() - start} s.')
60
+ else:
61
+ compiled(**kwargs)
62
+
63
+ # UNCOMPILED
64
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
65
+ torch.cuda.synchronize()
66
+ starter.record() # type:ignore
67
+ start = time.perf_counter()
68
+
69
+ for _ in range(n):
70
+ fn(**kwargs)
71
+
72
+ ender.record() # type:ignore
73
+ torch.cuda.synchronize()
74
+ sec = 1e-3 * starter.elapsed_time(ender)
75
+
76
+ print(f'Uncompiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
77
+
78
+ # COMPILED
79
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
80
+ torch.cuda.synchronize()
81
+ starter.record() # type:ignore
82
+ start = time.perf_counter()
83
+
84
+ for _ in range(n):
85
+ compiled(**kwargs)
86
+
87
+ ender.record() # type:ignore
88
+ torch.cuda.synchronize()
89
+ sec = 1e-3 * starter.elapsed_time(ender)
90
+
91
+ print(f'Compiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
92
+
93
+ # UNCOMPILED
94
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
95
+ torch.cuda.synchronize()
96
+ starter.record() # type:ignore
97
+ start = time.perf_counter()
98
+
99
+ for _ in range(n):
100
+ fn(**kwargs)
101
+
102
+ ender.record() # type:ignore
103
+ torch.cuda.synchronize()
104
+ sec = 1e-3 * starter.elapsed_time(ender)
105
+
106
+ print(f'Uncompiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
107
+
108
+ # COMPILED
109
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
110
+ torch.cuda.synchronize()
111
+ starter.record() # type:ignore
112
+ start = time.perf_counter()
113
+
114
+ for _ in range(n):
115
+ compiled(**kwargs)
116
+
117
+ ender.record() # type:ignore
118
+ torch.cuda.synchronize()
119
+ sec = 1e-3 * starter.elapsed_time(ender)
120
+
121
+ print(f'Compiled took {sec} CUDA s, {time.perf_counter() - start} perf_counter s.')
122
+
123
+ def benchmark_compile_cpu(fn, n: int, **kwargs):
124
+ # warmup
125
+ for _ in range(n):
126
+ fn(**kwargs)
127
+
128
+ compiled = torch.compile(fn, mode = 'max-autotune-no-cudagraphs')
129
+
130
+ # compiled warmup
131
+ for _ in range(n):
132
+ if _ == 0:
133
+ start = time.perf_counter()
134
+ compiled(**kwargs)
135
+ print(f'Compiling took {time.perf_counter() - start} s.')
136
+ else:
137
+ compiled(**kwargs)
138
+
139
+ # UNCOMPILED
140
+ start = time.perf_counter()
141
+
142
+ for _ in range(n):
143
+ fn(**kwargs)
144
+
145
+ sec = time.perf_counter() - start
146
+
147
+ print(f'Uncompiled took {sec} s., {sec/n} per call')
148
+
149
+ # COMPILED
150
+ start = time.perf_counter()
151
+
152
+ for _ in range(n):
153
+ compiled(**kwargs)
154
+
155
+ sec = time.perf_counter() - start
156
+
157
+ print(f'Compiled took {sec} s., {sec/n} per call')
158
+
159
+ # UNCOMPILED
160
+ start = time.perf_counter()
161
+
162
+ for _ in range(n):
163
+ fn(**kwargs)
164
+
165
+ sec = time.perf_counter() - start
166
+
167
+ print(f'Uncompiled took {sec} s., {sec/n} per call')
168
+
169
+ # COMPILED
170
+ start = time.perf_counter()
171
+
172
+ for _ in range(n):
173
+ compiled(**kwargs)
174
+
175
+ sec = time.perf_counter() - start
5
176
 
6
- ENABLE_COMPILING = True
7
-
8
- def _try_compiling(warn=False):
9
- def add(x,y): return x + y
10
- compled_add = torch.compile(add)
11
- try:
12
- res = compled_add(torch.tensor(1.), torch.tensor(2.))
13
- except Exception as e:
14
- if warn: warnings.warn(f'Compiling failed so no further functions will be compiled:\n{e}')
15
- return False
16
- if res == 3: return True
17
- return False
18
-
19
- class _Compiler:
20
- def __init__(self, warn=False):
21
- self.can_compile = None
22
- self.warn = warn
23
-
24
- def maybe_compile(self, fn, **kwargs):
25
- if self.can_compile is None: self.can_compile = _try_compiling(self.warn)
26
- if self.can_compile: return torch.compile(fn, **kwargs)
27
- return fn
28
-
29
- _COMPILER = _Compiler(False)
30
-
31
- @functools.wraps(torch.compile)
32
- def maybe_compile(*args, **kwargs):
33
- """Compiles a function if possible. Same usage as `torch.compile`.
34
-
35
- On first try this will attempt to compile a simple test function. If that fails, all subsequent functions will not be compiled.
36
- I need to actually test this on windows.
37
- """
38
- if ENABLE_COMPILING: return _COMPILER.maybe_compile(*args, **kwargs)
39
- return args[0]
177
+ print(f'Compiled took {sec} s., {sec/n} per call')