torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ from typing import Any, Literal
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
- from ...core import Module, Target, Var
11
+ from ...core import Module, Objective
12
12
  from ...utils import tofloat, set_storage_
13
13
  from ..functional import clip_by_finfo
14
14
 
@@ -139,7 +139,7 @@ class LineSearchBase(Module, ABC):
139
139
  for c, n in zip(params, new_params):
140
140
  set_storage_(c, n)
141
141
 
142
- def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
142
+ def _loss(self, step_size: float, var: Objective, closure, params: list[torch.Tensor],
143
143
  update: list[torch.Tensor], backward:bool=False) -> float:
144
144
 
145
145
  # if step_size is 0, we might already know the loss
@@ -165,16 +165,16 @@ class LineSearchBase(Module, ABC):
165
165
  # if evaluated loss at step size 0, set it to var.loss
166
166
  if step_size == 0:
167
167
  var.loss = loss
168
- if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
168
+ if backward: var.grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
169
169
 
170
170
  return tofloat(loss)
171
171
 
172
- def _loss_derivative_gradient(self, step_size: float, var: Var, closure,
172
+ def _loss_derivative_gradient(self, step_size: float, var: Objective, closure,
173
173
  params: list[torch.Tensor], update: list[torch.Tensor]):
174
174
  # if step_size is 0, we might already know the derivative
175
- if (var.grad is not None) and (step_size == 0):
175
+ if (var.grads is not None) and (step_size == 0):
176
176
  loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
177
- derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
177
+ derivative = - sum(t.sum() for t in torch._foreach_mul(var.grads, update))
178
178
 
179
179
  else:
180
180
  # loss with a backward pass sets params.grad
@@ -184,79 +184,79 @@ class LineSearchBase(Module, ABC):
184
184
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
185
185
  else torch.zeros_like(p) for p in params], update))
186
186
 
187
- assert var.grad is not None
188
- return loss, tofloat(derivative), var.grad
187
+ assert var.grads is not None
188
+ return loss, tofloat(derivative), var.grads
189
189
 
190
- def _loss_derivative(self, step_size: float, var: Var, closure,
190
+ def _loss_derivative(self, step_size: float, var: Objective, closure,
191
191
  params: list[torch.Tensor], update: list[torch.Tensor]):
192
192
  return self._loss_derivative_gradient(step_size=step_size, var=var,closure=closure,params=params,update=update)[:2]
193
193
 
194
- def evaluate_f(self, step_size: float, var: Var, backward:bool=False):
194
+ def evaluate_f(self, step_size: float, var: Objective, backward:bool=False):
195
195
  """evaluate function value at alpha `step_size`."""
196
196
  closure = var.closure
197
197
  if closure is None: raise RuntimeError('line search requires closure')
198
- return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
198
+ return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates(),backward=backward)
199
199
 
200
- def evaluate_f_d(self, step_size: float, var: Var):
200
+ def evaluate_f_d(self, step_size: float, var: Objective):
201
201
  """evaluate function value and directional derivative in the direction of the update at step size `step_size`."""
202
202
  closure = var.closure
203
203
  if closure is None: raise RuntimeError('line search requires closure')
204
- return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
204
+ return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
205
205
 
206
- def evaluate_f_d_g(self, step_size: float, var: Var):
206
+ def evaluate_f_d_g(self, step_size: float, var: Objective):
207
207
  """evaluate function value, directional derivative, and gradient list at step size `step_size`."""
208
208
  closure = var.closure
209
209
  if closure is None: raise RuntimeError('line search requires closure')
210
- return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
210
+ return self._loss_derivative_gradient(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_updates())
211
211
 
212
- def make_objective(self, var: Var, backward:bool=False):
212
+ def make_objective(self, var: Objective, backward:bool=False):
213
213
  closure = var.closure
214
214
  if closure is None: raise RuntimeError('line search requires closure')
215
- return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
215
+ return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_updates(), backward=backward)
216
216
 
217
- def make_objective_with_derivative(self, var: Var):
217
+ def make_objective_with_derivative(self, var: Objective):
218
218
  closure = var.closure
219
219
  if closure is None: raise RuntimeError('line search requires closure')
