torchzero 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchzero/core/__init__.py +1 -1
- torchzero/core/module.py +72 -49
- torchzero/core/tensorlist_optimizer.py +1 -1
- torchzero/modules/adaptive/adaptive.py +11 -11
- torchzero/modules/experimental/experimental.py +41 -41
- torchzero/modules/experimental/quad_interp.py +8 -8
- torchzero/modules/experimental/subspace.py +37 -37
- torchzero/modules/gradient_approximation/base_approximator.py +19 -24
- torchzero/modules/gradient_approximation/fdm.py +1 -1
- torchzero/modules/gradient_approximation/newton_fdm.py +13 -13
- torchzero/modules/gradient_approximation/rfdm.py +1 -1
- torchzero/modules/line_search/armijo.py +8 -8
- torchzero/modules/line_search/base_ls.py +8 -8
- torchzero/modules/line_search/directional_newton.py +14 -14
- torchzero/modules/line_search/grid_ls.py +7 -7
- torchzero/modules/line_search/scipy_minimize_scalar.py +3 -3
- torchzero/modules/meta/alternate.py +4 -4
- torchzero/modules/meta/grafting.py +23 -23
- torchzero/modules/meta/optimizer_wrapper.py +14 -14
- torchzero/modules/meta/return_overrides.py +8 -8
- torchzero/modules/misc/accumulate.py +6 -6
- torchzero/modules/misc/basic.py +16 -16
- torchzero/modules/misc/lr.py +2 -2
- torchzero/modules/misc/multistep.py +7 -7
- torchzero/modules/misc/on_increase.py +9 -9
- torchzero/modules/momentum/momentum.py +4 -4
- torchzero/modules/operations/multi.py +44 -44
- torchzero/modules/operations/reduction.py +28 -28
- torchzero/modules/operations/singular.py +9 -9
- torchzero/modules/optimizers/adagrad.py +1 -1
- torchzero/modules/optimizers/adam.py +8 -8
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +1 -1
- torchzero/modules/optimizers/rprop.py +1 -1
- torchzero/modules/optimizers/sgd.py +2 -2
- torchzero/modules/orthogonalization/newtonschulz.py +3 -3
- torchzero/modules/orthogonalization/svd.py +1 -1
- torchzero/modules/regularization/dropout.py +1 -1
- torchzero/modules/regularization/noise.py +3 -3
- torchzero/modules/regularization/normalization.py +5 -5
- torchzero/modules/regularization/ortho_grad.py +1 -1
- torchzero/modules/regularization/weight_decay.py +1 -1
- torchzero/modules/scheduling/lr_schedulers.py +2 -2
- torchzero/modules/scheduling/step_size.py +8 -8
- torchzero/modules/second_order/newton.py +12 -12
- torchzero/modules/smoothing/__init__.py +1 -1
- torchzero/modules/smoothing/gaussian_smoothing.py +7 -7
- torchzero/modules/smoothing/laplacian_smoothing.py +1 -1
- torchzero/modules/weight_averaging/ema.py +3 -3
- torchzero/modules/weight_averaging/swa.py +8 -8
- torchzero/optim/first_order/forward_gradient.py +1 -1
- torchzero/optim/modular.py +4 -4
- torchzero/tensorlist.py +8 -1
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/METADATA +1 -1
- torchzero-0.1.5.dist-info/RECORD +104 -0
- torchzero-0.1.3.dist-info/RECORD +0 -104
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/LICENSE +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/WHEEL +0 -0
- {torchzero-0.1.3.dist-info → torchzero-0.1.5.dist-info}/top_level.txt +0 -0
torchzero/core/__init__.py
CHANGED
torchzero/core/module.py
CHANGED
|
@@ -23,8 +23,8 @@ def _get_loss(fx0, fx0_approx):
|
|
|
23
23
|
return fx0
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class
|
|
27
|
-
"""Holds optimization
|
|
26
|
+
class OptimizationVars:
|
|
27
|
+
"""Holds optimization variables. This is usually automatically created by :any:`torchzero.optim.Modular`."""
|
|
28
28
|
def __init__(self, closure: _ClosureType | None, model: torch.nn.Module | None):
|
|
29
29
|
|
|
30
30
|
self.closure: _ClosureType | None = closure
|
|
@@ -121,23 +121,23 @@ class OptimizationState:
|
|
|
121
121
|
Returns:
|
|
122
122
|
A copy of this OptimizationState.
|
|
123
123
|
"""
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
124
|
+
vars = OptimizationVars(self.closure, self.model)
|
|
125
|
+
vars.fx0 = self.fx0
|
|
126
|
+
vars.fx0_approx = self.fx0_approx
|
|
127
|
+
vars.grad = self.grad
|
|
128
128
|
|
|
129
|
-
if clone_ascent and self.ascent is not None:
|
|
130
|
-
else:
|
|
129
|
+
if clone_ascent and self.ascent is not None: vars.ascent = self.ascent.clone()
|
|
130
|
+
else: vars.ascent = self.ascent
|
|
131
131
|
|
|
132
|
-
return
|
|
132
|
+
return vars
|
|
133
133
|
|
|
134
|
-
def update_attrs_(self,
|
|
134
|
+
def update_attrs_(self, vars: "OptimizationVars"):
|
|
135
135
|
"""Updates attributes of this state with attributes of another state.
|
|
136
136
|
|
|
137
137
|
This updates `grad`, `fx0` and `fx0_approx`."""
|
|
138
|
-
if
|
|
139
|
-
if
|
|
140
|
-
if
|
|
138
|
+
if vars.grad is not None: self.grad = vars.grad
|
|
139
|
+
if vars.fx0 is not None: self.fx0 = vars.fx0
|
|
140
|
+
if vars.fx0_approx is not None: self.fx0_approx = vars.fx0_approx
|
|
141
141
|
|
|
142
142
|
|
|
143
143
|
def add_post_step_hook(self, hook: Callable):
|
|
@@ -283,7 +283,7 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
|
|
|
283
283
|
for c in self.children.values():
|
|
284
284
|
self._update_child_params_(c)
|
|
285
285
|
|
|
286
|
-
def _update_params_or_step_with_next(self,
|
|
286
|
+
def _update_params_or_step_with_next(self, vars: OptimizationVars, params: TensorList | None = None) -> _ScalarLoss | None:
|
|
287
287
|
"""If this has no children, update params and return loss. Otherwise step with the next module.
|
|
288
288
|
|
|
289
289
|
Optionally pass params to not recreate them if you've already made them.
|
|
@@ -293,29 +293,29 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
|
|
|
293
293
|
"""
|
|
294
294
|
# if this has no children, update params and return loss.
|
|
295
295
|
if self.next_module is None:
|
|
296
|
-
if
|
|
296
|
+
if vars.ascent is None: raise ValueError('Called _update_params_or_step_with_child but ascent_direction is None...')
|
|
297
297
|
if params is None: params = self.get_params()
|
|
298
|
-
params -=
|
|
299
|
-
return
|
|
298
|
+
params -= vars.ascent # type:ignore
|
|
299
|
+
return vars.get_loss()
|
|
300
300
|
|
|
301
301
|
# otherwise pass the updated ascent direction to the child
|
|
302
|
-
return self.next_module.step(
|
|
302
|
+
return self.next_module.step(vars)
|
|
303
303
|
|
|
304
304
|
@torch.no_grad
|
|
305
|
-
def _step_update_closure(self,
|
|
305
|
+
def _step_update_closure(self, vars: OptimizationVars) -> _ScalarLoss | None:
|
|
306
306
|
"""Create a new closure which applies the `_update` method and passes it to the next module."""
|
|
307
|
-
if
|
|
307
|
+
if vars.closure is None: raise ValueError('If target == "closure", closure must be provided')
|
|
308
308
|
|
|
309
309
|
params = self.get_params()
|
|
310
|
-
closure =
|
|
311
|
-
ascent_direction =
|
|
310
|
+
closure = vars.closure # closure shouldn't reference state attribute because it can be changed
|
|
311
|
+
ascent_direction = vars.ascent
|
|
312
312
|
|
|
313
313
|
def update_closure(backward = True):
|
|
314
314
|
loss = _maybe_pass_backward(closure, backward)
|
|
315
315
|
|
|
316
316
|
# on backward, update the ascent direction
|
|
317
317
|
if backward:
|
|
318
|
-
grad = self._update(
|
|
318
|
+
grad = self._update(vars, ascent_direction) # type:ignore
|
|
319
319
|
# set new ascent direction as gradients
|
|
320
320
|
# (accumulation doesn't make sense here as closure always calls zero_grad)
|
|
321
321
|
for p, g in zip(params,grad):
|
|
@@ -327,12 +327,12 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
|
|
|
327
327
|
# if self.next_module is None:
|
|
328
328
|
# raise ValueError(f'{self.__class__.__name__} has no child to step with (maybe set "target" from "closure" to something else??).')
|
|
329
329
|
|
|
330
|
-
|
|
331
|
-
return self._update_params_or_step_with_next(
|
|
330
|
+
vars.closure = update_closure
|
|
331
|
+
return self._update_params_or_step_with_next(vars)
|
|
332
332
|
|
|
333
333
|
|
|
334
334
|
@torch.no_grad
|
|
335
|
-
def _step_update_target(self,
|
|
335
|
+
def _step_update_target(self, vars: OptimizationVars) -> _ScalarLoss | None:
|
|
336
336
|
"""Apply _update method to the ascent direction and pass it to the child, or make a step if child is None."""
|
|
337
337
|
# the following code by default uses `_update` method which simple modules can override.
|
|
338
338
|
# But you can also just override the entire `step`.
|
|
@@ -342,50 +342,73 @@ class OptimizerModule(TensorListOptimizer, ABC): # type:ignore
|
|
|
342
342
|
# update ascent direction
|
|
343
343
|
if self._default_step_target == 'ascent':
|
|
344
344
|
# if this is the first module, it uses the gradients
|
|
345
|
-
if
|
|
346
|
-
t =
|
|
347
|
-
|
|
345
|
+
if vars.grad is None: params = self.get_params()
|
|
346
|
+
t = vars.maybe_use_grad_(params)
|
|
347
|
+
vars.ascent = self._update(vars, t)
|
|
348
348
|
|
|
349
349
|
# update gradients
|
|
350
350
|
elif self._default_step_target == 'grad':
|
|
351
351
|
if params is None: params = self.get_params()
|
|
352
|
-
g =
|
|
353
|
-
g = self._update(
|
|
354
|
-
|
|
352
|
+
g = vars.maybe_compute_grad_(params)
|
|
353
|
+
g = self._update(vars, g)
|
|
354
|
+
vars.set_grad_(g, params)
|
|
355
355
|
else:
|
|
356
356
|
raise ValueError(f"Invalid {self._default_step_target = }")
|
|
357
357
|
|
|
358
358
|
# peform an update with the new state, or pass it to the child.
|
|
359
|
-
return self._update_params_or_step_with_next(
|
|
359
|
+
return self._update_params_or_step_with_next(vars, params=params)
|
|
360
360
|
|
|
361
361
|
@torch.no_grad
|
|
362
362
|
def step( # type:ignore # pylint:disable=signature-differs # pylint:disable = arguments-renamed
|
|
363
363
|
self,
|
|
364
|
-
|
|
364
|
+
vars: OptimizationVars
|
|
365
365
|
) -> _ScalarLoss | None:
|
|
366
366
|
"""Perform a single optimization step to update parameter."""
|
|
367
367
|
|
|
368
|
-
if self._default_step_target == 'closure': return self._step_update_closure(
|
|
369
|
-
return self._step_update_target(
|
|
368
|
+
if self._default_step_target == 'closure': return self._step_update_closure(vars)
|
|
369
|
+
return self._step_update_target(vars)
|
|
370
370
|
|
|
371
371
|
@torch.no_grad
|
|
372
|
-
def _update(self,
|
|
372
|
+
def _update(self, vars: OptimizationVars, ascent: TensorList) -> TensorList:
|
|
373
373
|
"""Update `ascent_direction` and return the new ascent direction (but it may update it in place).
|
|
374
|
-
Make sure it doesn't return anything from `state` to avoid future modules modifying that in-place.
|
|
374
|
+
Make sure it doesn't return anything from `self.state` to avoid future modules modifying that in-place.
|
|
375
375
|
|
|
376
376
|
Before calling `_update`, if ascent direction was not provided to `step`, it will be set to the gradients.
|
|
377
377
|
|
|
378
378
|
After generating a new ascent direction with this `_update` method,
|
|
379
379
|
if this module has no child, ascent direction will be subtracted from params.
|
|
380
380
|
Otherwise everything is passed to the child."""
|
|
381
|
+
params = self.get_params()
|
|
382
|
+
gradients = ascent.grad
|
|
383
|
+
if gradients is None: gradients = [None] * len(params)
|
|
384
|
+
settings = tuple(self.get_all_group_keys(list).items())
|
|
385
|
+
|
|
386
|
+
new_ascent = TensorList()
|
|
387
|
+
for i, (asc, param, grad) in enumerate(zip(ascent, params, gradients)):
|
|
388
|
+
kwargs = {"vars": vars, "ascent": asc, "param": param, "grad": grad}
|
|
389
|
+
kwargs.update({k:v[i] for k,v in settings})
|
|
390
|
+
new_ascent.append(self._single_tensor_update(**kwargs))
|
|
391
|
+
return new_ascent
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def _single_tensor_update(self, vars: OptimizationVars, ascent: torch.Tensor, param: torch.Tensor, grad: torch.Tensor | None, **kwargs) -> torch.Tensor:
|
|
395
|
+
"""Update function for a single tensor.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
vars (OptimizationState): holds loss, gradients, etc.
|
|
399
|
+
ascent (torch.Tensor): update tensor.
|
|
400
|
+
param (torch.Tensor): parameter tensor.
|
|
401
|
+
grad (torch.Tensor | None): gradient tensor (may be None)
|
|
402
|
+
**kwargs: all per-parameter settings (stuff that you put in `defaults = dict(beta1=beta1, beta2=beta2, eps=eps)`).
|
|
403
|
+
"""
|
|
381
404
|
raise NotImplementedError()
|
|
382
405
|
|
|
383
|
-
def return_ascent(self,
|
|
406
|
+
def return_ascent(self, vars: OptimizationVars, params=None) -> TensorList:
|
|
384
407
|
"""step with this module and return the ascent as tensorlist"""
|
|
385
408
|
if params is None: params = self.get_params()
|
|
386
409
|
true_next = self.next_module
|
|
387
410
|
self.next_module = _ReturnAscent(params) # type:ignore
|
|
388
|
-
ascent: TensorList = self.step(
|
|
411
|
+
ascent: TensorList = self.step(vars) # type:ignore
|
|
389
412
|
self.next_module = true_next
|
|
390
413
|
return ascent
|
|
391
414
|
|
|
@@ -412,8 +435,8 @@ class _ReturnAscent:
|
|
|
412
435
|
self.next_module = None
|
|
413
436
|
|
|
414
437
|
@torch.no_grad
|
|
415
|
-
def step(self,
|
|
416
|
-
update =
|
|
438
|
+
def step(self, vars: OptimizationVars) -> TensorList: # type:ignore
|
|
439
|
+
update = vars.maybe_use_grad_(self.params) # this will execute the closure which might be modified
|
|
417
440
|
return update
|
|
418
441
|
|
|
419
442
|
|
|
@@ -424,13 +447,13 @@ class _MaybeReturnAscent(OptimizerModule):
|
|
|
424
447
|
self._return_ascent = False
|
|
425
448
|
|
|
426
449
|
@torch.no_grad
|
|
427
|
-
def step(self,
|
|
450
|
+
def step(self, vars: OptimizationVars):
|
|
428
451
|
assert self.next_module is None, self.next_module
|
|
429
452
|
|
|
430
453
|
if self._return_ascent:
|
|
431
|
-
return
|
|
454
|
+
return vars.ascent
|
|
432
455
|
|
|
433
|
-
return self._update_params_or_step_with_next(
|
|
456
|
+
return self._update_params_or_step_with_next(vars)
|
|
434
457
|
|
|
435
458
|
_Chainable = OptimizerModule | Iterable[OptimizerModule]
|
|
436
459
|
|
|
@@ -456,16 +479,16 @@ class _Chain(OptimizerModule):
|
|
|
456
479
|
self._chain_modules = flat_modules
|
|
457
480
|
|
|
458
481
|
@torch.no_grad
|
|
459
|
-
def step(self,
|
|
482
|
+
def step(self, vars: OptimizationVars):
|
|
460
483
|
# no next module, step with the child
|
|
461
484
|
if self.next_module is None:
|
|
462
|
-
return self.children['first'].step(
|
|
485
|
+
return self.children['first'].step(vars)
|
|
463
486
|
|
|
464
487
|
# return ascent and pass it to next module
|
|
465
488
|
# we do this because updating parameters directly is often more efficient
|
|
466
489
|
params = self.get_params()
|
|
467
490
|
self._last_module.next_module = _ReturnAscent(params) # type:ignore
|
|
468
|
-
|
|
491
|
+
vars.ascent: TensorList = self.children['first'].step(vars) # type:ignore
|
|
469
492
|
self._last_module.next_module = None
|
|
470
493
|
|
|
471
|
-
return self._update_params_or_step_with_next(
|
|
494
|
+
return self._update_params_or_step_with_next(vars)
|
|
@@ -149,7 +149,7 @@ class TensorListOptimizer(torch.optim.Optimizer, ABC):
|
|
|
149
149
|
|
|
150
150
|
# def get_group_keys[CLS: MutableSequence](self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
|
|
151
151
|
def get_group_keys(self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
|
|
152
|
-
"""Returns a
|
|
152
|
+
"""Returns a list with the param_groups `key` setting of each param."""
|
|
153
153
|
|
|
154
154
|
all_values: list[CLS] = [cls() for _ in keys]
|
|
155
155
|
for group in self.param_groups:
|
|
@@ -33,16 +33,16 @@ class Cautious(OptimizerModule):
|
|
|
33
33
|
self.mode: typing.Literal['zero', 'grad', 'backtrack'] = mode
|
|
34
34
|
|
|
35
35
|
@torch.no_grad
|
|
36
|
-
def _update(self,
|
|
36
|
+
def _update(self, vars, ascent):
|
|
37
37
|
params = self.get_params()
|
|
38
|
-
grad =
|
|
38
|
+
grad = vars.maybe_compute_grad_(params)
|
|
39
39
|
|
|
40
40
|
# mask will be > 0 for parameters where both signs are the same
|
|
41
41
|
mask = (ascent * grad) > 0
|
|
42
42
|
if self.mode in ('zero', 'grad'):
|
|
43
43
|
if self.normalize and self.mode == 'zero':
|
|
44
44
|
fmask = mask.to(ascent[0].dtype)
|
|
45
|
-
fmask /= fmask.total_mean() + self.eps
|
|
45
|
+
fmask /= fmask.total_mean() + self.eps # type:ignore
|
|
46
46
|
else:
|
|
47
47
|
fmask = mask
|
|
48
48
|
|
|
@@ -66,9 +66,9 @@ class UseGradSign(OptimizerModule):
|
|
|
66
66
|
super().__init__({})
|
|
67
67
|
|
|
68
68
|
@torch.no_grad
|
|
69
|
-
def _update(self,
|
|
69
|
+
def _update(self, vars, ascent):
|
|
70
70
|
params = self.get_params()
|
|
71
|
-
grad =
|
|
71
|
+
grad = vars.maybe_compute_grad_(params)
|
|
72
72
|
|
|
73
73
|
return ascent.abs_().mul_(grad.sign())
|
|
74
74
|
|
|
@@ -80,9 +80,9 @@ class UseGradMagnitude(OptimizerModule):
|
|
|
80
80
|
super().__init__({})
|
|
81
81
|
|
|
82
82
|
@torch.no_grad
|
|
83
|
-
def _update(self,
|
|
83
|
+
def _update(self, vars, ascent):
|
|
84
84
|
params = self.get_params()
|
|
85
|
-
grad =
|
|
85
|
+
grad = vars.maybe_compute_grad_(params)
|
|
86
86
|
|
|
87
87
|
return ascent.sign_().mul_(grad.abs())
|
|
88
88
|
|
|
@@ -109,10 +109,10 @@ class ScaleLRBySignChange(OptimizerModule):
|
|
|
109
109
|
self.use_grad = use_grad
|
|
110
110
|
|
|
111
111
|
@torch.no_grad
|
|
112
|
-
def _update(self,
|
|
112
|
+
def _update(self, vars, ascent):
|
|
113
113
|
params = self.get_params()
|
|
114
114
|
|
|
115
|
-
if self.use_grad: cur =
|
|
115
|
+
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
116
116
|
else: cur = ascent
|
|
117
117
|
|
|
118
118
|
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
|
|
@@ -168,10 +168,10 @@ class NegateOnSignChange(OptimizerModule):
|
|
|
168
168
|
self.current_step = 0
|
|
169
169
|
|
|
170
170
|
@torch.no_grad
|
|
171
|
-
def _update(self,
|
|
171
|
+
def _update(self, vars, ascent):
|
|
172
172
|
params = self.get_params()
|
|
173
173
|
|
|
174
|
-
if self.use_grad: cur =
|
|
174
|
+
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
175
175
|
else: cur = ascent
|
|
176
176
|
|
|
177
177
|
prev = self.get_state_key('prev')
|
|
@@ -35,9 +35,9 @@ class MinibatchRprop(OptimizerModule):
|
|
|
35
35
|
self.next_mode = next_mode
|
|
36
36
|
|
|
37
37
|
@torch.no_grad
|
|
38
|
-
def step(self,
|
|
39
|
-
if
|
|
40
|
-
if
|
|
38
|
+
def step(self, vars):
|
|
39
|
+
if vars.closure is None: raise ValueError("Minibatch Rprop requires closure")
|
|
40
|
+
if vars.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
|
|
41
41
|
params = self.get_params()
|
|
42
42
|
|
|
43
43
|
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
|
|
@@ -47,7 +47,7 @@ class MinibatchRprop(OptimizerModule):
|
|
|
47
47
|
params=params
|
|
48
48
|
)
|
|
49
49
|
|
|
50
|
-
g1_sign =
|
|
50
|
+
g1_sign = vars.maybe_compute_grad_(params).sign() # no inplace to not modify grads
|
|
51
51
|
# initialize on 1st iteration
|
|
52
52
|
if self.current_step == 0:
|
|
53
53
|
magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
|
|
@@ -58,8 +58,8 @@ class MinibatchRprop(OptimizerModule):
|
|
|
58
58
|
# first step
|
|
59
59
|
ascent = g1_sign.mul_(magnitudes).mul_(allowed)
|
|
60
60
|
params -= ascent
|
|
61
|
-
with torch.enable_grad():
|
|
62
|
-
f0 =
|
|
61
|
+
with torch.enable_grad(): vars.fx0_approx = vars.closure()
|
|
62
|
+
f0 = vars.fx0; f1 = vars.fx0_approx
|
|
63
63
|
assert f0 is not None and f1 is not None
|
|
64
64
|
|
|
65
65
|
# if loss increased, reduce all lrs and undo the update
|
|
@@ -73,9 +73,9 @@ class MinibatchRprop(OptimizerModule):
|
|
|
73
73
|
# on `continue` we move to params after 1st update
|
|
74
74
|
# therefore state must be updated to have all attributes after 1st update
|
|
75
75
|
if self.next_mode == 'continue':
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
sign =
|
|
76
|
+
vars.fx0 = vars.fx0_approx
|
|
77
|
+
vars.grad = params.ensure_grad_().grad
|
|
78
|
+
sign = vars.grad.sign()
|
|
79
79
|
|
|
80
80
|
else:
|
|
81
81
|
sign = params.ensure_grad_().grad.sign_() # can use in-place as this is not fx0 grad
|
|
@@ -109,19 +109,19 @@ class MinibatchRprop(OptimizerModule):
|
|
|
109
109
|
|
|
110
110
|
# update params or step
|
|
111
111
|
if self.next_mode == 'continue' or (self.next_mode == 'add' and self.next_module is None):
|
|
112
|
-
|
|
113
|
-
return self._update_params_or_step_with_next(
|
|
112
|
+
vars.ascent = ascent2
|
|
113
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
114
114
|
|
|
115
115
|
if self.next_mode == 'add':
|
|
116
116
|
# undo 1st step
|
|
117
117
|
params += ascent
|
|
118
|
-
|
|
119
|
-
return self._update_params_or_step_with_next(
|
|
118
|
+
vars.ascent = ascent + ascent2
|
|
119
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
120
120
|
|
|
121
121
|
if self.next_mode == 'undo':
|
|
122
122
|
params += ascent
|
|
123
|
-
|
|
124
|
-
return self._update_params_or_step_with_next(
|
|
123
|
+
vars.ascent = ascent2
|
|
124
|
+
return self._update_params_or_step_with_next(vars, params)
|
|
125
125
|
|
|
126
126
|
raise ValueError(f'invalid next_mode: {self.next_mode}')
|
|
127
127
|
|
|
@@ -140,9 +140,9 @@ class GradMin(OptimizerModule):
|
|
|
140
140
|
self.create_graph = create_graph
|
|
141
141
|
|
|
142
142
|
@torch.no_grad
|
|
143
|
-
def step(self,
|
|
144
|
-
if
|
|
145
|
-
if
|
|
143
|
+
def step(self, vars):
|
|
144
|
+
if vars.closure is None: raise ValueError()
|
|
145
|
+
if vars.ascent is not None:
|
|
146
146
|
raise ValueError("GradMin doesn't accept ascent_direction")
|
|
147
147
|
|
|
148
148
|
params = self.get_params()
|
|
@@ -150,26 +150,26 @@ class GradMin(OptimizerModule):
|
|
|
150
150
|
|
|
151
151
|
self.zero_grad()
|
|
152
152
|
with torch.enable_grad():
|
|
153
|
-
|
|
154
|
-
grads = jacobian([
|
|
153
|
+
vars.fx0 = vars.closure(False)
|
|
154
|
+
grads = jacobian([vars.fx0], params, create_graph=True, batched=False) # type:ignore
|
|
155
155
|
grads = TensorList(grads).squeeze_(0)
|
|
156
156
|
if self.square:
|
|
157
157
|
grads = grads ** 2
|
|
158
158
|
else:
|
|
159
159
|
grads = grads.abs()
|
|
160
160
|
|
|
161
|
-
if self.maximize_grad: grads: TensorList = grads - (
|
|
162
|
-
else: grads = grads + (
|
|
161
|
+
if self.maximize_grad: grads: TensorList = grads - (vars.fx0 * loss_term) # type:ignore
|
|
162
|
+
else: grads = grads + (vars.fx0 * loss_term)
|
|
163
163
|
grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
|
|
164
164
|
|
|
165
165
|
if self.create_graph: grad_mean.backward(create_graph=True)
|
|
166
166
|
else: grad_mean.backward(retain_graph=False)
|
|
167
167
|
|
|
168
|
-
if self.maximize_grad:
|
|
169
|
-
else:
|
|
168
|
+
if self.maximize_grad: vars.grad = params.ensure_grad_().grad.neg_()
|
|
169
|
+
else: vars.grad = params.ensure_grad_().grad
|
|
170
170
|
|
|
171
|
-
|
|
172
|
-
return self._update_params_or_step_with_next(
|
|
171
|
+
vars.maybe_use_grad_(params)
|
|
172
|
+
return self._update_params_or_step_with_next(vars)
|
|
173
173
|
|
|
174
174
|
|
|
175
175
|
class HVPDiagNewton(OptimizerModule):
|
|
@@ -182,26 +182,26 @@ class HVPDiagNewton(OptimizerModule):
|
|
|
182
182
|
super().__init__(dict(eps=eps))
|
|
183
183
|
|
|
184
184
|
@torch.no_grad
|
|
185
|
-
def step(self,
|
|
186
|
-
if
|
|
187
|
-
if
|
|
185
|
+
def step(self, vars):
|
|
186
|
+
if vars.closure is None: raise ValueError()
|
|
187
|
+
if vars.ascent is not None:
|
|
188
188
|
raise ValueError("HVPDiagNewton doesn't accept ascent_direction")
|
|
189
189
|
|
|
190
190
|
params = self.get_params()
|
|
191
191
|
eps = self.get_group_key('eps')
|
|
192
|
-
grad_fx0 =
|
|
193
|
-
|
|
192
|
+
grad_fx0 = vars.maybe_compute_grad_(params).clone()
|
|
193
|
+
vars.grad = grad_fx0 # set state grad to the cloned version, since it will be overwritten
|
|
194
194
|
|
|
195
195
|
params += grad_fx0 * eps
|
|
196
|
-
with torch.enable_grad(): _ =
|
|
196
|
+
with torch.enable_grad(): _ = vars.closure()
|
|
197
197
|
|
|
198
198
|
params -= grad_fx0 * eps
|
|
199
199
|
|
|
200
200
|
newton = grad_fx0 * ((grad_fx0 * eps) / (params.grad - grad_fx0))
|
|
201
201
|
newton.nan_to_num_(0,0,0)
|
|
202
202
|
|
|
203
|
-
|
|
204
|
-
return self._update_params_or_step_with_next(
|
|
203
|
+
vars.ascent = newton
|
|
204
|
+
return self._update_params_or_step_with_next(vars)
|
|
205
205
|
|
|
206
206
|
|
|
207
207
|
|
|
@@ -219,11 +219,11 @@ class ReduceOutwardLR(OptimizerModule):
|
|
|
219
219
|
self.invert = invert
|
|
220
220
|
|
|
221
221
|
@torch.no_grad
|
|
222
|
-
def _update(self,
|
|
222
|
+
def _update(self, vars, ascent):
|
|
223
223
|
params = self.get_params()
|
|
224
224
|
mul = self.get_group_key('mul')
|
|
225
225
|
|
|
226
|
-
if self.use_grad: cur =
|
|
226
|
+
if self.use_grad: cur = vars.maybe_compute_grad_(params)
|
|
227
227
|
else: cur = ascent
|
|
228
228
|
|
|
229
229
|
# mask of weights where sign matches with update sign (minus ascent sign), multiplied by `mul`.
|
|
@@ -241,7 +241,7 @@ class NoiseSign(OptimizerModule):
|
|
|
241
241
|
self.distribution:Distributions = distribution
|
|
242
242
|
|
|
243
243
|
|
|
244
|
-
def _update(self,
|
|
244
|
+
def _update(self, vars, ascent):
|
|
245
245
|
return ascent.sample_like(self.alpha, self.distribution).copysign_(ascent)
|
|
246
246
|
|
|
247
247
|
class ParamSign(OptimizerModule):
|
|
@@ -250,7 +250,7 @@ class ParamSign(OptimizerModule):
|
|
|
250
250
|
super().__init__({})
|
|
251
251
|
|
|
252
252
|
|
|
253
|
-
def _update(self,
|
|
253
|
+
def _update(self, vars, ascent):
|
|
254
254
|
params = self.get_params()
|
|
255
255
|
|
|
256
256
|
return params.copysign(ascent)
|
|
@@ -261,7 +261,7 @@ class NegParamSign(OptimizerModule):
|
|
|
261
261
|
super().__init__({})
|
|
262
262
|
|
|
263
263
|
|
|
264
|
-
def _update(self,
|
|
264
|
+
def _update(self, vars, ascent):
|
|
265
265
|
neg_params = self.get_params().abs()
|
|
266
266
|
max = neg_params.total_max()
|
|
267
267
|
neg_params = neg_params.neg_().add(max)
|
|
@@ -274,7 +274,7 @@ class InvParamSign(OptimizerModule):
|
|
|
274
274
|
self.eps = eps
|
|
275
275
|
|
|
276
276
|
|
|
277
|
-
def _update(self,
|
|
277
|
+
def _update(self, vars, ascent):
|
|
278
278
|
inv_params = self.get_params().abs().add_(self.eps).reciprocal_()
|
|
279
279
|
return inv_params.copysign(ascent)
|
|
280
280
|
|
|
@@ -286,7 +286,7 @@ class ParamWhereConsistentSign(OptimizerModule):
|
|
|
286
286
|
self.eps = eps
|
|
287
287
|
|
|
288
288
|
|
|
289
|
-
def _update(self,
|
|
289
|
+
def _update(self, vars, ascent):
|
|
290
290
|
params = self.get_params()
|
|
291
291
|
same_sign = params.sign() == ascent.sign()
|
|
292
292
|
ascent.masked_set_(same_sign, params)
|
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
6
|
from ...tensorlist import TensorList
|
|
7
|
-
from ...core import
|
|
7
|
+
from ...core import OptimizationVars
|
|
8
8
|
from ..line_search.base_ls import LineSearchBase
|
|
9
9
|
|
|
10
10
|
_FloatOrTensor = float | torch.Tensor
|
|
@@ -47,12 +47,12 @@ class QuadraticInterpolation2Point(LineSearchBase):
|
|
|
47
47
|
self.min_dist = min_dist
|
|
48
48
|
|
|
49
49
|
@torch.no_grad
|
|
50
|
-
def _find_best_lr(self,
|
|
51
|
-
if
|
|
52
|
-
closure =
|
|
53
|
-
if
|
|
54
|
-
grad =
|
|
55
|
-
if grad is None: grad =
|
|
50
|
+
def _find_best_lr(self, vars: OptimizationVars, params: TensorList) -> float:
|
|
51
|
+
if vars.closure is None: raise ValueError('QuardaticLS requires closure')
|
|
52
|
+
closure = vars.closure
|
|
53
|
+
if vars.fx0 is None: vars.fx0 = vars.closure(False)
|
|
54
|
+
grad = vars.grad
|
|
55
|
+
if grad is None: grad = vars.ascent # in case we used FDM
|
|
56
56
|
if grad is None: raise ValueError('QuardaticLS requires gradients.')
|
|
57
57
|
|
|
58
58
|
params = self.get_params()
|
|
@@ -67,7 +67,7 @@ class QuadraticInterpolation2Point(LineSearchBase):
|
|
|
67
67
|
|
|
68
68
|
# make min_dist relative
|
|
69
69
|
min_dist = abs(lr) * self.min_dist
|
|
70
|
-
points = sorted([Point(0, _ensure_float(
|
|
70
|
+
points = sorted([Point(0, _ensure_float(vars.fx0), dfx0), Point(lr, _ensure_float(fx1))], key = lambda x: x.fx)
|
|
71
71
|
|
|
72
72
|
for i in range(self.max_evals):
|
|
73
73
|
# find new point
|