torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,41 @@
1
+ # pyright: reportArgumentType=false
1
2
  from collections.abc import Callable
2
- from typing import overload
3
+ from typing import Any, overload
4
+
3
5
  import torch
4
6
 
5
- from .. import TensorList, generic_zeros_like, generic_vector_norm, generic_numel, generic_randn_like, generic_eq
7
+ from .. import (
8
+ TensorList,
9
+ generic_eq,
10
+ generic_finfo_eps,
11
+ generic_numel,
12
+ generic_randn_like,
13
+ generic_vector_norm,
14
+ generic_zeros_like,
15
+ )
16
+
17
+
18
+ def _make_A_mm_reg(A_mm: Callable | torch.Tensor, reg):
19
+ if callable(A_mm):
20
+ def A_mm_reg(x): # A_mm with regularization
21
+ Ax = A_mm(x)
22
+ if not generic_eq(reg, 0): Ax += x*reg
23
+ return Ax
24
+ return A_mm_reg
25
+
26
+ if not isinstance(A_mm, torch.Tensor): raise TypeError(type(A_mm))
27
+
28
+ def Ax_reg(x): # A_mm with regularization
29
+ if A_mm.ndim == 1: Ax = A_mm * x
30
+ else: Ax = A_mm @ x
31
+ if reg != 0: Ax += x*reg
32
+ return Ax
33
+ return Ax_reg
34
+
6
35
 
7
36
  @overload
