torchzero 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +47 -36
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +8 -2
  9. torchzero/core/chain.py +47 -0
  10. torchzero/core/functional.py +103 -0
  11. torchzero/core/modular.py +233 -0
  12. torchzero/core/module.py +132 -643
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +56 -23
  15. torchzero/core/transform.py +261 -365
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +132 -34
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +0 -1
  27. torchzero/modules/adaptive/__init__.py +1 -1
  28. torchzero/modules/adaptive/adagrad.py +163 -213
  29. torchzero/modules/adaptive/adahessian.py +74 -103
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +49 -30
  32. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/lion.py +5 -10
  36. torchzero/modules/adaptive/lmadagrad.py +87 -32
  37. torchzero/modules/adaptive/mars.py +5 -5
  38. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  39. torchzero/modules/adaptive/msam.py +70 -52
  40. torchzero/modules/adaptive/muon.py +59 -124
  41. torchzero/modules/adaptive/natural_gradient.py +33 -28
  42. torchzero/modules/adaptive/orthograd.py +11 -15
  43. torchzero/modules/adaptive/rmsprop.py +83 -75
  44. torchzero/modules/adaptive/rprop.py +48 -47
  45. torchzero/modules/adaptive/sam.py +55 -45
  46. torchzero/modules/adaptive/shampoo.py +123 -129
  47. torchzero/modules/adaptive/soap.py +207 -143
  48. torchzero/modules/adaptive/sophia_h.py +106 -130
  49. torchzero/modules/clipping/clipping.py +15 -18
  50. torchzero/modules/clipping/ema_clipping.py +31 -25
  51. torchzero/modules/clipping/growth_clipping.py +14 -17
  52. torchzero/modules/conjugate_gradient/cg.py +26 -37
  53. torchzero/modules/experimental/__init__.py +3 -6
  54. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  55. torchzero/modules/experimental/curveball.py +25 -41
  56. torchzero/modules/experimental/gradmin.py +2 -2
  57. torchzero/modules/{higher_order → experimental}/higher_order_newton.py +14 -40
  58. torchzero/modules/experimental/newton_solver.py +22 -53
  59. torchzero/modules/experimental/newtonnewton.py +20 -17
  60. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  61. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  62. torchzero/modules/experimental/spsa1.py +5 -5
  63. torchzero/modules/experimental/structural_projections.py +1 -4
  64. torchzero/modules/functional.py +8 -1
  65. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  66. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  67. torchzero/modules/grad_approximation/rfdm.py +20 -17
  68. torchzero/modules/least_squares/gn.py +90 -42
  69. torchzero/modules/line_search/__init__.py +1 -1
  70. torchzero/modules/line_search/_polyinterp.py +3 -1
  71. torchzero/modules/line_search/adaptive.py +3 -3
  72. torchzero/modules/line_search/backtracking.py +3 -3
  73. torchzero/modules/line_search/interpolation.py +160 -0
  74. torchzero/modules/line_search/line_search.py +42 -51
  75. torchzero/modules/line_search/strong_wolfe.py +5 -5
  76. torchzero/modules/misc/debug.py +12 -12
  77. torchzero/modules/misc/escape.py +10 -10
  78. torchzero/modules/misc/gradient_accumulation.py +10 -78
  79. torchzero/modules/misc/homotopy.py +16 -8
  80. torchzero/modules/misc/misc.py +120 -122
  81. torchzero/modules/misc/multistep.py +63 -61
  82. torchzero/modules/misc/regularization.py +49 -44
  83. torchzero/modules/misc/split.py +30 -28
  84. torchzero/modules/misc/switch.py +37 -32
  85. torchzero/modules/momentum/averaging.py +14 -14
  86. torchzero/modules/momentum/cautious.py +34 -28
  87. torchzero/modules/momentum/momentum.py +11 -11
  88. torchzero/modules/ops/__init__.py +4 -4
  89. torchzero/modules/ops/accumulate.py +21 -21
  90. torchzero/modules/ops/binary.py +67 -66
  91. torchzero/modules/ops/higher_level.py +19 -19
  92. torchzero/modules/ops/multi.py +44 -41
  93. torchzero/modules/ops/reduce.py +26 -23
  94. torchzero/modules/ops/unary.py +53 -53
  95. torchzero/modules/ops/utility.py +47 -46
  96. torchzero/modules/projections/galore.py +1 -1
  97. torchzero/modules/projections/projection.py +43 -43
  98. torchzero/modules/quasi_newton/__init__.py +2 -0
  99. torchzero/modules/quasi_newton/damping.py +1 -1
  100. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  101. torchzero/modules/quasi_newton/lsr1.py +7 -7
  102. torchzero/modules/quasi_newton/quasi_newton.py +25 -16
  103. torchzero/modules/quasi_newton/sg2.py +292 -0
  104. torchzero/modules/restarts/restars.py +26 -24
  105. torchzero/modules/second_order/__init__.py +6 -3
  106. torchzero/modules/second_order/ifn.py +58 -0
  107. torchzero/modules/second_order/inm.py +101 -0
  108. torchzero/modules/second_order/multipoint.py +40 -80
  109. torchzero/modules/second_order/newton.py +105 -228
  110. torchzero/modules/second_order/newton_cg.py +102 -154
  111. torchzero/modules/second_order/nystrom.py +158 -178
  112. torchzero/modules/second_order/rsn.py +237 -0
  113. torchzero/modules/smoothing/laplacian.py +13 -12
  114. torchzero/modules/smoothing/sampling.py +11 -10
  115. torchzero/modules/step_size/adaptive.py +23 -23
  116. torchzero/modules/step_size/lr.py +15 -15
  117. torchzero/modules/termination/termination.py +32 -30
  118. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  119. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  120. torchzero/modules/trust_region/trust_cg.py +1 -1
  121. torchzero/modules/trust_region/trust_region.py +27 -22
  122. torchzero/modules/variance_reduction/svrg.py +21 -18
  123. torchzero/modules/weight_decay/__init__.py +2 -1
  124. torchzero/modules/weight_decay/reinit.py +83 -0
  125. torchzero/modules/weight_decay/weight_decay.py +12 -13
  126. torchzero/modules/wrappers/optim_wrapper.py +57 -50
  127. torchzero/modules/zeroth_order/cd.py +9 -6
  128. torchzero/optim/root.py +3 -3
  129. torchzero/optim/utility/split.py +2 -1
  130. torchzero/optim/wrappers/directsearch.py +27 -63
  131. torchzero/optim/wrappers/fcmaes.py +14 -35
  132. torchzero/optim/wrappers/mads.py +11 -31
  133. torchzero/optim/wrappers/moors.py +66 -0
  134. torchzero/optim/wrappers/nevergrad.py +4 -4
  135. torchzero/optim/wrappers/nlopt.py +31 -25
  136. torchzero/optim/wrappers/optuna.py +6 -13
  137. torchzero/optim/wrappers/pybobyqa.py +124 -0
  138. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  139. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  140. torchzero/optim/wrappers/scipy/brute.py +48 -0
  141. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  142. torchzero/optim/wrappers/scipy/direct.py +69 -0
  143. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  144. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  145. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  146. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  147. torchzero/optim/wrappers/wrapper.py +121 -0
  148. torchzero/utils/__init__.py +7 -25
  149. torchzero/utils/compile.py +2 -2
  150. torchzero/utils/derivatives.py +112 -88
  151. torchzero/utils/optimizer.py +4 -77
  152. torchzero/utils/python_tools.py +31 -0
  153. torchzero/utils/tensorlist.py +11 -5
  154. torchzero/utils/thoad_tools.py +68 -0
  155. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  156. torchzero-0.4.0.dist-info/RECORD +191 -0
  157. tests/test_vars.py +0 -185
  158. torchzero/modules/experimental/momentum.py +0 -160
  159. torchzero/modules/higher_order/__init__.py +0 -1
  160. torchzero/optim/wrappers/scipy.py +0 -572
  161. torchzero/utils/linalg/__init__.py +0 -12
  162. torchzero/utils/linalg/matrix_funcs.py +0 -87
  163. torchzero/utils/linalg/orthogonalize.py +0 -12
  164. torchzero/utils/linalg/svd.py +0 -20
  165. torchzero/utils/ops.py +0 -10
  166. torchzero-0.3.14.dist-info/RECORD +0 -167
  167. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  168. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  169. {torchzero-0.3.14.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,188 @@
1
+ import pytest
2
+ import torch
3
+ from torchzero.core import Objective
4
+ from torchzero.utils.tensorlist import TensorList
5
+
6
+ @torch.no_grad
7
+ def test_get_loss():
8
+
9
+ # ---------------------------- test that it works ---------------------------- #
10
+ params = [torch.tensor(2.0, requires_grad=True)]
11
+ evaluated = False
12
+
13
+ def closure_1(backward=True):
14
+ assert not backward, 'backward = True'
15
+
16
+ # ensure closure only evaluates once
17
+ nonlocal evaluated
18
+ assert evaluated is False, 'closure was evaluated twice'
19
+ evaluated = True
20
+
21
+ loss = params[0]**2
22
+ if backward:
23
+ params[0].grad = None
24
+ loss.backward()
25
+ else:
26
+ assert not loss.requires_grad, "loss requires grad with backward=False"
27
+ return loss
28
+
29
+ obj = Objective(params=params, closure=closure_1, model=None, current_step=0)
30
+
31
+ assert obj.loss is None, obj.loss
32
+
33
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
34
+ assert evaluated, evaluated
35
+ assert loss is obj.loss
36
+ assert obj.loss == 4.0
37
+ assert obj.loss_approx == 4.0
38
+ assert obj.grads is None, obj.grads
39
+
40
+ # reevaluate, which should just return already evaluated loss
41
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
42
+ assert obj.grads is None, obj.grads
43
+
44
+
45
+ # ----------------------- test that backward=True works ---------------------- #
46
+ params = [torch.tensor(3.0, requires_grad=True)]
47
+ evaluated = False
48
+
49
+ def closure_2(backward=True):
50
+ # ensure closure only evaluates once
51
+ nonlocal evaluated
52
+ assert evaluated is False, 'closure was evaluated twice'
53
+ evaluated = True
54
+
55
+ loss = params[0] * 2
56
+ if backward:
57
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
58
+ params[0].grad = None
59
+ loss.backward()
60
+ else:
61
+ assert not loss.requires_grad, "loss requires grad with backward=False"
62
+ return loss
63
+
64
+ obj = Objective(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert obj.grads is None, obj.grads
66
+ assert (loss := obj.get_loss(backward=True)) == 6.0, loss
67
+ assert obj.grads is not None
68
+ assert obj.grads[0] == 2.0, obj.grads
69
+
70
+ # reevaluate, which should just return already evaluated loss
71
+ assert (loss := obj.get_loss(backward=True)) == 6.0, loss
72
+ assert obj.grads[0] == 2.0, obj.grads
73
+
74
+ # get grad, which should just return already evaluated grad
75
+ assert (grad := obj.get_grads())[0] == 2.0, grad
76
+ assert grad is obj.grads, grad
77
+
78
+ # get update, which should create and return cloned grad
79
+ assert obj.updates is None
80
+ assert (update := obj.get_updates())[0] == 2.0, update
81
+ assert update is obj.updates
82
+ assert update is not obj.grads
83
+ assert obj.grads is not None
84
+ assert update[0] == obj.grads[0]
85
+
86
+ @torch.no_grad
87
+ def test_get_grad():
88
+ params = [torch.tensor(2.0, requires_grad=True)]
89
+ evaluated = False
90
+
91
+ def closure(backward=True):
92
+ # ensure closure only evaluates once
93
+ nonlocal evaluated
94
+ assert evaluated is False, 'closure was evaluated twice'
95
+ evaluated = True
96
+
97
+ loss = params[0]**2
98
+ if backward:
99
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
100
+ params[0].grad = None
101
+ loss.backward()
102
+ else:
103
+ assert not loss.requires_grad, "loss requires grad with backward=False"
104
+ return loss
105
+
106
+ obj = Objective(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := obj.get_grads())[0] == 4.0, grad
108
+ assert grad is obj.grads
109
+
110
+ assert obj.loss == 4.0
111
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := obj.get_loss(backward=True)) == 4.0, loss
113
+ assert obj.loss_approx == 4.0
114
+
115
+ assert obj.updates is None, obj.updates
116
+ assert (update := obj.get_updates())[0] == 4.0, update
117
+
118
+ @torch.no_grad
119
+ def test_get_update():
120
+ params = [torch.tensor(2.0, requires_grad=True)]
121
+ evaluated = False
122
+
123
+ def closure(backward=True):
124
+ # ensure closure only evaluates once
125
+ nonlocal evaluated
126
+ assert evaluated is False, 'closure was evaluated twice'
127
+ evaluated = True
128
+
129
+ loss = params[0]**2
130
+ if backward:
131
+ assert loss.requires_grad, "loss does not require grad so `with torch.enable_grad()` context didn't work"
132
+ params[0].grad = None
133
+ loss.backward()
134
+ else:
135
+ assert not loss.requires_grad, "loss requires grad with backward=False"
136
+ return loss
137
+
138
+ obj = Objective(params=params, closure=closure, model=None, current_step=0)
139
+ assert obj.updates is None, obj.updates
140
+ assert (update := obj.get_updates())[0] == 4.0, update
141
+ assert update is obj.updates
142
+
143
+ assert (grad := obj.get_grads())[0] == 4.0, grad
144
+ assert grad is obj.grads
145
+ assert grad is not update
146
+
147
+ assert obj.loss == 4.0
148
+ assert (loss := obj.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := obj.get_loss(backward=True)) == 4.0, loss
150
+ assert obj.loss_approx == 4.0
151
+
152
+ assert (update := obj.get_updates())[0] == 4.0, update
153
+
154
+
155
+ def _assert_objectives_are_same_(o1: Objective, o2: Objective, clone_update: bool):
156
+ for k,v in o1.__dict__.items():
157
+ if not k.startswith('__'):
158
+ # if k == 'post_step_hooks': continue
159
+ if k == 'storage': continue
160
+ elif k == 'updates' and clone_update:
161
+ if o1.updates is None or o2.updates is None:
162
+ assert o1.updates is None and o2.updates is None, f'`{k}` attribute is not the same, {o1.updates = }, {o2.updates = }'
163
+ else:
164
+ assert (TensorList(o1.updates) == TensorList(o2.updates)).global_all()
165
+ assert o1.updates is not o2.updates
166
+ elif k == 'params':
167
+ for p1, p2 in zip(o1.params, o2.params):
168
+ assert p1.untyped_storage() == p2.untyped_storage()
169
+ else:
170
+ assert getattr(o2, k) is v, f'`{k}` attribute is not the same, {getattr(o1, k) = }, {getattr(o2, k) = }'
171
+
172
+ def test_var_clone():
173
+ model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
174
+ def closure(backward): return 1
175
+ obj = Objective(params=list(model.parameters()), closure=closure, model=model, current_step=0)
176
+
177
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
178
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
179
+
180
+ obj.grads = TensorList(torch.randn(5))
181
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
182
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
183
+
184
+ obj.updates = TensorList(torch.randn(5) * 2)
185
+ obj.loss = torch.randn(1)
186
+ obj.loss_approx = obj.loss
187
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=False), clone_update=False)
188
+ _assert_objectives_are_same_(obj, obj.clone(clone_updates=True), clone_update=True)
tests/test_opts.py CHANGED
@@ -4,15 +4,23 @@ Sanity tests to make sure everything works.
4
4
  This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
