torchzero 0.3.8__tar.gz → 0.3.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (160) hide show
  1. {torchzero-0.3.8 → torchzero-0.3.10}/PKG-INFO +14 -14
  2. {torchzero-0.3.8 → torchzero-0.3.10}/README.md +13 -13
  3. {torchzero-0.3.8 → torchzero-0.3.10}/pyproject.toml +1 -1
  4. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_opts.py +55 -22
  5. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_tensorlist.py +3 -3
  6. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_vars.py +61 -61
  7. torchzero-0.3.10/torchzero/core/__init__.py +2 -0
  8. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/core/module.py +49 -49
  9. torchzero-0.3.10/torchzero/core/transform.py +313 -0
  10. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/__init__.py +1 -0
  11. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/clipping.py +10 -10
  12. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/ema_clipping.py +14 -13
  13. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/growth_clipping.py +16 -18
  14. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/__init__.py +12 -3
  15. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/absoap.py +50 -156
  16. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adadam.py +15 -14
  17. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adamY.py +17 -27
  18. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/adasoap.py +20 -130
  19. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/curveball.py +12 -12
  20. torchzero-0.3.10/torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  21. torchzero-0.3.10/torchzero/modules/experimental/eigendescent.py +117 -0
  22. torchzero-0.3.10/torchzero/modules/experimental/etf.py +172 -0
  23. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/gradmin.py +2 -2
  24. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/newton_solver.py +11 -11
  25. torchzero-0.3.10/torchzero/modules/experimental/newtonnewton.py +88 -0
  26. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  27. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/soapy.py +19 -146
  28. torchzero-0.3.10/torchzero/modules/experimental/spectral.py +163 -0
  29. torchzero-0.3.10/torchzero/modules/experimental/structured_newton.py +111 -0
  30. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  31. torchzero-0.3.10/torchzero/modules/experimental/tada.py +38 -0
  32. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/fdm.py +2 -2
  33. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  34. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  35. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/rfdm.py +28 -15
  36. torchzero-0.3.10/torchzero/modules/higher_order/__init__.py +1 -0
  37. torchzero-0.3.10/torchzero/modules/higher_order/higher_order_newton.py +256 -0
  38. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/backtracking.py +42 -23
  39. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/line_search.py +40 -40
  40. torchzero-0.3.10/torchzero/modules/line_search/scipy.py +52 -0
  41. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/strong_wolfe.py +21 -32
  42. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/trust_region.py +18 -6
  43. torchzero-0.3.10/torchzero/modules/lr/__init__.py +2 -0
  44. torchzero-0.3.8/torchzero/modules/lr/step_size.py → torchzero-0.3.10/torchzero/modules/lr/adaptive.py +22 -26
  45. torchzero-0.3.10/torchzero/modules/lr/lr.py +63 -0
  46. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/averaging.py +25 -10
  47. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/cautious.py +73 -35
  48. torchzero-0.3.10/torchzero/modules/momentum/ema.py +224 -0
  49. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/experimental.py +21 -13
  50. torchzero-0.3.10/torchzero/modules/momentum/matrix_momentum.py +166 -0
  51. torchzero-0.3.10/torchzero/modules/momentum/momentum.py +63 -0
  52. torchzero-0.3.10/torchzero/modules/ops/accumulate.py +95 -0
  53. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/binary.py +36 -36
  54. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/debug.py +7 -7
  55. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/misc.py +128 -129
  56. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/multi.py +19 -19
  57. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/reduce.py +16 -16
  58. torchzero-0.3.10/torchzero/modules/ops/split.py +75 -0
  59. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/switch.py +4 -4
  60. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/unary.py +20 -20
  61. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/utility.py +37 -37
  62. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/adagrad.py +33 -24
  63. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/adam.py +31 -34
  64. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/lion.py +4 -4
  65. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/muon.py +6 -6
  66. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/orthograd.py +4 -5
  67. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/rmsprop.py +13 -16
  68. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/rprop.py +52 -49
  69. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/shampoo.py +17 -23
  70. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/soap.py +12 -19
  71. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/sophia_h.py +13 -13
  72. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/dct.py +4 -4
  73. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/fft.py +6 -6
  74. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/galore.py +1 -1
  75. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/projection.py +57 -57
  76. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/structural.py +17 -17
  77. torchzero-0.3.10/torchzero/modules/quasi_newton/__init__.py +36 -0
  78. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/cg.py +76 -26
  79. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  80. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/lbfgs.py +15 -15
  81. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/lsr1.py +18 -17
  82. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/olbfgs.py +19 -19
  83. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  84. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/newton.py +38 -21
  85. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/newton_cg.py +13 -12
  86. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/nystrom.py +19 -19
  87. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/gaussian.py +21 -21
  88. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/laplacian.py +7 -9
  89. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/weight_decay/__init__.py +1 -1
  90. torchzero-0.3.10/torchzero/modules/weight_decay/weight_decay.py +86 -0
  91. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/wrappers/optim_wrapper.py +11 -11
  92. torchzero-0.3.10/torchzero/optim/wrappers/directsearch.py +244 -0
  93. torchzero-0.3.10/torchzero/optim/wrappers/fcmaes.py +97 -0
  94. torchzero-0.3.10/torchzero/optim/wrappers/mads.py +90 -0
  95. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/nevergrad.py +4 -4
  96. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/nlopt.py +28 -14
  97. torchzero-0.3.10/torchzero/optim/wrappers/optuna.py +70 -0
  98. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/scipy.py +162 -13
  99. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/__init__.py +2 -6
  100. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/derivatives.py +2 -1
  101. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/optimizer.py +55 -74
  102. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/python_tools.py +17 -4
  103. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/PKG-INFO +14 -14
  104. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/SOURCES.txt +13 -4
  105. torchzero-0.3.8/torchzero/core/__init__.py +0 -3
  106. torchzero-0.3.8/torchzero/core/preconditioner.py +0 -138
  107. torchzero-0.3.8/torchzero/core/transform.py +0 -252
  108. torchzero-0.3.8/torchzero/modules/experimental/algebraic_newton.py +0 -145
  109. torchzero-0.3.8/torchzero/modules/experimental/spectral.py +0 -288
  110. torchzero-0.3.8/torchzero/modules/experimental/tropical_newton.py +0 -136
  111. torchzero-0.3.8/torchzero/modules/line_search/scipy.py +0 -37
  112. torchzero-0.3.8/torchzero/modules/lr/__init__.py +0 -2
  113. torchzero-0.3.8/torchzero/modules/lr/lr.py +0 -59
  114. torchzero-0.3.8/torchzero/modules/momentum/ema.py +0 -173
  115. torchzero-0.3.8/torchzero/modules/momentum/matrix_momentum.py +0 -124
  116. torchzero-0.3.8/torchzero/modules/momentum/momentum.py +0 -43
  117. torchzero-0.3.8/torchzero/modules/ops/accumulate.py +0 -65
  118. torchzero-0.3.8/torchzero/modules/ops/split.py +0 -75
  119. torchzero-0.3.8/torchzero/modules/quasi_newton/__init__.py +0 -7
  120. torchzero-0.3.8/torchzero/modules/weight_decay/weight_decay.py +0 -52
  121. {torchzero-0.3.8 → torchzero-0.3.10}/LICENSE +0 -0
  122. {torchzero-0.3.8 → torchzero-0.3.10}/docs/source/conf.py +0 -0
  123. {torchzero-0.3.8 → torchzero-0.3.10}/setup.cfg +0 -0
  124. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_identical.py +0 -0
  125. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_module.py +0 -0
  126. {torchzero-0.3.8 → torchzero-0.3.10}/tests/test_utils_optimizer.py +0 -0
  127. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/__init__.py +0 -0
  128. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/clipping/__init__.py +0 -0
  129. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/functional.py +0 -0
  130. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/grad_approximation/__init__.py +0 -0
  131. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/line_search/__init__.py +0 -0
  132. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/momentum/__init__.py +0 -0
  133. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/ops/__init__.py +0 -0
  134. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/optimizers/__init__.py +0 -0
  135. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/projections/__init__.py +0 -0
  136. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/quasi_newton/experimental/__init__.py +0 -0
  137. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/second_order/__init__.py +0 -0
  138. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/smoothing/__init__.py +0 -0
  139. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/modules/wrappers/__init__.py +0 -0
  140. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/__init__.py +0 -0
  141. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/utility/__init__.py +0 -0
  142. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/utility/split.py +0 -0
  143. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/optim/wrappers/__init__.py +0 -0
  144. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/compile.py +0 -0
  145. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/__init__.py +0 -0
  146. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/benchmark.py +0 -0
  147. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/matrix_funcs.py +0 -0
  148. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/orthogonalize.py +0 -0
  149. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/qr.py +0 -0
  150. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/solve.py +0 -0
  151. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/linalg/svd.py +0 -0
  152. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/numberlist.py +0 -0
  153. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/ops.py +0 -0
  154. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/optuna_tools.py +0 -0
  155. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/params.py +0 -0
  156. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/tensorlist.py +0 -0
  157. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero/utils/torch_tools.py +0 -0
  158. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/dependency_links.txt +0 -0
  159. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/requires.txt +0 -0
  160. {torchzero-0.3.8 → torchzero-0.3.10}/torchzero.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchzero
3
- Version: 0.3.8
3
+ Version: 0.3.10
4
4
  Summary: Modular optimization library for PyTorch.
5
5
  Author-email: Ivan Nikishev <nkshv2@gmail.com>
6
6
  License: MIT License
@@ -157,13 +157,14 @@ for epoch in range(100):
157
157
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
158
158
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
159
159
  * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
160
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
160
161
 
161
162
  * **Quasi-Newton**: Approximate second-order optimization methods.
162
163
  * `LBFGS`: Limited-memory BFGS.
163
164
  * `LSR1`: Limited-memory SR1.
164
165
  * `OnlineLBFGS`: Online LBFGS.
165
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
166
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
166
+ * `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
167
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
167
168
 
168
169
  * **Line Search**:
169
170
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
@@ -312,20 +313,20 @@ not in the module itself. Also both per-parameter settings and state are stored
312
313
 
313
314
  ```python
314
315
  import torch
315
- from torchzero.core import Module, Vars
316
+ from torchzero.core import Module, Var
316
317
 
317
318
  class HeavyBall(Module):
318
319
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
319
320
  defaults = dict(momentum=momentum, dampening=dampening)
320
321
  super().__init__(defaults)
321
322
 
322
- def step(self, vars: Vars):
323
- # a module takes a Vars object, modifies it or creates a new one, and returns it
324
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
323
+ def step(self, var: Var):
324
+ # a module takes a Var object, modifies it or creates a new one, and returns it
325
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
325
326
  # for now we are only interested in update, and we will apply the heavyball rule to it.