220
- return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
220
+ return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_updates())
221
221
 
222
- def make_objective_with_derivative_and_gradient(self, var: Var):
222
+ def make_objective_with_derivative_and_gradient(self, var: Objective):
223
223
  closure = var.closure
224
224
  if closure is None: raise RuntimeError('line search requires closure')
225
- return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_update())
225
+ return partial(self._loss_derivative_gradient, var=var, closure=closure, params=var.params, update=var.get_updates())
226
226
 
227
227
  @abstractmethod
228
- def search(self, update: list[torch.Tensor], var: Var) -> float:
228
+ def search(self, update: list[torch.Tensor], var: Objective) -> float:
229
229
  """Finds the step size to use"""
230
230
 
231
231
  @torch.no_grad
232
- def step(self, var: Var) -> Var:
232
+ def apply(self, objective: Objective) -> Objective:
233
233
  self._reset()
234
234
 
235
- params = var.params
235
+ params = objective.params
236
236
  self._initial_params = [p.clone() for p in params]
237
- update = var.get_update()
237
+ update = objective.get_updates()
238
238
 
239
239
  try:
240
- step_size = self.search(update=update, var=var)
240
+ step_size = self.search(update=update, var=objective)
241
241
  except MaxLineSearchItersReached:
242
242
  step_size = self._best_step_size
243
243
 
244
244
  step_size = clip_by_finfo(step_size, torch.finfo(update[0].dtype))
245
245
 
246
246
  # set loss_approx
247
- if var.loss_approx is None: var.loss_approx = self._lowest_loss
247
+ if objective.loss_approx is None: objective.loss_approx = self._lowest_loss
248
248
 
249
249
  # if this is last module, directly update parameters to avoid redundant operations
250
- if var.modular is not None and self is var.modular.modules[-1]:
250
+ if objective.modular is not None and self is objective.modular.modules[-1]:
251
251
  self.set_step_size_(step_size, params=params, update=update)
252
252
 
253
- var.stop = True; var.skip_update = True
254
- return var
253
+ objective.stop = True; objective.skip_update = True
254
+ return objective
255
255
 
256
256
  # revert parameters and multiply update by step size
257
257
  self.set_step_size_(0, params=params, update=update)
258
- torch._foreach_mul_(var.update, step_size)
259
- return var
258
+ torch._foreach_mul_(objective.updates, step_size)
259
+ return objective
260
260
 
261
261
 
262
262
 
@@ -284,8 +284,8 @@ class StrongWolfe(LineSearchBase):
284
284
  'init_value', 'init', 'c1', 'c2', 'a_max', 'maxiter', 'maxzoom',
285
285
  'maxeval', 'interpolation', 'adaptive', 'plus_minus', 'fallback', 'tol_change')(self.defaults)
286
286
 
287
- dir = as_tensorlist(var.get_update())
288
- grad_list = var.get_grad()
287
+ dir = as_tensorlist(var.get_updates())
288
+ grad_list = var.get_grads()
289
289
 
290
290
  g_0 = -sum(t.sum() for t in torch._foreach_mul(grad_list, dir))
291
291
  f_0 = var.get_loss(False)
@@ -11,9 +11,9 @@ class PrintUpdate(Module):
11
11
  defaults = dict(text=text, print_fn=print_fn)
12
12
  super().__init__(defaults)
13
13
 
14
- def step(self, var):
15
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.update}')
16
- return var
14
+ def apply(self, objective):
15
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.updates}')
16
+ return objective
17
17
 
18
18
  class PrintShape(Module):
19
19
  """Prints shapes of the update."""
@@ -21,10 +21,10 @@ class PrintShape(Module):
21
21
  defaults = dict(text=text, print_fn=print_fn)
22
22
  super().__init__(defaults)
23
23
 
24
- def step(self, var):
25
- shapes = [u.shape for u in var.update] if var.update is not None else None
24
+ def apply(self, objective):
25
+ shapes = [u.shape for u in objective.updates] if objective.updates is not None else None
26
26
  self.defaults["print_fn"](f'{self.defaults["text"]}{shapes}')
27
- return var
27
+ return objective
28
28
 
29
29
  class PrintParams(Module):
30
30
  """Prints current update."""
