torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import Any, cast
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, Target, Vars, maybe_chain
8
+ from ...core import Chainable, Module, Target, Var, maybe_chain
9
9
 
10
10
 
11
11
  class ReduceOperation(Module, ABC):
@@ -26,25 +26,25 @@ class ReduceOperation(Module, ABC):
26
26
  raise ValueError('At least one operand must be a module')
27
27
 
28
28
  @abstractmethod
29
- def transform(self, vars: Vars, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
29
+ def transform(self, var: Var, *operands: Any | list[torch.Tensor]) -> list[torch.Tensor]:
30
30
  """applies the operation to operands"""
31
31
  raise NotImplementedError
32
32
 
33
33
  @torch.no_grad
34
- def step(self, vars: Vars) -> Vars:
34
+ def step(self, var: Var) -> Var:
35
35
  # pass cloned update to all module operands
36
36
  processed_operands: list[Any | list[torch.Tensor]] = self.operands.copy()
37
37
 
38
38
  for i, v in enumerate(self.operands):
39
39
  if f'operand_{i}' in self.children:
40
40
  v: Module
41
- updated_vars = v.step(vars.clone(clone_update=True))
42
- processed_operands[i] = updated_vars.get_update()
43
- vars.update_attrs_from_clone_(updated_vars) # update loss, grad, etc if this module calculated them
41
+ updated_var = v.step(var.clone(clone_update=True))
42
+ processed_operands[i] = updated_var.get_update()
43
+ var.update_attrs_from_clone_(updated_var) # update loss, grad, etc if this module calculated them
44
44
 
45
- transformed = self.transform(vars, *processed_operands)
46
- vars.update = transformed
47
- return vars
45
+ transformed = self.transform(var, *processed_operands)
46
+ var.update = transformed
47
+ return var
48
48
 
49
49
  class Sum(ReduceOperation):
50
50
  USE_MEAN = False
@@ -52,7 +52,7 @@ class Sum(ReduceOperation):
52
52
  super().__init__({}, *inputs)
53
53
 
54
54
  @torch.no_grad
55
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
55
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
56
56
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
57
57
  sum = cast(list, sorted_inputs[0])
58
58
  if len(sorted_inputs) > 1:
@@ -76,9 +76,9 @@ class WeightedSum(ReduceOperation):
76
76
  super().__init__(defaults=defaults, *inputs)
77
77
 
78
78
  @torch.no_grad
79
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
79
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
80
80
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
81
- weights = self.settings[vars.params[0]]['weights']
81
+ weights = self.settings[var.params[0]]['weights']
82
82
  sum = cast(list, sorted_inputs[0])
83
83
  torch._foreach_mul_(sum, weights[0])
84
84
  if len(sorted_inputs) > 1:
@@ -98,7 +98,7 @@ class Median(ReduceOperation):
98
98
  super().__init__({}, *inputs)
99
99
 
100
100
  @torch.no_grad
101
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
101
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
102
102
  res = []
103
103
  lists = [i for i in inputs if isinstance(i, list)]
104
104
  floats = [i for i in inputs if isinstance(i, (int,float))]
@@ -111,7 +111,7 @@ class Prod(ReduceOperation):
111
111
  super().__init__({}, *inputs)
112
112
 
113
113
  @torch.no_grad
114
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
114
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
115
115
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
116
116
  prod = cast(list, sorted_inputs[0])
117
117
  if len(sorted_inputs) > 1:
@@ -125,7 +125,7 @@ class MaximumModules(ReduceOperation):
125
125
  super().__init__({}, *inputs)
126
126
 
127
127
  @torch.no_grad
128
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
128
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
129
129
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
130
130
  maximum = cast(list, sorted_inputs[0])
131
131
  if len(sorted_inputs) > 1:
@@ -139,7 +139,7 @@ class MinimumModules(ReduceOperation):
139
139
  super().__init__({}, *inputs)
140
140
 
141
141
  @torch.no_grad
142
- def transform(self, vars: Vars, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
142
+ def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
143
143
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
144
144
  minimum = cast(list, sorted_inputs[0])
145
145
  if len(sorted_inputs) > 1:
@@ -3,46 +3,46 @@ from typing import cast
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, Vars
6
+ from ...core import Chainable, Module, Var
7
7
 
8
8
 
9
9
  def _split(
10
10
  module: Module,
11
11
  idxs,
12
12
  params,
13
- vars: Vars,
13
+ var: Var,
14
14
  ):
15
15
  split_params = [p for i,p in enumerate(params) if i in idxs]
16
16
 
17
17
  split_grad = None
18
- if vars.grad is not None:
19
- split_grad = [g for i,g in enumerate(vars.grad) if i in idxs]
18
+ if var.grad is not None:
19
+ split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
20
20
 
21
21
  split_update = None
22
- if vars.update is not None:
23
- split_update = [u for i,u in enumerate(vars.update) if i in idxs]
22
+ if var.update is not None:
23
+ split_update = [u for i,u in enumerate(var.update) if i in idxs]
24
24
 
25
- split_vars = vars.clone(clone_update=False)
26
- split_vars.params = split_params
27
- split_vars.grad = split_grad
28
- split_vars.update = split_update
25
+ split_var = var.clone(clone_update=False)
26
+ split_var.params = split_params
27
+ split_var.grad = split_grad
28
+ split_var.update = split_update
29
29
 
30
- split_vars = module.step(split_vars)
30
+ split_var = module.step(split_var)
31
31
 
32
- if (vars.grad is None) and (split_vars.grad is not None):
33
- vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
32
+ if (var.grad is None) and (split_var.grad is not None):
33
+ var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
34
34
 
35
- if split_vars.update is not None:
35
+ if split_var.update is not None:
36
36
 
37
- if vars.update is None:
38
- if vars.grad is None: vars.update = [cast(torch.Tensor, None) for _ in vars.params]
39
- else: vars.update = [g.clone() for g in vars.grad]
37
+ if var.update is None:
38
+ if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
39
+ else: var.update = [g.clone() for g in var.grad]
40
40
 
41
- for idx, u in zip(idxs, split_vars.update):
42
- vars.update[idx] = u
41
+ for idx, u in zip(idxs, split_var.update):
42
+ var.update[idx] = u
43
43
 
44
- vars.update_attrs_from_clone_(split_vars)
45
- return vars
44
+ var.update_attrs_from_clone_(split_var)
45
+ return var
46
46
 
47
47
  class Split(Module):
48
48
  """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
@@ -53,9 +53,9 @@ class Split(Module):
53
53
  if true is not None: self.set_child('true', true)
54
54
  if false is not None: self.set_child('false', false)
55
55
 
56
- def step(self, vars):
56
+ def step(self, var):
57
57
 
58
- params = vars.params
58
+ params = var.params
59
59
  filter = self.settings[params[0]]['filter']
60
60
 
61
61
  true_idxs = []
@@ -66,10 +66,10 @@ class Split(Module):
66
66
 
67
67
  if 'true' in self.children:
68
68
  true = self.children['true']
69
- vars = _split(true, idxs=true_idxs, params=params, vars=vars)
69
+ var = _split(true, idxs=true_idxs, params=params, var=var)
70
70
 
71
71
  if 'false' in self.children:
72
72
  false = self.children['false']
73
- vars = _split(false, idxs=false_idxs, params=params, vars=vars)
73
+ var = _split(false, idxs=false_idxs, params=params, var=var)
74
74
 
75
- return vars
75
+ return var
@@ -23,16 +23,16 @@ class Alternate(Module):
23
23
  self.global_state['steps_to_next'] = steps[0] if isinstance(steps, list) else steps
24
24
 
25
25
  @torch.no_grad
26
- def step(self, vars):
26
+ def step(self, var):
27
27
  # get current module
28
28
  current_module_idx = self.global_state.setdefault('current_module_idx', 0)
29
29
  module = self.children[f'module_{current_module_idx}']
30
30
 
31
31
  # step
32
- vars = module.step(vars.clone(clone_update=False))
32
+ var = module.step(var.clone(clone_update=False))
33
33
 
34
34
  # number of steps until next module
35
- steps = self.settings[vars.params[0]]['steps']
35
+ steps = self.settings[var.params[0]]['steps']
36
36
  if isinstance(steps, int): steps = [steps]*len(self.children)
37
37
 
38
38
  if 'steps_to_next' not in self.global_state:
@@ -51,7 +51,7 @@ class Alternate(Module):
51
51
 
52
52
  self.global_state['steps_to_next'] = steps[self.global_state['current_module_idx']]
53
53
 
54
- return vars
54
+ return var
55
55
 
56
56
  class Switch(Alternate):
57
57
  """switch to next module after some steps"""
@@ -3,7 +3,7 @@ from collections import deque
3
3
  import torch
4
4
 
5
5
  from ...core import TensorwiseTransform, Target, Transform
6
- from ...utils import TensorList
6
+ from ...utils import TensorList, unpack_dicts,unpack_states
7
7
 
8
8
  class UnaryLambda(Transform):
9
9
  def __init__(self, fn, target: "Target" = 'update'):
@@ -11,8 +11,8 @@ class UnaryLambda(Transform):
11
11
  super().__init__(defaults=defaults, uses_grad=False, target=target)
12
12
 
13
13
  @torch.no_grad
14
- def transform(self, tensors, params, grads, vars):
15
- return self.settings[params[0]]['fn'](tensors)
14
+ def apply(self, tensors, params, grads, loss, states, settings):
15
+ return settings[0]['fn'](tensors)
16
16
 
17
17
  class UnaryParameterwiseLambda(TensorwiseTransform):
18
18
  def __init__(self, fn, target: "Target" = 'update'):
@@ -20,8 +20,8 @@ class UnaryParameterwiseLambda(TensorwiseTransform):
20
20
  super().__init__(uses_grad=False, defaults=defaults, target=target)
21
21
 
22
22
  @torch.no_grad
23
- def transform(self, tensor, param, grad, vars):
24
- return self.settings[param]['fn'](tensor)
23
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
24
+ return settings['fn'](tensor)
25
25
 
26
26
  class CustomUnaryOperation(Transform):
27
27
  def __init__(self, name: str, target: "Target" = 'update'):
@@ -29,35 +29,35 @@ class CustomUnaryOperation(Transform):
29
29
  super().__init__(defaults=defaults, uses_grad=False, target=target)
30
30
 
31
31
  @torch.no_grad
32
- def transform(self, tensors, params, grads, vars):
33
- return getattr(tensors, self.settings[params[0]]['name'])()
32
+ def apply(self, tensors, params, grads, loss, states, settings):
33
+ return getattr(tensors, settings[0]['name'])()
34
34
 
35
35
 
36
36
  class Abs(Transform):
37
37
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
38
38
  @torch.no_grad
39
- def transform(self, tensors, params, grads, vars):
39
+ def apply(self, tensors, params, grads, loss, states, settings):
40
40
  torch._foreach_abs_(tensors)
41
41
  return tensors
42
42
 
43
43
  class Sign(Transform):
44
44
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
45
45
  @torch.no_grad
46
- def transform(self, tensors, params, grads, vars):
46
+ def apply(self, tensors, params, grads, loss, states, settings):
47
47
  torch._foreach_sign_(tensors)
48
48
  return tensors
49
49
 
50
50
  class Exp(Transform):
51
51
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
52
52
  @torch.no_grad
53
- def transform(self, tensors, params, grads, vars):
53
+ def apply(self, tensors, params, grads, loss, states, settings):
54
54
  torch._foreach_exp_(tensors)
55
55
  return tensors
56
56
 
57
57
  class Sqrt(Transform):
58
58
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
59
59
  @torch.no_grad
60
- def transform(self, tensors, params, grads, vars):
60
+ def apply(self, tensors, params, grads, loss, states, settings):
61
61
  torch._foreach_sqrt_(tensors)
62
62
  return tensors
63
63
 
@@ -66,8 +66,8 @@ class Reciprocal(Transform):
66
66
  defaults = dict(eps = eps)
67
67
  super().__init__(defaults, uses_grad=False, target=target)
68
68
  @torch.no_grad
69
- def transform(self, tensors, params, grads, vars):
70
- eps = self.get_settings('eps', params=params)
69
+ def apply(self, tensors, params, grads, loss, states, settings):
70
+ eps = [s['eps'] for s in settings]
71
71
  if any(e != 0 for e in eps): torch._foreach_add_(tensors, eps)
72
72
  torch._foreach_reciprocal_(tensors)
73
73
  return tensors
@@ -75,7 +75,7 @@ class Reciprocal(Transform):
75
75
  class Negate(Transform):
76
76
  def __init__(self, target: "Target" = 'update'): super().__init__({}, uses_grad=False, target=target)
77
77
  @torch.no_grad
78
- def transform(self, tensors, params, grads, vars):
78
+ def apply(self, tensors, params, grads, loss, states, settings):
79
79
  torch._foreach_neg_(tensors)
80
80
  return tensors
81
81
 
@@ -97,8 +97,8 @@ class NanToNum(Transform):
97
97
  super().__init__(defaults, uses_grad=False, target=target)
98
98
 
99
99
  @torch.no_grad
100
- def transform(self, tensors, params, grads, vars):
101
- nan, posinf, neginf = self.get_settings('nan', 'posinf', 'neginf', params=params)
100
+ def apply(self, tensors, params, grads, loss, states, settings):
101
+ nan, posinf, neginf = unpack_dicts(settings, 'nan', 'posinf', 'neginf')
102
102
  return [t.nan_to_num_(nan_i, posinf_i, neginf_i) for t, nan_i, posinf_i, neginf_i in zip(tensors, nan, posinf, neginf)]
103
103
 
104
104
  class Rescale(Transform):
@@ -108,8 +108,8 @@ class Rescale(Transform):
108
108
  super().__init__(defaults, uses_grad=False, target=target)
109
109
 
110
110
  @torch.no_grad
111
- def transform(self, tensors, params, grads, vars):
112
- min,max = self.get_settings('min','max', params=params)
113
- tensorwise = self.settings[params[0]]['tensorwise']
111
+ def apply(self, tensors, params, grads, loss, states, settings):
112
+ min, max = unpack_dicts(settings, 'min','max')
113
+ tensorwise = settings[0]['tensorwise']
114
114
  dim = None if tensorwise else 'global'
115
- return TensorList(tensors).rescale(min=min, max=max, eps=self.settings[params[0]]['eps'], dim=dim)
115
+ return TensorList(tensors).rescale(min=min, max=max, eps=settings[0]['eps'], dim=dim)
@@ -9,47 +9,47 @@ from ...utils.tensorlist import Distributions, TensorList
9
9
  class Clone(Transform):
10
10
  def __init__(self): super().__init__({}, uses_grad=False)
11
11
  @torch.no_grad
12
- def transform(self, tensors, params, grads, vars): return [t.clone() for t in tensors]
12
+ def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
13
13
 
14
14
  class Grad(Module):
15
15
  def __init__(self):
16
16
  super().__init__({})
17
17
  @torch.no_grad
18
- def step(self, vars):
19
- vars.update = [g.clone() for g in vars.get_grad()]
20
- return vars
18
+ def step(self, var):
19
+ var.update = [g.clone() for g in var.get_grad()]
20
+ return var
21
21
 
22
22
  class Params(Module):
23
23
  def __init__(self):
24
24
  super().__init__({})
25
25
  @torch.no_grad
26
- def step(self, vars):
27
- vars.update = [p.clone() for p in vars.params]
28
- return vars
26
+ def step(self, var):
27
+ var.update = [p.clone() for p in var.params]
28
+ return var
29
29
 
30
30
  class Update(Module):
31
31
  def __init__(self):
32
32
  super().__init__({})
33
33
  @torch.no_grad
34
- def step(self, vars):
35
- vars.update = [u.clone() for u in vars.get_update()]
36
- return vars
34
+ def step(self, var):
35
+ var.update = [u.clone() for u in var.get_update()]
36
+ return var
37
37
 
38
38
  class Zeros(Module):
39
39
  def __init__(self):
40
40
  super().__init__({})
41
41
  @torch.no_grad
42
- def step(self, vars):
43
- vars.update = [torch.zeros_like(p) for p in vars.params]
44
- return vars
42
+ def step(self, var):
43
+ var.update = [torch.zeros_like(p) for p in var.params]
44
+ return var
45
45
 
46
46
  class Ones(Module):
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
50
- def step(self, vars):
51
- vars.update = [torch.ones_like(p) for p in vars.params]
52
- return vars
50
+ def step(self, var):
51
+ var.update = [torch.ones_like(p) for p in var.params]
52
+ return var
53
53
 
54
54
  class Fill(Module):
55
55
  def __init__(self, value: float):
@@ -57,9 +57,9 @@ class Fill(Module):
57
57
  super().__init__(defaults)
58
58
 
59
59
  @torch.no_grad
60
- def step(self, vars):
61
- vars.update = [torch.full_like(p, self.settings[p]['value']) for p in vars.params]
62
- return vars
60
+ def step(self, var):
61
+ var.update = [torch.full_like(p, self.settings[p]['value']) for p in var.params]
62
+ return var
63
63
 
64
64
  class RandomSample(Module):
65
65
  def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
@@ -67,20 +67,20 @@ class RandomSample(Module):
67
67
  super().__init__(defaults)
68
68
 
69
69
  @torch.no_grad
70
- def step(self, vars):
71
- vars.update = TensorList(vars.params).sample_like(
72
- eps=self.get_settings('eps',params=vars.params), distribution=self.settings[vars.params[0]]['distribution']
70
+ def step(self, var):
71
+ var.update = TensorList(var.params).sample_like(
72
+ eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
73
73
  )
74
- return vars
74
+ return var
75
75
 
76
76
  class Randn(Module):
77
77
  def __init__(self):
78
78
  super().__init__({})
79
79
 
80
80
  @torch.no_grad
81
- def step(self, vars):
82
- vars.update = [torch.randn_like(p) for p in vars.params]
83
- return vars
81
+ def step(self, var):
82
+ var.update = [torch.randn_like(p) for p in var.params]
83
+ return var
84
84
 
85
85
  class Uniform(Module):
86
86
  def __init__(self, low: float, high: float):
@@ -88,25 +88,25 @@ class Uniform(Module):
88
88
  super().__init__(defaults)
89
89
 
90
90
  @torch.no_grad
91
- def step(self, vars):
92
- low,high = self.get_settings('low','high', params=vars.params)
93
- vars.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(vars.params, low, high)]
94
- return vars
91
+ def step(self, var):
92
+ low,high = self.get_settings(var.params, 'low','high')
93
+ var.update = [torch.empty_like(t).uniform_(l,h) for t,l,h in zip(var.params, low, high)]
94
+ return var
95
95
 
96
96
  class GradToNone(Module):
97
97
  def __init__(self): super().__init__()
98
- def step(self, vars):
99
- vars.grad = None
100
- return vars
98
+ def step(self, var):
99
+ var.grad = None
100
+ return var
101
101
 
102
102
  class UpdateToNone(Module):
103
103
  def __init__(self): super().__init__()
104
- def step(self, vars):
105
- vars.update = None
106
- return vars
104
+ def step(self, var):
105
+ var.update = None
106
+ return var
107
107
 
108
108
  class Identity(Module):
109
109
  def __init__(self, *args, **kwargs): super().__init__()
110
- def step(self, vars): return vars
110
+ def step(self, var): return var
111
111
 
112
112
  NoOp = Identity
@@ -1,18 +1,17 @@
1
1
  from operator import itemgetter
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
-
5
5
  from ...core import (
6
6
  Chainable,
7
7
  Module,
8
- Preconditioner,
9
8
  Target,
10
- TensorwisePreconditioner,
9
+ TensorwiseTransform,
11
10
  Transform,
12
- Vars,
13
- apply,
11
+ Var,
12
+ apply_transform,
14
13
  )
15
- from ...utils import NumberList, TensorList
14
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
16
15
  from ...utils.linalg import matrix_power_eigh
17
16
  from ..functional import add_power_, lerp_power_, root
18
17
 
@@ -31,7 +30,6 @@ def adagrad_(
31
30
  inner: Module | None = None,
32
31
  params: list[torch.Tensor] | None = None,
33
32
  grads: list[torch.Tensor] | None = None,
34
- vars: Vars | None = None,
35
33
  ):
36
34
  """returns `tensors_`"""
37
35
  clr = alpha / (1 + step * lr_decay)
@@ -40,7 +38,7 @@ def adagrad_(
40
38
 
41
39
  if inner is not None:
42
40
  assert params is not None
43
- tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
41
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
44
42
 
45
43
  if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
46
44
  else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
@@ -79,19 +77,19 @@ class Adagrad(Transform):
79
77
  self.set_child('inner', inner)
80
78
 
81
79
  @torch.no_grad
82
- def transform(self, tensors, params, grads, vars):
80
+ def apply(self, tensors, params, grads, loss, states, settings):
83
81
  tensors = TensorList(tensors)
84
82
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
85
83
 
86
- lr_decay,alpha,eps = self.get_settings('lr_decay', 'alpha', 'eps', params=params, cls=NumberList)
84
+ lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
87
85
 
88
- pow, use_sqrt = itemgetter('pow', 'use_sqrt')(self.settings[params[0]])
86
+ pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
89
87
 
90
- sq_sum = self.get_state('sq_sum', params=params, cls=TensorList)
88
+ sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
91
89
 
92
90
  # initialize accumulator on 1st step
93
91
  if step == 1:
94
- sq_sum.set_(tensors.full_like(self.get_settings('initial_accumulator_value', params=params)))
92
+ sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
95
93
 
96
94
  return adagrad_(
97
95
  tensors,
@@ -107,40 +105,51 @@ class Adagrad(Transform):
107
105
  inner=self.children.get("inner", None),
108
106
  params=params,
109
107
  grads=grads,
110
- vars=vars,
111
108
  )
112
109
 
113
110
 
114
111
 
115
- class FullMatrixAdagrad(TensorwisePreconditioner):
116
- def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=False, update_freq=1, inner: Chainable | None = None):
117
- defaults = dict(beta=beta, decay=decay)
112
+ class FullMatrixAdagrad(TensorwiseTransform):
113
+ def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=False, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', inner: Chainable | None = None):
114
+ defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
118
115
  super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
119
116
 
120
117
  @torch.no_grad
121
- def update_tensor(self, tensor, param, grad, state, settings):
118
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
122
119
  G = tensor.ravel()
123
120
  GG = torch.outer(G, G)
124
121
  decay = settings['decay']
125
122
  beta = settings['beta']
126
-
127
- if 'GG' not in state: state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
123
+ init = settings['init']
124
+
125
+ if 'GG' not in state:
126
+ if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
127
+ elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
128
+ elif init == 'ones': state['GG'] = torch.ones_like(GG)
129
+ elif init == 'GGT': state['GG'] = GG.clone()
130
+ else: raise ValueError(init)
128
131
  if decay is not None: state['GG'].mul_(decay)
129
132
 
130
133
  if beta is not None: state['GG'].lerp_(GG, 1-beta)
131
134
  else: state['GG'].add_(GG)
132
135
 
133
136
  @torch.no_grad
134
- def apply_tensor(self, tensor, param, grad, state, settings):
137
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
135
138
  GG = state['GG']
139
+ sqrt = settings['sqrt']
136
140
 
137
141
  if tensor.numel() == 1:
138
- return tensor / (GG**(1/2)).squeeze()
142
+ GG = GG.squeeze()
143
+ if sqrt: return tensor / GG.sqrt()
144
+ return tensor / GG
139
145
 
140
146
  try:
141
- B = matrix_power_eigh(GG, -1/2)
147
+ if sqrt: B = matrix_power_eigh(GG, -1/2)
148
+ else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
149
+
142
150
  except torch.linalg.LinAlgError:
143
- return tensor.div_(tensor.abs().max()) # conservative scaling
151
+ scale = 1 / tensor.abs().max()
152
+ return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
144
153
 
145
154
  return (B @ tensor.ravel()).view_as(tensor)
146
155