326
327
 
327
- params = vars.params
328
- update = vars.get_update() # list of tensors
328
+ params = var.params
329
+ update = var.get_update() # list of tensors
329
330
 
330
331
  exp_avg_list = []
331
332
  for p, u in zip(params, update):
@@ -346,16 +347,15 @@ class HeavyBall(Module):
346
347
  # and it is part of self.state
347
348
  exp_avg_list.append(buf.clone())
348
349
 
349
- # set new update to vars
350
- vars.update = exp_avg_list
351
- return vars
350
+ # set new update to var
351
+ var.update = exp_avg_list
352
+ return var
352
353
  ```
353
354
 
354
355
  There are a some specialized base modules that make it much easier to implement some specific things.
355
356
 
356
357
  * `GradApproximator` for gradient approximations
357
358
  * `LineSearch` for line searches
358
- * `Preconditioner` for preconditioners
359
359
  * `Projection` for projections like GaLore or into fourier domain.
360
360
  * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
361
361
  * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
@@ -376,4 +376,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
376
376
 
377
377
  They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
378
378
 
379
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
379
+ Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
@@ -118,13 +118,14 @@ for epoch in range(100):
118
118
  * `NewtonCG`: Matrix-free newton's method with conjugate gradient solver.
119
119
  * `NystromSketchAndSolve`: Nyström sketch-and-solve method.
120
120
  * `NystromPCG`: NewtonCG with Nyström preconditioning (usually beats NewtonCG).
121
+ * `HigherOrderNewton`: Higher order Newton's method with trust region.
121
122
 
122
123
  * **Quasi-Newton**: Approximate second-order optimization methods.
123
124
  * `LBFGS`: Limited-memory BFGS.
124
125
  * `LSR1`: Limited-memory SR1.
125
126
  * `OnlineLBFGS`: Online LBFGS.
126
- * `BFGS`, `SR1`, `DFP`, `BroydenGood`, `BroydenBad`, `Greenstadt1`, `Greenstadt2`, `ColumnUpdatingMethod`, `ThomasOptimalMethod`, `PSB`, `Pearson2`, `SSVM`: Classic full-matrix quasi-newton methods.
127
- * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`: Conjugate gradient methods.
127
+ * `BFGS`, `DFP`, `PSB`, `SR1`, `SSVM`, `BroydenBad`, `BroydenGood`, `ColumnUpdatingMethod`, `FletcherVMM`, `GradientCorrection`, `Greenstadt1`, `Greenstadt2`, `Horisho`, `McCormick`, `Pearson`, `ProjectedNewtonRaphson`, `ThomasOptimalMethod`: Classic full-matrix quasi-newton methods.
128
+ * `PolakRibiere`, `FletcherReeves`, `HestenesStiefel`, `DaiYuan`, `LiuStorey`, `ConjugateDescent`, `HagerZhang`, `HybridHS_DY`, `ProjectedGradientMethod`: Conjugate gradient methods.
128
129
 