@@ -32,9 +32,9 @@ class PrintParams(Module):
32
32
  defaults = dict(text=text, print_fn=print_fn)
33
33
  super().__init__(defaults)
34
34
 
35
- def step(self, var):
36
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.params}')
37
- return var
35
+ def apply(self, objective):
36
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.params}')
37
+ return objective
38
38
 
39
39
 
40
40
  class PrintLoss(Module):
@@ -43,6 +43,6 @@ class PrintLoss(Module):
43
43
  defaults = dict(text=text, print_fn=print_fn)
44
44
  super().__init__(defaults)
45
45
 
46
- def step(self, var):
47
- self.defaults["print_fn"](f'{self.defaults["text"]}{var.get_loss(False)}')
48
- return var
46
+ def apply(self, objective):
47
+ self.defaults["print_fn"](f'{self.defaults["text"]}{objective.get_loss(False)}')
48
+ return objective
@@ -3,7 +3,7 @@ import math
3
3
  from typing import Literal
4
4
  import torch
5
5
 
6
- from ...core import Modular, Module, Var, Chainable
6
+ from ...core import Modular, Module, Objective, Chainable
7
7
  from ...utils import NumberList, TensorList
8
8
 
9
9
 
@@ -15,11 +15,11 @@ class EscapeAnnealing(Module):
15
15
 
16
16
 
17
17
  @torch.no_grad
18
- def step(self, var):
19
- closure = var.closure
18
+ def apply(self, objective):
19
+ closure = objective.closure
20
20
  if closure is None: raise RuntimeError("Escape requries closure")
21
21
 
22
- params = TensorList(var.params)
22
+ params = TensorList(objective.params)
23
23
  settings = self.settings[params[0]]
24
24
  max_region = self.get_settings(params, 'max_region', cls=NumberList)
25
25
  max_iter = settings['max_iter']
@@ -41,7 +41,7 @@ class EscapeAnnealing(Module):
41
41
  self.global_state['n_bad'] = n_bad
42
42
 
43
43
  # no progress
44
- f_0 = var.get_loss(False)
44
+ f_0 = objective.get_loss(False)
45
45
  if n_bad >= n_tol:
46
46
  for i in range(1, max_iter+1):
47
47
  alpha = max_region * (i / max_iter)
@@ -51,12 +51,12 @@ class EscapeAnnealing(Module):
51
51
  f_star = closure(False)
52
52
 
53
53
  if math.isfinite(f_star) and f_star < f_0-1e-12:
54
- var.update = None
55
- var.stop = True
56
- var.skip_update = True
57
- return var
54
+ objective.updates = None
55
+ objective.stop = True
56
+ objective.skip_update = True
57
+ return objective
58
58
 
59
59
  params.sub_(pert)
60
60
 
61
61
  self.global_state['n_bad'] = 0
62
- return var
62
+ return objective
@@ -3,74 +3,6 @@ import torch
3
3
  from ...core import Chainable, Module
4
4
 
5
5
 
6
- # class GradientAccumulation(Module):
7
- # """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
-
9
- # Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
- # is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
-
12
- # .. note::
13
- # Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
-
15
- # Args:
16
- # modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
- # n (int): number of gradients to accumulate.
18
- # mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
- # stop (bool, optional):
20
- # this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
-
22
- # Examples:
23
- # Adam with gradients accumulated for 16 batches.
24
-
25
- # .. code-block:: python
26
-
27
- # opt = tz.Modular(
28
- # model.parameters(),
29
- # tz.m.GradientAccumulation(
30
- # [tz.m.Adam(), tz.m.LR(1e-2)],
31
- # n=16
32
- # )
33
- # )
34
-
35
- # """
36
- # def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
- # defaults = dict(n=n, mean=mean, stop=stop)
38
- # super().__init__(defaults)
39
- # self.set_child('modules', modules)
40
-
41
-
42
- # @torch.no_grad
43
- # def step(self, var):
44
- # accumulator = self.get_state(var.params, 'accumulator')
45
- # settings = self.defaults
46
- # n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
- # step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
-
49
- # # add update to accumulator
50
- # torch._foreach_add_(accumulator, var.get_update())
51
-
52
- # # step with accumulated updates
53
- # if step % n == 0:
54
- # if mean:
55
- # torch._foreach_div_(accumulator, n)
56
-
57
- # var.update = [a.clone() for a in accumulator]
58
- # var = self.children['modules'].step(var)
59
-
60
- # # zero accumulator
61
- # torch._foreach_zero_(accumulator)
62
-
63
- # else:
64
- # # prevent update
65
- # if stop:
66
- # var.update = None
67
- # var.stop=True
68
- # var.skip_update=True
69
-
70
- # return var
71
-
72
-
73
-
74
6
 
