torchzero 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. tests/test_identical.py +22 -22
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +225 -214
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +2 -2
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +53 -57
  12. torchzero/core/module.py +132 -52
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +11 -0
  17. torchzero/linalg/eigh.py +253 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +93 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +16 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +74 -88
  24. torchzero/linalg/svd.py +47 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/__init__.py +4 -3
  27. torchzero/modules/adaptive/__init__.py +11 -3
  28. torchzero/modules/adaptive/adagrad.py +167 -217
  29. torchzero/modules/adaptive/adahessian.py +76 -105
  30. torchzero/modules/adaptive/adam.py +53 -76
  31. torchzero/modules/adaptive/adan.py +50 -31
  32. torchzero/modules/adaptive/adaptive_heavyball.py +12 -7
  33. torchzero/modules/adaptive/aegd.py +12 -12
  34. torchzero/modules/adaptive/esgd.py +98 -119
  35. torchzero/modules/adaptive/ggt.py +186 -0
  36. torchzero/modules/adaptive/lion.py +7 -11
  37. torchzero/modules/adaptive/lre_optimizers.py +299 -0
  38. torchzero/modules/adaptive/mars.py +7 -7
  39. torchzero/modules/adaptive/matrix_momentum.py +48 -52
  40. torchzero/modules/adaptive/msam.py +71 -53
  41. torchzero/modules/adaptive/muon.py +67 -129
  42. torchzero/modules/adaptive/natural_gradient.py +63 -41
  43. torchzero/modules/adaptive/orthograd.py +11 -15
  44. torchzero/modules/adaptive/psgd/__init__.py +5 -0
  45. torchzero/modules/adaptive/psgd/_psgd_utils.py +37 -0
  46. torchzero/modules/adaptive/psgd/psgd.py +1390 -0
  47. torchzero/modules/adaptive/psgd/psgd_dense_newton.py +174 -0
  48. torchzero/modules/adaptive/psgd/psgd_kron_newton.py +203 -0
  49. torchzero/modules/adaptive/psgd/psgd_kron_whiten.py +185 -0
  50. torchzero/modules/adaptive/psgd/psgd_lra_newton.py +118 -0
  51. torchzero/modules/adaptive/psgd/psgd_lra_whiten.py +116 -0
  52. torchzero/modules/adaptive/rmsprop.py +83 -75
  53. torchzero/modules/adaptive/rprop.py +48 -47
  54. torchzero/modules/adaptive/sam.py +55 -45
  55. torchzero/modules/adaptive/shampoo.py +149 -130
  56. torchzero/modules/adaptive/soap.py +207 -143
  57. torchzero/modules/adaptive/sophia_h.py +106 -130
  58. torchzero/modules/clipping/clipping.py +22 -25
  59. torchzero/modules/clipping/ema_clipping.py +31 -25
  60. torchzero/modules/clipping/growth_clipping.py +14 -17
  61. torchzero/modules/conjugate_gradient/cg.py +27 -38
  62. torchzero/modules/experimental/__init__.py +7 -6
  63. torchzero/modules/experimental/adanystrom.py +258 -0
  64. torchzero/modules/experimental/common_directions_whiten.py +142 -0
  65. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  66. torchzero/modules/experimental/cubic_adam.py +160 -0
  67. torchzero/modules/experimental/curveball.py +25 -41
  68. torchzero/modules/experimental/eigen_sr1.py +182 -0
  69. torchzero/modules/experimental/eigengrad.py +207 -0
  70. torchzero/modules/experimental/gradmin.py +2 -2
  71. torchzero/modules/experimental/higher_order_newton.py +14 -40
  72. torchzero/modules/experimental/l_infinity.py +1 -1
  73. torchzero/modules/experimental/matrix_nag.py +122 -0
  74. torchzero/modules/experimental/newton_solver.py +23 -54
  75. torchzero/modules/experimental/newtonnewton.py +45 -48
  76. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  77. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  78. torchzero/modules/experimental/spsa1.py +3 -3
  79. torchzero/modules/experimental/structural_projections.py +1 -4
  80. torchzero/modules/grad_approximation/fdm.py +2 -2
  81. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  82. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  83. torchzero/modules/grad_approximation/rfdm.py +24 -21
  84. torchzero/modules/least_squares/gn.py +121 -50
  85. torchzero/modules/line_search/backtracking.py +4 -4
  86. torchzero/modules/line_search/line_search.py +33 -33
  87. torchzero/modules/line_search/strong_wolfe.py +4 -4
  88. torchzero/modules/misc/debug.py +12 -12
  89. torchzero/modules/misc/escape.py +10 -10
  90. torchzero/modules/misc/gradient_accumulation.py +11 -79
  91. torchzero/modules/misc/homotopy.py +16 -8
  92. torchzero/modules/misc/misc.py +121 -123
  93. torchzero/modules/misc/multistep.py +52 -53
  94. torchzero/modules/misc/regularization.py +49 -44
  95. torchzero/modules/misc/split.py +31 -29
  96. torchzero/modules/misc/switch.py +37 -32
  97. torchzero/modules/momentum/averaging.py +14 -14
  98. torchzero/modules/momentum/cautious.py +37 -31
  99. torchzero/modules/momentum/momentum.py +12 -12
  100. torchzero/modules/ops/__init__.py +4 -4
  101. torchzero/modules/ops/accumulate.py +21 -21
  102. torchzero/modules/ops/binary.py +67 -66
  103. torchzero/modules/ops/higher_level.py +20 -20
  104. torchzero/modules/ops/multi.py +44 -41
  105. torchzero/modules/ops/reduce.py +26 -23
  106. torchzero/modules/ops/unary.py +53 -53
  107. torchzero/modules/ops/utility.py +47 -46
  108. torchzero/modules/{functional.py → opt_utils.py} +1 -1
  109. torchzero/modules/projections/galore.py +1 -1
  110. torchzero/modules/projections/projection.py +46 -43
  111. torchzero/modules/quasi_newton/__init__.py +1 -1
  112. torchzero/modules/quasi_newton/damping.py +2 -2
  113. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +1 -1
  114. torchzero/modules/quasi_newton/lbfgs.py +10 -10
  115. torchzero/modules/quasi_newton/lsr1.py +10 -10
  116. torchzero/modules/quasi_newton/quasi_newton.py +54 -39
  117. torchzero/modules/quasi_newton/sg2.py +69 -205
  118. torchzero/modules/restarts/restars.py +39 -37
  119. torchzero/modules/second_order/__init__.py +2 -2
  120. torchzero/modules/second_order/ifn.py +31 -62
  121. torchzero/modules/second_order/inm.py +57 -53
  122. torchzero/modules/second_order/multipoint.py +40 -80
  123. torchzero/modules/second_order/newton.py +165 -196
  124. torchzero/modules/second_order/newton_cg.py +105 -157
  125. torchzero/modules/second_order/nystrom.py +216 -185
  126. torchzero/modules/second_order/rsn.py +132 -125
  127. torchzero/modules/smoothing/laplacian.py +13 -12
  128. torchzero/modules/smoothing/sampling.py +10 -10
  129. torchzero/modules/step_size/adaptive.py +24 -24
  130. torchzero/modules/step_size/lr.py +17 -17
  131. torchzero/modules/termination/termination.py +32 -30
  132. torchzero/modules/trust_region/cubic_regularization.py +3 -3
  133. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  134. torchzero/modules/trust_region/trust_cg.py +2 -2
  135. torchzero/modules/trust_region/trust_region.py +27 -22
  136. torchzero/modules/variance_reduction/svrg.py +23 -21
  137. torchzero/modules/weight_decay/__init__.py +2 -1
  138. torchzero/modules/weight_decay/reinit.py +83 -0
  139. torchzero/modules/weight_decay/weight_decay.py +17 -18
  140. torchzero/modules/wrappers/optim_wrapper.py +14 -14
  141. torchzero/modules/zeroth_order/cd.py +10 -7
  142. torchzero/optim/mbs.py +291 -0
  143. torchzero/optim/root.py +3 -3
  144. torchzero/optim/utility/split.py +2 -1
  145. torchzero/optim/wrappers/directsearch.py +27 -63
  146. torchzero/optim/wrappers/fcmaes.py +14 -35
  147. torchzero/optim/wrappers/mads.py +11 -31
  148. torchzero/optim/wrappers/moors.py +66 -0
  149. torchzero/optim/wrappers/nevergrad.py +4 -13
  150. torchzero/optim/wrappers/nlopt.py +31 -25
  151. torchzero/optim/wrappers/optuna.py +8 -13
  152. torchzero/optim/wrappers/pybobyqa.py +124 -0
  153. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  154. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  155. torchzero/optim/wrappers/scipy/brute.py +48 -0
  156. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  157. torchzero/optim/wrappers/scipy/direct.py +69 -0
  158. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  159. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  160. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  161. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  162. torchzero/optim/wrappers/wrapper.py +121 -0
  163. torchzero/utils/__init__.py +7 -25
  164. torchzero/utils/benchmarks/__init__.py +0 -0
  165. torchzero/utils/benchmarks/logistic.py +122 -0
  166. torchzero/utils/compile.py +2 -2
  167. torchzero/utils/derivatives.py +97 -73
  168. torchzero/utils/optimizer.py +4 -77
  169. torchzero/utils/python_tools.py +31 -0
  170. torchzero/utils/tensorlist.py +11 -5
  171. torchzero/utils/thoad_tools.py +68 -0
  172. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/METADATA +1 -1
  173. torchzero-0.4.1.dist-info/RECORD +209 -0
  174. tests/test_vars.py +0 -185
  175. torchzero/core/var.py +0 -376
  176. torchzero/modules/adaptive/lmadagrad.py +0 -186
  177. torchzero/modules/experimental/momentum.py +0 -160
  178. torchzero/optim/wrappers/scipy.py +0 -572
  179. torchzero/utils/linalg/__init__.py +0 -12
  180. torchzero/utils/linalg/matrix_funcs.py +0 -87
  181. torchzero/utils/linalg/orthogonalize.py +0 -12
  182. torchzero/utils/linalg/svd.py +0 -20
  183. torchzero/utils/ops.py +0 -10
  184. torchzero-0.3.15.dist-info/RECORD +0 -175
  185. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  186. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/WHEEL +0 -0
  187. {torchzero-0.3.15.dist-info → torchzero-0.4.1.dist-info}/top_level.txt +0 -0