5
5
  don't error or become unhinged with different parameter shapes.
6
6
  """
7
+ import random
7
8
  from collections.abc import Callable
8
9
  from functools import partial
9
10
 
11
+ import numpy as np
10
12
  import pytest
11
13
  import torch
14
+
12
15
  import torchzero as tz
13
16
 
14
17
  PRINT = False # set to true in nbs
15
18
 
19
+ # seed
20
+ torch.manual_seed(0)
21
+ np.random.seed(0)
22
+ random.seed(0)
23
+
16
24
  def _booth(x, y):
17
25
  return (x + 2 * y - 7) ** 2 + (2 * x + y - 5) ** 2
18
26
 
@@ -51,7 +59,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
51
59
  losses = []
52
60
  for i in range(steps):
53
61
  if clear and i == steps//2:
54
- for m in opt.unrolled_modules: m.reset() # clear on middle step to see if there are any issues with it
62
+ for m in opt.flat_modules: m.reset() # clear on middle step to see if there are any issues with it
55
63
 
56
64
  if use_closure:
57
65
  def closure(backward=True):
@@ -283,8 +291,8 @@ ClipNormGrowth_additive = Run(
283
291
  sphere_steps=10, sphere_loss=10,
284
292
  )
285
293
  ClipNormGrowth_global = Run(
286
- func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
287
- sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(parameterwise=False), tz.m.LR(0.1)),
294
+ func_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
295
+ sphere_opt=lambda p: tz.Modular(p, tz.m.ClipNormGrowth(tensorwise=False), tz.m.LR(0.1)),
288
296
  needs_closure=False,
289
297
  func='booth', steps=50, loss=1e-6, merge_invariant=True,
290
298
  sphere_steps=10, sphere_loss=10,
@@ -340,56 +348,56 @@ RandomizedFDM_central2 = Run(
340
348
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(seed=0), tz.m.LR(0.001)),
341
349
  needs_closure=True,
342
350
  func='booth', steps=50, loss=10, merge_invariant=True,
343
- sphere_steps=100, sphere_loss=450,
351
+ sphere_steps=200, sphere_loss=420,
344
352
  )
345
353
  RandomizedFDM_forward2 = Run(
346
354
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.01)),
347
355
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward2', seed=0), tz.m.LR(0.001)),
348
356
  needs_closure=True,
349
357
  func='booth', steps=50, loss=10, merge_invariant=True,
350
- sphere_steps=100, sphere_loss=450,
358
+ sphere_steps=200, sphere_loss=420,
351
359
  )
352
360
  RandomizedFDM_backward2 = Run(
353
361
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.01)),
354
362
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward2', seed=0), tz.m.LR(0.001)),
355
363
  needs_closure=True,
356
364
  func='booth', steps=50, loss=10, merge_invariant=True,
357
- sphere_steps=100, sphere_loss=450,
365
+ sphere_steps=200, sphere_loss=420,
358
366
  )
359
367
  RandomizedFDM_forward3 = Run(
360
368
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.01)),
361
369
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward3', seed=0), tz.m.LR(0.001)),
362
370
  needs_closure=True,
363
371
  func='booth', steps=50, loss=10, merge_invariant=True,
364
- sphere_steps=100, sphere_loss=450,
372
+ sphere_steps=200, sphere_loss=420,
365
373
  )
366
374
  RandomizedFDM_backward3 = Run(
367
375
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.01)),
368
376
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='backward3', seed=0), tz.m.LR(0.001)),
369
377
  needs_closure=True,
370
378
  func='booth', steps=50, loss=10, merge_invariant=True,
371
- sphere_steps=100, sphere_loss=450,
379
+ sphere_steps=200, sphere_loss=420,
372
380
  )
373
381
  RandomizedFDM_central4 = Run(
374
382
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.01)),
375
383
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='central4', seed=0), tz.m.LR(0.001)),
376
384
  needs_closure=True,
377
385
  func='booth', steps=50, loss=10, merge_invariant=True,
378
- sphere_steps=100, sphere_loss=450,
386
+ sphere_steps=200, sphere_loss=420,
379
387
  )
380
388
  RandomizedFDM_forward4 = Run(
381
389
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.01)),
382
390
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward4', seed=0), tz.m.LR(0.001)),
383
391
  needs_closure=True,
384
392
  func='booth', steps=50, loss=10, merge_invariant=True,
385
- sphere_steps=100, sphere_loss=450,
393
+ sphere_steps=200, sphere_loss=420,
386
394
  )
387
395
  RandomizedFDM_forward5 = Run(
388
396
  func_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.01)),
389
397
  sphere_opt=lambda p: tz.Modular(p, tz.m.RandomizedFDM(formula='forward5', seed=0), tz.m.LR(0.001)),
390
398
  needs_closure=True,
391
399
  func='booth', steps=50, loss=10, merge_invariant=True,
392
- sphere_steps=100, sphere_loss=450,
400
+ sphere_steps=200, sphere_loss=420,
393
401
  )
394
402
 
395
403
 
@@ -427,35 +435,35 @@ ForwardGradient = Run(
427
435
  sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0), tz.m.LR(0.001)),
428
436
  needs_closure=True,
429
437
  func='booth', steps=50, loss=40, merge_invariant=True,
430
- sphere_steps=100, sphere_loss=450,
438
+ sphere_steps=200, sphere_loss=450,
431
439
  )
432
440
  ForwardGradient_forward = Run(
433
441
  func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.01)),
434
442
  sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='forward'), tz.m.LR(0.001)),
435
443
  needs_closure=True,
436
444
  func='booth', steps=50, loss=40, merge_invariant=True,
437
- sphere_steps=100, sphere_loss=450,
445
+ sphere_steps=200, sphere_loss=450,
438
446
  )
439
447
  ForwardGradient_central = Run(
440
448
  func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.01)),
441
449
  sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(seed=0, jvp_method='central'), tz.m.LR(0.001)),
442
450
  needs_closure=True,
443
451
  func='booth', steps=50, loss=40, merge_invariant=True,
444
- sphere_steps=100, sphere_loss=450,
452
+ sphere_steps=200, sphere_loss=450,
445
453
  )
446
454
  ForwardGradient_4samples = Run(
447
455
  func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.1)),
448
456
  sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0), tz.m.LR(0.001)),
449
457
  needs_closure=True,
450
458
  func='booth', steps=50, loss=0.1, merge_invariant=True,
451
- sphere_steps=100, sphere_loss=400,
459
+ sphere_steps=100, sphere_loss=420,
452
460
  )
453
461
  ForwardGradient_4samples_no_pre_generate = Run(
454
462
  func_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.1)),
455
463
  sphere_opt=lambda p: tz.Modular(p, tz.m.ForwardGradient(n_samples=4, seed=0, pre_generate=False), tz.m.LR(0.001)),
456
464
  needs_closure=True,
457
465
  func='booth', steps=50, loss=0.1, merge_invariant=True,
458
- sphere_steps=100, sphere_loss=400,
466
+ sphere_steps=100, sphere_loss=420,
459
467
  )
460
468
 
461
469
  # ------------------------- line_search/backtracking ------------------------- #
@@ -598,15 +606,15 @@ ScaleModulesByCosineSimilarity = Run(
598
606
 
599
607
  # ------------------------- momentum/matrix_momentum ------------------------- #
600
608
  MatrixMomentum_forward = Run(
601
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='forward'),),
602
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward')),
609
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_forward'),),
610
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward')),
603
611
  needs_closure=True,
604
612
  func='booth', steps=50, loss=0.05, merge_invariant=True,
605
613
  sphere_steps=10, sphere_loss=0.01,
606
614
  )
607
615
  MatrixMomentum_forward = Run(
608
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='central')),
609
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central')),
616
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.01, hvp_method='fd_central')),
617
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central')),
610
618
  needs_closure=True,
611
619
  func='booth', steps=50, loss=0.05, merge_invariant=True,
612
620
  sphere_steps=10, sphere_loss=0.01,
@@ -620,15 +628,15 @@ MatrixMomentum_forward = Run(
620
628
  )
621
629
 
622
630
  AdaptiveMatrixMomentum_forward = Run(
623
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True)),
624
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True)),
631
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True)),
632
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True)),
625
633
  needs_closure=True,
626
634
  func='booth', steps=50, loss=0.05, merge_invariant=True,
627
635
  sphere_steps=10, sphere_loss=0.05,
628
636
  )
629
637
  AdaptiveMatrixMomentum_central = Run(
630
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True)),
631
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True)),
638
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True)),
639
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True)),
632
640
  needs_closure=True,
633
641
  func='booth', steps=50, loss=0.05, merge_invariant=True,
634
642
  sphere_steps=10, sphere_loss=0.05,
@@ -642,15 +650,15 @@ AdaptiveMatrixMomentum_autograd = Run(
642
650
  )
643
651
 
644
652
  StochasticAdaptiveMatrixMomentum_forward = Run(
645
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='forward', adaptive=True, adapt_freq=1)),
646
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='forward', adaptive=True, adapt_freq=1)),
653
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
654
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_forward', adaptive=True, adapt_freq=1)),
647
655
  needs_closure=True,
648
656
  func='booth', steps=50, loss=0.05, merge_invariant=True,
649
657
  sphere_steps=10, sphere_loss=0.05,
650
658
  )
651
659
  StochasticAdaptiveMatrixMomentum_central = Run(
652
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='central', adaptive=True, adapt_freq=1)),
653
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='central', adaptive=True, adapt_freq=1)),
660
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.05, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
661
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(0.5, hvp_method='fd_central', adaptive=True, adapt_freq=1)),
654
662
  needs_closure=True,
655
663
  func='booth', steps=50, loss=0.05, merge_invariant=True,
656
664
  sphere_steps=10, sphere_loss=0.05,
@@ -720,10 +728,11 @@ Adam = Run(
720
728
  # ------------------------------ optimizers/soap ----------------------------- #
721
729
  SOAP = Run(
722
730
  func_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(0.4)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(), tz.m.LR(1)),
731
+ sphere_opt=lambda p: tz.Modular(p, tz.m.SOAP(precond_freq=1), tz.m.LR(1)),
724
732
  needs_closure=False,
733
+ # merge and unmerge lrs are very different so need to test convergence separately somewhere
725
734
  func='rosen', steps=50, loss=4, merge_invariant=False,
726
- sphere_steps=20, sphere_loss=25, # merge and unmerge lrs are very different so need to test convergence separately somewhere
735
+ sphere_steps=20, sphere_loss=25,
727
736
  )
728
737
  # ------------------------------ optimizers/lion ----------------------------- #
729
738
  Lion = Run(
@@ -735,11 +744,12 @@ Lion = Run(
735
744
  )
736
745
  # ---------------------------- optimizers/shampoo ---------------------------- #
737
746
  Shampoo = Run(
738
- func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
739
- sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
747
+ func_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
748
+ sphere_opt=lambda p: tz.Modular(p, tz.m.Graft(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
740
749
  needs_closure=False,
750
+ # merge and unmerge lrs are very different so need to test convergence separately somewhere
741
751
  func='booth', steps=50, loss=0.02, merge_invariant=False,
742
- sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
752
+ sphere_steps=20, sphere_loss=1,
743
753
  )
744
754
 
745
755
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
@@ -755,6 +765,7 @@ SR1 = Run(
755
765
  sphere_opt=lambda p: tz.Modular(p, tz.m.SR1(scale_first=True), tz.m.StrongWolfe(fallback=False)),
756
766
  needs_closure=True,
757
767
  func='rosen', steps=50, loss=1e-12, merge_invariant=True,
768
+ # this reaches 1e-13 on github so don't change to 0
758
769
  sphere_steps=10, sphere_loss=0,
759
770
  )
760
771
  SSVM = Run(
@@ -806,7 +817,7 @@ NewtonCG = Run(
806
817
  func_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
807
818
  sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe(fallback=True)),
808
819
  needs_closure=True,
809
- func='rosen', steps=20, loss=1e-7, merge_invariant=True,
820
+ func='rosen', steps=20, loss=1e-10, merge_invariant=True,
810
821
  sphere_steps=2, sphere_loss=3e-4,
811
822
  )
812
823
 
@@ -872,8 +883,8 @@ SophiaH = Run(
872
883
 
873
884
  # -------------------------- higher_order ------------------------- #
874
885
  HigherOrderNewton = Run(
875
- func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
876
- sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
886
+ func_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(trust_method=None)),
887
+ sphere_opt=lambda p: tz.Modular(p, tz.m.experimental.HigherOrderNewton(2, trust_method=None)),
877
888
  needs_closure=True,
878
889
  func='rosen', steps=1, loss=2e-10, merge_invariant=True,
879
890
  sphere_steps=1, sphere_loss=1e-10,
tests/test_tensorlist.py CHANGED
@@ -1567,13 +1567,6 @@ def test_where(simple_tl: TensorList):
1567
1567
  assert_tl_allclose(result_module, expected_tl)
1568
1568
 
1569
1569
 
1570
- # Test inplace where_ (needs TensorList other)
1571
- tl_copy = simple_tl.clone()
1572
- result_inplace = tl_copy.where_(condition_tl, other_tl)
1573
- assert result_inplace is tl_copy
1574
- assert_tl_allclose(tl_copy, expected_tl)
1575
-
1576
-
1577
1570
  def test_masked_fill(simple_tl: TensorList):
1578
1571
  mask_tl = simple_tl.lt(0)
1579
1572
  fill_value_scalar = 99.0
@@ -1600,7 +1593,6 @@ def test_select_set_(simple_tl: TensorList):
1600
1593
  mask_tl = simple_tl.gt(0.5)
1601
1594
  value_scalar = -1.0
1602
1595
  value_list_scalar = [-1.0, -2.0, -3.0]
1603
- value_tl = simple_tl.clone().mul_(0.1)
1604
1596
 
1605
1597
  # Set with scalar value
1606
1598
  tl_copy_scalar = simple_tl.clone()
@@ -4,7 +4,6 @@ from functools import partial
4
4
  import pytest
5
5
  import torch
6
6
  from torchzero.utils.optimizer import (
7
- Optimizer,
8
7
  get_group_vals,
9
8
  get_params,
10
9
  get_state_vals,
torchzero/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from . import core, optim, utils
2
2
  from .core import Modular
3
- from .utils import set_compilation
3
+ from .utils.compile import enable_compilation
4
4
  from . import modules as m
@@ -1,2 +1,8 @@
1
- from .module import Chain, Chainable, Modular, Module, Var, maybe_chain
2
- from .transform import Target, TensorwiseTransform, Transform, apply_transform
1
+ from .transform import TensorTransform, Transform
2
+ from .module import Chainable, Module
3
+ from .objective import DerivativesMethod, HessianMethod, HVPMethod, Objective
4
+
5
+ # order is important to avoid circular imports
6
+ from .modular import Modular
7
+ from .functional import apply, step, step_tensors, update
8
+ from .chain import Chain, maybe_chain
@@ -0,0 +1,47 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ..utils.python_tools import flatten
4
+ from .module import Module, Chainable
5
+ from .functional import _chain_step
6
+
7
+ class Chain(Module):
8
+ """Chain modules, mostly used internally"""
9
+ def __init__(self, *modules: Module | Iterable[Module]):
10
+ super().__init__()
11
+ flat_modules: list[Module] = flatten(modules)
12
+ for i, module in enumerate(flat_modules):
13
+ self.set_child(f'module_{i}', module)
14
+
15
+ def update(self, objective):
16
+ if len(self.children) > 1:
17
+ raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
18
+
19
+ if len(self.children) == 0: return
20
+ return self.children['module_0'].update(objective)
21
+
22
+ def apply(self, objective):
23
+ if len(self.children) > 1:
24
+ raise RuntimeError("can't call `update` on Chain with more than one child, as `update` and `apply` have to be called sequentially. Use the `step` method instead of update-apply.")
25
+
26
+ if len(self.children) == 0: return objective
27
+ return self.children['module_0'].apply(objective)
28
+
29
+ def step(self, objective):
30
+ children = [self.children[f'module_{i}'] for i in range(len(self.children))]
31
+ return _chain_step(objective, children)
32
+
33
+ def __repr__(self):
34
+ s = self.__class__.__name__
35
+ if self.children:
36
+ if s == 'Chain': s = 'C' # to shorten it
37
+ s = f'{s}({", ".join(str(m) for m in self.children.values())})'
38
+ return s
39
+
40
+ def maybe_chain(*modules: Chainable) -> Module:
41
+ """Returns a single module directly if only one is provided, otherwise wraps them in a ``Chain``."""
42
+ flat_modules: list[Module] = flatten(modules)
43
+ if len(flat_modules) == 1:
44
+ return flat_modules[0]
45
+ return Chain(*flat_modules)
46
+
47
+
@@ -0,0 +1,103 @@
1
+ from collections.abc import Mapping, Sequence, Iterable, Callable
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import torch
5
+
6
+ from .objective import Objective
7
+
8
+ if TYPE_CHECKING:
9
+ from .module import Module
10
+ from .transform import Transform
11
+
12
+
13
+
14
+ def update(
15
+ objective: "Objective",
16
+ module: "Transform",
17
+ states: list[dict[str, Any]] | None = None,
18
+ settings: Sequence[Mapping[str, Any]] | None = None,
19
+ ) -> None:
20
+ if states is None:
21
+ assert settings is None
22
+ module.update(objective)
23
+
24
+ else:
25
+ assert settings is not None
26
+ module.update_states(objective, states, settings)
27
+
28
+ def apply(
29
+ objective: "Objective",
30
+ module: "Transform",
31
+ states: list[dict[str, Any]] | None = None,
32
+ settings: Sequence[Mapping[str, Any]] | None = None,
33
+ ) -> "Objective":
34
+ if states is None:
35
+ assert settings is None
36
+ return module.apply(objective)
37
+
38
+ else:
39
+ assert settings is not None
40
+ return module.apply_states(objective, states, settings)
41
+
42
+ def _chain_step(objective: "Objective", modules: "Sequence[Module]"):
43
+ """steps with ``modules`` and returns updated objective, this is used within ``step`` and within ``Chain.step``"""
44
+ # step
45
+ for i, module in enumerate(modules):
46
+ if i!=0: objective = objective.clone(clone_updates=False)
47
+
48
+ objective = module.step(objective)
49
+ if objective.stop: break
50
+
51
+ return objective
52
+
53
+ def step(objective: "Objective", modules: "Module | Sequence[Module]"):
54
+ """doesn't apply hooks!"""
55
+ if not isinstance(modules, Sequence):
56
+ modules = (modules, )
57
+
58
+ if len(modules) == 0:
59
+ raise RuntimeError("`modules` is an empty sequence")
60
+
61
+ # if closure is None, assume backward has been called and gather grads
62
+ if objective.closure is None:
63
+ objective.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in objective.params]
64
+
65
+ # step and return
66
+ return _chain_step(objective, modules)
67
+
68
+
69
+ def step_tensors(
70
+ modules: "Module | Sequence[Module]",
71
+ tensors: Sequence[torch.Tensor],
72
+ params: Iterable[torch.Tensor] | None = None,
73
+ grads: Sequence[torch.Tensor] | None = None,
74
+ loss: torch.Tensor | None = None,
75
+ closure: Callable | None = None,
76
+ objective: "Objective | None" = None
77
+ ) -> list[torch.Tensor]:
78
+ if objective is not None:
79
+ if any(i is not None for i in (params, grads, loss, closure)):
80
+ raise RuntimeError("Specify either `objective` or `(params, grads, loss, closure)`")
81
+
82
+ if not isinstance(modules, Sequence):
83
+ modules = (modules, )
84
+
85
+ # make fake params if they are only used for shapes
86
+ if params is None:
87
+ params = [t.view_as(t).requires_grad_() for t in tensors]
88
+
89
+ # create objective
90
+ if objective is None:
91
+ objective = Objective(params=params, loss=loss, closure=closure)
92
+
93
+ if grads is not None:
94
+ objective.grads = list(grads)
95
+
96
+ objective.updates = list(tensors)
97
+
98
+ # step with modules
99
+ # this won't update parameters in-place because objective.Modular is None
100
+ objective = _chain_step(objective, modules)
101
+
102
+ # return updates
103
+ return objective.get_updates()