8
37
  def cg(
9
- A_mm: Callable[[torch.Tensor], torch.Tensor],
38
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
10
39
  b: torch.Tensor,
11
40
  x0_: torch.Tensor | None = None,
12
41
  tol: float | None = 1e-4,
@@ -24,17 +53,17 @@ def cg(
24
53
  ) -> TensorList: ...
25
54
 
26
55
  def cg(
27
- A_mm: Callable,
56
+ A_mm: Callable | torch.Tensor,
28
57
  b: torch.Tensor | TensorList,
29
58
  x0_: torch.Tensor | TensorList | None = None,
30
59
  tol: float | None = 1e-4,
31
60
  maxiter: int | None = None,
32
61
  reg: float | list[float] | tuple[float] = 0,
33
62
  ):
34
- def A_mm_reg(x): # A_mm with regularization
35
- Ax = A_mm(x)
36
- if not generic_eq(reg, 0): Ax += x*reg
37
- return Ax
63
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
64
+ eps = generic_finfo_eps(b)
65
+
66
+ if tol is None: tol = eps
38
67
 
39
68
  if maxiter is None: maxiter = generic_numel(b)
40
69
  if x0_ is None: x0_ = generic_zeros_like(b)
@@ -44,9 +73,10 @@ def cg(
44
73
  p = residual.clone() # search direction
45
74
  r_norm = generic_vector_norm(residual)
46
75
  init_norm = r_norm
47
- if tol is not None and r_norm < tol: return x
76
+ if r_norm < tol: return x
48
77
  k = 0
49
78
 
79
+
50
80
  while True:
51
81
  Ap = A_mm_reg(p)
52
82
  step_size = (r_norm**2) / p.dot(Ap)
@@ -55,7 +85,7 @@ def cg(
55
85
  new_r_norm = generic_vector_norm(residual)
56
86
 
57
87
  k += 1
58
- if tol is not None and new_r_norm <= tol * init_norm: return x
88
+ if new_r_norm <= tol * init_norm: return x
59
89
  if k >= maxiter: return x
60
90
 
61
91
  beta = (new_r_norm**2) / (r_norm**2)
@@ -131,6 +161,8 @@ def nystrom_pcg(
131
161
  generator=generator,
132
162
  )
133
163
  lambd += reg
164
+ eps = torch.finfo(b.dtype).eps ** 2
165
+ if tol is None: tol = eps
134
166
 
135
167
  def A_mm_reg(x): # A_mm with regularization
136
168
  Ax = A_mm(x)
@@ -150,7 +182,7 @@ def nystrom_pcg(
150
182
  p = z.clone() # search direction
151
183
 
152
184
  init_norm = torch.linalg.vector_norm(residual) # pylint:disable=not-callable
153
- if tol is not None and init_norm < tol: return x
185
+ if init_norm < tol: return x
154
186
  k = 0
155
187
  while True:
156
188
  Ap = A_mm_reg(p)
@@ -160,10 +192,217 @@ def nystrom_pcg(
160
192
  residual -= step_size * Ap
161
193
 
162
194
  k += 1
163
- if tol is not None and torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
195
+ if torch.linalg.vector_norm(residual) <= tol * init_norm: return x # pylint:disable=not-callable
164
196
  if k >= maxiter: return x
165
197
 
166
198
  z = P_inv @ residual
167
199
  beta = residual.dot(z) / rz
168
200
  p = z + p*beta
169
201
 
202
+
203
+ def _safe_clip(x: torch.Tensor):
204
+ """makes sure scalar tensor x is not smaller than epsilon"""
205
+ assert x.numel() == 1, x.shape
206
+ eps = torch.finfo(x.dtype).eps
207
+ if x.abs() < eps: return x.new_full(x.size(), eps).copysign(x)
208
+ return x
209
+
210
+ def _trust_tau(x,d,trust_region):
211
+ xx = x.dot(x)
212
+ xd = x.dot(d)
213
+ dd = _safe_clip(d.dot(d))
214
+
215
+ rad = (xd**2 - dd * (xx - trust_region**2)).clip(min=0).sqrt()
216
+ tau = (-xd + rad) / dd
217
+
218
+ return x + tau * d
219
+
220
+
221
+ @overload
222
+ def steihaug_toint_cg(
223
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
224
+ b: torch.Tensor,
225
+ trust_region: float,
226
+ x0: torch.Tensor | None = None,
227
+ tol: float | None = 1e-4,
228
+ maxiter: int | None = None,
229
+ reg: float = 0,
230
+ ) -> torch.Tensor: ...
231
+ @overload
232
+ def steihaug_toint_cg(
233
+ A_mm: Callable[[TensorList], TensorList],
234
+ b: TensorList,
235
+ trust_region: float,
236
+ x0: TensorList | None = None,
237
+ tol: float | None = 1e-4,
238
+ maxiter: int | None = None,
239
+ reg: float | list[float] | tuple[float] = 0,
240
+ ) -> TensorList: ...
241
+ def steihaug_toint_cg(
242
+ A_mm: Callable | torch.Tensor,
243
+ b: torch.Tensor | TensorList,
244
+ trust_region: float,
245
+ x0: torch.Tensor | TensorList | None = None,
246
+ tol: float | None = 1e-4,
247
+ maxiter: int | None = None,
248
+ reg: float | list[float] | tuple[float] = 0,
249
+ ):
250
+ """
251
+ Solution is bounded to have L2 norm no larger than :code:`trust_region`. If solution exceeds :code:`trust_region`, CG is terminated early, so it is also faster.
252
+ """
253
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
254
+
255
+ x = x0
256
+ if x is None: x = generic_zeros_like(b)
257
+ r = b
258
+ d = r.clone()
259
+
260
+ eps = generic_finfo_eps(b)**2
261
+ if tol is None: tol = eps
262
+
263
+ if generic_vector_norm(r) < tol:
264
+ return x
265
+
266
+ if maxiter is None:
267
+ maxiter = generic_numel(b)
268
+
269
+ for _ in range(maxiter):
270
+ Ad = A_mm_reg(d)
271
+
272
+ d_Ad = d.dot(Ad)
273
+ if d_Ad <= eps:
274
+ return _trust_tau(x, d, trust_region)
275
+
276
+ alpha = r.dot(r) / d_Ad
277
+ p_next = x + alpha * d
278
+
279
+ # check if the step exceeds the trust-region boundary
280
+ if generic_vector_norm(p_next) >= trust_region:
281
+ return _trust_tau(x, d, trust_region)
282
+
283
+ # update step, residual and direction
284
+ x = p_next
285
+ r_next = r - alpha * Ad
286
+
287
+ if generic_vector_norm(r_next) < tol:
288
+ return x
289
+
290
+ beta = r_next.dot(r_next) / r.dot(r)
291
+ d = r_next + beta * d
292
+ r = r_next
293
+
294
+ return x
295
+
296
+
297
+
298
+ # Liu, Yang, and Fred Roosta. "MINRES: From negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32.4 (2022): 2636-2661.
299
+ @overload
300
+ def minres(
301
+ A_mm: Callable[[torch.Tensor], torch.Tensor] | torch.Tensor,
302
+ b: torch.Tensor,
303
+ x0: torch.Tensor | None = None,
304
+ tol: float | None = 1e-4,
305
+ maxiter: int | None = None,
306
+ reg: float = 0,
307
+ npc_terminate: bool=True,
308
+ trust_region: float | None = None,
309
+ ) -> torch.Tensor: ...
310
+ @overload
311
+ def minres(
312
+ A_mm: Callable[[TensorList], TensorList],
313
+ b: TensorList,
314
+ x0: TensorList | None = None,
315
+ tol: float | None = 1e-4,
316
+ maxiter: int | None = None,
317
+ reg: float | list[float] | tuple[float] = 0,
318
+ npc_terminate: bool=True,
319
+ trust_region: float | None = None,
320
+ ) -> TensorList: ...
321
+ def minres(
322
+ A_mm,
323
+ b,
324
+ x0: torch.Tensor | TensorList | None = None,
325
+ tol: float | None = 1e-4,
326
+ maxiter: int | None = None,
327
+ reg: float | list[float] | tuple[float] = 0,
328
+ npc_terminate: bool=True,
329
+ trust_region: float | None = None,
330
+ ):
331
+ A_mm_reg = _make_A_mm_reg(A_mm, reg)
332
+ eps = generic_finfo_eps(b)
333
+ if tol is None: tol = eps**2
334
+
335
+ if maxiter is None: maxiter = generic_numel(b)
336
+ if x0 is None:
337
+ R = b
338
+ x0 = generic_zeros_like(b)
339
+ else:
340
+ R = b - A_mm_reg(x0)
341
+
342
+ X: Any = x0
343
+ beta = b_norm = generic_vector_norm(b)
344
+ if b_norm < eps**2:
345
+ return generic_zeros_like(b)
346
+
347
+
348
+ V = b / beta
349
+ V_prev = generic_zeros_like(b)
350
+ D = generic_zeros_like(b)
351
+ D_prev = generic_zeros_like(b)
352
+
353
+ c = -1
354
+ phi = tau = beta
355
+ s = delta1 = e = 0
356
+
357
+
358
+ for _ in range(maxiter):
359
+
360
+ P = A_mm_reg(V)
361
+ alpha = V.dot(P)
362
+ P -= beta*V_prev
363
+ P -= alpha*V
364
+ beta = generic_vector_norm(P)
365
+
366
+ delta2 = c*delta1 + s*alpha
367
+ gamma1 = s*delta1 - c*alpha
368
+ e_next = s*beta
369
+ delta1 = -c*beta
370
+
371
+ cgamma1 = c*gamma1
372
+ if trust_region is not None and cgamma1 >= 0:
373
+ if npc_terminate: return _trust_tau(X, R, trust_region)
374
+ return _trust_tau(X, D, trust_region)
375
+
376
+ if npc_terminate and cgamma1 >= 0:
377
+ return R
378
+
379
+ gamma2 = (gamma1**2 + beta**2)**(1/2)
380
+
381
+ if abs(gamma2) <= eps: # singular system
382
+ # c=0; s=1; tau=0
383
+ if trust_region is None: return X
384
+ return _trust_tau(X, D, trust_region)
385
+
386
+ c = gamma1 / gamma2
387
+ s = beta/gamma2
388
+ tau = c*phi
389
+ phi = s*phi
390
+
391
+ D_prev = D
392
+ D = (V - delta2*D - e*D_prev) / gamma2
393
+ e = e_next
394
+ X = X + tau*D
395
+
396
+ if trust_region is not None:
397
+ if generic_vector_norm(X) > trust_region:
398
+ return _trust_tau(X, D, trust_region)
399
+
400
+ if (abs(beta) < eps) or (phi / b_norm <= tol):
401
+ # R = zeros(R)
402
+ return X
403
+
404
+ V_prev = V
405
+ V = P/beta
406
+ R = s**2*R - phi*c*V
407
+
408
+ return X
@@ -129,4 +129,6 @@ class NumberList(list[int | float | Any]):
129
129
  return self.__class__(fn(i, *args, **kwargs) for i in self)
130
130
 
131
131
  def clamp(self, min=None, max=None):
132
+ return self.zipmap_args(_clamp, min, max)
133
+ def clip(self, min=None, max=None):
132
134
  return self.zipmap_args(_clamp, min, max)
@@ -1,3 +1,4 @@
1
+ from abc import ABC, abstractmethod
1
2
  from collections.abc import Callable, Iterable, Mapping, MutableSequence, Sequence, MutableMapping
2
3
  from typing import Any, Literal, TypeVar, overload
3
4
 
@@ -132,65 +133,7 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
132
133
  return values
133
134
 
134
135
 
135
-
136
- def loss_at_params(closure, params: Iterable[torch.Tensor],
137
- new_params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
138
- params = TensorList(params)
139
-
140
- old_params = params.clone() if restore else None
141
-
142
- if isinstance(new_params, Sequence) and isinstance(new_params[0], torch.Tensor):
143
- # when not restoring, copy new_params to params to avoid unexpected bugs due to shared storage
144
- # when restoring params will be set back to old_params so its fine
145
- if restore: params.set_(new_params)
146
- else: params.copy_(new_params) # type:ignore
147
-
148
- else:
149
- new_params = totensor(new_params)
150
- params.from_vec_(new_params)
151
-
152
- if backward: loss = closure()
153
- else: loss = closure(False)
154
-
155
- if restore:
156
- assert old_params is not None
157
- params.set_(old_params)
158
-
159
- return tofloat(loss)
160
-
161
- def loss_grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
162
- params = TensorList(params)
163
- old_params = params.clone() if restore else None
164
- loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
165
- grad = params.ensure_grad_().grad
166
-
167
- if restore:
168
- assert old_params is not None
169
- params.set_(old_params)
170
-
171
- return loss, grad
172
-
173
- def grad_at_params(closure, params: Iterable[torch.Tensor], new_params: Sequence[torch.Tensor], restore=False):
174
- return loss_grad_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
175
-
176
- def loss_grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
177
- params = TensorList(params)
178
- old_params = params.clone() if restore else None
179
- loss = loss_at_params(closure, params, new_params, backward=True, restore=False)
180
- grad = params.ensure_grad_().grad.to_vec()
181
-
182
- if restore:
183
- assert old_params is not None
184
- params.set_(old_params)
185
-
186
- return loss, grad
187
-
188
- def grad_vec_at_params(closure, params: Iterable[torch.Tensor], new_params: Any, restore=False):
189
- return loss_grad_vec_at_params(closure=closure,params=params,new_params=new_params,restore=restore)[1]
190
-
191
-
192
-
193
- class Optimizer(torch.optim.Optimizer):
136
+ class Optimizer(torch.optim.Optimizer, ABC):
194
137
  """subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
195
138
 
196
139
  Args:
@@ -251,21 +194,10 @@ class Optimizer(torch.optim.Optimizer):
251
194
 
252
195
  return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
253
196
 
254
- def loss_at_params(self, closure, params: Sequence[torch.Tensor] | Any, backward: bool, restore=False):
255
- return loss_at_params(closure=closure,params=self.get_params(),new_params=params,backward=backward,restore=restore)
256
-
257
- def loss_grad_at_params(self, closure, params: Sequence[torch.Tensor] | Any, restore=False):
258
- return loss_grad_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
259
-
260
- def grad_at_params(self, closure, new_params: Sequence[torch.Tensor], restore=False):
261
- return self.loss_grad_at_params(closure=closure,params=new_params,restore=restore)[1]
262
-
263
- def loss_grad_vec_at_params(self, closure, params: Any, restore=False):
264
- return loss_grad_vec_at_params(closure=closure,params=self.get_params(),new_params=params,restore=restore)
265
-
266
- def grad_vec_at_params(self, closure, params: Any, restore=False):
267
- return self.loss_grad_vec_at_params(closure=closure,params=params,restore=restore)[1]
268
197
 
198
+ # shut up pylance
199
+ @abstractmethod
200
+ def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
269
201
 
270
202
  def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
271
203
  if set_to_none:
@@ -281,4 +213,53 @@ def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
281
213
  else:
282
214
  grad.requires_grad_(False)
283
215
 
284
- torch._foreach_zero_(grads)
216
+ torch._foreach_zero_(grads)
217
+
218
+
219
+ @overload
220
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
221
+ key: str, *,
222
+ must_exist: bool = False, init: Init = torch.zeros_like,
223
+ cls: type[ListLike] = list) -> ListLike: ...
224
+ @overload
225
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
226
+ key: list[str] | tuple[str,...], *,
227
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
228
+ cls: type[ListLike] = list) -> list[ListLike]: ...
229
+ @overload
230
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
231
+ key: str, key2: str, *keys: str,
232
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
233
+ cls: type[ListLike] = list) -> list[ListLike]: ...
234
+
235
+ def unpack_states(states: Sequence[MutableMapping[str, Any]], tensors: Sequence[torch.Tensor],
236
+ key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
237
+ must_exist: bool = False, init: Init | Sequence[Init] = torch.zeros_like,
238
+ cls: type[ListLike] = list) -> ListLike | list[ListLike]:
239
+
240
+ # single key, return single cls
241
+ if isinstance(key, str) and key2 is None:
242
+ values = cls()
243
+ for i,s in enumerate(states):
244
+ if key not in s:
245
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
246
+ s[key] = _make_initial_state_value(tensors[i], init, i)
247
+ values.append(s[key])
248
+ return values
249
+
250
+ # multiple keys
251
+ k1 = (key,) if isinstance(key, str) else tuple(key)
252
+ k2 = () if key2 is None else (key2,)
253
+ keys = k1 + k2 + keys
254
+
255
+ values = [cls() for _ in keys]
256
+ for i,s in enumerate(states):
257
+ for k_i, key in enumerate(keys):
258
+ if key not in s:
259
+ if must_exist: raise KeyError(f"Key {key} doesn't exist in state with keys {tuple(s.keys())}")
260
+ k_init = init[k_i] if isinstance(init, (list,tuple)) else init
261
+ s[key] = _make_initial_state_value(tensors[i], k_init, i)
262
+ values[k_i].append(s[key])
263
+
264
+ return values
265
+
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Any, TypeVar
4
- from collections.abc import Iterable, Callable
3
+ from typing import Any, TypeVar, overload
4
+ from collections.abc import Iterable, Callable, Mapping, MutableSequence
5
5
  from collections import UserDict
6
6
 
7
7
 
@@ -17,8 +17,8 @@ def flatten(iterable: Iterable) -> list[Any]:
17
17
  raise TypeError(f'passed object is not an iterable, {type(iterable) = }')
18
18
 
19
19
  X = TypeVar("X")
20
- # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
21
- def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]: # pylint:disable=E0602
20
+ # def reduce_dim[X](x:Iterable[Iterable[X]]) -> list[X]:
21
+ def reduce_dim(x:Iterable[Iterable[X]]) -> list[X]:
22
22
  """Reduces one level of nesting. Takes an iterable of iterables of X, and returns an iterable of X."""
23
23
  return functools.reduce(operator.iconcat, x, [])
24
24
 
@@ -31,6 +31,16 @@ def generic_eq(x: int | float | Iterable[int | float], y: int | float | Iterable
31
31
  return all(i==y for i in x)
32
32
  return all(i==j for i,j in zip(x,y))
33
33
 
34
+ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable[int | float]) -> bool:
35
+ """generic not equals function that supports scalars and lists of numbers. Faster than not generic_eq"""
36
+ if isinstance(x, (int,float)):
37
+ if isinstance(y, (int,float)): return x!=y
38
+ return any(i!=x for i in y)
39
+ if isinstance(y, (int,float)):
40
+ return any(i!=y for i in x)
41
+ return any(i!=j for i,j in zip(x,y))
42
+
43
+
34
44
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
35
45
  """If `other` is list/tuple, applies `fn` to self zipped with `other`.
36
46
  Otherwise applies `fn` to this sequence and `other`.
@@ -38,3 +48,16 @@ def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
38
48
  if isinstance(other, (list, tuple)): return self.__class__(fn(i, j, *args, **kwargs) for i, j in zip(self, other))
39
49
  return self.__class__(fn(i, other, *args, **kwargs) for i in self)
40
50
 
51
+ ListLike = TypeVar('ListLike', bound=MutableSequence)
52
+ @overload
53
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, *, cls:type[ListLike]=list) -> ListLike: ...
54
+ @overload
55
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str, *keys:str, cls:type[ListLike]=list) -> list[ListLike]: ...
56
+ def unpack_dicts(dicts: Iterable[Mapping[str, Any]], key:str, key2: str | None = None, *keys:str, cls:type[ListLike]=list) -> ListLike | list[ListLike]:
57
+ k1 = (key,) if isinstance(key, str) else tuple(key)
58
+ k2 = () if key2 is None else (key2,)
59
+ keys = k1 + k2 + keys
60
+
61
+ values = [cls(s[k] for s in dicts) for k in keys] # pyright:ignore[reportCallIssue]
62
+ if len(values) == 1: return values[0]
63
+ return values