torchzero/core/module.py CHANGED
@@ -1,24 +1,18 @@
1
1
  import warnings
2
2
  from abc import ABC, abstractmethod
3
3
  from collections import ChainMap, defaultdict
4
- from collections.abc import Callable, Iterable, MutableMapping, Sequence
5
- from operator import itemgetter
6
- from typing import Any, Literal, cast, final, overload
4
+ from collections.abc import Callable, Iterable, Sequence
5
+ from typing import Any, overload, TYPE_CHECKING
7
6
 
8
7
  import torch
9
8
 
10
- from ..utils import (
11
- Init,
12
- ListLike,
13
- Params,
14
- _make_param_groups,
15
- get_state_vals,
16
- vec_to_tensors,
17
- )
18
- from ..utils.derivatives import flatten_jacobian, hvp, hvp_fd_central, hvp_fd_forward
19
- from ..utils.linalg.linear_operator import LinearOperator
20
- from ..utils.python_tools import flatten
21
- from .var import Var
9
+ from ..linalg.linear_operator import LinearOperator
10
+ from ..utils.optimizer import Init, ListLike, get_state_vals
11
+ from ..utils.params import Params, _make_param_groups
12
+ from .functional import step_tensors
13
+
14
+ if TYPE_CHECKING:
15
+ from .objective import Objective
22
16
 
