torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from typing import Literal
2
2
  from collections.abc import Callable
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform, Chainable, apply
5
+ from ...core import Module, Target, Transform, Chainable, apply_transform
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
8
 
@@ -56,8 +56,8 @@ class SophiaH(Module):
56
56
  self.set_child('inner', inner)
57
57
 
58
58
  @torch.no_grad
59
- def step(self, vars):
60
- params = vars.params
59
+ def step(self, var):
60
+ params = var.params
61
61
  settings = self.settings[params[0]]
62
62
  hvp_method = settings['hvp_method']
63
63
  fd_h = settings['fd_h']
@@ -71,15 +71,15 @@ class SophiaH(Module):
71
71
  self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
72
72
  generator = self.global_state['generator']
73
73
 
74
- beta1, beta2, precond_scale, clip, eps = self.get_settings(
75
- 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', params=params, cls=NumberList)
74
+ beta1, beta2, precond_scale, clip, eps = self.get_settings(params,
75
+ 'beta1', 'beta2', 'precond_scale', 'clip', 'eps', cls=NumberList)
76
76
 
77
- exp_avg, h_exp_avg = self.get_state('exp_avg', 'h_exp_avg', params=params, cls=TensorList)
77
+ exp_avg, h_exp_avg = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
78
78
 
79
79
  step = self.global_state.get('step', 0)
80
80
  self.global_state['step'] = step + 1
81
81
 
82
- closure = vars.closure
82
+ closure = var.closure
83
83
  assert closure is not None
84
84
 
85
85
  h = None
@@ -90,12 +90,12 @@ class SophiaH(Module):
90
90
  u = [torch.randn(p.shape, device=p.device, dtype=p.dtype, generator=generator) for p in params]
91
91
 
92
92
  if hvp_method == 'autograd':
93
- if grad is None: grad = vars.get_grad(create_graph=True)
93
+ if grad is None: grad = var.get_grad(create_graph=True)
94
94
  assert grad is not None
95
95
  Hvp = hvp(params, grad, u, retain_graph=i < n_samples-1)
96
96
 
97
97
  elif hvp_method == 'forward':
98
- loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=vars.get_grad(), normalize=True)
98
+ loss, Hvp = hvp_fd_forward(closure, params, u, h=fd_h, g_0=var.get_grad(), normalize=True)
99
99
 
100
100
  elif hvp_method == 'central':
101
101
  loss, Hvp = hvp_fd_central(closure, params, u, h=fd_h, normalize=True)
@@ -109,11 +109,11 @@ class SophiaH(Module):
109
109
  assert h is not None
110
110
  if n_samples > 1: torch._foreach_div_(h, n_samples)
111
111
 
112
- update = vars.get_update()
112
+ update = var.get_update()
113
113
  if 'inner' in self.children:
114
- update = apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars)
114
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
115
115
 
116
- vars.update = sophia_H(
116
+ var.update = sophia_H(
117
117
  tensors=TensorList(update),
118
118
  h=TensorList(h) if h is not None else None,
119
119
  exp_avg_=exp_avg,
@@ -126,4 +126,4 @@ class SophiaH(Module):
126
126
  eps=eps,
127
127
  step=step,
128
128
  )
129
- return vars
129
+ return var
@@ -34,8 +34,8 @@ class DCTProjection(Projection):
34
34
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
35
35
 
36
36
  @torch.no_grad
37
- def project(self, tensors, vars, current):
38
- settings = self.settings[vars.params[0]]
37
+ def project(self, tensors, var, current):
38
+ settings = self.settings[var.params[0]]
39
39
  dims = settings['dims']
40
40
  norm = settings['norm']
41
41
 
@@ -54,8 +54,8 @@ class DCTProjection(Projection):
54
54
  return projected
55
55
 
56
56
  @torch.no_grad
57
- def unproject(self, tensors, vars, current):
58
- settings = self.settings[vars.params[0]]
57
+ def unproject(self, tensors, var, current):
58
+ settings = self.settings[var.params[0]]
59
59
  dims = settings['dims']
60
60
  norm = settings['norm']
61
61
 
@@ -45,8 +45,8 @@ class FFTProjection(Projection):
45
45
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
46
46
 
47
47
  @torch.no_grad
48
- def project(self, tensors, vars, current):
49
- settings = self.settings[vars.params[0]]
48
+ def project(self, tensors, var, current):
49
+ settings = self.settings[var.params[0]]
50
50
  one_d = settings['one_d']
51
51
  norm = settings['norm']
52
52
 
@@ -60,14 +60,14 @@ class FFTProjection(Projection):
60
60
  return [torch.view_as_real(torch.fft.rfftn(t, norm=norm)) if t.numel() > 1 else t for t in tensors] # pylint:disable=not-callable
61
61
 
62
62
  @torch.no_grad
63
- def unproject(self, tensors, vars, current):
64
- settings = self.settings[vars.params[0]]
63
+ def unproject(self, tensors, var, current):
64
+ settings = self.settings[var.params[0]]
65
65
  one_d = settings['one_d']
66
66
  norm = settings['norm']
67
67
 
68
68
  if one_d:
69
69
  vec = torch.view_as_complex(tensors[0])
70
70
  unprojected_vec = torch.fft.irfft(vec, n=self.global_state['length'], norm=norm) # pylint:disable=not-callable
71
- return vec_to_tensors(unprojected_vec, reference=vars.params)
71
+ return vec_to_tensors(unprojected_vec, reference=var.params)
72
72
 
73
- return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, vars.params)] # pylint:disable=not-callable
73
+ return [torch.fft.irfftn(torch.view_as_complex(t.contiguous()), s=p.shape, norm=norm) if t.numel() > 1 else t for t, p in zip(tensors, var.params)] # pylint:disable=not-callable
@@ -6,5 +6,5 @@ from typing import Any, Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Vars
9
+ from ...core import Chainable, Module, Var
10
10
  from .projection import Projection