75
7
  class GradientAccumulation(Module):
76
8
  """Uses ``n`` steps to accumulate gradients, after ``n`` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
@@ -106,21 +38,21 @@ class GradientAccumulation(Module):
106
38
 
107
39
 
108
40
  @torch.no_grad
109
- def step(self, var):
110
- accumulator = self.get_state(var.params, 'accumulator')
41
+ def apply(self, objective):
42
+ accumulator = self.get_state(objective.params, 'accumulator')
111
43
  settings = self.defaults
112
44
  n = settings['n']; mean = settings['mean']; stop = settings['stop']
113
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
45
+ step = self.increment_counter("step", 0)
114
46
 
115
47
  # add update to accumulator
116
- torch._foreach_add_(accumulator, var.get_update())
48
+ torch._foreach_add_(accumulator, objective.get_updates())
117
49
 
118
50
  # step with accumulated updates
119
- if step % n == 0:
51
+ if (step + 1) % n == 0:
120
52
  if mean:
121
53
  torch._foreach_div_(accumulator, n)
122
54
 
123
- var.update = accumulator
55
+ objective.updates = accumulator
124
56
 
125
57
  # zero accumulator
126
58
  self.clear_state_keys('accumulator')
@@ -128,9 +60,9 @@ class GradientAccumulation(Module):
128
60
  else:
129
61
  # prevent update
130
62
  if stop:
131
- var.update = None
132
- var.stop=True
133
- var.skip_update=True
63
+ objective.updates = None
64
+ objective.stop=True
65
+ objective.skip_update=True
134
66
 
135
- return var
67
+ return objective
136
68
 
@@ -13,27 +13,27 @@ class HomotopyBase(Module):
13
13
  """transform the loss"""
14
14
 
15
15
  @torch.no_grad
16
- def step(self, var):
17
- if var.loss is not None:
18
- var.loss = self.loss_transform(var.loss)
16
+ def apply(self, objective):
17
+ if objective.loss is not None:
18
+ objective.loss = self.loss_transform(objective.loss)
19
19
 
20
- closure = var.closure
20
+ closure = objective.closure
21
21
  if closure is None: raise RuntimeError("SquareHomotopy requires closure")
22
22
 
23
23
  def homotopy_closure(backward=True):
24
24
  if backward:
25
25
  with torch.enable_grad():
26
26
  loss = self.loss_transform(closure(False))
27
- grad = torch.autograd.grad(loss, var.params, allow_unused=True)
28
- for p,g in zip(var.params, grad):
27
+ grad = torch.autograd.grad(loss, objective.params, allow_unused=True)
28
+ for p,g in zip(objective.params, grad):
29
29
  p.grad = g
30
30
  else:
31
31
  loss = self.loss_transform(closure(False))
32
32
 
33
33
  return loss
34
34
 
35
- var.closure = homotopy_closure
36
- return var
35
+ objective.closure = homotopy_closure
36
+ return objective
37
37
 
38
38
  class SquareHomotopy(HomotopyBase):
39
39
  def __init__(self): super().__init__()
@@ -57,3 +57,11 @@ class LambdaHomotopy(HomotopyBase):
57
57
  super().__init__(defaults)
58
58
 
59
59
  def loss_transform(self, loss): return self.defaults['fn'](loss)
60
+
61
+ class FixedLossHomotopy(HomotopyBase):
62
+ def __init__(self, value: float = 1):
63
+ defaults = dict(value=value)
64
+ super().__init__(defaults)
65
+
66
+ def loss_transform(self, loss): return loss / loss.detach().clip(min=torch.finfo(loss.dtype).tiny * 2)
67
+