129
130
  * **Line Search**:
130
131
  * `Backtracking`, `AdaptiveBacktracking`: Backtracking line searches (adaptive is my own).
@@ -273,20 +274,20 @@ not in the module itself. Also both per-parameter settings and state are stored
273
274
 
274
275
  ```python
275
276
  import torch
276
- from torchzero.core import Module, Vars
277
+ from torchzero.core import Module, Var
277
278
 
278
279
  class HeavyBall(Module):
279
280
  def __init__(self, momentum: float = 0.9, dampening: float = 0):
280
281
  defaults = dict(momentum=momentum, dampening=dampening)
281
282
  super().__init__(defaults)
282
283
 
283
- def step(self, vars: Vars):
284
- # a module takes a Vars object, modifies it or creates a new one, and returns it
285
- # Vars has a bunch of attributes, including parameters, gradients, update, closure, loss
284
+ def step(self, var: Var):
285
+ # a module takes a Var object, modifies it or creates a new one, and returns it
286
+ # Var has a bunch of attributes, including parameters, gradients, update, closure, loss
286
287
  # for now we are only interested in update, and we will apply the heavyball rule to it.
287
288
 
288
- params = vars.params
289
- update = vars.get_update() # list of tensors
289
+ params = var.params
290
+ update = var.get_update() # list of tensors
290
291
 
291
292
  exp_avg_list = []
292
293
  for p, u in zip(params, update):
@@ -307,16 +308,15 @@ class HeavyBall(Module):
307
308
  # and it is part of self.state
308
309
  exp_avg_list.append(buf.clone())
309
310
 
310
- # set new update to vars
311
- vars.update = exp_avg_list
312
- return vars
311
+ # set new update to var
312
+ var.update = exp_avg_list
313
+ return var
313
314
  ```
314
315
 
315
316
  There are a some specialized base modules that make it much easier to implement some specific things.
316
317
 
317
318
  * `GradApproximator` for gradient approximations
318
319
  * `LineSearch` for line searches
319
- * `Preconditioner` for preconditioners
320
320
  * `Projection` for projections like GaLore or into fourier domain.
321
321
  * `QuasiNewtonH` for full-matrix quasi-newton methods that update hessian inverse approximation (because they are all very similar)
322
322
  * `ConguateGradientBase` for conjugate gradient methods, basically the only difference is how beta is calculated.