@@ -6,15 +6,15 @@ from typing import Any, Literal
6
6
  import warnings
7
7
  import torch
8
8
 
9
- from ...core import Chainable, Module, Vars
9
+ from ...core import Chainable, Module, Var
10
10
  from ...utils import vec_to_tensors
11
11
 
12
12
 
13
- def _make_projected_closure(closure, vars: Vars, projection: "Projection",
13
+ def _make_projected_closure(closure, var: Var, projection: "Projection",
14
14
  params: list[torch.Tensor], projected_params: list[torch.Tensor]):
15
15
 
16
16
  def projected_closure(backward=True):
17
- unprojected_params = projection.unproject(projected_params, vars, current='params')
17
+ unprojected_params = projection.unproject(projected_params, var, current='params')
18
18
 
19
19
  with torch.no_grad():
20
20
  for p, new_p in zip(params, unprojected_params):
@@ -23,7 +23,7 @@ def _make_projected_closure(closure, vars: Vars, projection: "Projection",
23
23
  if backward:
24
24
  loss = closure()
25
25
  grads = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
26
- projected_grads = projection.project(grads, vars, current='grads')
26
+ projected_grads = projection.project(grads, var, current='grads')
27
27
  for p, g in zip(projected_params, projected_grads):
28
28
  p.grad = g
29
29
 
@@ -38,15 +38,15 @@ def _projected_get_grad_override(
38
38
  retain_graph: bool | None = None,
39
39
  create_graph: bool = False,
40
40
  projection: Any = ...,
41
- unprojected_vars: Any = ...,
41
+ unprojected_var: Any = ...,
42
42
  self: Any = ...,
43
43
  ):
44
44
  assert isinstance(projection, Projection)
45
- assert isinstance(unprojected_vars, Vars)
46
- assert isinstance(self, Vars)
45
+ assert isinstance(unprojected_var, Var)
46
+ assert isinstance(self, Var)
47
47
 
48
48
  if self.grad is not None: return self.grad
49
- grads = unprojected_vars.get_grad(retain_graph, create_graph)
49
+ grads = unprojected_var.get_grad(retain_graph, create_graph)
50
50
  projected_grads = list(projection.project(grads, self, current='grads'))
51
51
  self.grad = projected_grads
52
52
  for p, g in zip(self.params, projected_grads):
@@ -85,56 +85,56 @@ class Projection(Module, ABC):
85
85
  self._projected_params = None
86
86
 
87
87
  @abstractmethod
88
- def project(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
88
+ def project(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
89
89
  """projects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
90
90
 
91
91
  @abstractmethod
92
- def unproject(self, tensors: list[torch.Tensor], vars: Vars, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
92
+ def unproject(self, tensors: list[torch.Tensor], var: Var, current: Literal['params', 'grads', 'update']) -> Iterable[torch.Tensor]:
93
93
  """unprojects `tensors`. Note that this can be called multiple times per step with `params`, `grads`, and `update`."""
94
94
 
95
95
  @torch.no_grad
96
- def step(self, vars: Vars):
97
- projected_vars = vars.clone(clone_update=False)
96
+ def step(self, var: Var):
97
+ projected_var = var.clone(clone_update=False)
98
98
  update_is_grad = False
99
99
 
100
100
  # closure will calculate projected update and grad if needed
101
- if self._project_params and vars.closure is not None:
102
- if self._project_update and vars.update is not None: projected_vars.update = list(self.project(vars.update, vars=vars, current='update'))
101
+ if self._project_params and var.closure is not None:
102
+ if self._project_update and var.update is not None: projected_var.update = list(self.project(var.update, var=var, current='update'))
103
103
  else:
104
104
  update_is_grad = True
105
- if self._project_grad and vars.grad is not None: projected_vars.grad = list(self.project(vars.grad, vars=vars, current='grads'))
105
+ if self._project_grad and var.grad is not None: projected_var.grad = list(self.project(var.grad, var=var, current='grads'))
106
106
 
107
107
  # project update and grad, unprojected attributes are deleted
108
108
  else:
109
109
  if self._project_update:
110
- if vars.update is None:
110
+ if var.update is None:
111
111
  # update is None, meaning it will be set to `grad`.
112
112
  # we can project grad and use it for update
113
- grad = vars.get_grad()
114
- projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
115
- if self._project_grad: projected_vars.update = [g.clone() for g in projected_vars.grad]
116
- else: projected_vars.update = projected_vars.grad.copy() # don't clone because grad shouldn't be used
117
- del vars.update
113
+ grad = var.get_grad()
114
+ projected_var.grad = list(self.project(grad, var=var, current='grads'))
115
+ if self._project_grad: projected_var.update = [g.clone() for g in projected_var.grad]
116
+ else: projected_var.update = projected_var.grad.copy() # don't clone because grad shouldn't be used
117
+ del var.update
118
118
  update_is_grad = True
119
119
 
120
120
  else:
121
- update = vars.get_update()
122
- projected_vars.update = list(self.project(update, vars=vars, current='update'))
123
- del update, vars.update
121
+ update = var.get_update()
122
+ projected_var.update = list(self.project(update, var=var, current='update'))
123
+ del update, var.update
124
124
 
125
- if self._project_grad and projected_vars.grad is None:
126
- grad = vars.get_grad()
127
- projected_vars.grad = list(self.project(grad, vars=vars, current='grads'))
125
+ if self._project_grad and projected_var.grad is None:
126
+ grad = var.get_grad()
127
+ projected_var.grad = list(self.project(grad, var=var, current='grads'))
128
128
 
129
129
  original_params = None
130
130
  if self._project_params:
131
- original_params = [p.clone() for p in vars.params]
132
- projected_params = self.project(vars.params, vars=vars, current='params')
131
+ original_params = [p.clone() for p in var.params]
132
+ projected_params = self.project(var.params, var=var, current='params')
133
133
 
134
134
  else:
135
135
  # make fake params for correct shapes and state storage
136
136
  # they reuse update or grad storage for memory efficiency
137
- projected_params = projected_vars.update if projected_vars.update is not None else projected_vars.grad
137
+ projected_params = projected_var.update if projected_var.update is not None else projected_var.grad
138
138
  assert projected_params is not None
139
139
 
140
140
  if self._projected_params is None:
@@ -148,22 +148,22 @@ class Projection(Module, ABC):
148
148
 
149
149
  # project closure
150
150
  if self._project_params:
151
- closure = vars.closure; params = vars.params
152
- projected_vars.closure = _make_projected_closure(closure, vars=vars, projection=self, params=params,
151
+ closure = var.closure; params = var.params
152
+ projected_var.closure = _make_projected_closure(closure, var=var, projection=self, params=params,
153
153
  projected_params=self._projected_params)
154
154
 
155
155
  else:
156
- projected_vars.closure = None
156
+ projected_var.closure = None
157
157
 
158
158
  # step
159
- projected_vars.params = self._projected_params
160
- projected_vars.get_grad = partial(
159
+ projected_var.params = self._projected_params
160
+ projected_var.get_grad = partial(
161
161
  _projected_get_grad_override,
162
162
  projection=self,
163
- unprojected_vars=vars,
164
- self=projected_vars,
163
+ unprojected_var=var,
164
+ self=projected_var,
165
165
  )
166
- projected_vars = self.children['modules'].step(projected_vars)
166
+ projected_var = self.children['modules'].step(projected_var)
167
167
 
168
168
  # empty fake params storage
169
169
  # this doesn't affect update/grad because it is a different python object, set_ changes storage on an object
@@ -172,28 +172,28 @@ class Projection(Module, ABC):
172
172
  p.set_(torch.empty(0, device=p.device, dtype=p.dtype)) # pyright: ignore[reportArgumentType]
173
173
 
174
174
  # unproject
175
- unprojected_vars = projected_vars.clone(clone_update=False)
176
- unprojected_vars.closure = vars.closure
177
- unprojected_vars.params = vars.params
178
- unprojected_vars.grad = vars.grad
175
+ unprojected_var = projected_var.clone(clone_update=False)
176
+ unprojected_var.closure = var.closure
177
+ unprojected_var.params = var.params
178
+ unprojected_var.grad = var.grad
179
179
 
180
180
  if self._project_update:
181
- assert projected_vars.update is not None
182
- unprojected_vars.update = list(self.unproject(projected_vars.update, vars=vars, current='grads' if update_is_grad else 'update'))
183
- del projected_vars.update
181
+ assert projected_var.update is not None
182
+ unprojected_var.update = list(self.unproject(projected_var.update, var=var, current='grads' if update_is_grad else 'update'))
183
+ del projected_var.update
184
184
 
185
185
  # unprojecting grad doesn't make sense?
186
186
  # if self._project_grad:
187
- # assert projected_vars.grad is not None
188
- # unprojected_vars.grad = list(self.unproject(projected_vars.grad, vars=vars))
187
+ # assert projected_var.grad is not None
188
+ # unprojected_var.grad = list(self.unproject(projected_var.grad, var=var))
189
189
 
190
- del projected_vars
190
+ del projected_var
191
191
 
192
192
  if original_params is not None:
193
- for p, o in zip(unprojected_vars.params, original_params):
193
+ for p, o in zip(unprojected_var.params, original_params):
194
194
  p.set_(o) # pyright: ignore[reportArgumentType]
195
195
 
196
- return unprojected_vars
196
+ return unprojected_var
197
197
 
198
198
 
199
199
 
@@ -206,12 +206,12 @@ class FlipConcatProjection(Projection):
206
206
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
207
207
 
208
208
  @torch.no_grad
209
- def project(self, tensors, vars, current):
209
+ def project(self, tensors, var, current):
210
210
  return [torch.cat([u.view(-1) for u in tensors], dim=-1).flip(0)]
211
211
 
212
212
  @torch.no_grad
213
- def unproject(self, tensors, vars, current):
214
- return vec_to_tensors(vec=tensors[0].flip(0), reference=vars.params)
213
+ def unproject(self, tensors, var, current):
214
+ return vec_to_tensors(vec=tensors[0].flip(0), reference=var.params)
215
215
 
216
216
 
217
217
  class NoopProjection(Projection):
@@ -221,11 +221,11 @@ class NoopProjection(Projection):
221
221
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
222
222
 
223
223
  @torch.no_grad
224
- def project(self, tensors, vars, current):
224
+ def project(self, tensors, var, current):
225
225
  return tensors
226
226
 
227
227
  @torch.no_grad
228
- def unproject(self, tensors, vars, current):
228
+ def unproject(self, tensors, var, current):
229
229
  return tensors
230
230
 
231
231
  class MultipyProjection(Projection):
@@ -235,10 +235,10 @@ class MultipyProjection(Projection):
235
235
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
236
236
 
237
237
  @torch.no_grad
238
- def project(self, tensors, vars, current):
238
+ def project(self, tensors, var, current):
239
239
  return torch._foreach_mul(tensors, 2)
240
240
 
241
241
  @torch.no_grad
242
- def unproject(self, tensors, vars, current):
242
+ def unproject(self, tensors, var, current):
243
243
  return torch._foreach_div(tensors, 2)
244
244
 
@@ -17,12 +17,12 @@ class VectorProjection(Projection):
17
17
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
18
18
 
19
19
  @torch.no_grad
20
- def project(self, tensors, vars, current):
20
+ def project(self, tensors, var, current):
21
21
  return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
22
22
 
23
23
  @torch.no_grad
24
- def unproject(self, tensors, vars, current):
25
- return vec_to_tensors(vec=tensors[0], reference=vars.params)
24
+ def unproject(self, tensors, var, current):
25
+ return vec_to_tensors(vec=tensors[0], reference=var.params)
26
26
 
27
27
 
28
28
 
@@ -33,8 +33,8 @@ class TensorizeProjection(Projection):
33
33
  super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
34
34
 
35
35
  @torch.no_grad
36
- def project(self, tensors, vars, current):
37
- params = vars.params
36
+ def project(self, tensors, var, current):
37
+ params = var.params
38
38
  max_side = self.settings[params[0]]['max_side']
39
39
  num_elems = sum(t.numel() for t in tensors)
40
40
 
@@ -60,12 +60,12 @@ class TensorizeProjection(Projection):
60
60
  return [vec.view(dims)]
61
61
 
62
62
  @torch.no_grad
63
- def unproject(self, tensors, vars, current):
63
+ def unproject(self, tensors, var, current):
64
64
  remainder = self.global_state['remainder']
65
65
  # warnings.warn(f'{tensors[0].shape = }')
66
66
  vec = tensors[0].view(-1)
67
67
  if remainder > 0: vec = vec[:-remainder]
68
- return vec_to_tensors(vec, vars.params)
68
+ return vec_to_tensors(vec, var.params)
69
69
 
70
70
  class BlockPartition(Projection):
71
71
  """splits parameters into blocks (for now flatttens them and chunks)"""
@@ -74,9 +74,9 @@ class BlockPartition(Projection):
74
74
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
75
75
 
76
76
  @torch.no_grad
77
- def project(self, tensors, vars, current):
77
+ def project(self, tensors, var, current):
78
78
  partitioned = []
79
- for p,t in zip(vars.params, tensors):
79
+ for p,t in zip(var.params, tensors):
80
80
  settings = self.settings[p]
81
81
  max_size = settings['max_size']
82
82
  n = t.numel()
@@ -101,10 +101,10 @@ class BlockPartition(Projection):
101
101
  return partitioned
102
102
 
103
103
  @torch.no_grad
104
- def unproject(self, tensors, vars, current):
104
+ def unproject(self, tensors, var, current):
105
105
  ti = iter(tensors)
106
106
  unprojected = []
107
- for p in vars.params:
107
+ for p in var.params:
108
108
  settings = self.settings[p]
109
109
  n = p.numel()
110
110
 
@@ -130,19 +130,19 @@ class TensorNormsProjection(Projection):
130
130
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
131
131
 
132
132
  @torch.no_grad
133
- def project(self, tensors, vars, current):
134
- orig = self.get_state(f'{current}_orig', params=vars.params)
133
+ def project(self, tensors, var, current):
134
+ orig = self.get_state(var.params, f'{current}_orig')
135
135
  torch._foreach_copy_(orig, tensors)
136
136
 
137
137
  norms = torch._foreach_norm(tensors)
138
- self.get_state(f'{current}_orig_norms', params=vars.params, init=norms, cls=TensorList).set_(norms)
138
+ self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
139
139
 
140
140
  return [torch.stack(norms)]
141
141
 
142
142
  @torch.no_grad
143
- def unproject(self, tensors, vars, current):
144
- orig = self.get_state(f'{current}_orig', params=vars.params)
145
- orig_norms = torch.stack(self.get_state(f'{current}_orig_norms', params=vars.params))
143
+ def unproject(self, tensors, var, current):
144
+ orig = self.get_state(var.params, f'{current}_orig')
145
+ orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
146
146
  target_norms = tensors[0]
147
147
 
148
148
  orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
@@ -1,7 +1,36 @@
1
- from .cg import PolakRibiere, FletcherReeves, HestenesStiefel, DaiYuan, LiuStorey, ConjugateDescent, HagerZhang, HybridHS_DY
1
+ from .cg import (
2
+ ConjugateDescent,
3
+ DaiYuan,
4
+ FletcherReeves,
5
+ HagerZhang,
6
+ HestenesStiefel,
7
+ HybridHS_DY,
8
+ LiuStorey,
9
+ PolakRibiere,
10
+ ProjectedGradientMethod,
11
+ )
2
12
  from .lbfgs import LBFGS
13
+ from .lsr1 import LSR1
3
14
  from .olbfgs import OnlineLBFGS
4
- # from .experimental import ModularLBFGS
5
15
 
6
- from .quasi_newton import BFGS, SR1, DFP, BroydenGood, BroydenBad, Greenstadt1, Greenstadt2, ColumnUpdatingMethod, ThomasOptimalMethod, PSB, Pearson2, SSVM
7
- from .lsr1 import LSR1
16
+ # from .experimental import ModularLBFGS
17
+ from .quasi_newton import (
18
+ BFGS,
19
+ DFP,
20
+ PSB,
21
+ SR1,
22
+ SSVM,
23
+ BroydenBad,
24
+ BroydenGood,
25
+ ColumnUpdatingMethod,
26
+ FletcherVMM,
27
+ GradientCorrection,
28
+ Greenstadt1,
29
+ Greenstadt2,
30
+ Horisho,
31
+ McCormick,
32
+ NewSSM,
33
+ Pearson,
34
+ ProjectedNewtonRaphson,
35
+ ThomasOptimalMethod,
36
+ )