23
17
 
24
18
  class Module(ABC):
@@ -36,11 +30,12 @@ class Module(ABC):
36
30
  """
37
31
  def __init__(self, defaults: dict[str, Any] | None = None):
38
32
  if defaults is None: defaults = {}
33
+ if any(isinstance(v, Module) for v in defaults.values()): raise RuntimeError("Passed a module to defaults")
39
34
  self.defaults: dict[str, Any] = defaults
40
35
 
41
36
  # settings are stored like state in per-tensor defaultdict, with per-parameter overrides possible
42
37
  # 0 - this module specific per-parameter setting overrides set via `set_param_groups` - highest priority
43
- # 1 - global per-parameter setting overrides in param_groups passed to Modular - medium priority
38
+ # 1 - global per-parameter setting overrides in param_groups passed to Optimizer - medium priority
44
39
  # 2 - `defaults` - lowest priority
45
40
  self.settings: defaultdict[torch.Tensor, ChainMap[str, Any]] = defaultdict(lambda: ChainMap({}, {}, self.defaults))
46
41
  """per-parameter settings."""
@@ -55,7 +50,7 @@ class Module(ABC):
55
50
  """A dictionary of child modules."""
56
51
 
57
52
  self._overridden_keys = set()
58
- """tracks keys overridden with `set_param_groups`, only used to not give a warning"""
53
+ """tracks keys overridden with ``set_param_groups``, only used to not give a warning"""
59
54
 
60
55
 
61
56
  def set_param_groups(self, param_groups: Params):
@@ -71,7 +66,12 @@ class Module(ABC):
71
66
  self.settings[param].maps[0].update(settings) # set module-specific per-parameter settings
72
67
  return self
73
68
 
74
- def set_child(self, key: str, module: "Module | Sequence[Module]"):
69
+ def set_child(self, key: str, module: "Module | Sequence[Module] | None"):
70
+ if key in self.children:
71
+ warnings.warn(f"set_child overwriting child `{key}`")
72
+
73
+ if module is None: return
74
+
75
75
  from .chain import maybe_chain
76
76
  self.children[key] = maybe_chain(module)
77
77
 
@@ -85,6 +85,62 @@ class Module(ABC):
85
85
  def get_children_sequence(self, prefix = 'module_'):
86
86
  return [self.children[f'{prefix}{i}'] for i in range(len(self.children)) if f'{prefix}{i}' in self.children]
87
87
 
88
+ def inner_step(
89
+ self,
90
+ key: str,
91
+ objective: "Objective",
92
+ must_exist: bool = True,
93
+ ) -> "Objective":
94
+ """Passes ``objective`` to child and returns it."""
95
+ child = self.children.get(key, None)
96
+
97
+ if child is None:
98
+ if must_exist: raise KeyError(f"child `{key}` doesn't exist")
99
+ return objective
100
+
101
+ return child.step(objective)
102
+
103
+
104
+ def inner_step_tensors(
105
+ self,
106
+ key: str,
107
+ tensors: list[torch.Tensor],
108
+ clone: bool,
109
+ params: Iterable[torch.Tensor] | None = None,
110
+ grads: Sequence[torch.Tensor] | None = None,
111
+ loss: torch.Tensor | None = None,
112
+ closure: Callable | None = None,
113
+ objective: "Objective | None" = None,
114
+ must_exist: bool = True
115
+ ) -> list[torch.Tensor]:
116
+ """Steps with child module. Can be used to apply transforms to any internal buffers.
117
+
118
+ If ``objective`` is specified, other attributes shouldn't to be specified.
119
+
120
+ Args:
121
+ key (str): Child module key.
122
+ tensors (Sequence[torch.Tensor]): tensors to pass to child module.
123
+ clone (bool):
124
+ If ``key`` exists, whether to clone ``tensors`` to avoid modifying buffers in-place.
125
+ If ``key`` doesn't exist, ``tensors`` are always returned without cloning
126
+ params (Iterable[torch.Tensor] | None, optional): pass None if ``tensors`` have different shape. Defaults to None.
127
+ grads (Sequence[torch.Tensor] | None, optional): grads. Defaults to None.
128
+ loss (torch.Tensor | None, optional): loss. Defaults to None.
129
+ closure (Callable | None, optional): closure. Defaults to None.
130
+ must_exist (bool, optional): if True, if ``key`` doesn't exist, raises ``KeyError``. Defaults to True.
131
+ """
132
+
133
+ child = self.children.get(key, None)
134
+
135
+ if child is None:
136
+ if must_exist: raise KeyError(f"child `{key}` doesn't exist")
137
+ return tensors
138
+
139
+ if clone: tensors = [t.clone() for t in tensors]
140
+ return step_tensors(modules=child, tensors=tensors, params=params, grads=grads,
141
+ loss=loss, closure=closure, objective=objective)
142
+
143
+
88
144
  def __repr__(self):
89
145
  s = self.__class__.__name__
90
146
  if self.children:
@@ -106,7 +162,6 @@ class Module(ABC):
106
162
 
107
163
  def get_settings(self, params: Sequence[torch.Tensor], key: str | list[str] | tuple[str,...], key2: str | None = None,
108
164
  *keys: str, cls: type[ListLike] = list) -> ListLike | list[ListLike]:
109
- # if isinstance(params, Vars): params = params.params
110
165
  return get_state_vals(self.settings, params, key, key2, *keys, must_exist=True, cls=cls) # pyright:ignore[reportArgumentType]
111
166
 
112
167
 
@@ -176,13 +231,8 @@ class Module(ABC):
176
231
  - if state_keys has multiple keys and keys has a single key, return cls.
177
232
  - if state_keys has multiple keys and keys has multiple keys, return list of cls.
178
233
  """
