torchzero 0.3.13__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,21 +13,12 @@ from ..trust_region.trust_region import default_radius
13
13
  class NewtonCG(Module):
14
14
  """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
15
15
 
16
- This optimizer implements Newton's method using a matrix-free conjugate
17
- gradient (CG) or a minimal-residual (MINRES) solver to approximate the search direction. Instead of
18
- forming the full Hessian matrix, it only requires Hessian-vector products
19
- (HVPs). These can be calculated efficiently using automatic
20
- differentiation or approximated using finite differences.
16
+ Notes:
17
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
21
18
 
22
- .. note::
23
- In most cases NewtonCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
19
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
24
20
 
25
- .. note::
26
- This module requires the a closure passed to the optimizer step,
27
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
28
- The closure must accept a ``backward`` argument (refer to documentation).
29
-
30
- .. warning::
21
+ Warning:
31
22
  CG may fail if hessian is not positive-definite.
32
23
 
33
24
  Args:
@@ -66,26 +57,24 @@ class NewtonCG(Module):
66
57
  NewtonCG will attempt to apply preconditioning to the output of this module.
67
58
 
68
59
  Examples:
69
- Newton-CG with a backtracking line search:
70
-
71
- .. code-block:: python
72
-
73
- opt = tz.Modular(
74
- model.parameters(),
75
- tz.m.NewtonCG(),
76
- tz.m.Backtracking()
77
- )
78
-
79
- Truncated Newton method (useful for large-scale problems):
80
-
81
- .. code-block:: python
82
-
83
- opt = tz.Modular(
84
- model.parameters(),
85
- tz.m.NewtonCG(maxiter=10, warm_start=True),
86
- tz.m.Backtracking()
87
- )
88
-
60
+ Newton-CG with a backtracking line search:
61
+
62
+ ```python
63
+ opt = tz.Modular(
64
+ model.parameters(),
65
+ tz.m.NewtonCG(),
66
+ tz.m.Backtracking()
67
+ )
68
+ ```
69
+
70
+ Truncated Newton method (useful for large-scale problems):
71
+ ```
72
+ opt = tz.Modular(
73
+ model.parameters(),
74
+ tz.m.NewtonCG(maxiter=10),
75
+ tz.m.Backtracking()
76
+ )
77
+ ```
89
78
 
90
79
  """
