torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -38,11 +38,11 @@ class _MaybeCompiledFunc:
38
38
  _optional_compiler = _OptionalCompiler()
39
39
  """this holds .enable attribute, set to True to enable compiling for a few functions that benefit from it."""
40
40
 
41
- def set_compilation(enable: bool=True):
41
+ def enable_compilation(enable: bool=True):
42
42
  """`enable` is False by default. When True, certain functions will be compiled, which may not work on some systems like Windows, but it usually improves performance."""
43
43
  _optional_compiler.enable = enable
44
44
 
45
- def enable_compilation(fn): return _optional_compiler.enable_compilation(fn)
45
+ def allow_compile(fn): return _optional_compiler.enable_compilation(fn)
46
46
 
47
47
  def benchmark_compile_cuda(fn, n: int, **kwargs):
48
48
  # warmup
@@ -4,6 +4,7 @@ import torch
4
4
  import torch.autograd.forward_ad as fwAD
5
5
 
6
6
  from .torch_tools import swap_tensors_no_use_count_check, vec_to_tensors
7
+ from .tensorlist import TensorList
7
8
 
8
9
  def _jacobian(outputs: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
9
10
  flat_outputs = torch.cat([i.reshape(-1) for i in outputs])
@@ -261,7 +262,7 @@ def jvp_fd_central(
261
262
  params: Iterable[torch.Tensor],
262
263
  tangent: Iterable[torch.Tensor],
263
264
  h=1e-3,
264
- normalize=False,
265
+ normalize=True,
265
266
  ) -> tuple[torch.Tensor | None, torch.Tensor]:
266
267
  """Jacobian vector product using central finite difference formula.
267
268
 
@@ -310,7 +311,7 @@ def jvp_fd_forward(
310
311
  tangent: Iterable[torch.Tensor],
311
312
  h=1e-3,
312
313
  v_0=None,
313
- normalize=False,
314
+ normalize=True,
314
315
  ) -> tuple[torch.Tensor | None, torch.Tensor]:
315
316
  """Jacobian vector product using forward finite difference formula.
316
317
  Loss at initial point can be specified in the `v_0` argument.
@@ -357,52 +358,18 @@ def jvp_fd_forward(
357
358
  if normalize: res = res * tangent_norm
358
359
  return v_0, res
359
360
 
360
- def hvp(
361
- params: Iterable[torch.Tensor],
362
- grads: Iterable[torch.Tensor],
363
- vec: Iterable[torch.Tensor],
364
- retain_graph=None,
365
- create_graph=False,
366
- allow_unused=None,
367
- ):
368
- """Hessian-vector product
369
-
370
- Example:
371
- ```python
372
- model = nn.Linear(4, 2)
373
- X = torch.randn(10, 4)
374
- y = torch.randn(10, 2)
375
-
376
- y_hat = model(X)
377
- loss = F.mse_loss(y_hat, y)
378
- loss.backward(create_graph=True)
379
-
380
- grads = [p.grad for p in model.parameters()]
381
- vec = [torch.randn_like(p) for p in model.parameters()]
382
-
383
- # list of tensors, same layout as model.parameters()
384
- hvp(model.parameters(), grads, vec=vec)
385
- ```
386
- """
387
- params = list(params)
388
- g = list(grads)
389
- vec = list(vec)
390
-
391
- with torch.enable_grad():
392
- return torch.autograd.grad(g, params, vec, create_graph=create_graph, retain_graph=retain_graph, allow_unused=allow_unused)
393
-
394
361
 
395
362
  @torch.no_grad
396
363
  def hvp_fd_central(
397
364
  closure,
398
365
  params: Iterable[torch.Tensor],
399
- vec: Iterable[torch.Tensor],
366
+ x: Iterable[torch.Tensor],
400
367
  h=1e-3,
401
- normalize=False,
368
+ normalize=True,
402
369
  ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
403
- """Hessian-vector product using central finite difference formula.
370
+ """Returns ``(loss_approx, Hx)``.
404
371
 
405
- Please note that this will clear :code:`grad` attributes in params.
372
+ Please note that this will clear ``grad`` attributes in params.
406
373
 
407
374
  Example:
408
375
  ```python
@@ -424,48 +391,48 @@ def hvp_fd_central(
424
391
  ```
425
392
  """
426
393
  params = list(params)
427
- vec = list(vec)
394
+ x = list(x)
428
395
 
429
396
  vec_norm = None
430
397
  if normalize:
431
- vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in vec])) # pylint:disable=not-callable
398
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.view(-1) for t in x])) # pylint:disable=not-callable
432
399
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
433
- vec = torch._foreach_div(vec, vec_norm)
400
+ x = torch._foreach_div(x, vec_norm)
434
401
 
435
- vec_h = torch._foreach_mul(vec, h)
436
- torch._foreach_add_(params, vec_h)
402
+ xh = torch._foreach_mul(x, h)
403
+ torch._foreach_add_(params, xh)
437
404
  with torch.enable_grad(): loss = closure()
438
405
  g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
439
406
 
440
- torch._foreach_sub_(params, vec_h)
441
- torch._foreach_sub_(params, vec_h)
407
+ torch._foreach_sub_(params, xh)
408
+ torch._foreach_sub_(params, xh)
442
409
  with torch.enable_grad(): loss = closure()
443
410
  g_minus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
444
411
 
445
- torch._foreach_add_(params, vec_h)
412
+ torch._foreach_add_(params, xh)
446
413
  for p in params: p.grad = None
447
414
 
448
- hvp_ = g_plus
449
- torch._foreach_sub_(hvp_, g_minus)
450
- torch._foreach_div_(hvp_, 2*h)
415
+ hx = g_plus
416
+ torch._foreach_sub_(hx, g_minus)
417
+ torch._foreach_div_(hx, 2*h)
451
418
 
452
- if normalize: torch._foreach_mul_(hvp_, vec_norm)
453
- return loss, hvp_
419
+ if normalize: torch._foreach_mul_(hx, vec_norm)
420
+ return loss, hx
454
421
 
455
422
  @torch.no_grad
456
423
  def hvp_fd_forward(
457
424
  closure,
458
425
  params: Iterable[torch.Tensor],
459
- vec: Iterable[torch.Tensor],
426
+ x: Iterable[torch.Tensor],
460
427
  h=1e-3,
461
428
  g_0=None,
462
- normalize=False,
429
+ normalize=True,
463
430
  ) -> tuple[torch.Tensor | None, list[torch.Tensor]]:
464
- """Hessian-vector product using forward finite difference formula.
431
+ """Returns ``(loss_approx, Hx)``.
465
432
 
466
- Gradient at initial point can be specified in the `g_0` argument.
433
+ Gradient at initial point can be specified in the ``g_0`` argument.
467
434
 
468
- Please note that this will clear :code:`grad` attributes in params.
435
+ Please note that this will clear ``grad`` attributes in params.
469
436
 
470
437
  Example:
471
438
  ```python
@@ -492,16 +459,16 @@ def hvp_fd_forward(
492
459
  """
493
460
 
494
461
  params = list(params)
495
- vec = list(vec)
462
+ x = list(x)
496
463
  loss = None
497
464
 
498
465
  vec_norm = None
499
466
  if normalize:
500
- vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in vec])) # pylint:disable=not-callable
467
+ vec_norm = torch.linalg.vector_norm(torch.cat([t.ravel() for t in x])) # pylint:disable=not-callable
501
468
  if vec_norm == 0: return None, [torch.zeros_like(p) for p in params]