179
- # if isinstance(params, Vars): params = params.params
180
234
  return get_state_vals(self.state, params, key, key2, *keys, must_exist=must_exist, init=init, cls=cls) # pyright:ignore[reportArgumentType]
181
235
 
182
- # def first_setting(self, *keys:str, params:Sequence[torch.Tensor]):
183
- # # if isinstance(params, Vars): params = params.params
184
- # return itemgetter(*keys)(self.settings[params[0]])
185
-
186
236
  def clear_state_keys(self, *keys:str):
187
237
  for s in self.state.values():
188
238
  for k in keys:
@@ -223,7 +273,7 @@ class Module(ABC):
223
273
  return state_dict
224
274
 
225
275
  def _load_state_dict(self, state_dict: dict[str, Any], id_to_tensor: dict[int, torch.Tensor]):
226
- """loads state_dict, ``id_to_tensor`` is passed by ``Modular``"""
276
+ """loads state_dict, ``id_to_tensor`` is passed by ``Optimizer``"""
227
277
  # load state
228
278
  state = state_dict['state']
229
279
  self.state.clear()
@@ -248,36 +298,73 @@ class Module(ABC):
248
298
  # extra info
249
299
  self._extra_unpack(state_dict['extra'])
250
300
 
251
- # ---------------------------- OVERRIDABLE METHODS --------------------------- #
252
- def step(self, var: Var) -> Var:
253
- """performs a step, returns new ``var`` but may update it in-place."""
254
- self.update(var)
255
- return self.apply(var)
301
+ def get_generator(self, device: torch.types.Device, seed: int | None):
302
+ """If ``seed=None``, returns ``None``.
303
+
304
+ Otherwise, if generator on this device and with this seed hasn't been created,
305
+ creates it and stores in global state.
306
+
307
+ Returns ``torch.Generator``."""
308
+ if seed is None: return None
256
309
 
