torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,170 @@
1
+ from collections.abc import Iterable
2
+ from typing import Any
3
+ from functools import partial
4
+ import pytest
5
+ import torch
6
+ from torchzero.utils.optimizer import (
7
+ Optimizer,
8
+ get_group_vals,
9
+ get_params,
10
+ get_state_vals,
11
+ )
12
+
13
+
14
+ def _assert_same_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
15
+ seq1=tuple(seq1)
16
+ seq2=tuple(seq2)
17
+ assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
18
+ for t1, t2 in zip(seq1, seq2):
19
+ assert t1 is t2
20
+
21
+ def _assert_equals_different_storage_(seq1: Iterable[torch.Tensor], seq2: Iterable[torch.Tensor]):
22
+ seq1=tuple(seq1)
23
+ seq2=tuple(seq2)
24
+ assert len(seq1) == len(seq2), f'lengths do not match: {len(seq1)} != {len(seq2)}'
25
+ for t1, t2 in zip(seq1, seq2):
26
+ assert t1 is not t2
27
+ assert (t1 == t2).all()
28
+
29
+ def test_assert_compare_tensors():
30
+ t1 = [torch.randn(1, 3) for _ in range(10)]
31
+ t2 = [torch.randn(1, 3) for _ in range(10)]
32
+
33
+ _assert_same_storage_(t1, t1)
34
+ _assert_same_storage_(t2, t2)
35
+
36
+ with pytest.raises(AssertionError):
37
+ _assert_same_storage_(t1, t2)
38
+
39
+
40
+ def test_get_params():
41
+ param_groups = [
42
+ {'params': [torch.randn(1, 1, requires_grad=True), torch.randn(1, 2, requires_grad=True)]},
43
+ {'params': [torch.randn(2, 1, requires_grad=True), torch.randn(2, 2, requires_grad=False)], "lr": 0.1},
44
+ {'params': [torch.randn(3, 1, requires_grad=False)], 'lr': 0.001, 'betas': (0.9, 0.99)},
45
+ ]
46
+ param_groups[0]['params'][0].grad = torch.randn(1, 1)
47
+
48
+ params = get_params(param_groups, mode = 'requires_grad', cls = list)
49
+ _assert_same_storage_(params, [*param_groups[0]['params'], param_groups[1]['params'][0]])
50
+
51
+ params = get_params(param_groups, mode = 'has_grad', cls = list)
52
+ _assert_same_storage_(params, [param_groups[0]['params'][0]])
53
+
54
+ params = get_params(param_groups, mode = 'all', cls = list)
55
+ _assert_same_storage_(params, [*param_groups[0]['params'], *param_groups[1]['params'], *param_groups[2]['params']])
56
+
57
+ def test_get_group_vals():
58
+ param_groups = [
59
+ {'params': [torch.randn(2, 1, requires_grad=True), torch.randn(2, 2, requires_grad=True)], "lr": 0.1, 'beta': 0.95, 'eps': 1e-8},
60
+ {'params': [torch.randn(1, 1, requires_grad=True), torch.randn(1, 2, requires_grad=False)], 'lr': 0.01, 'beta': 0.99, 'eps': 1e-7},
61
+ {'params': [torch.randn(3, 1, requires_grad=False)], 'lr': 0.001, 'beta': 0.999, 'eps': 1e-6},
62
+ ]
63
+ param_groups[0]['params'][0].grad = torch.randn(2, 1)
64
+
65
+
66
+ lr = get_group_vals(param_groups, 'lr', mode = 'requires_grad', cls = list)
67
+ assert lr == [0.1, 0.1, 0.01], lr
68
+
69
+ lr, beta = get_group_vals(param_groups, 'lr', 'beta', mode = 'requires_grad', cls = list)
70
+ assert lr == [0.1, 0.1, 0.01], lr
71
+ assert beta == [0.95, 0.95, 0.99], beta
72
+
73
+ lr, beta, eps = get_group_vals(param_groups, ('lr', 'beta', 'eps'), mode = 'requires_grad', cls = list)
74
+ assert lr == [0.1, 0.1, 0.01], lr
75
+ assert beta == [0.95, 0.95, 0.99], beta
76
+ assert eps == [1e-8, 1e-8, 1e-7], eps
77
+
78
+ lr = get_group_vals(param_groups, 'lr', mode = 'has_grad', cls = list)
79
+ assert lr == [0.1], lr
80
+
81
+ lr, beta, eps = get_group_vals(param_groups, 'lr', 'beta', 'eps', mode = 'all', cls = list)
82
+ assert lr == [0.1, 0.1, 0.01, 0.01, 0.001], lr
83
+ assert beta == [0.95, 0.95, 0.99, 0.99, 0.999], beta
84
+ assert eps == [1e-8, 1e-8, 1e-7, 1e-7, 1e-6], eps
85
+
86
+ def test_get_state_vals():
87
+ # accessing state values of a single parameter, which acts as the key, so no tensors are passed
88
+ tensor = torch.randn(3,3)
89
+ state = {tensor: {'exp_avg': torch.ones_like(tensor)}}
90
+ existing_cov_exp_avg = state[tensor]['exp_avg']
91
+ cov_exp_avg, cov_exp_avg_sq = get_state_vals(state, [tensor], ('exp_avg', 'exp_avg_sq'), init = [torch.zeros_like, lambda x: torch.full_like(x, 2)])
92
+ assert torch.allclose(cov_exp_avg[0], torch.ones_like(tensor))
93
+ assert torch.allclose(cov_exp_avg_sq[0], torch.full_like(tensor, 2))
94
+ assert cov_exp_avg[0] is existing_cov_exp_avg
95
+ assert state[tensor]['exp_avg'] is existing_cov_exp_avg
96
+ assert state[tensor]['exp_avg_sq'] is cov_exp_avg_sq[0]
97
+
98
+ # accessing state values of multiple parameters
99
+ parameters = [torch.randn(i,2) for i in range(1, 11)]
100
+ state = {p: {} for p in parameters}
101
+ exp_avgs = get_state_vals(state, parameters, 'exp_avg', cls=list)
102
+ assert isinstance(exp_avgs, list), type(exp_avgs)
103
+ assert len(exp_avgs) == 10, len(exp_avgs)
104
+ assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avgs))
105
+ exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
106
+ _assert_same_storage_(exp_avgs, exp_avgs2)
107
+
108
+ # per-parameter inits
109
+ parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
110
+ state = {p: {} for p in parameters}
111
+ exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = [partial(torch.full_like, fill_value=i) for i in range(10)], cls=list)
112
+ assert isinstance(exp_avgs, list), type(exp_avgs)
113
+ assert len(exp_avgs) == 10, len(exp_avgs)
114
+ assert all(torch.allclose(a, torch.full_like(parameters[i], i)) for i, a in enumerate(exp_avgs)), exp_avgs
115
+ exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
116
+ _assert_same_storage_(exp_avgs, exp_avgs2)
117
+
118
+ # per-parmeter init with a list
119
+ parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
120
+ state = {p: {} for p in parameters}
121
+ inits = [torch.full([i], fill_value=i) for i in range(1, 11)]
122
+ exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = inits, cls=list)
123
+ assert isinstance(exp_avgs, list), type(exp_avgs)
124
+ assert len(exp_avgs) == 10, len(exp_avgs)
125
+ _assert_equals_different_storage_(inits, exp_avgs) # inits are cloned
126
+ exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
127
+ _assert_same_storage_(exp_avgs, exp_avgs2)
128
+
129
+ # init with a value
130
+ parameters = [torch.full((i,2), fill_value=i**2) for i in range(1, 11)]
131
+ state = {p: {} for p in parameters}
132
+ inits = 1
133
+ exp_avgs = get_state_vals(state, parameters, 'exp_avg', init = inits, cls=list)
134
+ assert isinstance(exp_avgs, list), type(exp_avgs)
135
+ assert len(exp_avgs) == 10, len(exp_avgs)
136
+ assert all(v==1 for v in exp_avgs), exp_avgs
137
+ assert exp_avgs == get_state_vals(state, parameters, 'exp_avg', cls=list) # no init because already initialized
138
+
139
+ # accessing multiple keys
140
+ parameters = [torch.randn(i,2) for i in range(1,11)]
141
+ state = {p: {} for p in parameters}
142
+ exp_avgs, exp_avg_sqs, max_avgs = get_state_vals(state, parameters, 'exp_avg', 'exp_avg_sq', 'max_avg', cls=list)
143
+ assert len(exp_avgs) == len(exp_avg_sqs) == len(max_avgs) == 10
144
+ assert isinstance(exp_avgs, list), type(exp_avgs)
145
+ assert isinstance(exp_avg_sqs, list), type(exp_avg_sqs)
146
+ assert isinstance(max_avgs, list), type(max_avgs)
147
+ assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avgs))
148
+ assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(exp_avg_sqs))
149
+ assert all(torch.allclose(a, torch.zeros_like(parameters[i])) for i, a in enumerate(max_avgs))
150
+ exp_avgs2 = get_state_vals(state, parameters, 'exp_avg', cls=list)
151
+ exp_avg_sqs2 = get_state_vals(state, parameters, 'exp_avg_sq', cls=list)
152
+ max_avgs2 = get_state_vals(state, parameters, 'max_avg', cls=list)
153
+ _assert_same_storage_(exp_avgs, exp_avgs2)
154
+ _assert_same_storage_(exp_avg_sqs, exp_avg_sqs2)
155
+ _assert_same_storage_(max_avgs, max_avgs2)
156
+
157
+ # per-key init
158
+ parameters = [torch.randn(i,2) for i in range(1,11)]
159
+ state = {p: {} for p in parameters}
160
+ exp_avgs, exp_avg_sqs, max_avgs = get_state_vals(state, parameters, 'exp_avg', 'exp_avg_sq', 'max_avg', init=(4,5,5.5), cls=list)
161
+ assert len(exp_avgs) == len(exp_avg_sqs) == len(max_avgs) == 10
162
+ assert isinstance(exp_avgs, list), type(exp_avgs)
163
+ assert isinstance(exp_avg_sqs, list), type(exp_avg_sqs)
164
+ assert isinstance(max_avgs, list), type(max_avgs)
165
+ assert all(v==4 for v in exp_avgs), exp_avgs
166
+ assert all(v==5 for v in exp_avg_sqs), exp_avg_sqs
167
+ assert all(v==5.5 for v in max_avgs), max_avgs
168
+ assert exp_avgs == get_state_vals(state, parameters, 'exp_avg', cls=list)
169
+ assert exp_avg_sqs == get_state_vals(state, parameters, 'exp_avg_sq', cls=list)
170
+ assert max_avgs == get_state_vals(state, parameters, 'max_avg', cls=list)
tests/test_vars.py ADDED
@@ -0,0 +1,184 @@
1
+ import pytest
2
+ import torch
3
+ from torchzero.core.module import Vars
4
+ from torchzero.utils.tensorlist import TensorList
5
+
6
+ @torch.no_grad
7
+ def test_vars_get_loss():
8
+
9
+ # ---------------------------- test that it works ---------------------------- #
10
+ params = [torch.tensor(2.0, requires_grad=True)]
11
+ evaluated = False
12
+
13
+ def closure_1(backward=True):
14
+ assert not backward, 'backward = True'
15
+
16
+ # ensure closure only evaluates once
17
+ nonlocal evaluated
18
+ assert evaluated is False, 'closure was evaluated twice'
19
+ evaluated = True
20
+
21
+ loss = params[0]**2
22
+ if backward:
23
+ params[0].grad = None
24
+ loss.backward()
25
+ else:
26
+ assert not loss.requires_grad, "loss requires grad with backward=False"
27
+ return loss
28
+
29
+ vars = Vars(params=params, closure=closure_1, model=None, current_step=0)
30
+
31
+ assert vars.loss is None, vars.loss
32
+
33
+ assert (loss := vars.get_loss(backward=False)) == 4.0, loss
34
+ assert evaluated, evaluated
35
+ assert loss is vars.loss
36
+ assert vars.loss == 4.0
37
+ assert vars.loss_approx == 4.0
38
+ assert vars.grad is None, vars.grad
39
+
40
+ # reevaluate, which should just return already evaluated loss
41
+ assert (loss := vars.get_loss(backward=False)) == 4.0, loss
42
+ assert vars.grad is None, vars.grad
43
+
44
+
45
+ # ----------------------- test that backward=True works ---------------------- #
46
+ params = [torch.tensor(3.0, requires_grad=True)]
47
+ evaluated = False
48
+
49
+ def closure_2(backward=True):
50
+ # ensure closure only evaluates once
51
+ nonlocal evaluated
52
+ assert evaluated is False, 'closure was evaluated twice'
53
+ evaluated = True
54
+
55
+ loss = params[0] * 2
56
+ if backward:
57
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
58
+ params[0].grad = None
59
+ loss.backward()
60
+ else:
61
+ assert not loss.requires_grad, "loss requires grad with backward=False"
62
+ return loss
63
+
64
+ vars = Vars(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert vars.grad is None, vars.grad
66
+ assert (loss := vars.get_loss(backward=True)) == 6.0, loss
67
+ assert vars.grad is not None
68
+ assert vars.grad[0] == 2.0, vars.grad
69
+
70
+ # reevaluate, which should just return already evaluated loss
71
+ assert (loss := vars.get_loss(backward=True)) == 6.0, loss
72
+ assert vars.grad[0] == 2.0, vars.grad
73
+
74
+ # get grad, which should just return already evaluated grad
75
+ assert (grad := vars.get_grad())[0] == 2.0, grad
76
+ assert grad is vars.grad, grad
77
+
78
+ # get update, which should create and return cloned grad
79
+ assert vars.update is None
80
+ assert (update := vars.get_update())[0] == 2.0, update
81
+ assert update is vars.update
82
+ assert update is not vars.grad
83
+ assert vars.grad is not None
84
+ assert update[0] == vars.grad[0]
85
+
86
+ @torch.no_grad
87
+ def test_vars_get_grad():
88
+ params = [torch.tensor(2.0, requires_grad=True)]
89
+ evaluated = False
90
+
91
+ def closure(backward=True):
92
+ # ensure closure only evaluates once
93
+ nonlocal evaluated
94
+ assert evaluated is False, 'closure was evaluated twice'
95
+ evaluated = True
96
+
97
+ loss = params[0]**2
98
+ if backward:
99
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
100
+ params[0].grad = None
101
+ loss.backward()
102
+ else:
103
+ assert not loss.requires_grad, "loss requires grad with backward=False"
104
+ return loss
105
+
106
+ vars = Vars(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := vars.get_grad())[0] == 4.0, grad
108
+ assert grad is vars.grad
109
+
110
+ assert vars.loss == 4.0
111
+ assert (loss := vars.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := vars.get_loss(backward=True)) == 4.0, loss
113
+ assert vars.loss_approx == 4.0
114
+
115
+ assert vars.update is None, vars.update
116
+ assert (update := vars.get_update())[0] == 4.0, update
117
+
118
+ @torch.no_grad
119
+ def test_vars_get_update():
120
+ params = [torch.tensor(2.0, requires_grad=True)]
121
+ evaluated = False
122
+
123
+ def closure(backward=True):
124
+ # ensure closure only evaluates once
125
+ nonlocal evaluated
126
+ assert evaluated is False, 'closure was evaluated twice'
127
+ evaluated = True
128
+
129
+ loss = params[0]**2
130
+ if backward:
131
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
132
+ params[0].grad = None
133
+ loss.backward()
134
+ else:
135
+ assert not loss.requires_grad, "loss requires grad with backward=False"
136
+ return loss
137
+
138
+ vars = Vars(params=params, closure=closure, model=None, current_step=0)
139
+ assert vars.update is None, vars.update
140
+ assert (update := vars.get_update())[0] == 4.0, update
141
+ assert update is vars.update
142
+
143
+ assert (grad := vars.get_grad())[0] == 4.0, grad
144
+ assert grad is vars.grad
145
+ assert grad is not update
146
+
147
+ assert vars.loss == 4.0
148
+ assert (loss := vars.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := vars.get_loss(backward=True)) == 4.0, loss
150
+ assert vars.loss_approx == 4.0
151
+
152
+ assert (update := vars.get_update())[0] == 4.0, update
153
+
154
+
155
+ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
156
+ for k,v in v1.__dict__.items():
157
+ if not k.startswith('__'):
158
+ # if k == 'post_step_hooks': continue
159
+ if k == 'update' and clone_update:
160
+ if v1.update is None or v2.update is None:
161
+ assert v1.update is None and v2.update is None, f'{k} is not the same, {v1 = }, {v2 = }'
162
+ else:
163
+ assert (TensorList(v1.update) == TensorList(v2.update)).global_all()
164
+ assert v1.update is not v2.update
165
+ else:
166
+ assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
167
+
168
+ def test_vars_clone():
169
+ model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
170
+ def closure(backward): return 1
171
+ vars = Vars(params=list(model.parameters()), closure=closure, model=model, current_step=0)
172
+
173
+ _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
174
+ _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
175
+
176
+ vars.grad = TensorList(torch.randn(5))
177
+ _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
178
+ _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
179
+
180
+ vars.update = TensorList(torch.randn(5) * 2)
181
+ vars.loss = torch.randn(1)
182
+ vars.loss_approx = vars.loss
183
+ _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
184
+ _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
torchzero/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from . import tensorlist as tl # this needs to be imported first to avoid circular imports
2
- from .tensorlist import TensorList
3
- from . import optim, modules as m, core, random
4
- from .optim import Modular
1
+ from . import core, optim, utils
2
+ from .core import Modular
3
+ from .utils import compile
4
+ from . import modules as m
@@ -1,13 +1,3 @@
1
- import sys
2
-
3
- from .module import (
4
- OptimizationVars,
5
- OptimizerModule,
6
- _Chain,
7
- _Chainable,
8
- _get_loss,
9
- _ScalarLoss,
10
- _Targets,
11
- )
12
-
13
- from .tensorlist_optimizer import TensorListOptimizer, ParamsT, _ClosureType, _maybe_pass_backward
1
+ from .module import Vars, Module, Modular, Chain, maybe_chain, Chainable
2
+ from .transform import Transform, TensorwiseTransform, Target, apply
3
+ from .preconditioner import Preconditioner, TensorwisePreconditioner