@@ -337,4 +337,4 @@ There are also wrappers providing `torch.optim.Optimizer` interface for for `sci
337
337
 
338
338
  They are in `torchzero.optim.wrappers.scipy.ScipyMinimize`, `torchzero.optim.wrappers.nlopt.NLOptOptimizer`, and `torchzero.optim.wrappers.nevergrad.NevergradOptimizer`. Make sure closure has `backward` argument as described in **Advanced Usage**.
339
339
 
340
- Apparently https://github.com/avaneev/biteopt is diabolical so I will add a wrapper for it too very soon.
340
+ Apparently <https://github.com/avaneev/biteopt> is diabolical so I will add a wrapper for it too very soon.
@@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
13
13
  name = "torchzero"
14
14
  description = "Modular optimization library for PyTorch."
15
15
 
16
- version = "0.3.8"
16
+ version = "0.3.10"
17
17
  dependencies = [
18
18
  "torch",
19
19
  "numpy",
@@ -1,4 +1,9 @@
1
- """snity tests to make sure everything works and converges on basic functions"""
1
+ """
2
+ Sanity tests to make sure everything works.
3
+
4
+ This will show major convergence regressions, but that is not the main purpose. Mainly this makes sure modules
5
+ don't error or become unhinged with different parameter shapes.
6
+ """
2
7
  from collections.abc import Callable
3
8
  from functools import partial
4
9
 
@@ -68,6 +73,7 @@ def _run_objective(opt: tz.Modular, objective: Callable, use_closure: bool, step
68
73
  assert torch.isfinite(loss), f"{opt}: Inifinite loss - {[l.item() for l in losses]}"
69
74
  losses.append(loss)
70
75
 
76
+ losses.append(objective())
71
77
  return torch.stack(losses).nan_to_num(0,10000,10000).min()
72
78
 
73
79
  def _run_func(opt_fn: Callable, func:str, merge: bool, use_closure: bool, steps: int):
@@ -524,7 +530,7 @@ PolyakStepSize = Run(
524
530
  func_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
525
531
  sphere_opt=lambda p: tz.Modular(p, tz.m.PolyakStepSize()),
526
532
  needs_closure=True,
527
- func='booth', steps=50, loss=1e-11, merge_invariant=True,
533
+ func='booth', steps=50, loss=1e-7, merge_invariant=True,
528
534
  sphere_steps=10, sphere_loss=0.002,
529
535
  )
530
536
  RandomStepSize = Run(
@@ -604,44 +610,44 @@ ScaleModulesByCosineSimilarity = Run(
604
610
 
605
611
  # ------------------------- momentum/matrix_momentum ------------------------- #
606
612
  MatrixMomentum_forward = Run(
607
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.01)),
608
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
613
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.01)),
614
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
609
615
  needs_closure=True,
610
616
  func='booth', steps=50, loss=0.05, merge_invariant=True,
611
617
  sphere_steps=10, sphere_loss=0,
612
618
  )
613
619
  MatrixMomentum_forward = Run(
614
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.01)),
615
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
620
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.01)),
621
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
616
622
  needs_closure=True,
617
623
  func='booth', steps=50, loss=0.05, merge_invariant=True,
618
624
  sphere_steps=10, sphere_loss=0,
619
625
  )
620
626
  MatrixMomentum_forward = Run(
621
- func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.01)),
622
- sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
627
+ func_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.01)),
628
+ sphere_opt=lambda p: tz.Modular(p, tz.m.MatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
623
629
  needs_closure=True,
624
630
  func='booth', steps=50, loss=0.05, merge_invariant=True,
625
631
  sphere_steps=10, sphere_loss=0,
626
632
  )
627
633
 
628
634
  AdaptiveMatrixMomentum_forward = Run(
629
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.05)),
630
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='forward'), tz.m.LR(0.5)),
635
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.05)),
636
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='forward'), tz.m.LR(0.5)),
631
637
  needs_closure=True,
632
638
  func='booth', steps=50, loss=0.002, merge_invariant=True,
633
639
  sphere_steps=10, sphere_loss=0,
634
640
  )
635
641
  AdaptiveMatrixMomentum_central = Run(
636
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.05)),
637
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='central'), tz.m.LR(0.5)),
642
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.05)),
643
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='central'), tz.m.LR(0.5)),
638
644
  needs_closure=True,
639
645
  func='booth', steps=50, loss=0.002, merge_invariant=True,
640
646
  sphere_steps=10, sphere_loss=0,
641
647
  )
642
648
  AdaptiveMatrixMomentum_autograd = Run(
643
- func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.05)),
644
- sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_mode='autograd'), tz.m.LR(0.5)),
649
+ func_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.05)),
650
+ sphere_opt=lambda p: tz.Modular(p, tz.m.AdaptiveMatrixMomentum(hvp_method='autograd'), tz.m.LR(0.5)),
645
651
  needs_closure=True,
646
652
  func='booth', steps=50, loss=0.002, merge_invariant=True,
647
653
  sphere_steps=10, sphere_loss=0,
@@ -719,11 +725,11 @@ Lion = Run(
719
725
  )
720
726
  # ---------------------------- optimizers/shampoo ---------------------------- #
721
727
  Shampoo = Run(
722
- func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
723
- sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.2)),
728
+ func_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(4)),
729
+ sphere_opt=lambda p: tz.Modular(p, tz.m.GraftModules(tz.m.Shampoo(), tz.m.RMSprop()), tz.m.LR(0.1)),
724
730
  needs_closure=False,
725
- func='booth', steps=50, loss=200, merge_invariant=False,
726
- sphere_steps=20, sphere_loss=1e-3, # merge and unmerge lrs are very different so need to test convergence separately somewhere
731
+ func='booth', steps=50, loss=0.02, merge_invariant=False,
732
+ sphere_steps=20, sphere_loss=1, # merge and unmerge lrs are very different so need to test convergence separately somewhere
727
733
  )
728
734
 
729
735
  # ------------------------- quasi_newton/quasi_newton ------------------------ #
@@ -745,7 +751,7 @@ SSVM = Run(
745
751
  func_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
746
752
  sphere_opt=lambda p: tz.Modular(p, tz.m.SSVM(1), tz.m.StrongWolfe()),
747
753
  needs_closure=True,
748
- func='rosen', steps=50, loss=0.02, merge_invariant=True,
754
+ func='rosen', steps=50, loss=1e-10, merge_invariant=True,
749
755
  sphere_steps=10, sphere_loss=0,
750
756
  )
751
757
 
@@ -791,7 +797,7 @@ NewtonCG = Run(
791
797
  sphere_opt=lambda p: tz.Modular(p, tz.m.NewtonCG(), tz.m.StrongWolfe()),
792
798
  needs_closure=True,
793
799
  func='rosen', steps=20, loss=1e-7, merge_invariant=True,
794
- sphere_steps=2, sphere_loss=1e-6,
800
+ sphere_steps=2, sphere_loss=3e-4,
795
801
  )
796
802
 
797
803
  # ---------------------------- smoothing/gaussian ---------------------------- #
@@ -854,8 +860,17 @@ SophiaH = Run(
854
860
  sphere_steps=10, sphere_loss=40,
855
861
  )
856
862
 
863
+ # -------------------------- optimizers/higher_order ------------------------- #
864
+ HigherOrderNewton = Run(
865
+ func_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(trust_method=None)),
866
+ sphere_opt=lambda p: tz.Modular(p, tz.m.HigherOrderNewton(2, trust_method=None)),
867
+ needs_closure=True,
868
+ func='rosen', steps=1, loss=2e-10, merge_invariant=True,
869
+ sphere_steps=1, sphere_loss=1e-10,
870
+ )
871
+
857
872
  # ------------------------------------ CGs ----------------------------------- #
858
- for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY):
873
+ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.DaiYuan, tz.m.LiuStorey, tz.m.ConjugateDescent, tz.m.HagerZhang, tz.m.HybridHS_DY, tz.m.ProjectedGradientMethod):
859
874
  for func_steps,sphere_steps_ in ([3,2], [10,10]): # CG should converge on 2D quadratic after 2nd step
860
875
  # but also test 10 to make sure it doesn't explode after converging
861
876
  Run(
@@ -868,7 +883,25 @@ for CG in (tz.m.PolakRibiere, tz.m.FletcherReeves, tz.m.HestenesStiefel, tz.m.Da
868
883
 
869
884
  # ------------------------------- QN stability ------------------------------- #
870
885
  # stability test
871
- for QN in (tz.m.BFGS, tz.m.SR1, tz.m.DFP, tz.m.BroydenGood, tz.m.BroydenBad, tz.m.Greenstadt1, tz.m.Greenstadt2, tz.m.ColumnUpdatingMethod, tz.m.ThomasOptimalMethod, tz.m.PSB, tz.m.Pearson2, tz.m.SSVM):
886
+ for QN in (
887
+ tz.m.BFGS,
888
+ tz.m.SR1,
889
+ tz.m.DFP,
890
+ tz.m.BroydenGood,
891
+ tz.m.BroydenBad,
892
+ tz.m.Greenstadt1,
893
+ tz.m.Greenstadt2,
894
+ tz.m.ColumnUpdatingMethod,
895
+ tz.m.ThomasOptimalMethod,
896
+ tz.m.FletcherVMM,
897
+ tz.m.Horisho,
898
+ lambda scale_first: tz.m.Horisho(scale_first=scale_first, inner=tz.m.GradientCorrection()),
899
+ tz.m.Pearson,
900
+ tz.m.ProjectedNewtonRaphson,
901
+ tz.m.PSB,
902
+ tz.m.McCormick,
903
+ tz.m.SSVM,
904
+ ):
872
905
  Run(
873
906
  func_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
874
907
  sphere_opt=lambda p: tz.Modular(p, QN(scale_first=False), tz.m.StrongWolfe()),
@@ -835,7 +835,7 @@ def test_global_reductions(simple_tl: TensorList, global_method, vec_equiv_metho
835
835
  expected = vec_equiv_func()
836
836
 
837
837
  if isinstance(result, bool): assert result == expected
838
- else: assert torch.allclose(result, expected), f"Tensors not close: {result = }, {expected = }"
838
+ else: assert torch.allclose(result, expected, atol=1e-4), f"Tensors not close: {result = }, {expected = }"
839
839
 
840
840
 
841
841
  def test_global_vector_norm(simple_tl: TensorList):
@@ -1261,8 +1261,8 @@ def test_reduction_ops(simple_tl: TensorList, reduction_method, dim, keepdim):
1261
1261
  elif reduction_method == 'quantile': expected = vec.quantile(q)
1262
1262
  else:
1263
1263
  pytest.fail("Unknown global reduction")
1264
- assert False, 'sus'
1265
- assert torch.allclose(result, expected)
1264
+ assert False, reduction_method
1265
+ assert torch.allclose(result, expected, atol=1e-4)
1266
1266
  else:
1267
1267
  expected_list = []
1268
1268
  for t in simple_tl:
@@ -1,10 +1,10 @@
1
1
  import pytest
2
2
  import torch
3
- from torchzero.core.module import Vars
3
+ from torchzero.core.module import Var
4
4
  from torchzero.utils.tensorlist import TensorList
5
5
 
6
6
  @torch.no_grad
7
- def test_vars_get_loss():
7
+ def test_var_get_loss():
8
8
 
9
9
  # ---------------------------- test that it works ---------------------------- #
10
10
  params = [torch.tensor(2.0, requires_grad=True)]
@@ -26,20 +26,20 @@ def test_vars_get_loss():
26
26
  assert not loss.requires_grad, "loss requires grad with backward=False"
27
27
  return loss
28
28
 
29
- vars = Vars(params=params, closure=closure_1, model=None, current_step=0)
29
+ var = Var(params=params, closure=closure_1, model=None, current_step=0)
30
30
 
31
- assert vars.loss is None, vars.loss
31
+ assert var.loss is None, var.loss
32
32
 
33
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
33
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
34
34
  assert evaluated, evaluated
35
- assert loss is vars.loss
36
- assert vars.loss == 4.0
37
- assert vars.loss_approx == 4.0
38
- assert vars.grad is None, vars.grad
35
+ assert loss is var.loss
36
+ assert var.loss == 4.0
37
+ assert var.loss_approx == 4.0
38
+ assert var.grad is None, var.grad
39
39
 
40
40
  # reevaluate, which should just return already evaluated loss
41
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
42
- assert vars.grad is None, vars.grad
41
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
42
+ assert var.grad is None, var.grad
43
43
 
44
44
 
45
45
  # ----------------------- test that backward=True works ---------------------- #
@@ -61,30 +61,30 @@ def test_vars_get_loss():
61
61
  assert not loss.requires_grad, "loss requires grad with backward=False"
62
62
  return loss
63
63
 
64
- vars = Vars(params=params, closure=closure_2, model=None, current_step=0)
65
- assert vars.grad is None, vars.grad
66
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
67
- assert vars.grad is not None
68
- assert vars.grad[0] == 2.0, vars.grad
64
+ var = Var(params=params, closure=closure_2, model=None, current_step=0)
65
+ assert var.grad is None, var.grad
66
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
67
+ assert var.grad is not None
68
+ assert var.grad[0] == 2.0, var.grad
69
69
 
70
70
  # reevaluate, which should just return already evaluated loss
71
- assert (loss := vars.get_loss(backward=True)) == 6.0, loss
72
- assert vars.grad[0] == 2.0, vars.grad
71
+ assert (loss := var.get_loss(backward=True)) == 6.0, loss
72
+ assert var.grad[0] == 2.0, var.grad
73
73
 
74
74
  # get grad, which should just return already evaluated grad
75
- assert (grad := vars.get_grad())[0] == 2.0, grad
76
- assert grad is vars.grad, grad
75
+ assert (grad := var.get_grad())[0] == 2.0, grad
76
+ assert grad is var.grad, grad
77
77
 
78
78
  # get update, which should create and return cloned grad
79
- assert vars.update is None
80
- assert (update := vars.get_update())[0] == 2.0, update
81
- assert update is vars.update
82
- assert update is not vars.grad
83
- assert vars.grad is not None
84
- assert update[0] == vars.grad[0]
79
+ assert var.update is None
80
+ assert (update := var.get_update())[0] == 2.0, update
81
+ assert update is var.update
82
+ assert update is not var.grad
83
+ assert var.grad is not None
84
+ assert update[0] == var.grad[0]
85
85
 
86
86
  @torch.no_grad
87
- def test_vars_get_grad():
87
+ def test_var_get_grad():
88
88
  params = [torch.tensor(2.0, requires_grad=True)]
89
89
  evaluated = False
90
90
 
@@ -103,20 +103,20 @@ def test_vars_get_grad():
103
103
  assert not loss.requires_grad, "loss requires grad with backward=False"
104
104
  return loss
105
105
 
106
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
107
- assert (grad := vars.get_grad())[0] == 4.0, grad
108
- assert grad is vars.grad
106
+ var = Var(params=params, closure=closure, model=None, current_step=0)
107
+ assert (grad := var.get_grad())[0] == 4.0, grad
108
+ assert grad is var.grad
109
109
 
110
- assert vars.loss == 4.0
111
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
112
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
113
- assert vars.loss_approx == 4.0
110
+ assert var.loss == 4.0
111
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
112
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
113
+ assert var.loss_approx == 4.0
114
114
 
115
- assert vars.update is None, vars.update
116
- assert (update := vars.get_update())[0] == 4.0, update
115
+ assert var.update is None, var.update
116
+ assert (update := var.get_update())[0] == 4.0, update
117
117
 
118
118
  @torch.no_grad
119
- def test_vars_get_update():
119
+ def test_var_get_update():
120
120
  params = [torch.tensor(2.0, requires_grad=True)]
121
121
  evaluated = False
122
122
 
@@ -135,24 +135,24 @@ def test_vars_get_update():
135
135
  assert not loss.requires_grad, "loss requires grad with backward=False"
136
136
  return loss
137
137
 
138
- vars = Vars(params=params, closure=closure, model=None, current_step=0)
139
- assert vars.update is None, vars.update
140
- assert (update := vars.get_update())[0] == 4.0, update
141
- assert update is vars.update
138
+ var = Var(params=params, closure=closure, model=None, current_step=0)
139
+ assert var.update is None, var.update
140
+ assert (update := var.get_update())[0] == 4.0, update
141
+ assert update is var.update
142
142
 
143
- assert (grad := vars.get_grad())[0] == 4.0, grad
144
- assert grad is vars.grad
143
+ assert (grad := var.get_grad())[0] == 4.0, grad
144
+ assert grad is var.grad
145
145
  assert grad is not update
146
146
 
147
- assert vars.loss == 4.0
148
- assert (loss := vars.get_loss(backward=False)) == 4.0, loss
149
- assert (loss := vars.get_loss(backward=True)) == 4.0, loss
150
- assert vars.loss_approx == 4.0
147
+ assert var.loss == 4.0
148
+ assert (loss := var.get_loss(backward=False)) == 4.0, loss
149
+ assert (loss := var.get_loss(backward=True)) == 4.0, loss
150
+ assert var.loss_approx == 4.0
151
151
 
152
- assert (update := vars.get_update())[0] == 4.0, update
152
+ assert (update := var.get_update())[0] == 4.0, update
153
153
 
154
154
 
155
- def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
155
+ def _assert_var_are_same_(v1: Var, v2: Var, clone_update: bool):
156
156
  for k,v in v1.__dict__.items():
157
157
  if not k.startswith('__'):
158
158
  # if k == 'post_step_hooks': continue
@@ -165,20 +165,20 @@ def _assert_vars_are_same_(v1: Vars, v2: Vars, clone_update: bool):
165
165
  else:
166
166
  assert getattr(v2, k) is v, f'{k} is not the same, {v1 = }, {v2 = }'
167
167
 
168
- def test_vars_clone():
168
+ def test_var_clone():
169
169
  model = torch.nn.Sequential(torch.nn.Linear(2,2), torch.nn.Linear(2,4))
170
170
  def closure(backward): return 1
171
- vars = Vars(params=list(model.parameters()), closure=closure, model=model, current_step=0)
171
+ var = Var(params=list(model.parameters()), closure=closure, model=model, current_step=0)
172
172
 
173
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
174
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
173
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
174
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
175
175
 
176
- vars.grad = TensorList(torch.randn(5))
177
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
178
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
176
+ var.grad = TensorList(torch.randn(5))
177
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
178
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
179
179
 
180
- vars.update = TensorList(torch.randn(5) * 2)
181
- vars.loss = torch.randn(1)
182
- vars.loss_approx = vars.loss
183
- _assert_vars_are_same_(vars, vars.clone(clone_update=False), clone_update=False)
184
- _assert_vars_are_same_(vars, vars.clone(clone_update=True), clone_update=True)
180
+ var.update = TensorList(torch.randn(5) * 2)
181
+ var.loss = torch.randn(1)
182
+ var.loss_approx = var.loss
183
+ _assert_var_are_same_(var, var.clone(clone_update=False), clone_update=False)
184
+ _assert_var_are_same_(var, var.clone(clone_update=True), clone_update=True)
@@ -0,0 +1,2 @@
1
+ from .module import Var, Module, Modular, Chain, maybe_chain, Chainable
2
+ from .transform import Transform, TensorwiseTransform, Target, apply_transform