torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -33,8 +33,45 @@ class DirectSearch(Optimizer):
33
33
  solution.
34
34
 
35
35
  Args:
36
- params (_type_): _description_
37
- maxevals (_type_, optional): _description_. Defaults to DEFAULT_PARAMS['maxevals'].
36
+ params: iterable of parameters to optimize or dicts defining parameter groups.
37
+
38
+ rho: Choice of the forcing function.
39
+
40
+ sketch_dim: Reduced dimension to generate polling directions in.
41
+
42
+ sketch_type: Sketching technique to be used.
43
+
44
+ maxevals: Maximum number of calls to f performed by the algorithm.
45
+
46
+ poll_type: Type of polling directions generated in the reduced spaces.
47
+
48
+ alpha0: Initial value for the stepsize parameter.
49
+
50
+ alpha_max: Maximum value for the stepsize parameter.
51
+
52
+ alpha_min: Minimum value for the stepsize parameter.
53
+
54
+ gamma_inc: Increase factor for the stepsize update.
55
+
56
+ gamma_dec: Decrease factor for the stepsize update.
57
+
58
+ verbose:
59
+ Boolean indicating whether information should be displayed during an algorithmic run.
60
+
61
+ print_freq:
62
+ Value indicating how frequently information should be displayed.
63
+
64
+ use_stochastic_three_points:
65
+ Boolean indicating whether the specific stochastic three points method should be used.
66
+
67
+ poll_scale_prob: Probability of scaling the polling directions.
68
+
69
+ poll_scale_factor: Factor used to scale the polling directions.
70
+
71
+ rho_uses_normd:
72
+ Boolean indicating whether the forcing function should account for the norm of the direction.
73
+
74
+
38
75
  """
39
76
  def __init__(
40
77
  self,
@@ -27,18 +27,25 @@ class FcmaesWrapper(Optimizer):
27
27
  Note that this performs full minimization on each step, so only perform one step with this.
28
28
 
29
29
  Args:
30
- params (_type_): _description_
31
- lb (float): _description_
32
- ub (float): _description_
33
- optimizer (fcmaes.optimizer.Optimizer | None, optional): _description_. Defaults to None.
34
- max_evaluations (int | None, optional): _description_. Defaults to 50000.
35
- value_limit (float | None, optional): _description_. Defaults to np.inf.
36
- num_retries (int | None, optional): _description_. Defaults to 1.
37
- workers (int, optional): _description_. Defaults to 1.
38
- popsize (int | None, optional): _description_. Defaults to 31.
39
- capacity (int | None, optional): _description_. Defaults to 500.
40
- stop_fitness (float | None, optional): _description_. Defaults to -np.inf.
41
- statistic_num (int | None, optional): _description_. Defaults to 0.
30
+ params: iterable of parameters to optimize or dicts defining parameter groups.
31
+ lb (float): lower bounds, this can also be specified in param_groups.
32
+ ub (float): upper bounds, this can also be specified in param_groups.
33
+ optimizer (fcmaes.optimizer.Optimizer | None, optional):
34
+ optimizer to use. Default is a sequence of differential evolution and CMA-ES.
35
+ max_evaluations (int | None, optional):
36
+ Forced termination of all optimization runs after `max_evaluations` function evaluations.
37
+ Only used if optimizer is undefined, otherwise this setting is defined in the optimizer. Defaults to 50000.
38
+ value_limit (float | None, optional): Upper limit for optimized function values to be stored. Defaults to np.inf.
39
+ num_retries (int | None, optional): Number of optimization retries. Defaults to 1.
40
+ popsize (int | None, optional):
41
+ CMA-ES population size used for all CMA-ES runs.
42
+ Not used for differential evolution.
43
+ Ignored if parameter optimizer is defined. Defaults to 31.
44
+ capacity (int | None, optional): capacity of the evaluation store.. Defaults to 500.
45
+ stop_fitness (float | None, optional):
46
+ Limit for fitness value. optimization runs terminate if this value is reached. Defaults to -np.inf.
47
+ statistic_num (int | None, optional):
48
+ if > 0 stores the progress of the optimization. Defines the size of this store. Defaults to 0.
42
49
  """
