agilerl 2.3.5.dev0__tar.gz → 2.3.5.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/PKG-INFO +1 -1
  2. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/bc_lm.py +3 -3
  3. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/base.py +70 -89
  4. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/optimizer_wrapper.py +16 -16
  5. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/registry.py +77 -45
  6. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/cqn.py +5 -6
  7. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ddpg.py +14 -14
  8. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/dqn.py +2 -2
  9. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/dqn_rainbow.py +5 -6
  10. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/grpo.py +8 -8
  11. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ilql.py +2 -2
  12. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ippo.py +25 -25
  13. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/maddpg.py +22 -22
  14. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/matd3.py +28 -28
  15. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/neural_ts_bandit.py +4 -4
  16. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/neural_ucb_bandit.py +4 -4
  17. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ppo.py +29 -29
  18. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/td3.py +6 -4
  19. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/multi_agent_replay_buffer.py +23 -24
  20. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/replay_buffer.py +3 -3
  21. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/rollout_buffer.py +30 -30
  22. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/sampler.py +5 -5
  23. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/rl_data.py +8 -8
  24. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/tokenizer.py +5 -5
  25. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/mutation.py +20 -20
  26. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/tournament.py +6 -8
  27. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/base.py +33 -37
  28. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/bert.py +11 -11
  29. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/cnn.py +43 -43
  30. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/configs.py +11 -11
  31. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/dummy.py +4 -4
  32. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/gpt.py +14 -14
  33. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/lstm.py +11 -11
  34. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/mlp.py +13 -13
  35. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/multi_input.py +18 -18
  36. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/resnet.py +12 -12
  37. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/simba.py +4 -4
  38. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/actors.py +7 -7
  39. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/base.py +27 -27
  40. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/custom_modules.py +4 -4
  41. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/distributions.py +12 -12
  42. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/distributions_experimental.py +3 -3
  43. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/q_networks.py +10 -10
  44. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/value_networks.py +4 -4
  45. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/protocols.py +41 -45
  46. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/rollouts/on_policy.py +10 -10
  47. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_bandits.py +4 -4
  48. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_llm.py +3 -3
  49. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_multi_agent_off_policy.py +4 -4
  50. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_multi_agent_on_policy.py +4 -4
  51. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_off_policy.py +5 -5
  52. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_offline.py +4 -4
  53. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_on_policy.py +5 -5
  54. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/typing.py +24 -28
  55. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/algo_utils.py +67 -67
  56. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/evolvable_networks.py +26 -26
  57. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/ilql_utils.py +6 -6
  58. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/llm_utils.py +12 -12
  59. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/torch_utils.py +4 -4
  60. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/utils.py +30 -9
  61. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/pz_async_vec_env.py +48 -48
  62. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/pz_vec_env.py +10 -10
  63. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/agent.py +14 -14
  64. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/make_evolvable.py +17 -17
  65. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/pyproject.toml +4 -1
  66. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/LICENSE +0 -0
  67. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/README.md +0 -0
  68. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/__init__.py +0 -0
  69. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/__init__.py +0 -0
  70. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/__init__.py +0 -0
  71. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/__init__.py +0 -0
  72. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/data.py +0 -0
  73. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/segment_tree.py +0 -0
  74. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/__init__.py +0 -0
  75. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/language_environment.py +0 -0
  76. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/torch_datasets.py +0 -0
  77. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/__init__.py +0 -0
  78. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/__init__.py +0 -0
  79. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/custom_components.py +0 -0
  80. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/__init__.py +0 -0
  81. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/rollouts/__init__.py +0 -0
  82. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/__init__.py +0 -0
  83. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/__init__.py +0 -0
  84. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/cache.py +0 -0
  85. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/log_utils.py +0 -0
  86. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/minari_utils.py +0 -0
  87. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/probe_envs.py +0 -0
  88. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/probe_envs_ma.py +0 -0
  89. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/sampling_utils.py +0 -0
  90. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/__init__.py +0 -0
  91. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/__init__.py +0 -0
  92. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/learning.py +0 -0
  93. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  94. {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: agilerl
3
- Version: 2.3.5.dev0
3
+ Version: 2.3.5.dev1
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  License: Apache 2.0
6
6
  Author: Nick Ustaran-Anderegg
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Optional, Tuple, Union
1
+ from typing import Any, Callable, Optional, Union
2
2
 
3
3
  import numpy as np
4
4
  import torch
@@ -167,7 +167,7 @@ class BC_LM(nn.Module):
167
167
  temp: float = 1.0,
168
168
  top_k: Optional[int] = None,
169
169
  top_p: Optional[float] = None,
170
- ) -> Tuple[torch.Tensor, Any]:
170
+ ) -> tuple[torch.Tensor, Any]:
171
171
  prepared_inputs = self.prepare_inputs(items)
