agilerl 2.3.4.dev1__tar.gz → 2.3.5.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/PKG-INFO +3 -3
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/README.md +0 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/base.py +129 -23
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/dqn_rainbow.py +25 -24
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/grpo.py +527 -153
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ppo.py +49 -13
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/dummy.py +20 -13
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/base.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/q_networks.py +2 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_llm.py +43 -31
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/typing.py +20 -2
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/algo_utils.py +109 -25
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/llm_utils.py +37 -12
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/log_utils.py +1 -2
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/utils.py +20 -6
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/pyproject.toml +4 -2
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/LICENSE +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/registry.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ilql.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ippo.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/data.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/sampler.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/base.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/bert.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/configs.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/simba.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/actors.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/protocols.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_bandits.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_multi_agent_off_policy.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_multi_agent_on_policy.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_off_policy.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_offline.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_on_policy.py +1 -1
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/cache.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.3.
|
|
3
|
+
Version: 2.3.5.dev0
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
6
|
Author: Nick Ustaran-Anderegg
|
|
@@ -37,14 +37,14 @@ Requires-Dist: redis (>=4.4.4,<5.0.0)
|
|
|
37
37
|
Requires-Dist: scipy (>=1.12.0,<2.0.0)
|
|
38
38
|
Requires-Dist: tensordict (>=0.8,<0.9)
|
|
39
39
|
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
40
|
-
Requires-Dist: torch (==2.
|
|
40
|
+
Requires-Dist: torch (==2.7.1)
|
|
41
41
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
42
42
|
Requires-Dist: transformers (>=4.48.1,<5.0.0)
|
|
43
43
|
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
44
|
+
Requires-Dist: vllm (==0.10.0)
|
|
44
45
|
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
45
46
|
Description-Content-Type: text/markdown
|
|
46
47
|
|
|
47
|
-
# AgileRL
|
|
48
48
|
<p align="center">
|
|
49
49
|
<img src=https://user-images.githubusercontent.com/47857277/222710068-e09a4e3c-368c-458a-9e01-b68674806887.png height="120">
|
|
50
50
|
</p>
|
|
@@ -30,13 +30,16 @@ from accelerate import Accelerator
|
|
|
30
30
|
from accelerate.utils import broadcast_object_list
|
|
31
31
|
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
32
32
|
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
33
|
+
from deepspeed.runtime.engine import DeepSpeedEngine
|
|
33
34
|
from gymnasium import spaces
|
|
34
35
|
from numpy.typing import ArrayLike
|
|
35
|
-
from peft import PeftModel
|
|
36
|
+
from peft import PeftModel, set_peft_model_state_dict
|
|
37
|
+
from safetensors.torch import load_file
|
|
36
38
|
from tensordict import TensorDict
|
|
37
39
|
from torch._dynamo import OptimizedModule
|
|
38
40
|
from torch.optim import AdamW
|
|
39
41
|
from torch.optim.lr_scheduler import SequentialLR
|
|
42
|
+
from vllm.distributed.parallel_state import destroy_model_parallel
|
|
40
43
|
|
|
41
44
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
42
45
|
from agilerl.algorithms.core.registry import (
|
|
@@ -70,6 +73,7 @@ from agilerl.typing import (
|
|
|
70
73
|
)
|
|
71
74
|
from agilerl.utils.algo_utils import (
|
|
72
75
|
CosineLRScheduleConfig,
|
|
76
|
+
check_supported_space,
|
|
73
77
|
chkpt_attribute_to_device,
|
|
74
78
|
clone_llm,
|
|
75
79
|
create_warmup_cosine_scheduler,
|
|
@@ -869,7 +873,11 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
869
873
|
k: v for k, v in network_info["modules"].items() if k.startswith(name)
|
|
870
874
|
}
|
|
871
875
|
|
|
872
|
-
module_cls = net_dict
|
|
876
|
+
module_cls = net_dict.get(f"{name}_cls", None)
|
|
877
|
+
if module_cls is None:
|
|
878
|
+
# This allows us to super this method in the LLMAlgorithm class
|
|
879
|
+
# as we don't want to reinstantiate the network in this class
|
|
880
|
+
break
|
|
873
881
|
init_dict = net_dict[f"{name}_init_dict"]
|
|
874
882
|
|
|
875
883
|
module_dict_cls = net_dict.get(f"{name}_module_dict_cls", None)
|
|
@@ -1164,12 +1172,8 @@ class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1164
1172
|
|
|
1165
1173
|
super().__init__(index, hp_config, device, accelerator, torch_compiler, name)
|
|
1166
1174
|
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
), "Observation space must be an instance of gymnasium.spaces.Space."
|
|
1170
|
-
assert isinstance(
|
|
1171
|
-
action_space, spaces.Space
|
|
1172
|
-
), "Action space must be an instance of gymnasium.spaces.Space."
|
|
1175
|
+
check_supported_space(observation_space)
|
|
1176
|
+
check_supported_space(action_space)
|
|
1173
1177
|
|
|
1174
1178
|
self.observation_space = observation_space
|
|
1175
1179
|
self.action_space = action_space
|
|
@@ -1257,12 +1261,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1257
1261
|
assert len(agent_ids) == len(
|
|
1258
1262
|
observation_spaces
|
|
1259
1263
|
), "Number of agent IDs must match number of observation spaces."
|
|
1260
|
-
|
|
1261
|
-
isinstance(_space, spaces.Space) for _space in observation_spaces
|
|
1262
|
-
), "Observation spaces must be instances of gymnasium.spaces.Space."
|
|
1263
|
-
assert all(
|
|
1264
|
-
isinstance(_space, spaces.Space) for _space in action_spaces
|
|
1265
|
-
), "Action spaces must be instances of gymnasium.spaces.Space."
|
|
1264
|
+
|
|
1266
1265
|
self.possible_observation_spaces = spaces.Dict(
|
|
1267
1266
|
{
|
|
1268
1267
|
agent_id: space
|
|
@@ -1284,6 +1283,11 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1284
1283
|
f"Observation spaces must be a list or dictionary of spaces.Space objects. Got {type(observation_spaces)}."
|
|
1285
1284
|
)
|
|
1286
1285
|
|
|
1286
|
+
for obs_space in self.possible_observation_spaces.values():
|
|
1287
|
+
check_supported_space(obs_space)
|
|
1288
|
+
for action_space in self.possible_action_spaces.values():
|
|
1289
|
+
check_supported_space(action_space)
|
|
1290
|
+
|
|
1287
1291
|
self.agent_ids = list(self.possible_observation_spaces.keys())
|
|
1288
1292
|
self.n_agents = len(self.agent_ids)
|
|
1289
1293
|
self.placeholder_value = placeholder_value
|
|
@@ -1823,8 +1827,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1823
1827
|
and self.zero_stage > 2
|
|
1824
1828
|
and self.accelerator.is_main_process
|
|
1825
1829
|
):
|
|
1826
|
-
|
|
1827
|
-
"
|
|
1830
|
+
raise NotImplementedError(
|
|
1831
|
+
"DeepSpeed ZeRO Stage 3 is not yet supported in AgileRL. This feature is in development and will be available in a future release."
|
|
1828
1832
|
)
|
|
1829
1833
|
|
|
1830
1834
|
seed = 42
|
|
@@ -1848,19 +1852,44 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1848
1852
|
|
|
1849
1853
|
# TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
|
|
1850
1854
|
# handle _save_distributed_actor if deepspeed is used.
|
|
1851
|
-
def save_checkpoint(self, path: str) -> None:
|
|
1855
|
+
def save_checkpoint(self, path: str, weights_only: bool = False) -> None:
|
|
1852
1856
|
"""
|
|
1853
1857
|
Override the save_checkpoint method to provide guidance on the correct method to use.
|
|
1854
1858
|
:param path: Location to save checkpoint at
|
|
1855
1859
|
:type path: string
|
|
1860
|
+
:param weights_only: If True, only save the weights of the model, defaults to False
|
|
1861
|
+
:type weights_only: bool, optional
|
|
1856
1862
|
"""
|
|
1863
|
+
|
|
1864
|
+
warnings.warn("weights_only default will be changed to True in the future.")
|
|
1865
|
+
|
|
1857
1866
|
if self.accelerator is not None:
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1867
|
+
if not weights_only:
|
|
1868
|
+
self._save_distributed_actor(path, tag="save_checkpoint")
|
|
1869
|
+
else:
|
|
1870
|
+
selected_adapters = (
|
|
1871
|
+
["actor", "reference"]
|
|
1872
|
+
if self.use_separate_reference_adapter
|
|
1873
|
+
else ["actor"]
|
|
1874
|
+
)
|
|
1875
|
+
self.actor.save_pretrained(
|
|
1876
|
+
save_directory=path,
|
|
1877
|
+
selected_adapters=selected_adapters,
|
|
1878
|
+
is_main_process=self.accelerator.is_main_process,
|
|
1879
|
+
)
|
|
1880
|
+
checkpoint_dict = get_checkpoint_dict(
|
|
1881
|
+
self, using_deepspeed=self.accelerator is not None
|
|
1863
1882
|
)
|
|
1883
|
+
checkpoint_dict["_weights_only"] = weights_only
|
|
1884
|
+
checkpoint_dict.pop("llm", None)
|
|
1885
|
+
checkpoint_dict.pop("tp_group", None)
|
|
1886
|
+
|
|
1887
|
+
if self.accelerator is None or self.accelerator.is_main_process:
|
|
1888
|
+
torch.save(
|
|
1889
|
+
checkpoint_dict,
|
|
1890
|
+
path + "/attributes.pt",
|
|
1891
|
+
pickle_module=dill,
|
|
1892
|
+
)
|
|
1864
1893
|
|
|
1865
1894
|
# TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
|
|
1866
1895
|
# handle _load_distributed_actor if deepspeed is used.
|
|
@@ -1872,8 +1901,27 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1872
1901
|
:type path: string
|
|
1873
1902
|
"""
|
|
1874
1903
|
if self.accelerator is not None:
|
|
1875
|
-
self._load_distributed_actor(path, tag="save_checkpoint")
|
|
1876
1904
|
checkpoint = torch.load(path + "/attributes.pt", weights_only=False)
|
|
1905
|
+
weights_only = checkpoint.get("_weights_only", False)
|
|
1906
|
+
|
|
1907
|
+
if weights_only:
|
|
1908
|
+
if self.use_separate_reference_adapter:
|
|
1909
|
+
self._update_existing_adapter(
|
|
1910
|
+
self.accelerator,
|
|
1911
|
+
self.actor,
|
|
1912
|
+
path,
|
|
1913
|
+
"reference",
|
|
1914
|
+
)
|
|
1915
|
+
|
|
1916
|
+
self._update_existing_adapter(
|
|
1917
|
+
self.accelerator,
|
|
1918
|
+
self.actor,
|
|
1919
|
+
path,
|
|
1920
|
+
"actor",
|
|
1921
|
+
)
|
|
1922
|
+
else:
|
|
1923
|
+
self._load_distributed_actor(path, tag="save_checkpoint")
|
|
1924
|
+
|
|
1877
1925
|
checkpoint["accelerator"] = (
|
|
1878
1926
|
Accelerator() if self.accelerator is not None else None
|
|
1879
1927
|
)
|
|
@@ -2040,6 +2088,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2040
2088
|
None,
|
|
2041
2089
|
None,
|
|
2042
2090
|
)
|
|
2091
|
+
if self.use_vllm:
|
|
2092
|
+
destroy_model_parallel()
|
|
2093
|
+
del self.llm.llm_engine.model_executor.driver_worker
|
|
2094
|
+
self.llm = None
|
|
2043
2095
|
gc.collect()
|
|
2044
2096
|
torch.cuda.empty_cache()
|
|
2045
2097
|
|
|
@@ -2106,11 +2158,19 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2106
2158
|
original_lr_scheduler = self.lr_scheduler
|
|
2107
2159
|
clone.lr_scheduler = None
|
|
2108
2160
|
self.lr_scheduler = None
|
|
2161
|
+
if self.use_vllm:
|
|
2162
|
+
original_llm = self.llm
|
|
2163
|
+
cloned_llm = clone.llm
|
|
2164
|
+
clone.llm = None
|
|
2165
|
+
self.llm = None
|
|
2109
2166
|
clone = EvolvableAlgorithm.copy_attributes(self, clone)
|
|
2110
2167
|
clone.accelerator = accelerator
|
|
2111
2168
|
clone.lr_scheduler = lr_scheduler
|
|
2112
2169
|
clone.lr_scheduler = cloned_lr_scheduler
|
|
2113
2170
|
self.lr_scheduler = original_lr_scheduler
|
|
2171
|
+
if self.use_vllm:
|
|
2172
|
+
clone.llm = cloned_llm
|
|
2173
|
+
self.llm = original_llm
|
|
2114
2174
|
|
|
2115
2175
|
if self.accelerator is None:
|
|
2116
2176
|
clone.optimizer.optimizer.load_state_dict(
|
|
@@ -2201,3 +2261,49 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2201
2261
|
] = lr
|
|
2202
2262
|
|
|
2203
2263
|
return accelerator, None
|
|
2264
|
+
|
|
2265
|
+
def recompile(self) -> None:
|
|
2266
|
+
"""Recompiles the algorithm."""
|
|
2267
|
+
raise NotImplementedError(
|
|
2268
|
+
"Recompile method is not available for LLM finetuning algorithms."
|
|
2269
|
+
)
|
|
2270
|
+
|
|
2271
|
+
@staticmethod
|
|
2272
|
+
def _update_existing_adapter(
|
|
2273
|
+
accelerator: Accelerator,
|
|
2274
|
+
wrapped_model: DeepSpeedEngine,
|
|
2275
|
+
checkpoint_dir: str,
|
|
2276
|
+
adapter_name: str,
|
|
2277
|
+
) -> None:
|
|
2278
|
+
"""
|
|
2279
|
+
Overwrite weights of an existing adapter in-place without creating new parameters.
|
|
2280
|
+
|
|
2281
|
+
:param accelerator: Accelerator
|
|
2282
|
+
:type accelerator: Accelerator
|
|
2283
|
+
:param wrapped_model: Wrapped model
|
|
2284
|
+
:type wrapped_model: DeepSpeedEngine
|
|
2285
|
+
:param checkpoint_dir: Checkpoint directory
|
|
2286
|
+
:type checkpoint_dir: str
|
|
2287
|
+
:param adapter_name: Adapter name
|
|
2288
|
+
:type adapter_name: str
|
|
2289
|
+
|
|
2290
|
+
:return: None
|
|
2291
|
+
:rtype: None
|
|
2292
|
+
"""
|
|
2293
|
+
base_model = accelerator.unwrap_model(wrapped_model)
|
|
2294
|
+
if hasattr(base_model, "module"):
|
|
2295
|
+
base_model = base_model.module
|
|
2296
|
+
|
|
2297
|
+
adapter_path = f"{checkpoint_dir}/{adapter_name}/adapter_model.safetensors"
|
|
2298
|
+
adapter_state = load_file(adapter_path, device="cpu")
|
|
2299
|
+
|
|
2300
|
+
with torch.no_grad():
|
|
2301
|
+
set_peft_model_state_dict(
|
|
2302
|
+
base_model, adapter_state, adapter_name=adapter_name
|
|
2303
|
+
)
|
|
2304
|
+
base_model.set_adapter(adapter_name)
|
|
2305
|
+
|
|
2306
|
+
# Make reference weights not trainable
|
|
2307
|
+
for name, param in base_model.named_parameters():
|
|
2308
|
+
if "reference" in name:
|
|
2309
|
+
param.requires_grad = False
|
|
@@ -275,23 +275,23 @@ class RainbowDQN(RLAlgorithm):
|
|
|
275
275
|
|
|
276
276
|
def _dqn_loss(
|
|
277
277
|
self,
|
|
278
|
-
|
|
278
|
+
obs: TorchObsType,
|
|
279
279
|
actions: torch.Tensor,
|
|
280
280
|
rewards: torch.Tensor,
|
|
281
|
-
|
|
281
|
+
next_obs: torch.Tensor,
|
|
282
282
|
dones: torch.Tensor,
|
|
283
283
|
gamma: float,
|
|
284
284
|
) -> torch.Tensor:
|
|
285
285
|
"""Calculates the DQN loss.
|
|
286
286
|
|
|
287
|
-
:param
|
|
288
|
-
:type
|
|
287
|
+
:param obs: Batch of current states
|
|
288
|
+
:type obs: torch.Tensor
|
|
289
289
|
:param actions: Batch of actions taken
|
|
290
290
|
:type actions: torch.Tensor
|
|
291
291
|
:param rewards: Batch of rewards received
|
|
292
292
|
:type rewards: torch.Tensor
|
|
293
|
-
:param
|
|
294
|
-
:type
|
|
293
|
+
:param next_obs: Batch of next states
|
|
294
|
+
:type next_obs: torch.Tensor
|
|
295
295
|
:param dones: Batch of done flags indicating episode termination
|
|
296
296
|
:type dones: torch.Tensor
|
|
297
297
|
:param gamma: Discount factor
|
|
@@ -299,16 +299,15 @@ class RainbowDQN(RLAlgorithm):
|
|
|
299
299
|
:return: Element-wise loss
|
|
300
300
|
:rtype: torch.Tensor
|
|
301
301
|
"""
|
|
302
|
-
|
|
303
|
-
|
|
302
|
+
obs = self.preprocess_observation(obs)
|
|
303
|
+
next_obs = self.preprocess_observation(next_obs)
|
|
304
304
|
|
|
305
305
|
with torch.no_grad():
|
|
306
|
+
# Predict next actions from next_obs
|
|
307
|
+
next_actions = self.actor(next_obs).argmax(1)
|
|
306
308
|
|
|
307
|
-
# Predict
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
# Predict the target q distribution for the same next states
|
|
311
|
-
target_q_dist = self.actor_target(next_states, q=False)
|
|
309
|
+
# Predict the target q distribution for the same next obs
|
|
310
|
+
target_q_dist = self.actor_target(next_obs, q=False)
|
|
312
311
|
|
|
313
312
|
# Index the target q_dist to select the distributions corresponding to next_actions
|
|
314
313
|
target_q_dist = target_q_dist[range(self.batch_size), next_actions]
|
|
@@ -349,7 +348,7 @@ class RainbowDQN(RLAlgorithm):
|
|
|
349
348
|
)
|
|
350
349
|
|
|
351
350
|
# Calculate the current obs
|
|
352
|
-
log_q_dist = self.actor(
|
|
351
|
+
log_q_dist = self.actor(obs, q=False, log=True)
|
|
353
352
|
log_p = log_q_dist[range(self.batch_size), actions.squeeze().long()]
|
|
354
353
|
|
|
355
354
|
# loss
|
|
@@ -375,29 +374,29 @@ class RainbowDQN(RLAlgorithm):
|
|
|
375
374
|
:rtype: Tuple[float, numpy.ndarray, numpy.ndarray]
|
|
376
375
|
"""
|
|
377
376
|
n_step = n_experiences is not None
|
|
378
|
-
|
|
377
|
+
obs = experiences["obs"]
|
|
379
378
|
actions = experiences["action"]
|
|
380
379
|
rewards = experiences["reward"]
|
|
381
|
-
|
|
380
|
+
next_obs = experiences["next_obs"]
|
|
382
381
|
dones = experiences["done"]
|
|
383
382
|
if per:
|
|
384
383
|
weights = experiences["weights"]
|
|
385
384
|
idxs = experiences["idxs"]
|
|
386
385
|
if n_step:
|
|
387
|
-
|
|
386
|
+
n_obs = n_experiences["obs"]
|
|
388
387
|
n_actions = n_experiences["action"]
|
|
389
388
|
n_rewards = n_experiences["reward"]
|
|
390
|
-
|
|
389
|
+
n_next_obs = n_experiences["next_obs"]
|
|
391
390
|
n_dones = n_experiences["done"]
|
|
392
391
|
|
|
393
392
|
if self.combined_reward or not n_step:
|
|
394
393
|
elementwise_loss = self._dqn_loss(
|
|
395
|
-
|
|
394
|
+
obs, actions, rewards, next_obs, dones, self.gamma
|
|
396
395
|
)
|
|
397
396
|
if n_step:
|
|
398
397
|
n_gamma = self.gamma**self.n_step
|
|
399
398
|
n_step_elementwise_loss = self._dqn_loss(
|
|
400
|
-
|
|
399
|
+
n_obs, n_actions, n_rewards, n_next_obs, n_dones, n_gamma
|
|
401
400
|
)
|
|
402
401
|
if self.combined_reward:
|
|
403
402
|
elementwise_loss += n_step_elementwise_loss
|
|
@@ -409,10 +408,10 @@ class RainbowDQN(RLAlgorithm):
|
|
|
409
408
|
else:
|
|
410
409
|
if n_step:
|
|
411
410
|
idxs = experiences["idxs"]
|
|
412
|
-
|
|
411
|
+
n_obs = n_experiences["obs"]
|
|
413
412
|
n_actions = n_experiences["action"]
|
|
414
413
|
n_rewards = n_experiences["reward"]
|
|
415
|
-
|
|
414
|
+
n_next_obs = n_experiences["next_obs"]
|
|
416
415
|
n_dones = n_experiences["done"]
|
|
417
416
|
else:
|
|
418
417
|
idxs = None
|
|
@@ -420,13 +419,13 @@ class RainbowDQN(RLAlgorithm):
|
|
|
420
419
|
new_priorities = None
|
|
421
420
|
if self.combined_reward or not n_step:
|
|
422
421
|
elementwise_loss = self._dqn_loss(
|
|
423
|
-
|
|
422
|
+
obs, actions, rewards, next_obs, dones, self.gamma
|
|
424
423
|
)
|
|
425
424
|
|
|
426
425
|
if n_step:
|
|
427
426
|
n_gamma = self.gamma**self.n_step
|
|
428
427
|
n_step_elementwise_loss = self._dqn_loss(
|
|
429
|
-
|
|
428
|
+
n_obs, n_actions, n_rewards, n_next_obs, n_dones, n_gamma
|
|
430
429
|
)
|
|
431
430
|
if self.combined_reward:
|
|
432
431
|
elementwise_loss += n_step_elementwise_loss
|
|
@@ -508,7 +507,9 @@ class RainbowDQN(RLAlgorithm):
|
|
|
508
507
|
) and not finished[idx]:
|
|
509
508
|
completed_episode_scores[idx] = scores[idx]
|
|
510
509
|
finished[idx] = 1
|
|
510
|
+
|
|
511
511
|
rewards.append(np.mean(completed_episode_scores))
|
|
512
|
+
|
|
512
513
|
mean_fit = np.mean(rewards)
|
|
513
514
|
self.fitness.append(mean_fit)
|
|
514
515
|
return mean_fit
|