43
50
  def __init__(
44
51
  self,
@@ -49,7 +56,7 @@ class FcmaesWrapper(Optimizer):
49
56
  max_evaluations: int | None = 50000,
50
57
  value_limit: float | None = np.inf,
51
58
  num_retries: int | None = 1,
52
- workers: int = 1,
59
+ # workers: int = 1,
53
60
  popsize: int | None = 31,
54
61
  capacity: int | None = 500,
55
62
  stop_fitness: float | None = -np.inf,
@@ -60,6 +67,7 @@ class FcmaesWrapper(Optimizer):
60
67
  kwargs = locals().copy()
61
68
  del kwargs['self'], kwargs['params'], kwargs['lb'], kwargs['ub'], kwargs['__class__']
62
69
  self._kwargs = kwargs
70
+ self._kwargs['workers'] = 1
63
71
 
64
72
  def _objective(self, x: np.ndarray, params: TensorList, closure) -> float:
65
73
  if self.raised: return np.inf
@@ -31,16 +31,15 @@ class MADS(Optimizer):
31
31
  solution.
32
32
 
33
33
  Args:
34
- params (params): params
35
- lb (float): lower bounds
36
- ub (float): upper bounds
34
+ params: iterable of parameters to optimize or dicts defining parameter groups.
35
+ lb (float): lower bounds, this can also be specified in param_groups.
36
+ ub (float): upper bounds, this can also be specified in param_groups.
37
37
  dp (float, optional): Initial poll size as percent of bounds. Defaults to 0.1.
38
38
  dm (float, optional): Initial mesh size as percent of bounds. Defaults to 0.01.
39
- dp_tol (_type_, optional): Minimum poll size stopping criteria. Defaults to -float('inf').
40
- nitermax (_type_, optional): Maximum objective function evaluations. Defaults to float('inf').
39
+ dp_tol (float, optional): Minimum poll size stopping criteria. Defaults to -float('inf').
40
+ nitermax (float, optional): Maximum objective function evaluations. Defaults to float('inf').
41
41
  displog (bool, optional): whether to show log. Defaults to False.
42
42
  savelog (bool, optional): whether to save log. Defaults to False.
43
-
44
43
  """
45
44
  def __init__(
46
45
  self,
@@ -29,6 +29,12 @@ class NevergradWrapper(Optimizer):
29
29
  use certain rule for first 50% of the steps, and then switch to another rule.
30
30
  This parameter doesn't actually limit the maximum number of steps!
31
31
  But it doesn't have to be exact. Defaults to None.
32
+ lb (float | None, optional):
33
+ lower bounds, this can also be specified in param_groups. Bounds are optional, however
34
+ some nevergrad algorithms will raise an exception of bounds are not specified.
35
+ ub (float, optional):
36
+ upper bounds, this can also be specified in param_groups. Bounds are optional, however
37
+ some nevergrad algorithms will raise an exception of bounds are not specified.
32
38
  mutable_sigma (bool, optional):
33
39
  nevergrad parameter, sets whether the mutation standard deviation must mutate as well
34
40
  (for mutation based algorithms). Defaults to False.
@@ -44,11 +50,20 @@ class NevergradWrapper(Optimizer):
44
50
  params,
45
51
  opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
46
52
  budget: int | None = None,
47
- mutable_sigma = False,
48
53
  lb: float | None = None,
49
54
  ub: float | None = None,
55
+ mutable_sigma = False,
50
56
  use_init = True,
51
57
  ):
58
+ """_summary_
59
+
60
+ Args:
61
+ params (_type_): _description_
62
+ opt_cls (type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]): _description_
63
+ budget (int | None, optional): _description_. Defaults to None.
64
+ mutable_sigma (bool, optional): _description_. Defaults to False.
65
+ use_init (bool, optional): _description_. Defaults to True.
66
+ """
52
67
  defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
53
68
  super().__init__(params, defaults)
54
69
  self.opt_cls = opt_cls
@@ -23,7 +23,7 @@ class OptunaSampler(Optimizer):
23
23
  Note - optuna is surprisingly scalable to large number of parameters (up to 10,000), despite literally requiring a for-loop because it only supports scalars. Default TPESampler is good for BBO. Maybe not for NNs...
24
24
 
25
25
  Args:
26
- params (_type_): parameters
26
+ params: iterable of parameters to optimize or dicts defining parameter groups.
27
27
  lb (float): lower bounds.
28
28
  ub (float): upper bounds.
29
29
  sampler (optuna.samplers.BaseSampler | type[optuna.samplers.BaseSampler] | None, optional): sampler. Defaults to None.
@@ -139,9 +139,11 @@ class ScipyMinimize(Optimizer):
139
139
 
140
140
  # make bounds
141
141
  lb, ub = self.group_vals('lb', 'ub', cls=list)
142
- bounds = []
143
- for p, l, u in zip(params, lb, ub):
144
- bounds.extend([(l, u)] * p.numel())
142
+ bounds = None
143
+ if any(b is not None for b in lb) or any(b is not None for b in ub):
144
+ bounds = []
145
+ for p, l, u in zip(params, lb, ub):
146
+ bounds.extend([(l, u)] * p.numel())
145
147
 
146
148
  if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
147
149
  x0 = x0.astype(np.float64) # those methods error without this
@@ -18,6 +18,6 @@ from .params import (
18
18
  _copy_param_groups,
19
19
  _make_param_groups,
20
20
  )
21
- from .python_tools import flatten, generic_eq, reduce_dim, unpack_dicts
22
- from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like
21
+ from .python_tools import flatten, generic_eq, generic_ne, reduce_dim, unpack_dicts
22
+ from .tensorlist import TensorList, as_tensorlist, Distributions, generic_clamp, generic_numel, generic_vector_norm, generic_zeros_like, generic_randn_like, generic_finfo_eps
23
23
  from .torch_tools import tofloat, tolist, tonumpy, totensor, vec_to_tensors, vec_to_tensors_, set_storage_
@@ -158,7 +158,7 @@ def hessian_mat(
158
158
  method="func",
159
159
  vectorize=False,
160
160
  outer_jacobian_strategy="reverse-mode",
161
- ):
161
+ ) -> torch.Tensor:
162
162
  """