172
172
  tokens = prepared_inputs["tokens"]
173
173
  scores, model_outputs = self.score(
@@ -189,7 +189,7 @@ class BC_LM(nn.Module):
189
189
  temp: float = 1.0,
190
190
  top_k: Optional[int] = None,
191
191
  top_p: Optional[float] = None,
192
- ) -> Tuple[torch.Tensor, Any]:
192
+ ) -> tuple[torch.Tensor, Any]:
193
193
  scores, model_outputs = self.score(
194
194
  (
195
195
  tokens.unsqueeze(1),
@@ -12,12 +12,8 @@ from importlib.metadata import version
12
12
  from typing import (
13
13
  Any,
14
14
  Callable,
15
- Dict,
16
15
  Iterable,
17
- List,
18
16
  Optional,
19
- Tuple,
20
- Type,
21
17
  TypeVar,
22
18
  Union,
23
19
  cast,
@@ -30,16 +26,13 @@ from accelerate import Accelerator
30
26
  from accelerate.utils import broadcast_object_list
31
27
  from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
32
28
  from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
33
- from deepspeed.runtime.engine import DeepSpeedEngine
34
29
  from gymnasium import spaces
35
- from numpy.typing import ArrayLike
36
30
  from peft import PeftModel, set_peft_model_state_dict
37
31
  from safetensors.torch import load_file
38
32
  from tensordict import TensorDict
39
33
  from torch._dynamo import OptimizedModule
40
34
  from torch.optim import AdamW
41
35
  from torch.optim.lr_scheduler import SequentialLR
42
- from vllm.distributed.parallel_state import destroy_model_parallel
43
36
 
44
37
  from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
45
38
  from agilerl.algorithms.core.registry import (
@@ -107,7 +100,7 @@ class _RegistryMeta(type):
107
100
  initializing with specified network groups and optimizers."""
108
101
 
109
102
  def __call__(
110
- cls: Type[SelfEvolvableAlgorithm], *args, **kwargs
103
+ cls: type[SelfEvolvableAlgorithm], *args, **kwargs
111
104
  ) -> SelfEvolvableAlgorithm:
112
105
  # Create the instance
113
106
  instance: SelfEvolvableAlgorithm = super().__call__(*args, **kwargs)
@@ -124,7 +117,7 @@ class RegistryMeta(_RegistryMeta, ABCMeta): ...
124
117
 
125
118
  def get_checkpoint_dict(
126
119
  agent: SelfEvolvableAlgorithm, using_deepspeed: bool = False
127
- ) -> Dict[str, Any]:
120
+ ) -> dict[str, Any]:
128
121
  """Returns a dictionary of the agent's attributes to save in a checkpoint.
129
122
 
130
123
  Note: Accelerator is always excluded from the checkpoint as it cannot be serialized.
@@ -152,7 +145,7 @@ def get_checkpoint_dict(
152
145
  attribute_dict.pop("rollout_buffer")
153
146
 
154
147
  # Get checkpoint dictionaries for evolvable modules and optimizers
155
- network_info: Dict[str, Dict[str, Any]] = {"modules": {}, "optimizers": {}}
148
+ network_info: dict[str, dict[str, Any]] = {"modules": {}, "optimizers": {}}
156
149
  for attr in agent.evolvable_attributes():
157
150
  evolvable_obj: EvolvableAttributeType = getattr(agent, attr)
158
151
  if isinstance(evolvable_obj, OptimizerWrapper):
@@ -186,14 +179,14 @@ def get_checkpoint_dict(
186
179
 
187
180
 
188
181
  def get_optimizer_cls(
189
- optimizer_cls: Union[str, Dict[str, str]],
190
- ) -> Union[Type[torch.optim.Optimizer], Dict[str, Type[torch.optim.Optimizer]]]:
182
+ optimizer_cls: Union[str, dict[str, str]],
183
+ ) -> Union[type[torch.optim.Optimizer], dict[str, type[torch.optim.Optimizer]]]:
191
184
  """Returns the optimizer class from the string or dictionary of optimizer classes.
192
185
 
193
186
  :param optimizer_cls: The optimizer class or dictionary of optimizer classes.
194
- :type optimizer_cls: Union[str, Dict[str, str]]
187
+ :type optimizer_cls: Union[str, dict[str, str]]
195
188
  :return: The optimizer class or dictionary of optimizer classes.
196
- :rtype: Union[Type[torch.optim.Optimizer], Dict[str, Type[torch.optim.Optimizer]]]
189
+ :rtype: Union[type[torch.optim.Optimizer], dict[str, type[torch.optim.Optimizer]]]
197
190
  """
198
191
  if isinstance(optimizer_cls, dict):
199
192
  optimizer_cls = {
@@ -313,20 +306,20 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
313
306
  raise NotImplementedError
314
307
 
315
308
  @abstractmethod
316
- def test(self, *args, **kwargs) -> ArrayLike:
309
+ def test(self, *args, **kwargs) -> np.ndarray:
317
310
  """Abstract method for testing the algorithm."""
318
311
  raise NotImplementedError
319
312
 
320
313
  @staticmethod
321
- def get_state_dim(observation_space: GymSpaceType) -> Tuple[int, ...]:
314
+ def get_state_dim(observation_space: GymSpaceType) -> tuple[int, ...]:
322
315
  """Returns the dimension of the state space as it pertains to the underlying
323
316
  networks (i.e. the input size of the networks).
324
317
 
325
318
  :param observation_space: The observation space of the environment.
326
- :type observation_space: spaces.Space or List[spaces.Space].
319
+ :type observation_space: spaces.Space or list[spaces.Space].
327
320
 
328
321
  :return: The dimension of the state space.
329
- :rtype: Tuple[int, ...].
322
+ :rtype: tuple[int, ...].
330
323
  """
331
324
  warnings.warn(
332
325
  "This method is deprecated. Use get_input_size_from_space instead.",
@@ -335,12 +328,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
335
328
  return get_input_size_from_space(observation_space)
336
329
 
337
330
  @staticmethod
338
- def get_action_dim(action_space: GymSpaceType) -> Tuple[int, ...]:
331
+ def get_action_dim(action_space: GymSpaceType) -> tuple[int, ...]:
339
332
  """Returns the dimension of the action space as it pertains to the underlying
340
333
  networks (i.e. the output size of the networks).
341
334
 
342
335
  :param action_space: The action space of the environment.
343
- :type action_space: spaces.Space or List[spaces.Space].
336
+ :type action_space: spaces.Space or list[spaces.Space].
344
337
 
345
338
  :return: The dimension of the action space.
346
339
  :rtype: int.
@@ -354,7 +347,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
354
347
  @staticmethod
355
348
  def inspect_attributes(
356
349
  agent: SelfEvolvableAlgorithm, input_args_only: bool = False
357
- ) -> Dict[str, Any]:
350
+ ) -> dict[str, Any]:
358
351
  """
359
352
  Inspect and retrieve the attributes of the current object, excluding attributes related to the
360
353
  underlying evolvable networks (i.e. `EvolvableModule`, `torch.optim.Optimizer`) and with
@@ -451,21 +444,21 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
451
444
 
452
445
  @classmethod
453
446
  def population(
454
- cls: Type[SelfEvolvableAlgorithm],
447
+ cls: type[SelfEvolvableAlgorithm],
455
448
  size: int,
456
449
  observation_space: GymSpaceType,
457
450
  action_space: GymSpaceType,
458
- wrapper_cls: Optional[Type[SelfAgentWrapper]] = None,
459
- wrapper_kwargs: Dict[str, Any] = {},
451
+ wrapper_cls: Optional[type[SelfAgentWrapper]] = None,
452
+ wrapper_kwargs: dict[str, Any] = {},
460
453
  **kwargs,
461
- ) -> List[Union[SelfEvolvableAlgorithm, SelfAgentWrapper]]:
454
+ ) -> list[Union[SelfEvolvableAlgorithm, SelfAgentWrapper]]:
462
455
  """Creates a population of algorithms.
463
456
 
464
457
  :param size: The size of the population.
465
458
  :type size: int.
466
459
 
467
460
  :return: A list of algorithms.
468
- :rtype: List[SelfEvolvableAlgorithm].
461
+ :rtype: list[SelfEvolvableAlgorithm].
469
462
  """
470
463
  if wrapper_cls is not None:
471
464
  return [
@@ -549,11 +542,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
549
542
  hp_value = getattr(self, hp)
550
543
  hp_spec = self.registry.hp_config[hp]
551
544
  dtype = type(hp_value)
552
- if dtype not in [int, float]:
545
+ if dtype not in [int, float, np.ndarray]:
553
546
  raise TypeError(
554
547
  f"Can't mutate hyperparameter {hp} of type {dtype}. AgileRL only supports "
555
- "mutating integer or float hyperparameters."
548
+ "mutating integer, float, and numpy ndarray hyperparameters."
556
549
  )
550
+
557
551
  hp_spec.dtype = dtype
558
552
 
559
553
  def _wrap_attr(self, attr: EvolvableAttributeType) -> EvolvableAttributeType:
@@ -637,7 +631,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
637
631
  """
638
632
  self.training = training
639
633
 
640
- def get_lr_names(self) -> List[str]:
634
+ def get_lr_names(self) -> list[str]:
641
635
  """Returns the learning rates of the algorithm."""
642
636
  return [opt.lr for opt in self.registry.optimizers]
643
637
 
@@ -695,14 +689,14 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
695
689
  for name, obj in self.evolvable_attributes(networks_only=True).items():
696
690
  setattr(self, name, compile_model(obj, self.torch_compiler))
697
691
 
698
- def to_device(self, *experiences: TorchObsType) -> Tuple[TorchObsType, ...]:
692
+ def to_device(self, *experiences: TorchObsType) -> tuple[TorchObsType, ...]:
699
693
  """Moves experiences to the device.
700
694
 
701
695
  :param experiences: Experiences to move to device
702
- :type experiences: Tuple[torch.Tensor[float], ...]
696
+ :type experiences: tuple[torch.Tensor[float], ...]
703
697
 
704
698
  :return: Experiences on the device
705
- :rtype: Tuple[torch.Tensor[float], ...]
699
+ :rtype: tuple[torch.Tensor[float], ...]
706
700
  """
707
701
  device = self.device if self.accelerator is None else self.accelerator.device
708
702
  on_device = []
@@ -861,12 +855,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
861
855
  :param path: Location to load checkpoint from
862
856
  :type path: string
863
857
  """
864
- checkpoint: Dict[str, Any] = torch.load(
858
+ checkpoint: dict[str, Any] = torch.load(
865
859
  path, map_location=self.device, pickle_module=dill, weights_only=False
866
860
  )
867
861
 
868
862
  # Recreate evolvable modules
869
- network_info: Dict[str, Dict[str, Any]] = checkpoint["network_info"]
863
+ network_info: dict[str, dict[str, Any]] = checkpoint["network_info"]
870
864
  network_names = network_info["network_names"]
871
865
  for name in network_names:
872
866
  net_dict = {
@@ -967,7 +961,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
967
961
 
968
962
  @classmethod
969
963
  def load(
970
- cls: Type[SelfEvolvableAlgorithm],
964
+ cls: type[SelfEvolvableAlgorithm],
971
965
  path: str,
972
966
  device: DeviceType = "cpu",
973
967
  accelerator: Optional[Accelerator] = None,
@@ -984,12 +978,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
984
978
  :return: An instance of the algorithm
985
979
  :rtype: RLAlgorithm
986
980
  """
987
- checkpoint: Dict[str, Any] = torch.load(
981
+ checkpoint: dict[str, Any] = torch.load(
988
982
  path, map_location=device, pickle_module=dill, weights_only=False
989
983
  )
990
984
 
991
985
  # Reconstruct evolvable modules in algorithm
992
- network_info: Optional[Dict[str, Dict[str, Any]]] = checkpoint.get(
986
+ network_info: Optional[dict[str, dict[str, Any]]] = checkpoint.get(
993
987
  "network_info"
994
988
  )
995
989
  if network_info is None:
@@ -1001,7 +995,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
1001
995
  )
1002
996
 
1003
997
  network_names = network_info["network_names"]
1004
- loaded_modules: Dict[str, EvolvableAttributeType] = {}
998
+ loaded_modules: dict[str, EvolvableAttributeType] = {}
1005
999
  for name in network_names:
1006
1000
  net_dict = {
1007
1001
  k: v for k, v in network_info["modules"].items() if k.startswith(name)
@@ -1021,7 +1015,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
1021
1015
 
1022
1016
  # Reconstruct the modules
1023
1017
  module_cls: Union[
1024
- Type[EvolvableModule], Dict[str, Type[EvolvableModule]]
1018
+ type[EvolvableModule], dict[str, type[EvolvableModule]]
1025
1019
  ] = net_dict[f"{name}_cls"]
1026
1020
  if isinstance(module_cls, dict):
1027
1021
  for agent_id, mod_cls in module_cls.items():
@@ -1187,7 +1181,7 @@ class RLAlgorithm(EvolvableAlgorithm, ABC):
1187
1181
  :type observations: ObservationType
1188
1182
 
1189
1183
  :return: Preprocessed observations
1190
- :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or Tuple[torch.Tensor[float], ...]
1184
+ :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
1191
1185
  """
1192
1186
  return preprocess_observation(
1193
1187
  self.observation_space,
@@ -1201,13 +1195,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1201
1195
  """Base object for all multi-agent algorithms in the AgileRL framework.
1202
1196
 
1203
1197
  :param observation_spaces: The observation spaces of the agent environments.
1204
- :type observation_spaces: Union[List[spaces.Space], spaces.Dict]
1198
+ :type observation_spaces: Union[list[spaces.Space], spaces.Dict]
1205
1199
  :param action_spaces: The action spaces of the agent environments.
1206
- :type action_spaces: Union[List[spaces.Space], spaces.Dict]
1200
+ :type action_spaces: Union[list[spaces.Space], spaces.Dict]
1207
1201
  :param index: The index of the individual in the population.
1208
1202
  :type index: int.
1209
1203
  :param agent_ids: The agent IDs of the agents in the environment.
1210
- :type agent_ids: Optional[List[int]], optional
1204
+ :type agent_ids: Optional[list[int]], optional
1211
1205
  :param learn_step: Learning frequency, defaults to 2048
1212
1206
  :type learn_step: int, optional
1213
1207
  :param device: Device to run the algorithm on, defaults to "cpu"
@@ -1224,13 +1218,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1224
1218
  :type name: Optional[str], optional
1225
1219
  """
1226
1220
 
1227
- possible_observation_spaces: Dict[str, spaces.Space]
1228
- possible_action_spaces: Dict[str, spaces.Space]
1221
+ possible_observation_spaces: dict[str, spaces.Space]
1222
+ possible_action_spaces: dict[str, spaces.Space]
1229
1223
 
1230
- shared_agent_ids: List[str]
1231
- grouped_agents: Dict[str, List[str]]
1232
- unique_observation_spaces: Dict[str, spaces.Space]
1233
- unique_action_spaces: Dict[str, spaces.Space]
1224
+ shared_agent_ids: list[str]
1225
+ grouped_agents: dict[str, list[str]]
1226
+ unique_observation_spaces: dict[str, spaces.Space]
1227
+ unique_action_spaces: dict[str, spaces.Space]
1234
1228
 
1235
1229
  def __init__(
1236
1230
  self,
@@ -1396,14 +1390,14 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1396
1390
 
1397
1391
  def preprocess_observation(
1398
1392
  self, observation: ObservationType
1399
- ) -> Dict[str, TorchObsType]:
1393
+ ) -> dict[str, TorchObsType]:
1400
1394
  """Preprocesses observations for forward pass through neural network.
1401
1395
 
1402
1396
  :param observations: Observations of environment
1403
1397
  :type observations: numpy.ndarray[float] or dict[str, numpy.ndarray[float]]
1404
1398
 
1405
1399
  :return: Preprocessed observations
1406
- :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or Tuple[torch.Tensor[float], ...]
1400
+ :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
1407
1401
  """
1408
1402
  preprocessed = {}
1409
1403
  for agent_id, agent_obs in observation.items():
@@ -1421,10 +1415,10 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1421
1415
  """Extract action masks from info dictionary
1422
1416
 
1423
1417
  :param infos: Info dict
1424
- :type infos: Dict[str, Dict[...]]
1418
+ :type infos: dict[str, dict[...]]
1425
1419
 
1426
1420
  :return: Action masks
1427
- :rtype: Dict[str, np.ndarray]
1421
+ :rtype: dict[str, np.ndarray]
1428
1422
  """
1429
1423
  # Get dict of form {"agent_id" : [1, 0, 0, 0]...} etc
1430
1424
  action_masks = {
@@ -1437,14 +1431,14 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1437
1431
 
1438
1432
  def extract_agent_masks(
1439
1433
  self, infos: Optional[InfosDict] = None
1440
- ) -> Tuple[ArrayDict, ArrayDict]:
1434
+ ) -> tuple[ArrayDict, ArrayDict]:
1441
1435
  """Extract env_defined_actions from info dictionary and determine agent masks
1442
1436
 
1443
1437
  :param infos: Info dict
1444
- :type infos: Dict[str, Dict[...]]
1438
+ :type infos: dict[str, dict[...]]
1445
1439
 
1446
1440
  :return: Env defined actions and agent masks
1447
- :rtype: Tuple[ArrayDict, ArrayDict]
1441
+ :rtype: tuple[ArrayDict, ArrayDict]
1448
1442
  """
1449
1443
  # Deal with case of no env_defined_actions defined in the info dict
1450
1444
  # Deal with empty info dicts for each sub agent
@@ -1506,7 +1500,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1506
1500
  net_config: Optional[NetConfigType] = None,
1507
1501
  flatten: bool = True,
1508
1502
  return_encoders: bool = False,
1509
- ) -> Union[NetConfigType, Tuple[NetConfigType, Dict[str, NetConfigType]]]:
1503
+ ) -> Union[NetConfigType, tuple[NetConfigType, dict[str, NetConfigType]]]:
1510
1504
  """Extract an appropriate net config for each sub-agent from the passed net config dictionary. If
1511
1505
  grouped_agents is True, the net config will be built for the grouped agents i.e. through their
1512
1506
  common prefix in their agent_id, whenever the passed net config is None.
@@ -1539,7 +1533,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1539
1533
  # Helper function to append unique configs to the unique_configs dictionary
1540
1534
  # -> Access to unique configs is relevant for algorithms with networks that process
1541
1535
  # multiple agents' observations (e.g. shared critic in MADDPG)
1542
- def _add_to_encoder_configs(config: Dict[str, Any], agent_id: str = "") -> None:
1536
+ def _add_to_encoder_configs(config: dict[str, Any], agent_id: str = "") -> None:
1543
1537
  config = config_from_dict(config)
1544
1538
  config_key = "mlp_config" if isinstance(config, MlpNetConfig) else agent_id
1545
1539
 
@@ -1697,7 +1691,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1697
1691
  self,
1698
1692
  group_outputs: ArrayDict,
1699
1693
  vect_dim: int,
1700
- grouped_agents: Dict[str, List[str]],
1694
+ grouped_agents: dict[str, list[str]],
1701
1695
  ) -> ArrayDict:
1702
1696
  """Disassembles batched output by shared policies into their grouped agents' outputs.
1703
1697
 
@@ -1705,13 +1699,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1705
1699
  i.e. any given agent will always terminate at the same timestep in different vectorized environments.
1706
1700
 
1707
1701
  :param group_outputs: Dictionary to be disassembled, has the form {'agent': [4, 7, 8]}
1708
- :type group_outputs: Dict[str, np.ndarray]
1702
+ :type group_outputs: dict[str, np.ndarray]
1709
1703
  :param vect_dim: Vectorization dimension size, i.e. number of vect envs
1710
1704
  :type vect_dim: int
1711
1705
  :param grouped_agents: Dictionary of grouped agent IDs
1712
- :type grouped_agents: Dict[str, List[str]]
1706
+ :type grouped_agents: dict[str, list[str]]
1713
1707
  :return: Assembled dictionary, e.g. {'agent_0': 4, 'agent_1': 7, 'agent_2': 8}
1714
- :rtype: Dict[str, np.ndarray]
1708
+ :rtype: dict[str, np.ndarray]
1715
1709
  """
1716
1710
  output_dict = {}
1717
1711
  for group_id, agent_ids in grouped_agents.items():
@@ -1728,9 +1722,9 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1728
1722
  """Sums the rewards for grouped agents
1729
1723
 
1730
1724
  :param rewards: Reward dictionary from environment
1731
- :type rewards: Dict[str, np.ndarray]
1725
+ :type rewards: dict[str, np.ndarray]
1732
1726
  :return: Summed rewards dictionary
1733
- :rtype: Dict[str, np.ndarray]
1727
+ :rtype: dict[str, np.ndarray]
1734
1728
  """
1735
1729
  reward_shape = list(rewards.values())[0]
1736
1730
  reward_shape = (
@@ -1751,11 +1745,11 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1751
1745
  """Assembles individual agent outputs into batched outputs for shared policies.
1752
1746
 
1753
1747
  :param agent_outputs: Dictionary with individual agent outputs, e.g. {'agent_0': 4, 'agent_1': 7, 'agent_2': 8}
1754
- :type agent_outputs: Dict[str, np.ndarray]
1748
+ :type agent_outputs: dict[str, np.ndarray]
1755
1749
  :param vect_dim: Vectorization dimension size, i.e. number of vect envs
1756
1750
  :type vect_dim: int
1757
1751
  :return: Assembled dictionary with the form {'agent': [4, 7, 8]}
1758
- :rtype: Dict[str, np.ndarray]
1752
+ :rtype: dict[str, np.ndarray]
1759
1753
  """
1760
1754
  group_outputs = {}
1761
1755
  for group_id in self.shared_agent_ids:
@@ -1846,7 +1840,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1846
1840
  :type observations: numpy.ndarray[float] or dict[str, numpy.ndarray[float]]
1847
1841
 
1848
1842
  :return: Preprocessed observations
1849
- :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or Tuple[torch.Tensor[float], ...]
1843
+ :rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
1850
1844
  """
1851
1845
  return cast(TorchObsType, observation)
1852
1846
 
@@ -1890,6 +1884,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1890
1884
  path + "/attributes.pt",
1891
1885
  pickle_module=dill,
1892
1886
  )
1887
+ if self.accelerator is not None:
1888
+ self.accelerator.wait_for_everyone()
1893
1889
 
1894
1890
  # TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
1895
1891
  # handle _load_distributed_actor if deepspeed is used.
@@ -1907,28 +1903,22 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1907
1903
  if weights_only:
1908
1904
  if self.use_separate_reference_adapter:
1909
1905
  self._update_existing_adapter(
1910
- self.accelerator,
1911
- self.actor,
1912
1906
  path,
1913
1907
  "reference",
1914
1908
  )
1915
1909
 
1916
1910
  self._update_existing_adapter(
1917
- self.accelerator,
1918
- self.actor,
1919
1911
  path,
1920
1912
  "actor",
1921
1913
  )
1922
1914
  else:
1923
1915
  self._load_distributed_actor(path, tag="save_checkpoint")
1924
1916
 
1925
- checkpoint["accelerator"] = (
1926
- Accelerator() if self.accelerator is not None else None
1927
- )
1928
- self.accelerator = None
1929
1917
  for attr, value in checkpoint.items():
1930
1918
  setattr(self, attr, value)
1931
1919
 
1920
+ self.device = self.accelerator.device
1921
+
1932
1922
  self.optimizer = None
1933
1923
  self.optimizer = OptimizerWrapper(
1934
1924
  optimizer_cls=self._select_optim_class(),
@@ -1937,7 +1927,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1937
1927
  lr=self.lr,
1938
1928
  lr_name="lr",
1939
1929
  )
1940
- self.wrap_models()
1941
1930
  else:
1942
1931
  super().load_checkpoint(path + "/attributes.pt")
1943
1932
 
@@ -1964,11 +1953,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1964
1953
  """
1965
1954
  )
1966
1955
 
1967
- def _select_optim_class(self) -> Union[Type[OptimizerType], Type[DummyOptimizer]]:
1956
+ def _select_optim_class(self) -> Union[type[OptimizerType], type[DummyOptimizer]]:
1968
1957
  """Select the optimizer class based on the accelerator and deepspeed config.
1969
1958
 
1970
1959
  :return: Optimizer class
1971
- :rtype: Union[Type[torch.optim.Optimizer], Type[DummyOptimizer]]
1960
+ :rtype: Union[type[torch.optim.Optimizer], type[DummyOptimizer]]
1972
1961
  """
1973
1962
  if self.accelerator is None:
1974
1963
  return AdamW
@@ -2089,8 +2078,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2089
2078
  None,
2090
2079
  )
2091
2080
  if self.use_vllm:
2092
- destroy_model_parallel()
2093
- del self.llm.llm_engine.model_executor.driver_worker
2094
2081
  self.llm = None
2095
2082
  gc.collect()
2096
2083
  torch.cuda.empty_cache()
@@ -2201,7 +2188,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2201
2188
  lr: float,
2202
2189
  accelerator: Optional[Accelerator] = None,
2203
2190
  scheduler_config: Optional[CosineLRScheduleConfig] = None,
2204
- ) -> Tuple[Optional[Accelerator], Optional[SequentialLR]]:
2191
+ ) -> tuple[Optional[Accelerator], Optional[SequentialLR]]:
2205
2192
  """Update the learning rate of the optimizer
2206
2193
 
2207
2194
  :param optimizer: Optimizer
@@ -2268,20 +2255,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2268
2255
  "Recompile method is not available for LLM finetuning algorithms."
2269
2256
  )
2270
2257
 
2271
- @staticmethod
2272
2258
  def _update_existing_adapter(
2273
- accelerator: Accelerator,
2274
- wrapped_model: DeepSpeedEngine,
2259
+ self,
2275
2260
  checkpoint_dir: str,
2276
2261
  adapter_name: str,
2277
2262
  ) -> None:
2278
2263
  """
2279
2264
  Overwrite weights of an existing adapter in-place without creating new parameters.
2280
2265
 
2281
- :param accelerator: Accelerator
2282
- :type accelerator: Accelerator
2283
- :param wrapped_model: Wrapped model
2284
- :type wrapped_model: DeepSpeedEngine
2285
2266
  :param checkpoint_dir: Checkpoint directory
2286
2267
  :type checkpoint_dir: str
2287
2268
  :param adapter_name: Adapter name
@@ -2290,7 +2271,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2290
2271
  :return: None
2291
2272
  :rtype: None
2292
2273
  """
2293
- base_model = accelerator.unwrap_model(wrapped_model)
2274
+ base_model = self.accelerator.unwrap_model(self.actor)
2294
2275
  if hasattr(base_model, "module"):
2295
2276
  base_model = base_model.module
2296
2277
 
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, Dict, List, Optional, Type, Union
2
+ from typing import Any, Optional, Union
3
3
 
4
4
  import torch.nn as nn
5
5
  from peft import PeftModel
@@ -10,9 +10,9 @@ from agilerl.protocols import EvolvableAlgorithm
10
10
  from agilerl.typing import OptimizerType, StateDict
11
11
  from agilerl.utils.llm_utils import DummyOptimizer
12
12
 
13
- ModuleList = List[EvolvableModule]
13
+ ModuleList = list[EvolvableModule]
14
14
  _Optimizer = Union[
15
- Type[OptimizerType], Dict[str, Type[OptimizerType]], Type[DummyOptimizer]
15
+ type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
16
16
  ]
17
17
  _Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModel]
18
18
 
@@ -21,7 +21,7 @@ def init_from_multiple(
21
21
  networks: ModuleList,
22
22
  optimizer_cls: OptimizerType,
23
23
  lr: float,
24
- optimizer_kwargs: Dict[str, Any],
24
+ optimizer_kwargs: dict[str, Any],
25
25
  ) -> Optimizer:
26
26
  """
27
27
  Initialize an optimizer from a list of networks.
@@ -33,7 +33,7 @@ def init_from_multiple(
33
33
  :param lr: The learning rate of the optimizer.
34
34
  :type lr: float
35
35
  :param optimizer_kwargs: The keyword arguments to be passed to the optimizer.
36
- :type optimizer_kwargs: Dict[str, Any]
36
+ :type optimizer_kwargs: dict[str, Any]
37
37
  """
38
38
  opt_args = []
39
39
  for i, net in enumerate(networks):
@@ -51,7 +51,7 @@ def init_from_single(
51
51
  network: EvolvableModule,
52
52
  optimizer_cls: OptimizerType,
53
53
  lr: float,
54
- optimizer_kwargs: Dict[str, Any],
54
+ optimizer_kwargs: dict[str, Any],
55
55
  ) -> Optimizer:
56
56
  """
57
57
  Initialize an optimizer from a single network.
@@ -67,15 +67,15 @@ class OptimizerWrapper:
67
67
  to be able to reinitialize them after mutating an individual.
68
68
 
69
69
  :param optimizer_cls: The optimizer class to be initialized.
70
- :type optimizer_cls: Type[torch.optim.Optimizer]
70
+ :type optimizer_cls: type[torch.optim.Optimizer]
71
71
  :param networks: The network/s that the optimizer will update.
72
72
  :type networks: EvolvableModule, ModuleDict
73
73
  :param lr: The learning rate of the optimizer.
74
74
  :type lr: float
75
75
  :param optimizer_kwargs: The keyword arguments to be passed to the optimizer.
76
- :type optimizer_kwargs: Dict[str, Any]
76
+ :type optimizer_kwargs: dict[str, Any]
77
77
  :param network_names: The attribute names of the networks in the parent container.
78
- :type network_names: List[str]
78
+ :type network_names: list[str]
79
79
  :param lr_name: The attribute name of the learning rate in the parent container.
80
80
  :type lr_name: str
81
81
  """
@@ -87,8 +87,8 @@ class OptimizerWrapper:
87
87
  optimizer_cls: _Optimizer,
88
88
  networks: _Module,
89
89
  lr: float,
90
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
91
- network_names: Optional[List[str]] = None,
90
+ optimizer_kwargs: Optional[dict[str, Any]] = None,
91
+ network_names: Optional[list[str]] = None,
92
92
  lr_name: Optional[str] = None,
93
93
  ) -> None:
94
94
 
@@ -208,7 +208,7 @@ class OptimizerWrapper:
208
208
  current_frame = inspect.currentframe()
209
209
  return current_frame.f_back.f_back.f_locals["self"]
210
210
 
211
- def _infer_network_attr_names(self, container: Any) -> List[str]:
211
+ def _infer_network_attr_names(self, container: Any) -> list[str]:
212
212
  """
213
213
  Infer attribute names of the networks being optimized.
214
214
 
@@ -263,7 +263,7 @@ class OptimizerWrapper:
263
263
  Load the state of the optimizer from the passed state dictionary.
264
264
 
265
265
  :param state_dict: State dictionary of the optimizer.
266
- :type state_dict: Dict[str, Any]
266
+ :type state_dict: dict[str, Any]
267
267
  """
268
268
  if isinstance(self.networks[0], ModuleDict):
269
269
  assert (
@@ -293,7 +293,7 @@ class OptimizerWrapper:
293
293
 
294
294
  return self.optimizer.state_dict()
295
295
 
296
- def optimizer_cls_names(self) -> Union[str, Dict[str, str]]:
296
+ def optimizer_cls_names(self) -> Union[str, dict[str, str]]:
297
297
  """
298
298
  Return the names of the optimizers.
299
299
  """
@@ -304,7 +304,7 @@ class OptimizerWrapper:
304
304
  }
305
305
  return self.optimizer_cls.__name__
306
306
 
307
- def checkpoint_dict(self, name: str) -> Dict[str, Any]:
307
+ def checkpoint_dict(self, name: str) -> dict[str, Any]:
308
308
  """
309
309
  Return a dictionary of the optimizer's state and parameters.
310
310
 
@@ -312,7 +312,7 @@ class OptimizerWrapper:
312
312
  :type name: str
313
313
 
314
314
  :return: A dictionary of the optimizer's state and parameters.
315
- :rtype: Dict[str, Any]
315
+ :rtype: dict[str, Any]
316
316
  """
317
317
  return {
318
318
  f"{name}_cls": self.optimizer_cls_names(),