agilerl 2.3.5.dev0__tar.gz → 2.3.5.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/PKG-INFO +1 -1
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/bc_lm.py +3 -3
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/base.py +70 -89
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/optimizer_wrapper.py +16 -16
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/registry.py +77 -45
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/cqn.py +5 -6
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ddpg.py +14 -14
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/dqn.py +2 -2
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/dqn_rainbow.py +5 -6
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/grpo.py +8 -8
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ilql.py +2 -2
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ippo.py +25 -25
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/maddpg.py +22 -22
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/matd3.py +28 -28
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/neural_ts_bandit.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/neural_ucb_bandit.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/ppo.py +29 -29
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/td3.py +6 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/multi_agent_replay_buffer.py +23 -24
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/replay_buffer.py +3 -3
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/rollout_buffer.py +30 -30
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/sampler.py +5 -5
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/rl_data.py +8 -8
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/tokenizer.py +5 -5
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/mutation.py +20 -20
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/tournament.py +6 -8
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/base.py +33 -37
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/bert.py +11 -11
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/cnn.py +43 -43
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/configs.py +11 -11
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/dummy.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/gpt.py +14 -14
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/lstm.py +11 -11
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/mlp.py +13 -13
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/multi_input.py +18 -18
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/resnet.py +12 -12
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/simba.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/actors.py +7 -7
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/base.py +27 -27
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/custom_modules.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/distributions.py +12 -12
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/distributions_experimental.py +3 -3
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/q_networks.py +10 -10
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/value_networks.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/protocols.py +41 -45
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/rollouts/on_policy.py +10 -10
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_bandits.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_llm.py +3 -3
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_multi_agent_off_policy.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_multi_agent_on_policy.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_off_policy.py +5 -5
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_offline.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/train_on_policy.py +5 -5
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/typing.py +24 -28
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/algo_utils.py +67 -67
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/evolvable_networks.py +26 -26
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/ilql_utils.py +6 -6
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/llm_utils.py +12 -12
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/torch_utils.py +4 -4
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/utils.py +30 -9
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/pz_async_vec_env.py +48 -48
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/pz_vec_env.py +10 -10
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/agent.py +14 -14
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/make_evolvable.py +17 -17
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/pyproject.toml +4 -1
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/LICENSE +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/README.md +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/data.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/training/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/cache.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.3.5.dev0 → agilerl-2.3.5.dev1}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable, Optional,
|
|
1
|
+
from typing import Any, Callable, Optional, Union
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -167,7 +167,7 @@ class BC_LM(nn.Module):
|
|
|
167
167
|
temp: float = 1.0,
|
|
168
168
|
top_k: Optional[int] = None,
|
|
169
169
|
top_p: Optional[float] = None,
|
|
170
|
-
) ->
|
|
170
|
+
) -> tuple[torch.Tensor, Any]:
|
|
171
171
|
prepared_inputs = self.prepare_inputs(items)
|
|
172
172
|
tokens = prepared_inputs["tokens"]
|
|
173
173
|
scores, model_outputs = self.score(
|
|
@@ -189,7 +189,7 @@ class BC_LM(nn.Module):
|
|
|
189
189
|
temp: float = 1.0,
|
|
190
190
|
top_k: Optional[int] = None,
|
|
191
191
|
top_p: Optional[float] = None,
|
|
192
|
-
) ->
|
|
192
|
+
) -> tuple[torch.Tensor, Any]:
|
|
193
193
|
scores, model_outputs = self.score(
|
|
194
194
|
(
|
|
195
195
|
tokens.unsqueeze(1),
|
|
@@ -12,12 +12,8 @@ from importlib.metadata import version
|
|
|
12
12
|
from typing import (
|
|
13
13
|
Any,
|
|
14
14
|
Callable,
|
|
15
|
-
Dict,
|
|
16
15
|
Iterable,
|
|
17
|
-
List,
|
|
18
16
|
Optional,
|
|
19
|
-
Tuple,
|
|
20
|
-
Type,
|
|
21
17
|
TypeVar,
|
|
22
18
|
Union,
|
|
23
19
|
cast,
|
|
@@ -30,16 +26,13 @@ from accelerate import Accelerator
|
|
|
30
26
|
from accelerate.utils import broadcast_object_list
|
|
31
27
|
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
32
28
|
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
33
|
-
from deepspeed.runtime.engine import DeepSpeedEngine
|
|
34
29
|
from gymnasium import spaces
|
|
35
|
-
from numpy.typing import ArrayLike
|
|
36
30
|
from peft import PeftModel, set_peft_model_state_dict
|
|
37
31
|
from safetensors.torch import load_file
|
|
38
32
|
from tensordict import TensorDict
|
|
39
33
|
from torch._dynamo import OptimizedModule
|
|
40
34
|
from torch.optim import AdamW
|
|
41
35
|
from torch.optim.lr_scheduler import SequentialLR
|
|
42
|
-
from vllm.distributed.parallel_state import destroy_model_parallel
|
|
43
36
|
|
|
44
37
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
45
38
|
from agilerl.algorithms.core.registry import (
|
|
@@ -107,7 +100,7 @@ class _RegistryMeta(type):
|
|
|
107
100
|
initializing with specified network groups and optimizers."""
|
|
108
101
|
|
|
109
102
|
def __call__(
|
|
110
|
-
cls:
|
|
103
|
+
cls: type[SelfEvolvableAlgorithm], *args, **kwargs
|
|
111
104
|
) -> SelfEvolvableAlgorithm:
|
|
112
105
|
# Create the instance
|
|
113
106
|
instance: SelfEvolvableAlgorithm = super().__call__(*args, **kwargs)
|
|
@@ -124,7 +117,7 @@ class RegistryMeta(_RegistryMeta, ABCMeta): ...
|
|
|
124
117
|
|
|
125
118
|
def get_checkpoint_dict(
|
|
126
119
|
agent: SelfEvolvableAlgorithm, using_deepspeed: bool = False
|
|
127
|
-
) ->
|
|
120
|
+
) -> dict[str, Any]:
|
|
128
121
|
"""Returns a dictionary of the agent's attributes to save in a checkpoint.
|
|
129
122
|
|
|
130
123
|
Note: Accelerator is always excluded from the checkpoint as it cannot be serialized.
|
|
@@ -152,7 +145,7 @@ def get_checkpoint_dict(
|
|
|
152
145
|
attribute_dict.pop("rollout_buffer")
|
|
153
146
|
|
|
154
147
|
# Get checkpoint dictionaries for evolvable modules and optimizers
|
|
155
|
-
network_info:
|
|
148
|
+
network_info: dict[str, dict[str, Any]] = {"modules": {}, "optimizers": {}}
|
|
156
149
|
for attr in agent.evolvable_attributes():
|
|
157
150
|
evolvable_obj: EvolvableAttributeType = getattr(agent, attr)
|
|
158
151
|
if isinstance(evolvable_obj, OptimizerWrapper):
|
|
@@ -186,14 +179,14 @@ def get_checkpoint_dict(
|
|
|
186
179
|
|
|
187
180
|
|
|
188
181
|
def get_optimizer_cls(
|
|
189
|
-
optimizer_cls: Union[str,
|
|
190
|
-
) -> Union[
|
|
182
|
+
optimizer_cls: Union[str, dict[str, str]],
|
|
183
|
+
) -> Union[type[torch.optim.Optimizer], dict[str, type[torch.optim.Optimizer]]]:
|
|
191
184
|
"""Returns the optimizer class from the string or dictionary of optimizer classes.
|
|
192
185
|
|
|
193
186
|
:param optimizer_cls: The optimizer class or dictionary of optimizer classes.
|
|
194
|
-
:type optimizer_cls: Union[str,
|
|
187
|
+
:type optimizer_cls: Union[str, dict[str, str]]
|
|
195
188
|
:return: The optimizer class or dictionary of optimizer classes.
|
|
196
|
-
:rtype: Union[
|
|
189
|
+
:rtype: Union[type[torch.optim.Optimizer], dict[str, type[torch.optim.Optimizer]]]
|
|
197
190
|
"""
|
|
198
191
|
if isinstance(optimizer_cls, dict):
|
|
199
192
|
optimizer_cls = {
|
|
@@ -313,20 +306,20 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
313
306
|
raise NotImplementedError
|
|
314
307
|
|
|
315
308
|
@abstractmethod
|
|
316
|
-
def test(self, *args, **kwargs) ->
|
|
309
|
+
def test(self, *args, **kwargs) -> np.ndarray:
|
|
317
310
|
"""Abstract method for testing the algorithm."""
|
|
318
311
|
raise NotImplementedError
|
|
319
312
|
|
|
320
313
|
@staticmethod
|
|
321
|
-
def get_state_dim(observation_space: GymSpaceType) ->
|
|
314
|
+
def get_state_dim(observation_space: GymSpaceType) -> tuple[int, ...]:
|
|
322
315
|
"""Returns the dimension of the state space as it pertains to the underlying
|
|
323
316
|
networks (i.e. the input size of the networks).
|
|
324
317
|
|
|
325
318
|
:param observation_space: The observation space of the environment.
|
|
326
|
-
:type observation_space: spaces.Space or
|
|
319
|
+
:type observation_space: spaces.Space or list[spaces.Space].
|
|
327
320
|
|
|
328
321
|
:return: The dimension of the state space.
|
|
329
|
-
:rtype:
|
|
322
|
+
:rtype: tuple[int, ...].
|
|
330
323
|
"""
|
|
331
324
|
warnings.warn(
|
|
332
325
|
"This method is deprecated. Use get_input_size_from_space instead.",
|
|
@@ -335,12 +328,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
335
328
|
return get_input_size_from_space(observation_space)
|
|
336
329
|
|
|
337
330
|
@staticmethod
|
|
338
|
-
def get_action_dim(action_space: GymSpaceType) ->
|
|
331
|
+
def get_action_dim(action_space: GymSpaceType) -> tuple[int, ...]:
|
|
339
332
|
"""Returns the dimension of the action space as it pertains to the underlying
|
|
340
333
|
networks (i.e. the output size of the networks).
|
|
341
334
|
|
|
342
335
|
:param action_space: The action space of the environment.
|
|
343
|
-
:type action_space: spaces.Space or
|
|
336
|
+
:type action_space: spaces.Space or list[spaces.Space].
|
|
344
337
|
|
|
345
338
|
:return: The dimension of the action space.
|
|
346
339
|
:rtype: int.
|
|
@@ -354,7 +347,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
354
347
|
@staticmethod
|
|
355
348
|
def inspect_attributes(
|
|
356
349
|
agent: SelfEvolvableAlgorithm, input_args_only: bool = False
|
|
357
|
-
) ->
|
|
350
|
+
) -> dict[str, Any]:
|
|
358
351
|
"""
|
|
359
352
|
Inspect and retrieve the attributes of the current object, excluding attributes related to the
|
|
360
353
|
underlying evolvable networks (i.e. `EvolvableModule`, `torch.optim.Optimizer`) and with
|
|
@@ -451,21 +444,21 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
451
444
|
|
|
452
445
|
@classmethod
|
|
453
446
|
def population(
|
|
454
|
-
cls:
|
|
447
|
+
cls: type[SelfEvolvableAlgorithm],
|
|
455
448
|
size: int,
|
|
456
449
|
observation_space: GymSpaceType,
|
|
457
450
|
action_space: GymSpaceType,
|
|
458
|
-
wrapper_cls: Optional[
|
|
459
|
-
wrapper_kwargs:
|
|
451
|
+
wrapper_cls: Optional[type[SelfAgentWrapper]] = None,
|
|
452
|
+
wrapper_kwargs: dict[str, Any] = {},
|
|
460
453
|
**kwargs,
|
|
461
|
-
) ->
|
|
454
|
+
) -> list[Union[SelfEvolvableAlgorithm, SelfAgentWrapper]]:
|
|
462
455
|
"""Creates a population of algorithms.
|
|
463
456
|
|
|
464
457
|
:param size: The size of the population.
|
|
465
458
|
:type size: int.
|
|
466
459
|
|
|
467
460
|
:return: A list of algorithms.
|
|
468
|
-
:rtype:
|
|
461
|
+
:rtype: list[SelfEvolvableAlgorithm].
|
|
469
462
|
"""
|
|
470
463
|
if wrapper_cls is not None:
|
|
471
464
|
return [
|
|
@@ -549,11 +542,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
549
542
|
hp_value = getattr(self, hp)
|
|
550
543
|
hp_spec = self.registry.hp_config[hp]
|
|
551
544
|
dtype = type(hp_value)
|
|
552
|
-
if dtype not in [int, float]:
|
|
545
|
+
if dtype not in [int, float, np.ndarray]:
|
|
553
546
|
raise TypeError(
|
|
554
547
|
f"Can't mutate hyperparameter {hp} of type {dtype}. AgileRL only supports "
|
|
555
|
-
"mutating integer
|
|
548
|
+
"mutating integer, float, and numpy ndarray hyperparameters."
|
|
556
549
|
)
|
|
550
|
+
|
|
557
551
|
hp_spec.dtype = dtype
|
|
558
552
|
|
|
559
553
|
def _wrap_attr(self, attr: EvolvableAttributeType) -> EvolvableAttributeType:
|
|
@@ -637,7 +631,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
637
631
|
"""
|
|
638
632
|
self.training = training
|
|
639
633
|
|
|
640
|
-
def get_lr_names(self) ->
|
|
634
|
+
def get_lr_names(self) -> list[str]:
|
|
641
635
|
"""Returns the learning rates of the algorithm."""
|
|
642
636
|
return [opt.lr for opt in self.registry.optimizers]
|
|
643
637
|
|
|
@@ -695,14 +689,14 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
695
689
|
for name, obj in self.evolvable_attributes(networks_only=True).items():
|
|
696
690
|
setattr(self, name, compile_model(obj, self.torch_compiler))
|
|
697
691
|
|
|
698
|
-
def to_device(self, *experiences: TorchObsType) ->
|
|
692
|
+
def to_device(self, *experiences: TorchObsType) -> tuple[TorchObsType, ...]:
|
|
699
693
|
"""Moves experiences to the device.
|
|
700
694
|
|
|
701
695
|
:param experiences: Experiences to move to device
|
|
702
|
-
:type experiences:
|
|
696
|
+
:type experiences: tuple[torch.Tensor[float], ...]
|
|
703
697
|
|
|
704
698
|
:return: Experiences on the device
|
|
705
|
-
:rtype:
|
|
699
|
+
:rtype: tuple[torch.Tensor[float], ...]
|
|
706
700
|
"""
|
|
707
701
|
device = self.device if self.accelerator is None else self.accelerator.device
|
|
708
702
|
on_device = []
|
|
@@ -861,12 +855,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
861
855
|
:param path: Location to load checkpoint from
|
|
862
856
|
:type path: string
|
|
863
857
|
"""
|
|
864
|
-
checkpoint:
|
|
858
|
+
checkpoint: dict[str, Any] = torch.load(
|
|
865
859
|
path, map_location=self.device, pickle_module=dill, weights_only=False
|
|
866
860
|
)
|
|
867
861
|
|
|
868
862
|
# Recreate evolvable modules
|
|
869
|
-
network_info:
|
|
863
|
+
network_info: dict[str, dict[str, Any]] = checkpoint["network_info"]
|
|
870
864
|
network_names = network_info["network_names"]
|
|
871
865
|
for name in network_names:
|
|
872
866
|
net_dict = {
|
|
@@ -967,7 +961,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
967
961
|
|
|
968
962
|
@classmethod
|
|
969
963
|
def load(
|
|
970
|
-
cls:
|
|
964
|
+
cls: type[SelfEvolvableAlgorithm],
|
|
971
965
|
path: str,
|
|
972
966
|
device: DeviceType = "cpu",
|
|
973
967
|
accelerator: Optional[Accelerator] = None,
|
|
@@ -984,12 +978,12 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
984
978
|
:return: An instance of the algorithm
|
|
985
979
|
:rtype: RLAlgorithm
|
|
986
980
|
"""
|
|
987
|
-
checkpoint:
|
|
981
|
+
checkpoint: dict[str, Any] = torch.load(
|
|
988
982
|
path, map_location=device, pickle_module=dill, weights_only=False
|
|
989
983
|
)
|
|
990
984
|
|
|
991
985
|
# Reconstruct evolvable modules in algorithm
|
|
992
|
-
network_info: Optional[
|
|
986
|
+
network_info: Optional[dict[str, dict[str, Any]]] = checkpoint.get(
|
|
993
987
|
"network_info"
|
|
994
988
|
)
|
|
995
989
|
if network_info is None:
|
|
@@ -1001,7 +995,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1001
995
|
)
|
|
1002
996
|
|
|
1003
997
|
network_names = network_info["network_names"]
|
|
1004
|
-
loaded_modules:
|
|
998
|
+
loaded_modules: dict[str, EvolvableAttributeType] = {}
|
|
1005
999
|
for name in network_names:
|
|
1006
1000
|
net_dict = {
|
|
1007
1001
|
k: v for k, v in network_info["modules"].items() if k.startswith(name)
|
|
@@ -1021,7 +1015,7 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1021
1015
|
|
|
1022
1016
|
# Reconstruct the modules
|
|
1023
1017
|
module_cls: Union[
|
|
1024
|
-
|
|
1018
|
+
type[EvolvableModule], dict[str, type[EvolvableModule]]
|
|
1025
1019
|
] = net_dict[f"{name}_cls"]
|
|
1026
1020
|
if isinstance(module_cls, dict):
|
|
1027
1021
|
for agent_id, mod_cls in module_cls.items():
|
|
@@ -1187,7 +1181,7 @@ class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1187
1181
|
:type observations: ObservationType
|
|
1188
1182
|
|
|
1189
1183
|
:return: Preprocessed observations
|
|
1190
|
-
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or
|
|
1184
|
+
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
|
|
1191
1185
|
"""
|
|
1192
1186
|
return preprocess_observation(
|
|
1193
1187
|
self.observation_space,
|
|
@@ -1201,13 +1195,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1201
1195
|
"""Base object for all multi-agent algorithms in the AgileRL framework.
|
|
1202
1196
|
|
|
1203
1197
|
:param observation_spaces: The observation spaces of the agent environments.
|
|
1204
|
-
:type observation_spaces: Union[
|
|
1198
|
+
:type observation_spaces: Union[list[spaces.Space], spaces.Dict]
|
|
1205
1199
|
:param action_spaces: The action spaces of the agent environments.
|
|
1206
|
-
:type action_spaces: Union[
|
|
1200
|
+
:type action_spaces: Union[list[spaces.Space], spaces.Dict]
|
|
1207
1201
|
:param index: The index of the individual in the population.
|
|
1208
1202
|
:type index: int.
|
|
1209
1203
|
:param agent_ids: The agent IDs of the agents in the environment.
|
|
1210
|
-
:type agent_ids: Optional[
|
|
1204
|
+
:type agent_ids: Optional[list[int]], optional
|
|
1211
1205
|
:param learn_step: Learning frequency, defaults to 2048
|
|
1212
1206
|
:type learn_step: int, optional
|
|
1213
1207
|
:param device: Device to run the algorithm on, defaults to "cpu"
|
|
@@ -1224,13 +1218,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1224
1218
|
:type name: Optional[str], optional
|
|
1225
1219
|
"""
|
|
1226
1220
|
|
|
1227
|
-
possible_observation_spaces:
|
|
1228
|
-
possible_action_spaces:
|
|
1221
|
+
possible_observation_spaces: dict[str, spaces.Space]
|
|
1222
|
+
possible_action_spaces: dict[str, spaces.Space]
|
|
1229
1223
|
|
|
1230
|
-
shared_agent_ids:
|
|
1231
|
-
grouped_agents:
|
|
1232
|
-
unique_observation_spaces:
|
|
1233
|
-
unique_action_spaces:
|
|
1224
|
+
shared_agent_ids: list[str]
|
|
1225
|
+
grouped_agents: dict[str, list[str]]
|
|
1226
|
+
unique_observation_spaces: dict[str, spaces.Space]
|
|
1227
|
+
unique_action_spaces: dict[str, spaces.Space]
|
|
1234
1228
|
|
|
1235
1229
|
def __init__(
|
|
1236
1230
|
self,
|
|
@@ -1396,14 +1390,14 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1396
1390
|
|
|
1397
1391
|
def preprocess_observation(
|
|
1398
1392
|
self, observation: ObservationType
|
|
1399
|
-
) ->
|
|
1393
|
+
) -> dict[str, TorchObsType]:
|
|
1400
1394
|
"""Preprocesses observations for forward pass through neural network.
|
|
1401
1395
|
|
|
1402
1396
|
:param observations: Observations of environment
|
|
1403
1397
|
:type observations: numpy.ndarray[float] or dict[str, numpy.ndarray[float]]
|
|
1404
1398
|
|
|
1405
1399
|
:return: Preprocessed observations
|
|
1406
|
-
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or
|
|
1400
|
+
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
|
|
1407
1401
|
"""
|
|
1408
1402
|
preprocessed = {}
|
|
1409
1403
|
for agent_id, agent_obs in observation.items():
|
|
@@ -1421,10 +1415,10 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1421
1415
|
"""Extract action masks from info dictionary
|
|
1422
1416
|
|
|
1423
1417
|
:param infos: Info dict
|
|
1424
|
-
:type infos:
|
|
1418
|
+
:type infos: dict[str, dict[...]]
|
|
1425
1419
|
|
|
1426
1420
|
:return: Action masks
|
|
1427
|
-
:rtype:
|
|
1421
|
+
:rtype: dict[str, np.ndarray]
|
|
1428
1422
|
"""
|
|
1429
1423
|
# Get dict of form {"agent_id" : [1, 0, 0, 0]...} etc
|
|
1430
1424
|
action_masks = {
|
|
@@ -1437,14 +1431,14 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1437
1431
|
|
|
1438
1432
|
def extract_agent_masks(
|
|
1439
1433
|
self, infos: Optional[InfosDict] = None
|
|
1440
|
-
) ->
|
|
1434
|
+
) -> tuple[ArrayDict, ArrayDict]:
|
|
1441
1435
|
"""Extract env_defined_actions from info dictionary and determine agent masks
|
|
1442
1436
|
|
|
1443
1437
|
:param infos: Info dict
|
|
1444
|
-
:type infos:
|
|
1438
|
+
:type infos: dict[str, dict[...]]
|
|
1445
1439
|
|
|
1446
1440
|
:return: Env defined actions and agent masks
|
|
1447
|
-
:rtype:
|
|
1441
|
+
:rtype: tuple[ArrayDict, ArrayDict]
|
|
1448
1442
|
"""
|
|
1449
1443
|
# Deal with case of no env_defined_actions defined in the info dict
|
|
1450
1444
|
# Deal with empty info dicts for each sub agent
|
|
@@ -1506,7 +1500,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1506
1500
|
net_config: Optional[NetConfigType] = None,
|
|
1507
1501
|
flatten: bool = True,
|
|
1508
1502
|
return_encoders: bool = False,
|
|
1509
|
-
) -> Union[NetConfigType,
|
|
1503
|
+
) -> Union[NetConfigType, tuple[NetConfigType, dict[str, NetConfigType]]]:
|
|
1510
1504
|
"""Extract an appropriate net config for each sub-agent from the passed net config dictionary. If
|
|
1511
1505
|
grouped_agents is True, the net config will be built for the grouped agents i.e. through their
|
|
1512
1506
|
common prefix in their agent_id, whenever the passed net config is None.
|
|
@@ -1539,7 +1533,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1539
1533
|
# Helper function to append unique configs to the unique_configs dictionary
|
|
1540
1534
|
# -> Access to unique configs is relevant for algorithms with networks that process
|
|
1541
1535
|
# multiple agents' observations (e.g. shared critic in MADDPG)
|
|
1542
|
-
def _add_to_encoder_configs(config:
|
|
1536
|
+
def _add_to_encoder_configs(config: dict[str, Any], agent_id: str = "") -> None:
|
|
1543
1537
|
config = config_from_dict(config)
|
|
1544
1538
|
config_key = "mlp_config" if isinstance(config, MlpNetConfig) else agent_id
|
|
1545
1539
|
|
|
@@ -1697,7 +1691,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1697
1691
|
self,
|
|
1698
1692
|
group_outputs: ArrayDict,
|
|
1699
1693
|
vect_dim: int,
|
|
1700
|
-
grouped_agents:
|
|
1694
|
+
grouped_agents: dict[str, list[str]],
|
|
1701
1695
|
) -> ArrayDict:
|
|
1702
1696
|
"""Disassembles batched output by shared policies into their grouped agents' outputs.
|
|
1703
1697
|
|
|
@@ -1705,13 +1699,13 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1705
1699
|
i.e. any given agent will always terminate at the same timestep in different vectorized environments.
|
|
1706
1700
|
|
|
1707
1701
|
:param group_outputs: Dictionary to be disassembled, has the form {'agent': [4, 7, 8]}
|
|
1708
|
-
:type group_outputs:
|
|
1702
|
+
:type group_outputs: dict[str, np.ndarray]
|
|
1709
1703
|
:param vect_dim: Vectorization dimension size, i.e. number of vect envs
|
|
1710
1704
|
:type vect_dim: int
|
|
1711
1705
|
:param grouped_agents: Dictionary of grouped agent IDs
|
|
1712
|
-
:type grouped_agents:
|
|
1706
|
+
:type grouped_agents: dict[str, list[str]]
|
|
1713
1707
|
:return: Assembled dictionary, e.g. {'agent_0': 4, 'agent_1': 7, 'agent_2': 8}
|
|
1714
|
-
:rtype:
|
|
1708
|
+
:rtype: dict[str, np.ndarray]
|
|
1715
1709
|
"""
|
|
1716
1710
|
output_dict = {}
|
|
1717
1711
|
for group_id, agent_ids in grouped_agents.items():
|
|
@@ -1728,9 +1722,9 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1728
1722
|
"""Sums the rewards for grouped agents
|
|
1729
1723
|
|
|
1730
1724
|
:param rewards: Reward dictionary from environment
|
|
1731
|
-
:type rewards:
|
|
1725
|
+
:type rewards: dict[str, np.ndarray]
|
|
1732
1726
|
:return: Summed rewards dictionary
|
|
1733
|
-
:rtype:
|
|
1727
|
+
:rtype: dict[str, np.ndarray]
|
|
1734
1728
|
"""
|
|
1735
1729
|
reward_shape = list(rewards.values())[0]
|
|
1736
1730
|
reward_shape = (
|
|
@@ -1751,11 +1745,11 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1751
1745
|
"""Assembles individual agent outputs into batched outputs for shared policies.
|
|
1752
1746
|
|
|
1753
1747
|
:param agent_outputs: Dictionary with individual agent outputs, e.g. {'agent_0': 4, 'agent_1': 7, 'agent_2': 8}
|
|
1754
|
-
:type agent_outputs:
|
|
1748
|
+
:type agent_outputs: dict[str, np.ndarray]
|
|
1755
1749
|
:param vect_dim: Vectorization dimension size, i.e. number of vect envs
|
|
1756
1750
|
:type vect_dim: int
|
|
1757
1751
|
:return: Assembled dictionary with the form {'agent': [4, 7, 8]}
|
|
1758
|
-
:rtype:
|
|
1752
|
+
:rtype: dict[str, np.ndarray]
|
|
1759
1753
|
"""
|
|
1760
1754
|
group_outputs = {}
|
|
1761
1755
|
for group_id in self.shared_agent_ids:
|
|
@@ -1846,7 +1840,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1846
1840
|
:type observations: numpy.ndarray[float] or dict[str, numpy.ndarray[float]]
|
|
1847
1841
|
|
|
1848
1842
|
:return: Preprocessed observations
|
|
1849
|
-
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or
|
|
1843
|
+
:rtype: torch.Tensor[float] or dict[str, torch.Tensor[float]] or tuple[torch.Tensor[float], ...]
|
|
1850
1844
|
"""
|
|
1851
1845
|
return cast(TorchObsType, observation)
|
|
1852
1846
|
|
|
@@ -1890,6 +1884,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1890
1884
|
path + "/attributes.pt",
|
|
1891
1885
|
pickle_module=dill,
|
|
1892
1886
|
)
|
|
1887
|
+
if self.accelerator is not None:
|
|
1888
|
+
self.accelerator.wait_for_everyone()
|
|
1893
1889
|
|
|
1894
1890
|
# TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
|
|
1895
1891
|
# handle _load_distributed_actor if deepspeed is used.
|
|
@@ -1907,28 +1903,22 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1907
1903
|
if weights_only:
|
|
1908
1904
|
if self.use_separate_reference_adapter:
|
|
1909
1905
|
self._update_existing_adapter(
|
|
1910
|
-
self.accelerator,
|
|
1911
|
-
self.actor,
|
|
1912
1906
|
path,
|
|
1913
1907
|
"reference",
|
|
1914
1908
|
)
|
|
1915
1909
|
|
|
1916
1910
|
self._update_existing_adapter(
|
|
1917
|
-
self.accelerator,
|
|
1918
|
-
self.actor,
|
|
1919
1911
|
path,
|
|
1920
1912
|
"actor",
|
|
1921
1913
|
)
|
|
1922
1914
|
else:
|
|
1923
1915
|
self._load_distributed_actor(path, tag="save_checkpoint")
|
|
1924
1916
|
|
|
1925
|
-
checkpoint["accelerator"] = (
|
|
1926
|
-
Accelerator() if self.accelerator is not None else None
|
|
1927
|
-
)
|
|
1928
|
-
self.accelerator = None
|
|
1929
1917
|
for attr, value in checkpoint.items():
|
|
1930
1918
|
setattr(self, attr, value)
|
|
1931
1919
|
|
|
1920
|
+
self.device = self.accelerator.device
|
|
1921
|
+
|
|
1932
1922
|
self.optimizer = None
|
|
1933
1923
|
self.optimizer = OptimizerWrapper(
|
|
1934
1924
|
optimizer_cls=self._select_optim_class(),
|
|
@@ -1937,7 +1927,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1937
1927
|
lr=self.lr,
|
|
1938
1928
|
lr_name="lr",
|
|
1939
1929
|
)
|
|
1940
|
-
self.wrap_models()
|
|
1941
1930
|
else:
|
|
1942
1931
|
super().load_checkpoint(path + "/attributes.pt")
|
|
1943
1932
|
|
|
@@ -1964,11 +1953,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1964
1953
|
"""
|
|
1965
1954
|
)
|
|
1966
1955
|
|
|
1967
|
-
def _select_optim_class(self) -> Union[
|
|
1956
|
+
def _select_optim_class(self) -> Union[type[OptimizerType], type[DummyOptimizer]]:
|
|
1968
1957
|
"""Select the optimizer class based on the accelerator and deepspeed config.
|
|
1969
1958
|
|
|
1970
1959
|
:return: Optimizer class
|
|
1971
|
-
:rtype: Union[
|
|
1960
|
+
:rtype: Union[type[torch.optim.Optimizer], type[DummyOptimizer]]
|
|
1972
1961
|
"""
|
|
1973
1962
|
if self.accelerator is None:
|
|
1974
1963
|
return AdamW
|
|
@@ -2089,8 +2078,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2089
2078
|
None,
|
|
2090
2079
|
)
|
|
2091
2080
|
if self.use_vllm:
|
|
2092
|
-
destroy_model_parallel()
|
|
2093
|
-
del self.llm.llm_engine.model_executor.driver_worker
|
|
2094
2081
|
self.llm = None
|
|
2095
2082
|
gc.collect()
|
|
2096
2083
|
torch.cuda.empty_cache()
|
|
@@ -2201,7 +2188,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2201
2188
|
lr: float,
|
|
2202
2189
|
accelerator: Optional[Accelerator] = None,
|
|
2203
2190
|
scheduler_config: Optional[CosineLRScheduleConfig] = None,
|
|
2204
|
-
) ->
|
|
2191
|
+
) -> tuple[Optional[Accelerator], Optional[SequentialLR]]:
|
|
2205
2192
|
"""Update the learning rate of the optimizer
|
|
2206
2193
|
|
|
2207
2194
|
:param optimizer: Optimizer
|
|
@@ -2268,20 +2255,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2268
2255
|
"Recompile method is not available for LLM finetuning algorithms."
|
|
2269
2256
|
)
|
|
2270
2257
|
|
|
2271
|
-
@staticmethod
|
|
2272
2258
|
def _update_existing_adapter(
|
|
2273
|
-
|
|
2274
|
-
wrapped_model: DeepSpeedEngine,
|
|
2259
|
+
self,
|
|
2275
2260
|
checkpoint_dir: str,
|
|
2276
2261
|
adapter_name: str,
|
|
2277
2262
|
) -> None:
|
|
2278
2263
|
"""
|
|
2279
2264
|
Overwrite weights of an existing adapter in-place without creating new parameters.
|
|
2280
2265
|
|
|
2281
|
-
:param accelerator: Accelerator
|
|
2282
|
-
:type accelerator: Accelerator
|
|
2283
|
-
:param wrapped_model: Wrapped model
|
|
2284
|
-
:type wrapped_model: DeepSpeedEngine
|
|
2285
2266
|
:param checkpoint_dir: Checkpoint directory
|
|
2286
2267
|
:type checkpoint_dir: str
|
|
2287
2268
|
:param adapter_name: Adapter name
|
|
@@ -2290,7 +2271,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2290
2271
|
:return: None
|
|
2291
2272
|
:rtype: None
|
|
2292
2273
|
"""
|
|
2293
|
-
base_model = accelerator.unwrap_model(
|
|
2274
|
+
base_model = self.accelerator.unwrap_model(self.actor)
|
|
2294
2275
|
if hasattr(base_model, "module"):
|
|
2295
2276
|
base_model = base_model.module
|
|
2296
2277
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
5
5
|
from peft import PeftModel
|
|
@@ -10,9 +10,9 @@ from agilerl.protocols import EvolvableAlgorithm
|
|
|
10
10
|
from agilerl.typing import OptimizerType, StateDict
|
|
11
11
|
from agilerl.utils.llm_utils import DummyOptimizer
|
|
12
12
|
|
|
13
|
-
ModuleList =
|
|
13
|
+
ModuleList = list[EvolvableModule]
|
|
14
14
|
_Optimizer = Union[
|
|
15
|
-
|
|
15
|
+
type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
|
|
16
16
|
]
|
|
17
17
|
_Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModel]
|
|
18
18
|
|
|
@@ -21,7 +21,7 @@ def init_from_multiple(
|
|
|
21
21
|
networks: ModuleList,
|
|
22
22
|
optimizer_cls: OptimizerType,
|
|
23
23
|
lr: float,
|
|
24
|
-
optimizer_kwargs:
|
|
24
|
+
optimizer_kwargs: dict[str, Any],
|
|
25
25
|
) -> Optimizer:
|
|
26
26
|
"""
|
|
27
27
|
Initialize an optimizer from a list of networks.
|
|
@@ -33,7 +33,7 @@ def init_from_multiple(
|
|
|
33
33
|
:param lr: The learning rate of the optimizer.
|
|
34
34
|
:type lr: float
|
|
35
35
|
:param optimizer_kwargs: The keyword arguments to be passed to the optimizer.
|
|
36
|
-
:type optimizer_kwargs:
|
|
36
|
+
:type optimizer_kwargs: dict[str, Any]
|
|
37
37
|
"""
|
|
38
38
|
opt_args = []
|
|
39
39
|
for i, net in enumerate(networks):
|
|
@@ -51,7 +51,7 @@ def init_from_single(
|
|
|
51
51
|
network: EvolvableModule,
|
|
52
52
|
optimizer_cls: OptimizerType,
|
|
53
53
|
lr: float,
|
|
54
|
-
optimizer_kwargs:
|
|
54
|
+
optimizer_kwargs: dict[str, Any],
|
|
55
55
|
) -> Optimizer:
|
|
56
56
|
"""
|
|
57
57
|
Initialize an optimizer from a single network.
|
|
@@ -67,15 +67,15 @@ class OptimizerWrapper:
|
|
|
67
67
|
to be able to reinitialize them after mutating an individual.
|
|
68
68
|
|
|
69
69
|
:param optimizer_cls: The optimizer class to be initialized.
|
|
70
|
-
:type optimizer_cls:
|
|
70
|
+
:type optimizer_cls: type[torch.optim.Optimizer]
|
|
71
71
|
:param networks: The network/s that the optimizer will update.
|
|
72
72
|
:type networks: EvolvableModule, ModuleDict
|
|
73
73
|
:param lr: The learning rate of the optimizer.
|
|
74
74
|
:type lr: float
|
|
75
75
|
:param optimizer_kwargs: The keyword arguments to be passed to the optimizer.
|
|
76
|
-
:type optimizer_kwargs:
|
|
76
|
+
:type optimizer_kwargs: dict[str, Any]
|
|
77
77
|
:param network_names: The attribute names of the networks in the parent container.
|
|
78
|
-
:type network_names:
|
|
78
|
+
:type network_names: list[str]
|
|
79
79
|
:param lr_name: The attribute name of the learning rate in the parent container.
|
|
80
80
|
:type lr_name: str
|
|
81
81
|
"""
|
|
@@ -87,8 +87,8 @@ class OptimizerWrapper:
|
|
|
87
87
|
optimizer_cls: _Optimizer,
|
|
88
88
|
networks: _Module,
|
|
89
89
|
lr: float,
|
|
90
|
-
optimizer_kwargs: Optional[
|
|
91
|
-
network_names: Optional[
|
|
90
|
+
optimizer_kwargs: Optional[dict[str, Any]] = None,
|
|
91
|
+
network_names: Optional[list[str]] = None,
|
|
92
92
|
lr_name: Optional[str] = None,
|
|
93
93
|
) -> None:
|
|
94
94
|
|
|
@@ -208,7 +208,7 @@ class OptimizerWrapper:
|
|
|
208
208
|
current_frame = inspect.currentframe()
|
|
209
209
|
return current_frame.f_back.f_back.f_locals["self"]
|
|
210
210
|
|
|
211
|
-
def _infer_network_attr_names(self, container: Any) ->
|
|
211
|
+
def _infer_network_attr_names(self, container: Any) -> list[str]:
|
|
212
212
|
"""
|
|
213
213
|
Infer attribute names of the networks being optimized.
|
|
214
214
|
|
|
@@ -263,7 +263,7 @@ class OptimizerWrapper:
|
|
|
263
263
|
Load the state of the optimizer from the passed state dictionary.
|
|
264
264
|
|
|
265
265
|
:param state_dict: State dictionary of the optimizer.
|
|
266
|
-
:type state_dict:
|
|
266
|
+
:type state_dict: dict[str, Any]
|
|
267
267
|
"""
|
|
268
268
|
if isinstance(self.networks[0], ModuleDict):
|
|
269
269
|
assert (
|
|
@@ -293,7 +293,7 @@ class OptimizerWrapper:
|
|
|
293
293
|
|
|
294
294
|
return self.optimizer.state_dict()
|
|
295
295
|
|
|
296
|
-
def optimizer_cls_names(self) -> Union[str,
|
|
296
|
+
def optimizer_cls_names(self) -> Union[str, dict[str, str]]:
|
|
297
297
|
"""
|
|
298
298
|
Return the names of the optimizers.
|
|
299
299
|
"""
|
|
@@ -304,7 +304,7 @@ class OptimizerWrapper:
|
|
|
304
304
|
}
|
|
305
305
|
return self.optimizer_cls.__name__
|
|
306
306
|
|
|
307
|
-
def checkpoint_dict(self, name: str) ->
|
|
307
|
+
def checkpoint_dict(self, name: str) -> dict[str, Any]:
|
|
308
308
|
"""
|
|
309
309
|
Return a dictionary of the optimizer's state and parameters.
|
|
310
310
|
|
|
@@ -312,7 +312,7 @@ class OptimizerWrapper:
|
|
|
312
312
|
:type name: str
|
|
313
313
|
|
|
314
314
|
:return: A dictionary of the optimizer's state and parameters.
|
|
315
|
-
:rtype:
|
|
315
|
+
:rtype: dict[str, Any]
|
|
316
316
|
"""
|
|
317
317
|
return {
|
|
318
318
|
f"{name}_cls": self.optimizer_cls_names(),
|