torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,12 +3,20 @@ from collections.abc import Callable, Sequence
3
3
 
4
4
  import torch
5
5
 
6
- from .chain import Chain
7
6
  from .module import Chainable, Module
8
- from .var import Var
7
+ from .objective import Objective
9
8
 
10
9
 
11
10
  class Reformulation(Module, ABC):
11
+ """Reformulation allows the definition of a new closure which returns custom loss and gradient.
12
+
13
+ If ``modules`` are passed, steps with those modules using the reformulated closure. Only ``step`` method is supported.
14
+
15
+ If ``modules`` is ``None``, sets new closure to the objective so that all further modules use it.
16
+ In that case make sure this method is first.
17
+
18
+ To use this, subclass and override ``closure`` and optionally ``pre_step``.
19
+ """
12
20
  def __init__(self, defaults: dict | None, modules: Chainable | None):
13
21
  super().__init__(defaults)
14
22
 
@@ -16,30 +24,52 @@ class Reformulation(Module, ABC):
16
24
  self.set_child("modules", modules)
17
25
 
18
26
  @abstractmethod
19
- def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
27
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], objective: Objective) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
20
28
  """
21
- returns (loss, gradient), if backward is False then gradient can be None.
29
+ returns ``(loss, gradient)``, if backward is False then gradient can be None.
22
30
 
23
- If evaluating original loss/gradient at x_0, set them to ``var``.
31
+ If evaluating original loss/gradient at ``x0``, set them to ``objective``.
24
32
  """
25
33
 
26
- def pre_step(self, var: Var) -> Var | None:
27
- """This runs once before each step, whereas `closure` may run multiple times per step if further modules
34
+ def pre_step(self, objective: Objective):
35
+ """This runs once before each step, whereas ``closure`` may run multiple times per step if further modules
28
36
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
29
37
 
30
- def step(self, var):
31
- ret = self.pre_step(var) # pylint:disable = assignment-from-no-return
32
- if isinstance(ret, Var): var = ret
38
+ def update(self, objective):
39
+ if "modules" in self.children:
40
+ raise RuntimeError("Reformulation ({self.__class__.__name__} only supports `step` method if it has sub-modules.)")
41
+
42
+ self.pre_step(objective) # pylint:disable = assignment-from-no-return
43
+
44
+ if objective.closure is None: raise RuntimeError("Reformulation requires closure")
45
+ params, closure = objective.params, objective.closure # make sure to decouple from `objective` object
46
+
47
+ # define modified closure and set objective to use it
48
+ def modified_closure(backward=True):
49
+ loss, grad = self.closure(backward, closure, params, objective)
33
50
 
34
- if var.closure is None: raise RuntimeError("Reformulation requires closure")
35
- params, closure = var.params, var.closure
51
+ if grad is not None:
52
+ for p,g in zip(params, grad):
53
+ p.grad = g
54
+
55
+ return loss
56
+
57
+ objective.closure = modified_closure
58
+
59
+ def apply(self, objective): return objective
60
+
61
+ def step(self, objective):
36
62
 
37
- # step with children
38
63
  if 'modules' in self.children:
39
64
 
65
+ self.pre_step(objective) # pylint:disable = assignment-from-no-return
66
+
67
+ if objective.closure is None: raise RuntimeError("Reformulation requires closure")
68
+ params, closure = objective.params, objective.closure # make sure to decouple from `objective` object
69
+
40
70
  # make a reformulated closure
41
71
  def modified_closure(backward=True):
42
- loss, grad = self.closure(backward, closure, params, var)
72
+ loss, grad = self.closure(backward, closure, params, objective)
43
73
 
44
74
  if grad is not None:
45
75
  for p,g in zip(params, grad):
@@ -47,21 +77,22 @@ class Reformulation(Module, ABC):
47
77
 
48
78
  return loss
49
79
 
50
- # set it to a new Var object
51
- modified_var = var.clone(clone_update=False)
52
- modified_var.closure = modified_closure
80
+ # set it to a new Objective object
81
+ modified_objective = objective.clone(clone_updates=False)
82
+ modified_objective.closure = modified_closure
53
83
 
54
- # step with child
84
+ # update the child
55
85
  modules = self.children['modules']
56
- modified_var = modules.step(modified_var)
86
+ modified_objective = modules.step(modified_objective)
57
87
 
58
88
  # modified_var.loss and grad refers to loss and grad of a modified objective
59
89
  # so we only take the update
60
- var.update = modified_var.update
90
+ objective.updates = modified_objective.updates
61
91
 
62
- # or just evaluate new closure and set to update
92
+ # or just set closure to a modified one
93
+ # update already calls self.pre_step
63
94
  else:
64
- loss, grad = self.closure(backward=True, closure=closure, params=params, var=var)
65
- if grad is not None: var.update = list(grad)
95
+ self.update(objective)
96
+ self.apply(objective) # does nothing unless overridden
66
97
 
67
- return var
98
+ return objective