257
- def update(self, var:Var) -> Any:
258
- """Updates the internal state of this module. This should not modify ``var.update``.
310
+ if device is None: device_obj = torch.get_default_device()
311
+ else: device_obj = torch.device(device)
312
+ key = f"__generator-{seed}-{device_obj.type}:{device_obj.index}"
313
+
314
+ if key not in self.global_state:
315
+ self.global_state[key] = torch.Generator(device).manual_seed(seed)
316
+
317
+ return self.global_state[key]
318
+
319
+ def increment_counter(self, key: str, start: int):
320
+ """first value is ``start``"""
321
+ value = self.global_state.get(key, start - 1) + 1
322
+ self.global_state[key] = value
323
+ return value
324
+
325
+ # ---------------------------- OVERRIDABLE METHODS --------------------------- #
326
+ def update(self, objective:"Objective") -> None:
327
+ """Updates internal state of this module. This should not modify ``objective.update``.
259
328
 
260
329
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
261
- such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
330
+ such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
331
+
332
+ ``update`` is guaranteed to be called at least once before ``apply``.
333
+
334
+ Args:
335
+ objective (Objective): ``Objective`` object
262
336
  """
263
337
 
264
- def apply(self, var: Var) -> Var:
265
- """Applies this module to ``var.get_update()``.
266
- This should not modify the internal state of this module if possible.
338
+ @abstractmethod
339
+ def apply(self, objective: "Objective") -> "Objective":
340
+ """Updates ``objective`` using the internal state of this module.
341
+
342
+ If ``update`` method is defined, ``apply`` shouldn't modify the internal state of this module if possible.
267
343
 
268
344
  Specifying ``update`` and ``apply`` methods is optional and allows certain meta-modules to be used,
269
- such as ``tz.m.Online`` or trust regions. Alternatively, simply override the ``step`` method.
345
+ such as ``tz.m.Online`` or trust regions. Alternatively, define all logic within the ``apply`` method.
346
+
347
+ ``update`` is guaranteed to be called at least once before ``apply``.
348
+
349
+ Args:
350
+ objective (Objective): ``Objective`` object
270
351
  """