163
163
  returns hessian matrix for parameters (as if they were flattened and concatenated into a vector).
164
164
 
@@ -190,7 +190,7 @@ def hessian_mat(
190
190
  return loss
191
191
 
192
192
  if method == 'func':
193
- return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph))
193
+ return torch.func.hessian(func)(torch.cat([p.view(-1) for p in params]).detach().requires_grad_(create_graph)) # pyright:ignore[reportReturnType]
194
194
 
195
195
  if method == 'autograd.functional':
196
196
  return torch.autograd.functional.hessian(
@@ -199,7 +199,7 @@ def hessian_mat(
199
199
  create_graph=create_graph,
200
200
  vectorize=vectorize,
201
201
  outer_jacobian_strategy=outer_jacobian_strategy,
202
- )
202
+ ) # pyright:ignore[reportReturnType]
203
203
  raise ValueError(method)
204
204
 
205
205
  def jvp(fn, params: Iterable[torch.Tensor], tangent: Iterable[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
@@ -2,4 +2,4 @@ from .matrix_funcs import inv_sqrt_2x2, eigvals_func, singular_vals_func, matrix
2
2
  from .orthogonalize import gram_schmidt
3
3
  from .qr import qr_householder
4
4
  from .svd import randomized_svd
5
- from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve
5
+ from .solve import cg, nystrom_approximation, nystrom_sketch_and_solve, steihaug_toint_cg
@@ -1,12 +1,41 @@
1
+ # pyright: reportArgumentType=false
1
2
  from collections.abc import Callable
2
- from typing import overload
3
+ from typing import Any, overload
4
+
3
5
  import torch
4
6
 
5
- from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_numel, generic_randn_like, generic_eq
7
+ from .. import (
8
+ TensorList,
9
+ generic_eq,
10
+ generic_finfo_eps,
11
+ generic_numel,
12
+ generic_randn_like,
13
+ generic_vector_norm,
14
+ generic_zeros_like,
15
+ )
16
+
17
+
18
+ def _make_A_mm_reg(A_mm: Callable | torch.Tensor, reg):
19
+ if callable(A_mm):
20
+ def A_mm_reg(x): # A_mm with regularization
21
+ Ax = A_mm(x)
22
+ if not generic_eq(reg, 0): Ax += x*reg
23
+ return Ax
24
+ return A_mm_reg
25
+
26
+ if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
27
+
28
+ def Ax_reg(x): # A_mm with regularization
29
+ if A_mm.ndim == 1: Ax = A_mm * x
30
+ else: Ax = A_mm @ x
31
+ if reg != 0: Ax += x*reg
32
+ return Ax
33
+ return Ax_reg
34
+
6
35
 
7
36
  @overload
8
37
  def cg(
9
- A_mm: Callable[[torch.Tensor], torch.Tensor],
38
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
10
39
  b: torch.Tensor,
11
40
  x0_: torch.Tensor | None = None,
12
41
  tol: float | None = 1e-4,
@@ -24,17 +53,17 @@ def cg(
24
53
  ) -> TensorList: ...
25
54
 
26
55
  def cg(
27
- A_mm: Callable,
56
+ A_mm: Callable | torch.Tensor,
28
57
  b: torch.Tensor | TensorList,
29
58
  x0_: torch.Tensor | TensorList | None = None,
30
59
  tol: float | None = 1e-4,
31
60
  maxiter: int | None = None,
32
61
  reg: float | list[float] | tuple[float] = 0,
33
62
  ):
34
- def A_mm_reg(x): # A_mm with regularization
35
- Ax = A_mm(x)
36
- if not generic_eq(reg, 0): Ax += x*reg
37
- return Ax
63
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
64
+ eps = generic_finfo_eps(b)
65
+
66
+ if tol is None: tol = eps
38
67
 
39
68
  if maxiter is None: maxiter = generic_numel(b)
40
69
  if x0_ is None: x0_ = generic_zeros_like(b)
@@ -44,9 +73,10 @@ def cg(
44
73
  p = residual.clone() # search direction
45
74
  r_norm = generic_vector_norm(residual)
46
75
  init_norm = r_norm
47
- if tol is not None and r_norm < tol: return x
76
+ if r_norm < tol: return x
48
77
  k = 0
49
78
 
79
+
50
80
  while True:
51
81
  Ap = A_mm_reg(p)
52
82
  step_size = (r_norm**2) / p.dot(Ap)
@@ -55,7 +85,7 @@ def cg(
55
85
  new_r_norm = generic_vector_norm(residual)
56
86
 
57
87
  k += 1
58
- if tol is not None and new_r_norm <= tol * init_norm: return x
88
+ if new_r_norm <= tol * init_norm: return x
59
89
  if k >= maxiter: return x
60
90
 
61
91
  beta = (new_r_norm**2) / (r_norm**2)
@@ -131,6 +161,8 @@ def nystrom_pcg(
131
161
  generator=generator,
132
162
  )
133
163
  lambd += reg
164
+ eps = torch.finfo(b.dtype).eps ** 2
165
+ if tol is None: tol = eps
134
166
 
135
167
  def A_mm_reg(x): # A_mm with regularization
136
168
  Ax = A_mm(x)
@@ -150,7 +182,7 @@ def nystrom_pcg(
150
182
  p = z.clone() # search direction
151
183
 
152
184
  init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
153
- if tol is not None and init_norm < tol: return x
185
+ if init_norm < tol: return x
154
186
  k = 0
155
187
  while True:
156
188
  Ap = A_mm_reg(p)
@@ -160,10 +192,217 @@ def nystrom_pcg(
160
192
  residual -= step_size * Ap
161
193
 
162
194
  k += 1
163
- if tol is not None and torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
195
+ if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
164
196
  if k >= maxiter: return x
165
197
 
166
198
  z = P_inv @ residual
167
199
  beta = residual.dot(z) / rz
168
200
  p = z + p*beta
169
201
 
202
+
203
+ def _safe_clip(x: torch.Tensor):
204
+ """makes sure scalar tensor x is not smaller than epsilon"""
205
+ assert x.numel() == 1, x.shape
206
+ eps = torch.finfo(x.dtype).eps
207
+ if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
208
+ return x
209
+
210
+ def _trust_tau(x,d,trust_region):
211
+ xx = x.dot(x)
212
+ xd = x.dot(d)
213
+ dd = _safe_clip(d.dot(d))
214
+
215
+ rad = (xd**2 - dd * (xx - trust_region**2)).clip(min=0).sqrt()
216
+ tau = (-xd + rad) / dd
217
+
218
+ return x + tau * d
219
+
220
+
221
+ @overload
222
+ def steihaug_toint_cg(
223
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
224
+ b: torch.Tensor,
225
+ trust_region: float,
226
+ x0: torch.Tensor | None = None,
227
+ tol: float | None = 1e-4,
228
+ maxiter: int | None = None,
229
+ reg: float = 0,
230
+ ) -> torch.Tensor: ...
231
+ @overload
232
+ def steihaug_toint_cg(
233
+ A_mm: Callable[[TensorList], TensorList],
234
+ b: TensorList,
235
+ trust_region: float,
236
+ x0: TensorList | None = None,
237
+ tol: float | None = 1e-4,
238
+ maxiter: int | None = None,
239
+ reg: float | list[float] | tuple[float] = 0,
240
+ ) -> TensorList: ...
241
+ def steihaug_toint_cg(
242
+ A_mm: Callable | torch.Tensor,
243
+ b: torch.Tensor | TensorList,
244
+ trust_region: float,
245
+ x0: torch.Tensor | TensorList | None = None,
246
+ tol: float | None = 1e-4,
247
+ maxiter: int | None = None,
248
+ reg: float | list[float] | tuple[float] = 0,
249
+ ):
250
+ """
251
+ Solution is bounded to have L2 norm no larger than :code:`trust_region`. If solution exceeds :code:`trust_region`, CG is terminated early, so it is also faster.
252
+ """
253
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
254
+
255
+ x = x0
256
+ if x is None: x = generic_zeros_like(b)
257
+ r = b
258
+ d = r.clone()
259
+
260
+ eps = generic_finfo_eps(b)**2
261
+ if tol is None: tol = eps
262
+
263
+ if generic_vector_norm(r) < tol:
264
+ return x
265
+
266
+ if maxiter is None:
267
+ maxiter = generic_numel(b)
268
+
269
+ for _ in range(maxiter):
270
+ Ad = A_mm_reg(d)
271
+
272
+ d_Ad = d.dot(Ad)
273
+ if d_Ad <= eps:
274
+ return _trust_tau(x, d, trust_region)
275
+
276
+ alpha = r.dot(r) / d_Ad
277
+ p_next = x + alpha * d
278
+
279
+ # check if the step exceeds the trust-region boundary
280
+ if generic_vector_norm(p_next) >= trust_region:
281
+ return _trust_tau(x, d, trust_region)
282
+
283
+ # update step, residual and direction
284
+ x = p_next
285
+ r_next = r - alpha * Ad
286
+
287
+ if generic_vector_norm(r_next) < tol:
288
+ return x
289
+
290
+ beta = r_next.dot(r_next) / r.dot(r)
291
+ d = r_next + beta * d
292
+ r = r_next
293
+
294
+ return x
295
+
296
+
297
+
298
+ # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
299
+ @overload
300
+ def minres(
301
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
302
+ b: torch.Tensor,
303
+ x0: torch.Tensor | None = None,
304
+ tol: float | None = 1e-4,
305
+ maxiter: int | None = None,
306
+ reg: float = 0,
307
+ npc_terminate: bool=True,
308
+ trust_region: float | None = None,
309
+ ) -> torch.Tensor: ...
310
+ @overload
311
+ def minres(
312
+ A_mm: Callable[[TensorList], TensorList],
313
+ b: TensorList,
314
+ x0: TensorList | None = None,
315
+ tol: float | None = 1e-4,
316
+ maxiter: int | None = None,
317
+ reg: float | list[float] | tuple[float] = 0,
318
+ npc_terminate: bool=True,
319
+ trust_region: float | None = None,
320
+ ) -> TensorList: ...
321
+ def minres(
322
+ A_mm,
323
+ b,
324
+ x0: torch.Tensor | TensorList | None = None,
325
+ tol: float | None = 1e-4,
326
+ maxiter: int | None = None,
327
+ reg: float | list[float] | tuple[float] = 0,
328
+ npc_terminate: bool=True,
329
+ trust_region: float | None = None,
330
+ ):
331
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
332
+ eps = generic_finfo_eps(b)
333
+ if tol is None: tol = eps**2
334
+
335
+ if maxiter is None: maxiter = generic_numel(b)
336
+ if x0 is None:
337
+ R = b
338
+ x0 = generic_zeros_like(b)
339
+ else:
340
+ R = b - A_mm_reg(x0)
341
+
342
+ X: Any = x0
343
+ beta = b_norm = generic_vector_norm(b)
344
+ if b_norm < eps**2:
345
+ return generic_zeros_like(b)
346
+
347
+
348
+ V = b / beta
349
+ V_prev = generic_zeros_like(b)
350
+ D = generic_zeros_like(b)
351
+ D_prev = generic_zeros_like(b)
352
+
353
+ c = -1
354
+ phi = tau = beta
355
+ s = delta1 = e = 0
356
+
357
+
358
+ for _ in range(maxiter):
359
+
360
+ P = A_mm_reg(V)
361
+ alpha = V.dot(P)
362
+ P -= beta*V_prev
363
+ P -= alpha*V
364
+ beta = generic_vector_norm(P)
365
+
366
+ delta2 = c*delta1 + s*alpha
367
+ gamma1 = s*delta1 - c*alpha
368
+ e_next = s*beta
369
+ delta1 = -c*beta
370
+
371
+ cgamma1 = c*gamma1
372
+ if trust_region is not None and cgamma1 >= 0:
373
+ if npc_terminate: return _trust_tau(X, R, trust_region)
374
+ return _trust_tau(X, D, trust_region)
375
+
376
+ if npc_terminate and cgamma1 >= 0:
377
+ return R
378
+
379
+ gamma2 = (gamma1**2 + beta**2)**(1/2)
380
+
381
+ if abs(gamma2) <= eps: # singular system
382
+ # c=0; s=1; tau=0
383
+ if trust_region is None: return X
384
+ return _trust_tau(X, D, trust_region)
385
+
386
+ c = gamma1 / gamma2
387
+ s = beta/gamma2
388
+ tau = c*phi
389
+ phi = s*phi
390
+
391
+ D_prev = D
392
+ D = (V - delta2*D - e*D_prev) / gamma2
393
+ e = e_next
394
+ X = X + tau*D
395
+
396
+ if trust_region is not None:
397
+ if generic_vector_norm(X) > trust_region:
398
+ return _trust_tau(X, D, trust_region)
399
+
400
+ if (abs(beta) < eps) or (phi / b_norm <= tol):
401
+ # R = zeros(R)
402
+ return X
403
+
404
+ V_prev = V
405
+ V = P/beta
406
+ R = s**2*R - phi*c*V
407
+
408
+ return X
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
129
129
  return self.__class__(fn(i, *args, **kwargs) for i in self)
130
130
 
131
131
  def clamp(self, min=None, max=None):
132
+ return self.zipmap_args(_clamp, min, max)
133
+ def clip(self, min=None, max=None):
132
134
  return self.zipmap_args(_clamp, min, max)
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
31
31
  return all(i==y for i in x)
32
32
  return all(i==j for i,j in zip(x,y))
33
33
 
34
+ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
35
+ """generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
36
+ if isinstance(x, (int,float)):
37
+ if isinstance(y, (int,float)): return x!=y
38
+ return any(i!=x for i in y)
39
+ if isinstance(y, (int,float)):
40
+ return any(i!=y for i in x)
41
+ return any(i!=j for i,j in zip(x,y))
42
+
43
+
34
44
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
35
45
  """If `other` is list/tuple, applies `fn` to self zipped with `other`.
36
46
  Otherwise applies `fn` to this sequence and `other`.