91
80
  def __init__(
@@ -95,7 +84,7 @@ class NewtonCG(Module):
95
84
  reg: float = 1e-8,
96
85
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
97
86
  solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
98
- h: float = 1e-3,
87
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
99
88
  miniter:int = 1,
100
89
  warm_start=False,
101
90
  inner: Chainable | None = None,
@@ -187,96 +176,95 @@ class NewtonCG(Module):
187
176
 
188
177
 
189
178
  class NewtonCGSteihaug(Module):
190
- """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
179
+ """Newton's method with trust region and a matrix-free Steihaug-Toint conjugate gradient solver.
191
180
 
192
- This optimizer implements Newton's method using a matrix-free conjugate
193
- gradient (CG) solver to approximate the search direction. Instead of
194
- forming the full Hessian matrix, it only requires Hessian-vector products
195
- (HVPs). These can be calculated efficiently using automatic
196
- differentiation or approximated using finite differences.
181
+ Notes:
182
+ * In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
197
183
 
198
- .. note::
199
- In most cases NewtonCGSteihaug should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
200
-
201
- .. note::
202
- This module requires the a closure passed to the optimizer step,
203
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
204
- The closure must accept a ``backward`` argument (refer to documentation).
205
-
206
- .. warning::
207
- CG may fail if hessian is not positive-definite.
184
+ * This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
208
185
 
209
186
  Args:
210
- maxiter (int | None, optional):
211
- Maximum number of iterations for the conjugate gradient solver.
212
- By default, this is set to the number of dimensions in the
213
- objective function, which is the theoretical upper bound for CG
214
- convergence. Setting this to a smaller value (truncated Newton)
215
- can still generate good search directions. Defaults to None.
216
187
  eta (float, optional):
217
- whenever actual to predicted loss reduction ratio is larger than this, a step is accepted.
218
- nplus (float, optional):
219
- trust region multiplier on successful steps.
220
- nminus (float, optional):
221
- trust region multiplier on unsuccessful steps.
222
- init (float, optional): initial trust region.
188
+ if ratio of actual to predicted rediction is larger than this, step is accepted. Defaults to 0.0.
189
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
190
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
191
+ rho_good (float, optional):
192
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
193
+ rho_bad (float, optional):
194
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
195
+ init (float, optional): Initial trust region value. Defaults to 1.
196
+ max_attempts (max_attempts, optional):
197
+ maximum number of trust radius reductions per step. A zero update vector is returned when
198
+ this limit is exceeded. Defaults to 10.
199
+ max_history (int, optional):
200
+ CG will store this many intermediate solutions, reusing them when trust radius is reduced
201
+ instead of re-running CG. Each solution storage requires 2N memory. Defaults to 100.
202
+ boundary_tol (float | None, optional):
203
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
204
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
205
+
206
+ maxiter (int | None, optional):
207
+ maximum number of CG iterations per step. Each iteration requies one backward pass if `hvp_method="forward"`, two otherwise. Defaults to None.
208
+ miniter (int, optional):
209
+ minimal number of CG iterations. This prevents making no progress
223
210
  tol (float, optional):
224
- Relative tolerance for the conjugate gradient solver to determine
225
- convergence. Defaults to 1e-4.
226
- reg (float, optional):
227
- Regularization parameter (damping) added to the Hessian diagonal.
228
- This helps ensure the system is positive-definite. Defaults to 1e-8.
211
+ terminates CG when norm of the residual is less than this value. Defaults to 1e-8.
212
+ when initial guess is below tolerance. Defaults to 1.
213
+ reg (float, optional): hessian regularization. Defaults to 1e-8.
214
+ solver (str, optional): solver, "cg" or "minres". "cg" is recommended. Defaults to 'cg'.
215
+ adapt_tol (bool, optional):
216
+ if True, whenever trust radius collapses to smallest representable number,
217
+ the tolerance is multiplied by 0.1. Defaults to True.
218
+ npc_terminate (bool, optional):
219
+ whether to terminate CG/MINRES whenever negative curvature is detected. Defaults to False.
220
+
229
221
  hvp_method (str, optional):
230
- Determines how Hessian-vector products are evaluated.
222
+ either "forward" to use forward formula which requires one backward pass per Hvp, or "central" to use a more accurate central formula which requires two backward passes. "forward" is usually accurate enough. Defaults to "forward".
223
+ h (float, optional): finite difference step size. Defaults to 1e-3.
231
224
 
232
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
233
- This requires creating a graph for the gradient.
234
- - ``"forward"``: Use a forward finite difference formula to
235
- approximate the HVP. This requires one extra gradient evaluation.
236
- - ``"central"``: Use a central finite difference formula for a
237
- more accurate HVP approximation. This requires two extra
238
- gradient evaluations.
239
- Defaults to "autograd".
240
- h (float, optional):
241
- The step size for finite differences if :code:`hvp_method` is
242
- ``"forward"`` or ``"central"``. Defaults to 1e-3.
243
225
  inner (Chainable | None, optional):
244
- NewtonCG will attempt to apply preconditioning to the output of this module.
226
+ applies preconditioning to output of this module. Defaults to None.
245
227
 
246
- Examples:
247
- Trust-region Newton-CG:
248
-
249
- .. code-block:: python
228
+ ### Examples:
229
+ Trust-region Newton-CG:
250
230
 
251
- opt = tz.Modular(
252
- model.parameters(),
253
- tz.m.NewtonCGSteihaug(),
254
- )
231
+ ```python
232
+ opt = tz.Modular(
233
+ model.parameters(),
234
+ tz.m.NewtonCGSteihaug(),
235
+ )
236
+ ```
255
237
 
256
- Reference:
238
+ ### Reference:
257
239
  Steihaug, Trond. "The conjugate gradient method and trust regions in large scale optimization." SIAM Journal on Numerical Analysis 20.3 (1983): 626-637.
258
240
  """
259
241
  def __init__(
260
242
  self,
261
- maxiter: int | None = None,
243
+ # trust region settings
262
244
  eta: float= 0.0,
263
245
  nplus: float = 3.5,
264
246
  nminus: float = 0.25,
265
247
  rho_good: float = 0.99,
266
248
  rho_bad: float = 1e-4,
267
249
  init: float = 1,
268
- tol: float = 1e-8,
269
- reg: float = 1e-8,
270
- hvp_method: Literal["forward", "central"] = "forward",
271
- solver: Literal['cg', "minres"] = 'cg',
272
- h: float = 1e-3,
273
250
  max_attempts: int = 100,
274
251
  max_history: int = 100,
275
- boundary_tol: float = 1e-1,
252
+ boundary_tol: float = 1e-6, # tuned
253
+
254
+ # cg settings
255
+ maxiter: int | None = None,
276
256
  miniter: int = 1,
277
- rms_beta: float | None = None,
257
+ tol: float = 1e-8,
258
+ reg: float = 1e-8,
259
+ solver: Literal['cg', "minres"] = 'cg',
278
260
  adapt_tol: bool = True,
279
261
  npc_terminate: bool = False,
262
+
263
+ # hvp settings
264
+ hvp_method: Literal["forward", "central"] = "central",
265
+ h: float = 1e-3, # tuned 1e-4 or 1e-3
266
+
267
+ # inner
280
268
  inner: Chainable | None = None,
281
269
  ):
282
270
  defaults = locals().copy()
@@ -336,19 +324,8 @@ class NewtonCGSteihaug(Module):
336
324
  raise ValueError(hvp_method)
337
325
 
338
326
 
339
- # ------------------------- update RMS preconditioner ------------------------ #
340
- b = var.get_update()
341
- P_mm = None
342
- rms_beta = self.defaults["rms_beta"]
343
- if rms_beta is not None:
344
- exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
345
- exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
346
- exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
347
- def _P_mm(x):
348
- return x / exp_avg_sq_sqrt
349
- P_mm = _P_mm
350
-
351
327
  # -------------------------------- inner step -------------------------------- #
328
+ b = var.get_update()
352
329
  if 'inner' in self.children:
353
330
  b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
354
331
  b = as_tensorlist(b)
@@ -392,7 +369,6 @@ class NewtonCGSteihaug(Module):
392
369
  miniter=miniter,
393
370
  npc_terminate=npc_terminate,
394
371
  history_size=max_history,
395
- P_mm=P_mm,
396
372
  )
397
373
 
398
374
  elif solver == 'minres':
@@ -14,13 +14,13 @@ class LevenbergMarquardt(TrustRegionBase):
14
14
  hess_module (Module | None, optional):
15
15
  A module that maintains a hessian approximation (not hessian inverse!).
16
16
  This includes all full-matrix quasi-newton methods, ``tz.m.Newton`` and ``tz.m.GaussNewton``.
17
- When using quasi-newton methods, set `inverse=False` when constructing them.
17
+ When using quasi-newton methods, set ``inverse=False`` when constructing them.
18
18
  y (float, optional):
19
19
  when ``y=0``, identity matrix is added to hessian, when ``y=1``, diagonal of the hessian approximation
20
20
  is added. Values between interpolate. This should only be used with Gauss-Newton. Defaults to 0.
21
21
  eta (float, optional):
22
22
  if ratio of actual to predicted rediction is larger than this, step is accepted.
23
- When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
23
+ When ``hess_module`` is ``Newton`` or ``GaussNewton``, this can be set to 0. Defaults to 0.15.
24
24
  nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
25
25
  nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
26
26
  rho_good (float, optional):
@@ -60,17 +60,19 @@ class TrustCG(TrustRegionBase):
60
60
  nminus: float = 0.25,
61
61
  rho_good: float = 0.99,
62
62
  rho_bad: float = 1e-4,
63
- boundary_tol: float | None = 1e-1,
63
+ boundary_tol: float | None = 1e-6, # tuned
64
64
  init: float = 1,
65
65
  max_attempts: int = 10,
66
66
  radius_strategy: _RadiusStrategy | _RADIUS_KEYS = 'default',
67
67
  reg: float = 0,
68
- cg_tol: float = 1e-4,
68
+ maxiter: int | None = None,
69
+ miniter: int = 1,
70
+ cg_tol: float = 1e-8,
69
71
  prefer_exact: bool = True,
70
72
  update_freq: int = 1,
71
73
  inner: Chainable | None = None,
72
74
  ):
73
- defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol)
75
+ defaults = dict(reg=reg, prefer_exact=prefer_exact, cg_tol=cg_tol, maxiter=maxiter, miniter=miniter)
74
76
  super().__init__(
75
77
  defaults=defaults,
76
78
  hess_module=hess_module,
@@ -93,5 +95,5 @@ class TrustCG(TrustRegionBase):
93
95
  if settings['prefer_exact'] and isinstance(H, linear_operator.ScaledIdentity):
94
96
  return H.solve_bounded(g, radius)
95
97
 
96
- x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], tol=settings["cg_tol"])
98
+ x, _ = cg(H.matvec, g, trust_radius=radius, reg=settings['reg'], maxiter=settings["maxiter"], miniter=settings["miniter"], tol=settings["cg_tol"])
97
99
  return x
@@ -1 +1 @@
1
- from .cd import CD, CCD, CCDLS
1
+ from .cd import CD
@@ -9,11 +9,10 @@ import torch
9
9
 
10
10
  from ...core import Module
11
11
  from ...utils import NumberList, TensorList
12
- from ..line_search.adaptive import adaptive_tracking
13
12
 
14
13
  class CD(Module):
15
14
  """Coordinate descent. Proposes a descent direction along a single coordinate.
16
- You can then put a line search such as ``tz.m.ScipyMinimizeScalar``, or a fixed step size.
15
+ A line search such as ``tz.m.ScipyMinimizeScalar(maxiter=8)`` or a fixed step size can be used after this.
17
16
 
18
17
  Args:
19
18
  h (float, optional): finite difference step size. Defaults to 1e-3.
@@ -121,239 +120,3 @@ class CD(Module):
121
120
  var.update = update
122
121
  return var
123
122
 
124
-
125
- def _icd_get_idx(self: Module, params: TensorList):
126
- ndim = params.global_numel()
127
- igrad = self.get_state(params, "igrad", cls=TensorList)
128
-
129
- # -------------------------- 1st n steps fill igrad -------------------------- #
130
- index = self.global_state.get('index', 0)
131
- self.global_state['index'] = index + 1
132
- if index < ndim:
133
- return index, igrad
134
-
135
- # ------------------ select randomly weighted by magnitudes ------------------ #
136
- igrad_abs = igrad.abs()
137
- gmin = igrad_abs.global_min()
138
- gmax = igrad_abs.global_max()
139
-
140
- pmin, pmax, pow = self.get_settings(params, "pmin", "pmax", "pow", cls=NumberList)
141
-
142
- p: TensorList = ((igrad_abs - gmin) / (gmax - gmin)) ** pow # pyright:ignore[reportOperatorIssue]
143
- p.mul_(pmax-pmin).add_(pmin)
144
-
145
- if 'np_gen' not in self.global_state:
146
- self.global_state['np_gen'] = np.random.default_rng(0)
147
- np_gen = self.global_state['np_gen']
148
-
149
- p_vec = p.to_vec()
150
- p_sum = p_vec.sum()
151
- if p_sum > 1e-12:
152
- return np_gen.choice(ndim, p=p_vec.div_(p_sum).numpy(force=True)), igrad
153
-
154
- # --------------------- sum is too small, do cycle again --------------------- #
155
- self.global_state.clear()
156
- self.clear_state_keys('h_vec', 'igrad', 'alphas')
157
-
158
- if 'generator' not in self.global_state:
159
- self.global_state['generator'] = random.Random(0)
160
- generator = self.global_state['generator']
161
- return generator.randrange(0, p_vec.numel()), igrad
162
-
163
- class CCD(Module):
164
- """Cumulative coordinate descent. This updates one gradient coordinate at a time and accumulates it
165
- to the update direction. The coordinate updated is random weighted by magnitudes of current update direction.
166
- As update direction ceases to be a descent direction due to old accumulated coordinates, it is decayed.
167
-
168
- Args:
169
- pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
170
- pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
171
- pow (int, optional): power transform to probabilities. Defaults to 2.
172
- decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
173
- decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
174
- nplus (float, optional): step size increase on successful steps. Defaults to 1.5.
175
- nminus (float, optional): step size increase on unsuccessful steps. Defaults to 0.75.
176
- """
177
- def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay:float=0.8, decay2:float=0.2, nplus=1.5, nminus=0.75):
178
-
179
- defaults = dict(pmin=pmin, pmax=pmax, pow=pow, decay=decay, decay2=decay2, nplus=nplus, nminus=nminus)
180
- super().__init__(defaults)
181
-
182
- @torch.no_grad
183
- def step(self, var):
184
- closure = var.closure
185
- if closure is None:
186
- raise RuntimeError("CD requires closure")
187
-
188
- params = TensorList(var.params)
189
- p_prev = self.get_state(params, "p_prev", init=params, cls=TensorList)
190
-
191
- f_0 = var.get_loss(False)
192
- step_size = self.global_state.get('step_size', 1)
193
-
194
- # ------------------------ hard reset on infinite loss ----------------------- #
195
- if not math.isfinite(f_0):
196
- del self.global_state['f_prev']
197
- var.update = params - p_prev
198
- self.global_state.clear()
199
- self.state.clear()
200
- self.global_state["step_size"] = step_size / 10
201
- return var
202
-
203
- # ---------------------------- soft reset if stuck --------------------------- #
204
- if "igrad" in self.state[params[0]]:
205
- n_bad = self.global_state.get('n_bad', 0)
206
-
207
- f_prev = self.global_state.get("f_prev", None)
208
- if f_prev is not None:
209
-
210
- decay2 = self.defaults["decay2"]
211
- decay = self.global_state.get("decay", self.defaults["decay"])
212
-
213
- if f_0 >= f_prev:
214
-
215
- igrad = self.get_state(params, "igrad", cls=TensorList)
216
- del self.global_state['f_prev']
217
-
218
- # undo previous update
219
- var.update = params - p_prev
220
-
221
- # increment n_bad
222
- self.global_state['n_bad'] = n_bad + 1
223
-
224
- # decay step size
225
- self.global_state['step_size'] = step_size * self.defaults["nminus"]
226
-
227
- # soft reset
228
- if n_bad > 0:
229
- igrad *= decay
230
- self.global_state["decay"] = decay*decay2
231
- self.global_state['n_bad'] = 0
232
-
233
- return var
234
-
235
- else:
236
- # increase step size and reset n_bad
237
- self.global_state['step_size'] = step_size * self.defaults["nplus"]
238
- self.global_state['n_bad'] = 0
239
- self.global_state["decay"] = self.defaults["decay"]
240
-
241
- self.global_state['f_prev'] = float(f_0)
242
-
243
- # ------------------------------ determine index ----------------------------- #
244
- idx, igrad = _icd_get_idx(self, params)
245
-
246
- # -------------------------- find descent direction -------------------------- #
247
- h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
248
- h = float(h_vec.flat_get(idx))
249
-
250
- params.flat_set_lambda_(idx, lambda x: x + h)
251
- f_p = closure(False)
252
-
253
- params.flat_set_lambda_(idx, lambda x: x - 2*h)
254
- f_n = closure(False)
255
- params.flat_set_lambda_(idx, lambda x: x + h)
256
-
257
- # ---------------------------------- adapt h --------------------------------- #
258
- if f_0 <= f_p and f_0 <= f_n:
259
- h_vec.flat_set_lambda_(idx, lambda x: max(x/2, 1e-10))
260
- else:
261
- if abs(f_0 - f_n) < 1e-12 or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
262
- h_vec.flat_set_lambda_(idx, lambda x: min(x*2, 1e10))
263
-
264
- # ------------------------------- update igrad ------------------------------- #
265
- if f_0 < f_p and f_0 < f_n: alpha = 0
266
- else: alpha = (f_p - f_n) / (2*h)
267
-
268
- igrad.flat_set_(idx, alpha)
269
-
270
- # ----------------------------- create the update ---------------------------- #
271
- var.update = igrad * step_size
272
- p_prev.copy_(params)
273
- return var
274
-
275
-
276
- class CCDLS(Module):
277
- """CCD with line search instead of adaptive step size.
278
-
279
- Args:
280
- pmin (float, optional): multiplier to probability of picking the lowest magnitude gradient. Defaults to 0.1.
281
- pmax (float, optional): multiplier to probability of picking the largest magnitude gradient. Defaults to 1.0.
282
- pow (int, optional): power transform to probabilities. Defaults to 2.
283
- decay (float, optional): accumulated gradient decay on failed step. Defaults to 0.5.
284
- decay2 (float, optional): decay multiplier decay on failed step. Defaults to 0.25.
285
- maxiter (int, optional): max number of line search iterations.
286
- """
287
- def __init__(self, pmin=0.1, pmax=1.0, pow=2, decay=0.8, decay2=0.2, maxiter=10, ):
288
- defaults = dict(pmin=pmin, pmax=pmax, pow=pow, maxiter=maxiter, decay=decay, decay2=decay2)
289
- super().__init__(defaults)
290
-
291
- @torch.no_grad
292
- def step(self, var):
293
- closure = var.closure
294
- if closure is None:
295
- raise RuntimeError("CD requires closure")
296
-
297
- params = TensorList(var.params)
298
- finfo = torch.finfo(params[0].dtype)
299
- f_0 = var.get_loss(False)
300
-
301
- # ------------------------------ determine index ----------------------------- #
302
- idx, igrad = _icd_get_idx(self, params)
303
-
304
- # -------------------------- find descent direction -------------------------- #
305
- h_vec = self.get_state(params, 'h_vec', init=lambda x: torch.full_like(x, 1e-3), cls=TensorList)
306
- h = float(h_vec.flat_get(idx))
307
-
308
- params.flat_set_lambda_(idx, lambda x: x + h)
309
- f_p = closure(False)
310
-
311
- params.flat_set_lambda_(idx, lambda x: x - 2*h)
312
- f_n = closure(False)
313
- params.flat_set_lambda_(idx, lambda x: x + h)
314
-
315
- # ---------------------------------- adapt h --------------------------------- #
316
- if f_0 <= f_p and f_0 <= f_n:
317
- h_vec.flat_set_lambda_(idx, lambda x: max(x/2, finfo.tiny * 2))
318
- else:
319
- # here eps, not tiny
320
- if abs(f_0 - f_n) < finfo.eps or abs((f_p - f_0) / (f_0 - f_n) - 1) < 1e-2:
321
- h_vec.flat_set_lambda_(idx, lambda x: min(x*2, finfo.max / 2))
322
-
323
- # ------------------------------- update igrad ------------------------------- #
324
- if f_0 < f_p and f_0 < f_n: alpha = 0
325
- else: alpha = (f_p - f_n) / (2*h)
326
-
327
- igrad.flat_set_(idx, alpha)
328
-
329
- # -------------------------------- line search ------------------------------- #
330
- x0 = params.clone()
331
- def f(a):
332
- params.sub_(igrad, alpha=a)
333
- loss = closure(False)
334
- params.copy_(x0)
335
- return loss
336
-
337
- a_prev = self.global_state.get('a_prev', 1)
338
- a, f_a, niter = adaptive_tracking(f, a_prev, maxiter=self.defaults['maxiter'], f_0=f_0)
339
- if (a is None) or (not math.isfinite(a)) or (not math.isfinite(f_a)):
340
- a = 0
341
-
342
- # -------------------------------- set a_prev -------------------------------- #
343
- decay2 = self.defaults["decay2"]
344
- decay = self.global_state.get("decay", self.defaults["decay"])
345
-
346
- if abs(a) > finfo.tiny * 2:
347
- assert f_a < f_0
348
- self.global_state['a_prev'] = max(min(a, finfo.max / 2), finfo.tiny * 2)
349
- self.global_state["decay"] = self.defaults["decay"]
350
-
351
- # ---------------------------- soft reset on fail ---------------------------- #
352
- else:
353
- igrad *= decay
354
- self.global_state["decay"] = decay*decay2
355
- self.global_state['a_prev'] = a_prev / 2
356
-
357
- # -------------------------------- set update -------------------------------- #
358
- var.update = igrad * a
359
- return var
@@ -110,7 +110,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
110
110
  for i, param in enumerate(params):
111
111
  s = state[param]
112
112
  if key not in s:
113
- if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
113
+ if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
114
114
  s[key] = _make_initial_state_value(param, init, i)
115
115
  values.append(s[key])
116
116
  return values
@@ -125,7 +125,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
125
125
  s = state[param]
126
126
  for k_i, key in enumerate(keys):
127
127
  if key not in s:
128
- if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
128
+ if must_exist: raise KeyError(f"Key `{key}` doesn't exist in state with keys {tuple(s.keys())}")
129
129
  k_init = init[k_i] if isinstance(init, (list,tuple)) else init
130
130
  s[key] = _make_initial_state_value(param, k_init, i)
131
131
  values[k_i].append(s[key])
@@ -67,3 +67,4 @@ def safe_dict_update_(d1_:dict, d2:dict):
67
67
  inter = set(d1_.keys()).intersection(d2.keys())
68
68
  if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
69
  d1_.update(d2)
70
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.13
3
+ Version: 0.3.14
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero