torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
docs/source/conf.py ADDED
@@ -0,0 +1,57 @@
1
+ # Configuration file for the Sphinx documentation builder.
2
+ #
3
+ # For the full list of built-in configuration values, see the documentation:
4
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
5
+
6
+ # -- Project information -----------------------------------------------------
7
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8
+ import sys, os
9
+ #sys.path.insert(0, os.path.abspath('.../src'))
10
+
11
+ project = 'torchzero'
12
+ copyright = '2024, Ivan Nikishev'
13
+ author = 'Ivan Nikishev'
14
+
15
+ # -- General configuration ---------------------------------------------------
16
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
17
+
18
+ # https://sphinx-intro-tutorial.readthedocs.io/en/latest/sphinx_extensions.html
19
+ extensions = [
20
+ 'sphinx.ext.autodoc',
21
+ 'sphinx.ext.autosummary',
22
+ 'sphinx.ext.viewcode',
23
+ 'sphinx.ext.autosectionlabel',
24
+ 'sphinx.ext.githubpages',
25
+ 'sphinx.ext.napoleon',
26
+ 'autoapi.extension',
27
+ # 'sphinx_rtd_theme',
28
+ ]
29
+ autosummary_generate = True
30
+ autoapi_dirs = ['../../src']
31
+ autoapi_type = "python"
32
+ # autoapi_ignore = ["*/tensorlist.py"]
33
+
34
+ # https://sphinx-autoapi.readthedocs.io/en/latest/reference/config.html#confval-autoapi_options
35
+ autoapi_options = [
36
+ "members",
37
+ "undoc-members",
38
+ "show-inheritance",
39
+ "show-module-summary",
40
+ "imported-members",
41
+ ]
42
+
43
+
44
+ templates_path = ['_templates']
45
+ exclude_patterns = []
46
+
47
+ # -- Options for HTML output -------------------------------------------------
48
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49
+
50
+ #html_theme = 'alabaster'
51
+ html_theme = 'furo'
52
+ html_static_path = ['_static']
53
+
54
+
55
+ # OTHER STUFF I FOUND ON THE INTERNET AND PUT THERE HOPING IT DOES SOMETHING USEFUL
56
+ source_suffix = ['.rst', '.md']
57
+ master_doc = 'index'
@@ -0,0 +1,230 @@
1
+ from collections.abc import Callable, Sequence
2
+ import pytest
3
+ import torch
4
+ import torchzero as tz
5
+
6
+ def _booth(x, y):
7
+ return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
8
+
9
+ _BOOTH_X0 = torch.tensor([0., -8.])
10
+
11
+ def _get_trajectory(opt_fn: Callable, x0: torch.Tensor, merge: bool, use_closure: bool, steps: int):
12
+ """Returns a Tensor - trajectory of `opt_fn` on the booth function."""
13
+ trajectory = []
14
+ if merge:
15
+ params = x0.clone().requires_grad_()
16
+ optimizer = opt_fn([params])
17
+ else:
18
+ params = [x0[0].clone().requires_grad_(), x0[1].clone().requires_grad_()]
19
+ optimizer = opt_fn(params)
20
+
21
+ for _ in range(steps):
22
+ if use_closure:
23
+ def closure(backward=True):
24
+ trajectory.append(torch.stack([p.clone() for p in params]))
25
+
26
+ loss = _booth(*params)
27
+ if backward:
28
+ optimizer.zero_grad()
29
+ loss.backward()
30
+ return loss
31
+
32
+ loss = optimizer.step(closure)
33
+ assert torch.isfinite(loss), f'non-finite loss {loss}'
34
+ for p in params: assert torch.isfinite(p), f'non-finite params {params}'
35
+
36
+ else:
37
+ trajectory.append(torch.stack([p.clone() for p in params]))
38
+
39
+ loss = _booth(*params)
40
+ assert torch.isfinite(loss), f'non-finite loss {loss}'
41
+ optimizer.zero_grad()
42
+ loss.backward()
43
+ optimizer.step()
44
+ for p in params: assert torch.isfinite(p), f'non-finite params {params}'
45
+
46
+
47
+ return torch.stack(trajectory, 0), optimizer
48
+
49
+ def _compare_trajectories(opt1, t1:torch.Tensor, opt2, t2:torch.Tensor):
50
+ assert torch.allclose(t1, t2, rtol=1e-4, atol=1e-6), f'trajectories dont match. opts:\n{opt1}\n{opt2}\ntrajectories:\n{t1}\n{t2}'
51
+
52
+ def _assert_identical_opts(opt_fns: Sequence[Callable], merge: bool, use_closure: bool, device, steps: int):
53
+ """checks that all `opt_fns` have identical trajectories on booth"""
54
+ x0 = _BOOTH_X0.clone().to(device=device)
55
+ base_opt = None
56
+ base_trajectory = None
57
+ for opt_fn in opt_fns:
58
+ t, opt = _get_trajectory(opt_fn, x0, merge, use_closure, steps)
59
+ if base_trajectory is None or base_opt is None:
60
+ base_trajectory = t
61
+ base_opt = opt
62
+ else: _compare_trajectories(base_opt, base_trajectory, opt, t)
63
+
64
+ def _assert_identical_merge(opt_fn: Callable, use_closure, device, steps: int):
65
+ """checks that trajectories match with x and y parameters split and merged"""
66
+ x0 = _BOOTH_X0.clone().to(device=device)
67
+ merged, merged_opt = _get_trajectory(opt_fn, x0, merge=True, use_closure=use_closure, steps=steps)
68
+ unmerged, unmerged_opt = _get_trajectory(opt_fn, x0, merge=False, use_closure=use_closure, steps=steps)
69
+ _compare_trajectories(merged_opt, merged, unmerged_opt, unmerged)
70
+
71
+ def _assert_identical_closure(opt_fn: Callable, merge, device, steps: int):
72
+ """checks that trajectories match with and without closure"""
73
+ x0 = _BOOTH_X0.clone().to(device=device)
74
+ closure, closure_opt = _get_trajectory(opt_fn, x0, merge=merge, use_closure=True, steps=steps)
75
+ no_closure, no_closure_opt = _get_trajectory(opt_fn, x0, merge=merge, use_closure=False, steps=steps)
76
+ _compare_trajectories(closure_opt, closure, no_closure_opt, no_closure)
77
+
78
+ def _assert_identical_merge_closure(opt_fn: Callable, device, steps: int):
79
+ """checks that trajectories match with x and y parameters split and merged and with and without closure"""
80
+ x0 = _BOOTH_X0.clone().to(device=device)
81
+ merge_closure, opt_merge_closure = _get_trajectory(opt_fn, x0, merge=True, use_closure=True, steps=steps)
82
+ merge_no_closure, opt_merge_no_closure = _get_trajectory(opt_fn, x0, merge=True, use_closure=False, steps=steps)
83
+ no_merge_closure, opt_no_merge_closure = _get_trajectory(opt_fn, x0, merge=False, use_closure=True, steps=steps)
84
+ no_merge_no_closure, opt_no_merge_no_closure = _get_trajectory(opt_fn, x0, merge=False, use_closure=False, steps=steps)
85
+
86
+ _compare_trajectories(opt_merge_closure, merge_closure, opt_merge_no_closure, merge_no_closure)
87
+ _compare_trajectories(opt_merge_closure, merge_closure, opt_no_merge_closure, no_merge_closure)
88
+ _compare_trajectories(opt_merge_closure, merge_closure, opt_no_merge_no_closure, no_merge_no_closure)
89
+
90
+ def _assert_identical_device(opt_fn: Callable, merge: bool, use_closure: bool, steps: int):
91
+ """checks that trajectories match on cpu and cuda."""
92
+ if not torch.cuda.is_available(): return
93
+ cpu, cpu_opt = _get_trajectory(opt_fn, _BOOTH_X0.clone().cpu(), merge=merge, use_closure=use_closure, steps=steps)
94
+ cuda, cuda_opt = _get_trajectory(opt_fn, _BOOTH_X0.clone().cuda(), merge=merge, use_closure=use_closure, steps=steps)
95
+ _compare_trajectories(cpu_opt, cpu, cuda_opt, cuda.to(cpu))
96
+
97
+ @pytest.mark.parametrize('amsgrad', [True, False])
98
+ def test_adam(amsgrad):
99
+ # torch_fn = lambda p: torch.optim.Adam(p, lr=1, amsgrad=amsgrad)
100
+ # pytorch applies debiasing separately so it is applied before epsilo
101
+ tz_fn = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad))
102
+ tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1)) # test LR fusing
103
+ tz_fn3 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.LR(1), tz.m.Add(1), tz.m.Sub(1))
104
+ tz_fn4 = lambda p: tz.Modular(p, tz.m.Adam(amsgrad=amsgrad), tz.m.Add(1), tz.m.Sub(1), tz.m.LR(1))
105
+ tz_fn5 = lambda p: tz.Modular(p, tz.m.Clone(), tz.m.Adam(amsgrad=amsgrad))
106
+ tz_fn_ops = lambda p: tz.Modular(
107
+ p,
108
+ tz.m.DivModules(
109
+ tz.m.EMA(0.9, debiased=True),
110
+ [tz.m.SqrtEMASquared(0.999, debiased=True, amsgrad=amsgrad), tz.m.Add(1e-8)]
111
+ ))
112
+ tz_fn_ops2 = lambda p: tz.Modular(
113
+ p,
114
+ tz.m.DivModules(
115
+ [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
116
+ [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Debias2(beta=0.999), tz.m.Add(1e-8)]
117
+ ))
118
+ tz_fn_ops3 = lambda p: tz.Modular(
119
+ p,
120
+ tz.m.DivModules(
121
+ [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9, beta2=0.999)],
122
+ [tz.m.EMASquared(0.999, amsgrad=amsgrad), tz.m.Sqrt(), tz.m.Add(1e-8)]
123
+ ))
124
+ tz_fn_ops4 = lambda p: tz.Modular(
125
+ p,
126
+ tz.m.DivModules(
127
+ [tz.m.EMA(0.9), tz.m.Debias(beta1=0.9)],
128
+ [
129
+ tz.m.Pow(2),
130
+ tz.m.EMA(0.999),
131
+ tz.m.AccumulateMaximum() if amsgrad else tz.m.Identity(),
132
+ tz.m.Sqrt(),
133
+ tz.m.Debias2(beta=0.999),
134
+ tz.m.Add(1e-8)]
135
+ ))
136
+ tz_fns = (tz_fn, tz_fn2, tz_fn3, tz_fn4, tz_fn5, tz_fn_ops, tz_fn_ops2, tz_fn_ops3, tz_fn_ops4)
137
+
138
+ _assert_identical_opts(tz_fns, merge=True, use_closure=True, device='cpu', steps=10)
139
+ for fn in tz_fns:
140
+ _assert_identical_merge_closure(fn, device='cpu', steps=10)
141
+ _assert_identical_device(fn, merge=True, use_closure=True, steps=10)
142
+
143
+ @pytest.mark.parametrize('beta1', [0.5, 0.9])
144
+ @pytest.mark.parametrize('beta2', [0.99, 0.999])
145
+ @pytest.mark.parametrize('eps', [1e-1, 1e-8])
146
+ @pytest.mark.parametrize('amsgrad', [True, False])
147
+ @pytest.mark.parametrize('lr', [0.1, 1])
148
+ def test_adam_hyperparams(beta1, beta2, eps, amsgrad, lr):
149
+ tz_fn = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad), tz.m.LR(lr))
150
+ tz_fn2 = lambda p: tz.Modular(p, tz.m.Adam(beta1, beta2, eps, amsgrad=amsgrad, alpha=lr))
151
+ _assert_identical_opts([tz_fn, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
152
+
153
+ @pytest.mark.parametrize('centered', [True, False])
154
+ def test_rmsprop(centered):
155
+ torch_fn = lambda p: torch.optim.RMSprop(p, 1, centered=centered)
156
+ tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(centered=centered, init='zeros'))
157
+ tz_fn2 = lambda p: tz.Modular(
158
+ p,
159
+ tz.m.Div([tz.m.CenteredSqrtEMASquared(0.99) if centered else tz.m.SqrtEMASquared(0.99), tz.m.Add(1e-8)]),
160
+ )
161
+ tz_fn3 = lambda p: tz.Modular(
162
+ p,
163
+ tz.m.Div([tz.m.CenteredEMASquared(0.99) if centered else tz.m.EMASquared(0.99), tz.m.Sqrt(), tz.m.Add(1e-8)]),
164
+ )
165
+ tz_fns = (tz_fn, tz_fn2, tz_fn3)
166
+ _assert_identical_opts([torch_fn, *tz_fns], merge=True, use_closure=True, device='cpu', steps=10)
167
+ for fn in tz_fns:
168
+ _assert_identical_merge_closure(fn, device='cpu', steps=10)
169
+ _assert_identical_device(fn, merge=True, use_closure=True, steps=10)
170
+
171
+
172
+ @pytest.mark.parametrize('beta', [0.5, 0.9])
173
+ @pytest.mark.parametrize('eps', [1e-1, 1e-8])
174
+ @pytest.mark.parametrize('centered', [True, False])
175
+ @pytest.mark.parametrize('lr', [0.1, 1])
176
+ def test_rmsprop_hyperparams(beta, eps, centered, lr):
177
+ tz_fn = lambda p: tz.Modular(p, tz.m.RMSprop(beta, eps, centered, init='zeros'), tz.m.LR(lr))
178
+ torch_fn = lambda p: torch.optim.RMSprop(p, lr, beta, eps=eps, centered=centered)
179
+ _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=10)
180
+
181
+
182
+
183
+ @pytest.mark.parametrize('nplus', (1.2, 2))
184
+ @pytest.mark.parametrize('nminus', (0.5, 0.9))
185
+ @pytest.mark.parametrize('lb', [1e-8, 1])
186
+ @pytest.mark.parametrize('ub', [50, 1.5])
187
+ @pytest.mark.parametrize('lr', [0.1, 1])
188
+ def test_rprop(nplus, nminus, lb, ub, lr):
189
+ tz_fn = lambda p: tz.Modular(p, tz.m.LR(lr), tz.m.Rprop(nplus, nminus, lb, ub, alpha=lr, backtrack=False))
190
+ torch_fn = lambda p: torch.optim.Rprop(p, lr, (nminus, nplus), (lb, ub))
191
+ _assert_identical_opts([torch_fn, tz_fn], merge=True, use_closure=True, device='cpu', steps=30)
192
+ _assert_identical_merge_closure(tz_fn, 'cpu', 30)
193
+ _assert_identical_device(tz_fn, merge=True, use_closure=True, steps=10)
194
+
195
+ def test_adagrad():
196
+ torch_fn = lambda p: torch.optim.Adagrad(p, 1)
197
+ tz_fn = lambda p: tz.Modular(p, tz.m.Adagrad(), tz.m.LR(1))
198
+ tz_fn2 = lambda p: tz.Modular(
199
+ p,
200
+ tz.m.Div([tz.m.Pow(2), tz.m.AccumulateSum(), tz.m.Sqrt(), tz.m.Add(1e-10)]),
201
+ )
202
+
203
+ tz_fns = (tz_fn, tz_fn2)
204
+ _assert_identical_opts([torch_fn, *tz_fns], merge=True, use_closure=True, device='cpu', steps=10)
205
+ for fn in tz_fns:
206
+ _assert_identical_merge_closure(fn, device='cpu', steps=10)
207
+ _assert_identical_device(fn, merge=True, use_closure=True, steps=10)
208
+
209
+
210
+
211
+ @pytest.mark.parametrize('initial_accumulator_value', [0, 1])
212
+ @pytest.mark.parametrize('eps', [1e-2, 1e-10])
213
+ @pytest.mark.parametrize('lr', [0.1, 1])
214
+ def test_adagrad_hyperparams(initial_accumulator_value, eps, lr):
215
+ torch_fn = lambda p: torch.optim.Adagrad(p, lr, initial_accumulator_value=initial_accumulator_value, eps=eps)
216
+ tz_fn1 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps), tz.m.LR(lr))
217
+ tz_fn2 = lambda p: tz.Modular(p, tz.m.Adagrad(initial_accumulator_value=initial_accumulator_value, eps=eps, alpha=lr))
218
+ _assert_identical_opts([torch_fn, tz_fn1, tz_fn2], merge=True, use_closure=True, device='cpu', steps=10)
219
+
220
+
221
+ @pytest.mark.parametrize('tensorwise', [True, False])
222
+ def test_graft(tensorwise):
223
+ graft1 = lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.LBFGS(), tz.m.RMSprop(), tensorwise=tensorwise), tz.m.LR(1e-1))
224
+ graft2 = lambda p: tz.Modular(p, tz.m.LBFGS(), tz.m.Graft([tz.m.Grad(), tz.m.RMSprop()], tensorwise=tensorwise), tz.m.LR(1e-1))
225
+ _assert_identical_opts([graft1, graft2], merge=True, use_closure=True, device='cpu', steps=10)
226
+ for fn in [graft1, graft2]:
227
+ if tensorwise: _assert_identical_closure(fn, merge=True, device='cpu', steps=10)
228
+ else: _assert_identical_merge_closure(fn, device='cpu', steps=10)
229
+ _assert_identical_device(fn, merge=True, use_closure=True, steps=10)
230
+
tests/test_module.py ADDED
@@ -0,0 +1,50 @@
1
+ from collections.abc import Iterable
2
+
3
+ import torch
4
+ from torchzero.core.module import Module, _make_param_groups
5
+ from torchzero.utils.optimizer import get_params
6
+ from torchzero.utils.params import _add_defaults_to_param_groups_
7
+
8
+ def _assert_same_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
9
+ seq1=tuple(seq1)
10
+ seq2=tuple(seq2)
11
+ assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
12
+ for t1, t2 in zip(seq1, seq2):
13
+ assert t1 is t2
14
+
15
+
16
+ def test_process_parameters():
17
+ model = torch.nn.Sequential(torch.nn.Linear(3, 6), torch.nn.Linear(6, 3))
18
+
19
+ # iterable of parameters
20
+ _assert_same_storage_(model.parameters(), get_params(_make_param_groups(model.parameters(), differentiable=False), 'all'))
21
+
22
+ # named parameters
23
+ _assert_same_storage_(model.parameters(), get_params(_make_param_groups(model.named_parameters(), differentiable=False), 'all'))
24
+
25
+ # param groups
26
+ param_groups = [{'params': model[0].parameters(), 'lr': 0.1}, {'params': model[1].parameters()}]
27
+ _assert_same_storage_(model.parameters(), get_params(_make_param_groups(param_groups, differentiable=False), 'all'))
28
+
29
+ # check that param groups dict is correct
30
+ param_groups = [
31
+ {'params': model[0].parameters(), 'lr': 0.1},
32
+ {'params': model[1].parameters()}
33
+ ]
34
+ expected = [
35
+ {'params': list(model[0].parameters()), 'lr': 0.1},
36
+ {'params': list(model[1].parameters())}
37
+ ]
38
+ assert _make_param_groups(param_groups, differentiable=False) == expected
39
+
40
+ # named params
41
+ _names = {'param_names': ['weight','bias']}
42
+ param_groups = [
43
+ {'params': model[0].named_parameters(), 'lr': 0.1},
44
+ {'params': model[1].named_parameters()}
45
+ ]
46
+ expected = [
47
+ {'params': list(model[0].parameters()), 'lr': 0.1, **_names},
48
+ {'params': list(model[1].parameters()), 'lr': 0.01, **_names}
49
+ ]
50
+ assert _add_defaults_to_param_groups_(_make_param_groups(param_groups, differentiable=False), {"lr": 0.01}) == expected