271
- return self.step(var)
352
+ # if apply is empty, it should be defined explicitly.
353
+ raise NotImplementedError(f"{self.__class__.__name__} doesn't implement `apply`.")
354
+
355
+ def step(self, objective: "Objective") -> "Objective":
356
+ """Perform a step with this module. Calls ``update``, then ``apply``."""
357
+ self.update(objective)
358
+ return self.apply(objective)
272
359
 
273
- def get_H(self, var: Var) -> LinearOperator | None:
360
+ def get_H(self, objective: "Objective") -> LinearOperator | None:
274
361
  """returns a ``LinearOperator`` corresponding to hessian or hessian approximation.
275
362
  The hessian approximation is assumed to be for all parameters concatenated to a vector."""
276
363
  # if this method is not defined it searches in children
277
364
  # this should be overwritten to return None if child params are different from this modules params
278
365
  H = None
279
366
  for k,v in self.children.items():
280
- H_v = v.get_H(var)
367
+ H_v = v.get_H(objective)
281
368
 
282
369
  if (H is not None) and (H_v is not None):
283
370
  raise RuntimeError(f"Two children of {self} have a hessian, second one is {k}={v}")
@@ -307,21 +394,14 @@ class Module(ABC):
307
394
  """
308
395
  for c in self.children.values(): c.reset_for_online()
309
396
 
310
- def _extra_pack(self):
311
- """extra information to store in state_dict of this optimizer.
312
- Will be passed to ``_extra_unpack`` when loading the state_dict."""
397
+ def _extra_pack(self) -> dict:
398
+ """extra information to store in ``state_dict`` of this optimizer.
399
+ Will be passed to ``_extra_unpack`` when loading the ``state_dict``."""
313
400
  return {}
314
401
 
315
- def _extra_unpack(self, x):
316
- """``_extra_pack`` return will be passed to this method when loading state_dict.
402
+ def _extra_unpack(self, d: dict):
403
+ """``_extra_pack`` return will be passed to this method when loading ``state_dict``.
317
404
  This method is called after loading the rest of the state dict"""
318
405
 
319
- def get_generator(self, device: torch.types.Device, seed: int | None):
320
- if seed is None: return None
321
-
322
- if 'generator' not in self.global_state:
323
- self.global_state['generator'] = torch.Generator(device).manual_seed(seed)
324
-
325
- return self.global_state['generator']
326
406
 
327
407
  Chainable = Module | Sequence[Module]