502
- vec = torch._foreach_div(vec, vec_norm)
469
+ x = torch._foreach_div(x, vec_norm)
503
470
 
504
- vec_h = torch._foreach_mul(vec, h)
471
+ xh = torch._foreach_mul(x, h)
505
472
 
506
473
  if g_0 is None:
507
474
  with torch.enable_grad(): loss = closure()
@@ -509,18 +476,75 @@ def hvp_fd_forward(
509
476
  else:
510
477
  g_0 = list(g_0)
511
478
 
512
- torch._foreach_add_(params, vec_h)
479
+ torch._foreach_add_(params, xh)
513
480
  with torch.enable_grad():
514
481
  l = closure()
515
482
  if loss is None: loss = l
516
483
  g_plus = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
517
484
 
518
- torch._foreach_sub_(params, vec_h)
485
+ torch._foreach_sub_(params, xh)
519
486
  for p in params: p.grad = None
520
487
 
521
- hvp_ = g_plus
522
- torch._foreach_sub_(hvp_, g_0)
523
- torch._foreach_div_(hvp_, h)
488
+ hx = g_plus
489
+ torch._foreach_sub_(hx, g_0)
490
+ torch._foreach_div_(hx, h)
491
+
492
+ if normalize: torch._foreach_mul_(hx, vec_norm)
493
+ return loss, hx
494
+
495
+ @torch.no_grad
496
+ def hessian_fd(fn, params: Sequence[torch.Tensor], eps: float = 1e-4, full: bool = True):
497
+ """returns ``f(x), g(x), H(x)``, where ``g(x)`` is a tensor list.
498
+
499
+ Number of evals for full is: 4n^2 - 2n + 1
500
+
501
+ Number of evals for upper is: 2n^2 + 1.
502
+ """
503
+ params = TensorList(params)
504
+ p_0 = params.clone()
505
+ n = sum(t.numel() for t in params)
506
+ device = params[0].device; dtype = params[0].dtype
507
+ fx = fn()
508
+ g = params.zeros_like()
509
+ H = torch.zeros((n, n), device=device, dtype=dtype)
510
+
511
+ for i in range(n):
512
+ for j in (range(n) if full else range(i, n)):
513
+ if i == j:
514
+ params.flat_set_lambda_(i, lambda x: x + eps)
515
+ f_plus = fn()
516
+
517
+ params.flat_set_lambda_(i, lambda x: x - 2 * eps)
518
+ f_minus = fn()
519
+
520
+ # params.flat_set_lambda_(i, lambda x: x + eps)
521
+ g.flat_set_(i, (f_plus - f_minus) / (2*eps))
522
+ H[i, i] = (f_plus - 2 * fx + f_minus) / (eps ** 2)
523
+
524
+ else:
525
+ params.flat_set_lambda_(i, lambda x: x + eps)
526
+ params.flat_set_lambda_(j, lambda x: x + eps)
527
+ f_pp = fn()
528
+
529
+ params.flat_set_lambda_(i, lambda x: x - 2 * eps)
530
+ f_np = fn()
531
+
532
+ params.flat_set_lambda_(j, lambda x: x - 2 * eps)
533
+ f_nn = fn()
534
+
535
+ params.flat_set_lambda_(i, lambda x: x + 2 * eps)
536
+ f_pn = fn()
537
+
538
+ # params.flat_set_lambda_(i, lambda x: x - eps)
539
+ # params.flat_set_lambda_(j, lambda x: x + eps)
540
+
541
+ H[i, j] = (f_pp - f_np - f_pn + f_nn) / (4 * eps ** 2)
542
+ if not full: H[j, i] = H[i, j]
543
+
544
+ params.copy_(p_0) # otherwise inaccuracy builds up
545
+
546
+ if full:
547
+ H = H + H.T
548
+ H /= 2
524
549
 
525
- if normalize: torch._foreach_mul_(hvp_, vec_norm)
526
- return loss, hvp_
550
+ return fx, g, H
@@ -64,22 +64,15 @@ def get_group_vals(param_groups: Iterable[Mapping[str, Any]],
64
64
  values[i].extend(group_value for _ in range(num_params))
65
65
  return values
66
66
 
67
- _InitLiterals = Literal['param', 'grad']
68
- Init = _InitLiterals | Any | list[_InitLiterals | Any] | tuple[_InitLiterals | Any]
67
+ Init = Any
69
68
 
70
- def _make_initial_state_value(param: torch.Tensor, init: Init, i: int | None):
71
- if callable(init): return init(param)
69
+ def _make_initial_state_value(tensor: torch.Tensor, init: Init, i: int | None):
70
+ if callable(init): return init(tensor)
72
71
  if isinstance(init, torch.Tensor): return init.detach().clone()
73
72
 
74
- if isinstance(init, str):
75
- if init in ('param','params'): return param.detach().clone()
76
- if init in ('grad', 'grads'):
77
- if param.grad is None: raise RuntimeError('init is set to "grad, but param.grad is None"')
78
- return param.grad.detach().clone()
79
-
80
73
  if isinstance(init, (list,tuple)):
81
74
  if i is None: raise RuntimeError(f'init is per-parameter ({type(init)}) but parameter index i is None')
82
- return _make_initial_state_value(param, init[i], None)
75
+ return _make_initial_state_value(tensor, init[i], None)
83
76
 
84
77
  return init
85
78
 
@@ -133,72 +126,6 @@ def get_state_vals(state: Mapping[torch.Tensor, MutableMapping[str, Any]], param
133
126
  return values
134
127
 
135
128
 
136
- class Optimizer(torch.optim.Optimizer, ABC):
137
- """subclass of torch.optim.Optimizer with some helper methods for fast experimentation, it's not used anywhere in torchzero.
138
-
139
- Args:
140
- params (iterable): an iterable of :class:`torch.Tensor` s or
141
- :class:`dict` s. Specifies what Tensors should be optimized.
142
- defaults (dict | None): a dict containing default values of optimization
143
- options (used when a parameter group doesn't specify them).
144
- """
145
- def __init__(self, params, defaults: dict[str, Any] | None = None, **_defaults):
146
- if defaults is None: defaults = {}
147
- defaults.update(_defaults)
148
-
149
- super().__init__(params, defaults)
150
- self.global_state = self.state[self.param_groups[0]['params'][0]]
151
- """state of 1st parameter, can be used as global state which is how L-BFGS uses it in pytorch, and there is some kind of good reason to do it like that"""
152
-
153
- def get_params(self, mode: ParamFilter = 'requires_grad', cls: type[ListLike] = TensorList) -> ListLike:
154
- return get_params(self.param_groups, mode, cls)
155
-
156
- @overload
157
- def group_vals(self, key: str, *,
158
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike: ...
159
- @overload
160
- def group_vals(self, key: list[str] | tuple[str,...], *,
161
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
162
- @overload
163
- def group_vals(self, key: str, key2: str, *keys: str,
164
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> list[ListLike]: ...
165
-
166
- def group_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
167
- mode: ParamFilter = 'requires_grad', cls: type[ListLike] = NumberList) -> ListLike | list[ListLike]:
168
- return get_group_vals(self.param_groups, key, key2, *keys, mode = mode, cls = cls) # pyright:ignore[reportArgumentType]
169
-
170
-
171
- @overload
172
- def state_vals(self, key: str, *,
173
- init: Init = torch.zeros_like,
174
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
175
- cls: type[ListLike] = TensorList) -> ListLike: ...
176
- @overload
177
- def state_vals(self, key: list[str] | tuple[str,...], *,
178
- init: Init | Sequence[Init] = torch.zeros_like,
179
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
180
- cls: type[ListLike] = TensorList) -> list[ListLike]: ...
181
- @overload
182
- def state_vals(self, key: str, key2: str, *keys: str,
183
- init: Init | Sequence[Init] = torch.zeros_like,
184
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
185
- cls: type[ListLike] = TensorList) -> list[ListLike]: ...
186
-
187
- def state_vals(self, key: str | list[str] | tuple[str,...], key2: str | None = None, *keys: str,
188
- init: Init | Sequence[Init] = torch.zeros_like,
189
- mode: ParamFilter | list[torch.Tensor] | tuple[torch.Tensor, ...] = 'requires_grad',
190
- cls: type[ListLike] = TensorList) -> ListLike | list[ListLike]:
191
-
192
- if isinstance(mode, (list,tuple)): params = mode
193
- else: params = self.get_params(mode)
194
-
195
- return get_state_vals(self.state, params, key, key2, *keys, init = init, cls = cls) # type:ignore[reportArgumentType]
196
-
197
-
198
- # shut up pylance
199
- @abstractmethod
200
- def step(self, closure) -> Any: ... # pylint:disable=signature-differs # pyright:ignore[reportIncompatibleMethodOverride]
201
-
202
129
  def zero_grad_(params: Iterable[torch.Tensor], set_to_none):
203
130
  if set_to_none:
204
131
  for p in params:
@@ -1,3 +1,4 @@
1
+ import importlib
1
2
  import functools
2
3
  import operator
3
4
  from typing import Any, TypeVar, overload
@@ -40,6 +41,11 @@ def generic_ne(x: int | float | Iterable[int | float], y: int | float | Iterable
40
41
  return any(i!=y for i in x)
41
42
  return any(i!=j for i,j in zip(x,y))
42
43
 
44
+ def generic_is_none(x: Any | Iterable[Any]):
45
+ """returns True if x is None or iterable with all elements set to None"""
46
+ if x is None: return True
47
+ if isinstance(x, Iterable): return all(i is None for i in x)
48
+ return False
43
49
 
44
50
  def zipmap(self, fn: Callable, other: Any | list | tuple, *args, **kwargs):
45
51
  """If `other` is list/tuple, applies `fn` to self zipped with `other`.
@@ -68,3 +74,28 @@ def safe_dict_update_(d1_:dict, d2:dict):
68
74
  if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
69
75
  d1_.update(d2)
70
76
 
77
+ # lazy loader from https://stackoverflow.com/a/78312674/15673832
78
+ class LazyLoader:
79
+ 'thin shell class to wrap modules. load real module on first access and pass thru'
80
+
81
+ def __init__(self, modname):
82
+ self._modname = modname
83
+ self._mod = None
84
+
85
+ def __getattr__(self, attr):
86
+ 'import module on first attribute access'
87
+
88
+ try:
89
+ return getattr(self._mod, attr)
90
+
91
+ except Exception as e :
92
+ if self._mod is None :
93
+ # module is unset, load it
94
+ self._mod = importlib.import_module (self._modname)
95
+ else :
96
+ # module is set, got different exception from getattr (). reraise it
97
+ raise e
98
+
99
+ # retry getattr if module was just loaded for first time
100
+ # call this outside exception handler in case it raises new exception
101
+ return getattr (self._mod, attr)
@@ -22,7 +22,6 @@ from typing_extensions import Self, TypeAlias, Unpack
22
22
 
23
23
  from .metrics import Metrics, evaluate_metric, calculate_metric_list
24
24
  from .numberlist import NumberList, as_numberlist, maybe_numberlist
25
- from .ops import where_
26
25
  from .python_tools import generic_ne, zipmap
27
26
 
28
27
  _Scalar = int | float | bool | complex
@@ -346,6 +345,10 @@ class TensorList(list[torch.Tensor | Any]):
346
345
  def global_all(self): return builtins.all(self.all())
347
346
  def global_numel(self) -> int: return builtins.sum(self.numel())
348
347
 
348
+ def global_allclose(self, other: _TensorSeq, rtol: float = 0.00001, atol: float = 1e-8, equal_nan: bool = False) -> bool:
349
+ bools = self.zipmap_args(torch.allclose, other, rtol, atol, equal_nan)
350
+ return all(bools)
351
+
349
352
  def empty_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.empty_like(i, **kwargs) for i in self)
350
353
  def zeros_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.zeros_like(i, **kwargs) for i in self)
351
354
  def ones_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(torch.ones_like(i, **kwargs) for i in self)
@@ -509,7 +512,6 @@ class TensorList(list[torch.Tensor | Any]):
509
512
  torch._foreach_mul_(self, other)
510
513
  return self
511
514
 
512
- # TODO: benchmark
513
515
  def lazy_mul(self, other: int | float | list[int | float] | tuple[int | float], clone=False):
514
516
  if generic_ne(other, 1):
515
517
  return self * other
@@ -536,6 +538,13 @@ class TensorList(list[torch.Tensor | Any]):
536
538
  torch._foreach_pow_(self, exponent)
537
539
  return self
538
540
 
541
+ def lazy_pow(self, other: int | float | list[int | float] | tuple[int | float]):
542
+ if generic_ne(other, 1): return self.pow(other)
543
+ return self
544
+ def lazy_pow_(self, other: int | float | list[int | float] | tuple[int | float]):
545
+ if generic_ne(other, 1): return self.pow_(other)
546
+ return self
547
+
539
548
  def rpow(self, input: _Scalar | _TensorSeq): return self.__class__(torch._foreach_pow(input, self))
540
549
  def rpow_(self, input: _TensorSeq):
541
550
  torch._foreach_pow_(input, self)
@@ -984,9 +993,6 @@ class TensorList(list[torch.Tensor | Any]):
984
993
  def where(self, condition: "torch.Tensor | _TensorSeq", other: _STOrSTSeq):
985
994
  """self where condition is true other otherwise"""
986
995
  return self.zipmap_args(_MethodCallerWithArgs('where'), condition, other)
987
- def where_(self, condition: "torch.Tensor | _TensorSeq", other: "torch.Tensor | _TensorSeq"):
988
- """self where condition is true other otherwise"""
989
- return self.zipmap_args_inplace_(where_, condition, other)
990
996
 
991
997
  def masked_fill(self, mask: "torch.Tensor | _TensorSeq", fill_value: "_Scalar | _ScalarSeq"):
992
998
  """Same as tensor[mask] = value (not in-place), where value must be scalar/scalars"""
@@ -0,0 +1,68 @@
1
+ import itertools
2
+ from collections.abc import Callable
3
+ from importlib.util import find_spec
4
+ from typing import TYPE_CHECKING, cast
5
+
6
+ import torch
7
+
8
+ from .python_tools import LazyLoader
9
+
10
+ lazy_thoad = LazyLoader("thoad")
11
+ if TYPE_CHECKING:
12
+ import thoad
13
+ lazy_thoad = cast(thoad, lazy_thoad)
14
+
15
+ def thoad_single_tensor(
16
+ ctrl: "thoad.Controller",
17
+ params: list[torch.Tensor],
18
+ order: int
19
+ ) -> torch.Tensor:
20
+ """treats params as if they were concatenated into a vector."""
21
+
22
+ if not all(p.requires_grad for p in params):
23
+ raise ValueError("All parameters must have requires_grad=True")
24
+
25
+ if order < 1:
26
+ raise ValueError("Order must be at least 1")
27
+
28
+ # we need parameter sizes and total size N
29
+ # final tensor is (N, N, ..., N) with `order` dimensions.
30
+ param_numels = [p.numel() for p in params]
31
+ total_params = sum(param_numels)
32
+
33
+ final_shape = (total_params,) * order
34
+ p = params[0]
35
+ T = torch.zeros(final_shape, device=p.device, dtype=p.dtype)
36
+
37
+ # start/end indices for each parameter in the flattened vector.
38
+ offsets = torch.cumsum(torch.tensor([0] + param_numels), dim=0)
39
+
40
+ # for order=2 this iterates through (p0,p0), (p0,p1), (p1,p0), (p1,p1), etc.
41
+ param_indices = range(len(params))
42
+ for block_indices in itertools.product(param_indices, repeat=order):
43
+
44
+ block_params = tuple(params[i] for i in block_indices)
45
+ block_tensor, _ = ctrl.fetch_hgrad(variables=block_params) # (1, *p1.shape, *p2.shape, ...).
46
+ block_tensor = block_tensor.squeeze(0) # (*p1.shape, *p2.shape, ...)
47
+
48
+ # convert (*p1.shape, *p2.shape) to (p1.numel(), p2.numel())
49
+ block_flat_shape = tuple(param_numels[i] for i in block_indices)
50
+ block_tensor_flat = block_tensor.reshape(block_flat_shape)
51
+
52
+ # place the flattened block into T
53
+ slicing = tuple(
54
+ slice(offsets[i], offsets[i+1]) for i in block_indices
55
+ )
56
+ T[slicing] = block_tensor_flat
57
+
58
+ ctrl.clear()
59
+ return T
60
+
61
+ def thoad_derivatives(
62
+ ctrl: "thoad.Controller",
63
+ params: list[torch.Tensor],
64
+ order: int,
65
+ ):
66
+ """returns all derivatives up to ``order`` in ascending order, all as single tensors
67
+ as if parameters were concatenated to a vector"""
68
+ return [thoad_single_tensor(ctrl, params, o) for o in range(1, order+1)]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.15
3
+ Version: 0.4.0
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  Project-URL: Homepage, https://github.com/inikishev/torchzero