agilerl 2.4.0.dev0__tar.gz → 2.4.1.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/PKG-INFO +2 -4
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/base.py +37 -33
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dpo.py +58 -7
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/grpo.py +23 -15
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_llm.py +2 -2
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/algo_utils.py +7 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/llm_utils.py +81 -33
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/utils.py +22 -34
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/pyproject.toml +1 -1
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/LICENSE +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/README.md +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/registry.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ilql.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ippo.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/data.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/sampler.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/base.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/bert.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/configs.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/simba.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/actors.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/base.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/protocols.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/typing.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/cache.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,9 +1,8 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.
|
|
3
|
+
Version: 2.4.1.dev0
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
|
-
License-File: LICENSE
|
|
7
6
|
Author: Nick Ustaran-Anderegg
|
|
8
7
|
Author-email: dev@agilerl.com
|
|
9
8
|
Requires-Python: >=3.10,<4.0
|
|
@@ -13,7 +12,6 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.14
|
|
17
15
|
Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
|
|
18
16
|
Requires-Dist: accelerate (>=1.7.0,<2.0.0)
|
|
19
17
|
Requires-Dist: deepspeed (>=0.17.1,<0.18.0)
|
|
@@ -37,6 +37,7 @@ from torch._dynamo import OptimizedModule
|
|
|
37
37
|
from torch.nn.utils import clip_grad_norm_
|
|
38
38
|
from torch.optim import AdamW
|
|
39
39
|
from torch.optim.lr_scheduler import SequentialLR
|
|
40
|
+
from transformers import PretrainedConfig
|
|
40
41
|
from transformers.modeling_utils import PreTrainedModel
|
|
41
42
|
from vllm import LLM, SamplingParams
|
|
42
43
|
|
|
@@ -95,7 +96,11 @@ from agilerl.utils.evolvable_networks import (
|
|
|
95
96
|
is_image_space,
|
|
96
97
|
is_vector_space,
|
|
97
98
|
)
|
|
98
|
-
from agilerl.utils.llm_utils import
|
|
99
|
+
from agilerl.utils.llm_utils import (
|
|
100
|
+
DummyOptimizer,
|
|
101
|
+
create_model_from_name_or_path,
|
|
102
|
+
gather_if_zero3,
|
|
103
|
+
)
|
|
99
104
|
|
|
100
105
|
__all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
|
|
101
106
|
|
|
@@ -1782,8 +1787,6 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1782
1787
|
class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
1783
1788
|
"""Base object for all LLM algorithms in the AgileRL framework.
|
|
1784
1789
|
|
|
1785
|
-
:param observation_space: The observation space of the environment.
|
|
1786
|
-
:type observation_space: gymnasium.spaces.Space
|
|
1787
1790
|
:param action_space: The action space of the environment.
|
|
1788
1791
|
:type action_space: gymnasium.spaces.Space
|
|
1789
1792
|
:param index: The index of the algorithm.
|
|
@@ -1800,9 +1803,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1800
1803
|
|
|
1801
1804
|
def __init__(
|
|
1802
1805
|
self,
|
|
1803
|
-
observation_space: spaces.Space,
|
|
1804
|
-
action_space: spaces.Space,
|
|
1805
|
-
actor_network: PreTrainedModel,
|
|
1806
1806
|
index: int,
|
|
1807
1807
|
batch_size: int,
|
|
1808
1808
|
lr: float,
|
|
@@ -1815,6 +1815,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1815
1815
|
pad_token: str,
|
|
1816
1816
|
lora_config: LoraConfig | None,
|
|
1817
1817
|
use_separate_reference_adapter: bool,
|
|
1818
|
+
model_name: str | None = None,
|
|
1819
|
+
actor_network: PreTrainedModel | None = None,
|
|
1818
1820
|
micro_batch_size_per_gpu: int | None = None,
|
|
1819
1821
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
1820
1822
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
@@ -1822,7 +1824,13 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1822
1824
|
device: Union[str, torch.device] = "cpu",
|
|
1823
1825
|
accelerator: Optional[Accelerator] = None,
|
|
1824
1826
|
name: Optional[str] = None,
|
|
1827
|
+
model_config: dict[str, Any] | PretrainedConfig | None = None,
|
|
1828
|
+
gradient_checkpointing: bool = True,
|
|
1825
1829
|
):
|
|
1830
|
+
if model_name is None and actor_network is None:
|
|
1831
|
+
raise ValueError(
|
|
1832
|
+
"At least one of model_name or actor_network must be provided."
|
|
1833
|
+
)
|
|
1826
1834
|
if (
|
|
1827
1835
|
accelerator is not None
|
|
1828
1836
|
and cosine_lr_schedule_config is not None
|
|
@@ -1835,20 +1843,16 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1835
1843
|
cosine_lr_schedule_config = None
|
|
1836
1844
|
|
|
1837
1845
|
super().__init__(index, hp_config, device, accelerator, None, name)
|
|
1838
|
-
|
|
1839
|
-
observation_space, spaces.Space
|
|
1840
|
-
), "Observation space must be an instance of gymnasium.spaces.Space."
|
|
1841
|
-
assert isinstance(
|
|
1842
|
-
action_space, spaces.Space
|
|
1843
|
-
), "Action space must be an instance of gymnasium.spaces.Space."
|
|
1844
|
-
|
|
1845
|
-
self.observation_space = observation_space
|
|
1846
|
-
self.action_space = action_space
|
|
1846
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
1847
1847
|
self.zero_stage = None
|
|
1848
1848
|
self.reference_update_tracker = 0 # Updated every time the reference policy is updated which is updated each time we pass through the train dataset
|
|
1849
1849
|
self.calc_position_embeddings = calc_position_embeddings
|
|
1850
1850
|
self.pad_token_id = pad_token_id
|
|
1851
1851
|
self.pad_token = pad_token
|
|
1852
|
+
self.pretrained_model_name_or_path = (
|
|
1853
|
+
model_name if model_name is not None else actor_network.name_or_path
|
|
1854
|
+
)
|
|
1855
|
+
self.model_config = model_config
|
|
1852
1856
|
|
|
1853
1857
|
if not clone and reduce_memory_peak and micro_batch_size_per_gpu is not None:
|
|
1854
1858
|
raise ValueError(
|
|
@@ -1858,7 +1862,9 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1858
1862
|
self._configure_batch_size(
|
|
1859
1863
|
batch_size, clone, reduce_memory_peak, micro_batch_size_per_gpu
|
|
1860
1864
|
)
|
|
1861
|
-
|
|
1865
|
+
self.batch_size = self.batch_size_per_process * (
|
|
1866
|
+
self.accelerator.num_processes if self.accelerator is not None else 1
|
|
1867
|
+
)
|
|
1862
1868
|
if self.accelerator is not None:
|
|
1863
1869
|
if (
|
|
1864
1870
|
self.accelerator.state.deepspeed_plugin.deepspeed_config.get(
|
|
@@ -1877,20 +1883,12 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1877
1883
|
|
|
1878
1884
|
if lora_config is None and not isinstance(actor_network, PeftModel):
|
|
1879
1885
|
warnings.warn(
|
|
1880
|
-
"No LoRA config provided. Using default LoRA configuration for RL finetuning."
|
|
1886
|
+
"No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
|
|
1881
1887
|
)
|
|
1882
1888
|
lora_config = LoraConfig(
|
|
1883
1889
|
r=16,
|
|
1884
|
-
lora_alpha=
|
|
1885
|
-
target_modules=
|
|
1886
|
-
"q_proj",
|
|
1887
|
-
"k_proj",
|
|
1888
|
-
"v_proj",
|
|
1889
|
-
"o_proj",
|
|
1890
|
-
"up_proj",
|
|
1891
|
-
"down_proj",
|
|
1892
|
-
"gate_proj",
|
|
1893
|
-
],
|
|
1890
|
+
lora_alpha=32,
|
|
1891
|
+
target_modules="all-linear",
|
|
1894
1892
|
task_type="CAUSAL_LM",
|
|
1895
1893
|
lora_dropout=0.05,
|
|
1896
1894
|
)
|
|
@@ -1908,7 +1906,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1908
1906
|
else:
|
|
1909
1907
|
self.max_grad_norm = max_grad_norm
|
|
1910
1908
|
self.reduce_memory_peak = reduce_memory_peak
|
|
1911
|
-
self.pretrained_model_name_or_path = actor_network.name_or_path
|
|
1912
1909
|
|
|
1913
1910
|
if self.accelerator is not None:
|
|
1914
1911
|
self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
@@ -2141,15 +2138,17 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2141
2138
|
if not is_dummy_optimizer
|
|
2142
2139
|
else type(self.actor.optimizer)
|
|
2143
2140
|
)
|
|
2144
|
-
self.
|
|
2145
|
-
|
|
2146
|
-
|
|
2141
|
+
if self.gradient_checkpointing:
|
|
2142
|
+
self.actor.module.gradient_checkpointing_enable(
|
|
2143
|
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
2144
|
+
)
|
|
2147
2145
|
else:
|
|
2148
2146
|
assert (
|
|
2149
2147
|
self.actor is not None
|
|
2150
2148
|
), "Actor is set to None, please check that the actor is defined."
|
|
2151
2149
|
self.actor = self.actor.to(self.device)
|
|
2152
|
-
self.
|
|
2150
|
+
if self.gradient_checkpointing:
|
|
2151
|
+
self.actor.gradient_checkpointing_enable()
|
|
2153
2152
|
|
|
2154
2153
|
def clean_up(self) -> None:
|
|
2155
2154
|
"""Clean up the algorithm."""
|
|
@@ -2408,7 +2407,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2408
2407
|
self.reference_update_tracker += 1
|
|
2409
2408
|
|
|
2410
2409
|
def _initialize_actors(
|
|
2411
|
-
self, base_model: PreTrainedModel, add_adapters: bool = True
|
|
2410
|
+
self, base_model: PreTrainedModel | None, add_adapters: bool = True
|
|
2412
2411
|
):
|
|
2413
2412
|
"""Initialize the actor network.
|
|
2414
2413
|
|
|
@@ -2418,6 +2417,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2418
2417
|
:type add_adapters: bool, optional
|
|
2419
2418
|
"""
|
|
2420
2419
|
|
|
2420
|
+
if base_model is None:
|
|
2421
|
+
base_model = create_model_from_name_or_path(
|
|
2422
|
+
self.pretrained_model_name_or_path
|
|
2423
|
+
)
|
|
2424
|
+
|
|
2421
2425
|
if isinstance(base_model, PeftModel) and add_adapters:
|
|
2422
2426
|
# Handles backwards compatibility with user providing a peft model as the actor network
|
|
2423
2427
|
if self.lora_config is None:
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
6
7
|
from accelerate import Accelerator
|
|
7
|
-
from gymnasium import spaces
|
|
8
8
|
from peft import LoraConfig
|
|
9
9
|
from transformers import PreTrainedModel
|
|
10
10
|
|
|
@@ -16,13 +16,62 @@ from agilerl.utils.llm_utils import PreferenceGym
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class DPO(LLMAlgorithm):
|
|
19
|
+
"""The DPO algorithm class. DPO paper: https://arxiv.org/pdf/2305.18290
|
|
20
|
+
|
|
21
|
+
:param pad_token_id: Pad token id
|
|
22
|
+
:type pad_token_id: int
|
|
23
|
+
:param pad_token: Pad token
|
|
24
|
+
:type pad_token: str
|
|
25
|
+
:param model_name: Model name
|
|
26
|
+
:type model_name: str, optional
|
|
27
|
+
:param actor_network: HuggingFace LLM
|
|
28
|
+
:type actor_network: PreTrainedModel
|
|
29
|
+
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
30
|
+
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
31
|
+
:type hp_config: HyperparameterConfig, optional
|
|
32
|
+
:param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
|
|
33
|
+
:type index: int, optional
|
|
34
|
+
:param batch_size: Batch size for training, defaults to 16
|
|
35
|
+
:type batch_size: int, optional
|
|
36
|
+
:param lr: Learning rate, defaults to 0.000005
|
|
37
|
+
:type lr: float, optional
|
|
38
|
+
:param beta: Beta parameter for DPO, defaults to 0.001
|
|
39
|
+
:type beta: float, optional
|
|
40
|
+
:param max_grad_norm: Maximum gradient norm, defaults to 0.1
|
|
41
|
+
:type max_grad_norm: float, optional
|
|
42
|
+
:param update_epochs: Number of update epochs, defaults to 1
|
|
43
|
+
:type update_epochs: int, optional
|
|
44
|
+
:param calc_position_embeddings: Flag to indicate if position embeddings should be calculated, defaults to True
|
|
45
|
+
:type calc_position_embeddings: bool, optional
|
|
46
|
+
:param micro_batch_size_per_gpu: Micro batch size per GPU, defaults to None
|
|
47
|
+
:type micro_batch_size_per_gpu: int, optional
|
|
48
|
+
:param reduce_memory_peak: Flag to indicate if memory peak should be reduced, defaults to False
|
|
49
|
+
:type reduce_memory_peak: bool, optional
|
|
50
|
+
:param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
|
|
51
|
+
:type device: str, optional
|
|
52
|
+
:param lora_config: Config for LoRA, defaults to None
|
|
53
|
+
:type lora_config: LoraConfig, optional
|
|
54
|
+
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
55
|
+
:type accelerator: accelerate.Accelerator(), optional
|
|
56
|
+
:param wrap: Wrap models for distributed training upon creation, defaults to True
|
|
57
|
+
:type wrap: bool, optional
|
|
58
|
+
:param clone: Flag to indicate if the instantiation is a cloning, defaults to False
|
|
59
|
+
:type clone: bool, optional
|
|
60
|
+
:param use_separate_reference_adapter: Flag to indicate if the reference policy should have a separate adapter, defaults to False
|
|
61
|
+
:type use_separate_reference_adapter: bool, optional
|
|
62
|
+
:param seed: Seed for the random number generator, defaults to 42
|
|
63
|
+
:type seed: int, optional
|
|
64
|
+
:param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
|
|
65
|
+
:type gradient_checkpointing: bool, optional
|
|
66
|
+
"""
|
|
67
|
+
|
|
19
68
|
def __init__(
|
|
20
69
|
self,
|
|
21
|
-
observation_space: spaces.Space,
|
|
22
|
-
action_space: spaces.Space,
|
|
23
|
-
actor_network: PreTrainedModel,
|
|
24
70
|
pad_token_id: int,
|
|
25
71
|
pad_token: str,
|
|
72
|
+
model_name: str | None = None,
|
|
73
|
+
actor_network: PreTrainedModel | None = None,
|
|
74
|
+
model_config: dict[str, Any] | None = None,
|
|
26
75
|
hp_config: HyperparameterConfig | None = None,
|
|
27
76
|
index: int = 0,
|
|
28
77
|
batch_size: int = 16,
|
|
@@ -40,6 +89,7 @@ class DPO(LLMAlgorithm):
|
|
|
40
89
|
clone: bool = False,
|
|
41
90
|
use_separate_reference_adapter: bool = False,
|
|
42
91
|
seed: int = 42,
|
|
92
|
+
gradient_checkpointing: bool = True,
|
|
43
93
|
):
|
|
44
94
|
device = (
|
|
45
95
|
f"cuda:{accelerator.process_index}"
|
|
@@ -47,9 +97,6 @@ class DPO(LLMAlgorithm):
|
|
|
47
97
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
48
98
|
)
|
|
49
99
|
super().__init__(
|
|
50
|
-
observation_space,
|
|
51
|
-
action_space,
|
|
52
|
-
actor_network,
|
|
53
100
|
index=index,
|
|
54
101
|
batch_size=batch_size,
|
|
55
102
|
lr=lr,
|
|
@@ -62,6 +109,9 @@ class DPO(LLMAlgorithm):
|
|
|
62
109
|
pad_token=pad_token,
|
|
63
110
|
lora_config=lora_config,
|
|
64
111
|
use_separate_reference_adapter=use_separate_reference_adapter,
|
|
112
|
+
model_name=model_name,
|
|
113
|
+
actor_network=actor_network,
|
|
114
|
+
model_config=model_config,
|
|
65
115
|
micro_batch_size_per_gpu=micro_batch_size_per_gpu,
|
|
66
116
|
cosine_lr_schedule_config=None,
|
|
67
117
|
hp_config=hp_config,
|
|
@@ -69,6 +119,7 @@ class DPO(LLMAlgorithm):
|
|
|
69
119
|
device=device,
|
|
70
120
|
accelerator=accelerator,
|
|
71
121
|
name="DPO",
|
|
122
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
72
123
|
)
|
|
73
124
|
self.beta = beta
|
|
74
125
|
self.temperature = (
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from typing import Optional, Union
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from accelerate import Accelerator
|
|
7
7
|
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
|
8
8
|
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
9
|
-
from gymnasium import spaces
|
|
10
9
|
from peft import LoraConfig, PeftModel
|
|
11
10
|
from transformers import GenerationConfig
|
|
12
11
|
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -33,12 +32,16 @@ DeepSpeedOptimizerType = Union[
|
|
|
33
32
|
class GRPO(LLMAlgorithm):
|
|
34
33
|
"""The GRPO algorithm class. GRPO paper: https://arxiv.org/pdf/2402.03300
|
|
35
34
|
|
|
36
|
-
:param
|
|
37
|
-
:type
|
|
38
|
-
:param
|
|
39
|
-
:type
|
|
35
|
+
:param pad_token_id: Pad token id
|
|
36
|
+
:type pad_token_id: int
|
|
37
|
+
:param pad_token: Pad token
|
|
38
|
+
:type pad_token: str
|
|
39
|
+
:param model_name: Model name
|
|
40
|
+
:type model_name: str, optional
|
|
40
41
|
:param actor_network: HuggingFace LLM
|
|
41
42
|
:type actor_network: PreTrainedModel
|
|
43
|
+
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
44
|
+
:type model_config: dict[str, Any], optional
|
|
42
45
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
43
46
|
:type hp_config: HyperparameterConfig, optional
|
|
44
47
|
:param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
|
|
@@ -93,15 +96,17 @@ class GRPO(LLMAlgorithm):
|
|
|
93
96
|
:type vllm_config: VLLMConfig, optional
|
|
94
97
|
:param seed: Seed for the random number generator, defaults to 42
|
|
95
98
|
:type seed: int, optional
|
|
99
|
+
:param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
|
|
100
|
+
:type gradient_checkpointing: bool, optional
|
|
96
101
|
"""
|
|
97
102
|
|
|
98
103
|
def __init__(
|
|
99
104
|
self,
|
|
100
|
-
observation_space: spaces.Space,
|
|
101
|
-
action_space: spaces.Space,
|
|
102
|
-
actor_network: PreTrainedModel,
|
|
103
105
|
pad_token_id: int,
|
|
104
106
|
pad_token: str,
|
|
107
|
+
model_name: str | None = None,
|
|
108
|
+
actor_network: PreTrainedModel | None = None,
|
|
109
|
+
model_config: dict[str, Any] | None = None,
|
|
105
110
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
106
111
|
index: int = 0,
|
|
107
112
|
batch_size: int = 16,
|
|
@@ -132,6 +137,7 @@ class GRPO(LLMAlgorithm):
|
|
|
132
137
|
use_vllm: bool = False,
|
|
133
138
|
vllm_config: Optional[VLLMConfig] = None,
|
|
134
139
|
seed: int = 42,
|
|
140
|
+
gradient_checkpointing: bool = True,
|
|
135
141
|
) -> None:
|
|
136
142
|
|
|
137
143
|
device = (
|
|
@@ -140,9 +146,6 @@ class GRPO(LLMAlgorithm):
|
|
|
140
146
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
141
147
|
)
|
|
142
148
|
super().__init__(
|
|
143
|
-
observation_space,
|
|
144
|
-
action_space,
|
|
145
|
-
actor_network,
|
|
146
149
|
index=index,
|
|
147
150
|
batch_size=batch_size,
|
|
148
151
|
lr=lr,
|
|
@@ -155,6 +158,9 @@ class GRPO(LLMAlgorithm):
|
|
|
155
158
|
pad_token=pad_token,
|
|
156
159
|
lora_config=lora_config,
|
|
157
160
|
use_separate_reference_adapter=use_separate_reference_adapter,
|
|
161
|
+
model_name=model_name,
|
|
162
|
+
actor_network=actor_network,
|
|
163
|
+
model_config=model_config,
|
|
158
164
|
micro_batch_size_per_gpu=micro_batch_size_per_gpu,
|
|
159
165
|
cosine_lr_schedule_config=cosine_lr_schedule_config,
|
|
160
166
|
wrap=wrap,
|
|
@@ -162,6 +168,7 @@ class GRPO(LLMAlgorithm):
|
|
|
162
168
|
device=device,
|
|
163
169
|
accelerator=accelerator,
|
|
164
170
|
name="GRPO",
|
|
171
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
165
172
|
)
|
|
166
173
|
assert isinstance(batch_size, int), "Batch size must be an integer."
|
|
167
174
|
assert batch_size >= 1, "Batch size must be greater than or equal to one."
|
|
@@ -179,9 +186,10 @@ class GRPO(LLMAlgorithm):
|
|
|
179
186
|
assert (
|
|
180
187
|
update_epochs >= 1
|
|
181
188
|
), "Policy update epochs must be greater than or equal to one."
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
189
|
+
if actor_network is not None:
|
|
190
|
+
assert isinstance(
|
|
191
|
+
actor_network, (PeftModel, PreTrainedModel)
|
|
192
|
+
), "Actor network must be a PeftModel or PreTrainedModel"
|
|
185
193
|
|
|
186
194
|
self.clip_coef = clip_coef
|
|
187
195
|
self.update_epochs = update_epochs
|
|
@@ -115,7 +115,7 @@ def finetune_llm_reasoning(
|
|
|
115
115
|
|
|
116
116
|
if init_hp is None:
|
|
117
117
|
init_hp = {}
|
|
118
|
-
init_hp["BATCH_SIZE_PER_GPU"] = pop[0].
|
|
118
|
+
init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size_per_process
|
|
119
119
|
init_hp["ALGO"] = pop[0].algo
|
|
120
120
|
data_increment = (
|
|
121
121
|
getattr(dist, "get_world_size", lambda: 1)() if dist.is_initialized() else 1
|
|
@@ -463,7 +463,7 @@ def finetune_llm_preference(
|
|
|
463
463
|
|
|
464
464
|
if init_hp is None:
|
|
465
465
|
init_hp = {}
|
|
466
|
-
init_hp["BATCH_SIZE_PER_GPU"] = pop[0].
|
|
466
|
+
init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size_per_process
|
|
467
467
|
init_hp["ALGO"] = pop[0].algo
|
|
468
468
|
|
|
469
469
|
data_increment = accelerator.num_processes if accelerator is not None else 1
|
|
@@ -1328,6 +1328,13 @@ class VLLMConfig:
|
|
|
1328
1328
|
max_num_seqs: int = 8
|
|
1329
1329
|
sleep_mode: bool = False
|
|
1330
1330
|
|
|
1331
|
+
def __post_init__(self):
|
|
1332
|
+
if self.sleep_mode:
|
|
1333
|
+
warnings.warn(
|
|
1334
|
+
"""VLLM sleep mode cannot be used with populations of agents on a single device. To use sleep mode, ensure,
|
|
1335
|
+
you are training a single agent or, alternatively, use a different device for each agent."""
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1331
1338
|
|
|
1332
1339
|
def create_warmup_cosine_scheduler(
|
|
1333
1340
|
optimizer: torch.optim.Optimizer,
|
|
@@ -11,12 +11,53 @@ import torch.nn as nn
|
|
|
11
11
|
from accelerate import Accelerator
|
|
12
12
|
from datasets import Dataset
|
|
13
13
|
from torch.utils.data import DataLoader
|
|
14
|
-
from transformers import AutoTokenizer
|
|
14
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
15
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
15
16
|
from transformers.tokenization_utils_base import BatchEncoding
|
|
16
17
|
|
|
17
18
|
from agilerl.typing import PreferencePrompts, ReasoningPrompts
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
def apply_chat_template(
|
|
22
|
+
conversation_template: list[dict[str, str]],
|
|
23
|
+
question: str,
|
|
24
|
+
answer: str,
|
|
25
|
+
tokenizer: AutoTokenizer,
|
|
26
|
+
) -> BatchEncoding:
|
|
27
|
+
"""
|
|
28
|
+
Create and tokenize a chat template for a reaosning task.
|
|
29
|
+
|
|
30
|
+
:param conversation_template: The conversation template to be tokenized.
|
|
31
|
+
:type conversation_template: list[dict[str, str]]
|
|
32
|
+
:param question: The question to be tokenized.
|
|
33
|
+
:type question: str
|
|
34
|
+
:param answer: The answer to be tokenized.
|
|
35
|
+
:type answer: str
|
|
36
|
+
:param tokenizer: The tokenizer to be used.
|
|
37
|
+
:type tokenizer: AutoTokenizer
|
|
38
|
+
:return: The tokenized prompt.
|
|
39
|
+
:rtype: BatchEncoding
|
|
40
|
+
"""
|
|
41
|
+
formatted_conversation = [
|
|
42
|
+
{
|
|
43
|
+
"role": msg["role"],
|
|
44
|
+
"content": msg["content"].format(question=question, answer=answer),
|
|
45
|
+
}
|
|
46
|
+
for msg in conversation_template
|
|
47
|
+
]
|
|
48
|
+
updated_prompt = tokenizer.apply_chat_template(
|
|
49
|
+
formatted_conversation, tokenize=False, continue_final_message=True
|
|
50
|
+
)
|
|
51
|
+
tokenized_prompt = tokenizer(
|
|
52
|
+
[updated_prompt],
|
|
53
|
+
return_tensors="pt",
|
|
54
|
+
padding=True,
|
|
55
|
+
padding_side="left",
|
|
56
|
+
return_attention_mask=True,
|
|
57
|
+
)
|
|
58
|
+
return tokenized_prompt
|
|
59
|
+
|
|
60
|
+
|
|
20
61
|
class HuggingFaceGym(gym.Env, ABC):
|
|
21
62
|
"""Abstract base class for HuggingFace Gymnasium environments.
|
|
22
63
|
|
|
@@ -28,8 +69,8 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
28
69
|
:type tokenizer: AutoTokenizer
|
|
29
70
|
:param custom_collate_fn: Custom collate function to be used for creating the batch, defaults to None
|
|
30
71
|
:type custom_collate_fn: Callable, optional
|
|
31
|
-
:param
|
|
32
|
-
:type
|
|
72
|
+
:param conversation_template: A structured conversation that acts as a base pattern for each data point.
|
|
73
|
+
:type conversation_template: list[dict[str, str]]
|
|
33
74
|
:param data_batch_size_per_gpu: DataLoader batch size, defaults to 8
|
|
34
75
|
:type data_batch_size_per_gpu: int, optional
|
|
35
76
|
:param max_context_length: Maximum context length, defaults to None
|
|
@@ -47,12 +88,7 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
47
88
|
train_dataset: Dataset,
|
|
48
89
|
test_dataset: Dataset,
|
|
49
90
|
tokenizer: AutoTokenizer,
|
|
50
|
-
|
|
51
|
-
Callable[[list[dict[str, Any]]], dict[str, Any]] | None
|
|
52
|
-
) = None,
|
|
53
|
-
apply_chat_template_fn: (
|
|
54
|
-
Callable[[str, str, AutoTokenizer], BatchEncoding] | None
|
|
55
|
-
) = None,
|
|
91
|
+
conversation_template: list[dict[str, str]],
|
|
56
92
|
data_batch_size_per_gpu: int = 8,
|
|
57
93
|
max_context_length: int | None = None,
|
|
58
94
|
min_completion_length: int = None,
|
|
@@ -70,11 +106,8 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
70
106
|
self.max_context_length = max_context_length
|
|
71
107
|
self.seed = seed
|
|
72
108
|
generator = torch.Generator().manual_seed(seed)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if apply_chat_template_fn is not None:
|
|
76
|
-
collate_kwargs["apply_chat_template_fn"] = apply_chat_template_fn
|
|
77
|
-
custom_collate_fn = self.create_collate_fn(**collate_kwargs)
|
|
109
|
+
self.conversation_template = conversation_template
|
|
110
|
+
custom_collate_fn = self.create_collate_fn(tokenizer)
|
|
78
111
|
dataloader_kwargs = {"collate_fn": custom_collate_fn}
|
|
79
112
|
train_dataset = self._filter_dataset_by_max_context_length(
|
|
80
113
|
train_dataset, "train dataset"
|
|
@@ -107,11 +140,6 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
107
140
|
self.test_dataloader_iter = iter(self.test_dataloader)
|
|
108
141
|
self.dataloader = self.train_dataloader_iter
|
|
109
142
|
self.reset_called = False
|
|
110
|
-
self.observation_space = gym.spaces.Box(low=0, high=tokenizer.vocab_size - 1)
|
|
111
|
-
self.action_space = gym.spaces.Box(
|
|
112
|
-
low=0,
|
|
113
|
-
high=tokenizer.vocab_size - 1,
|
|
114
|
-
)
|
|
115
143
|
self.evaluation_mode = False
|
|
116
144
|
self.num_epochs = 0
|
|
117
145
|
|
|
@@ -196,9 +224,11 @@ class HuggingFaceGym(gym.Env, ABC):
|
|
|
196
224
|
:rtype: tuple[Dataset, Dataset]
|
|
197
225
|
"""
|
|
198
226
|
dataset_type = "dataset" if dataset_type is None else dataset_type
|
|
199
|
-
if self.max_context_length is None:
|
|
200
|
-
return dataset
|
|
201
227
|
filter_keyword = "prompt" if "prompt" in dataset.features.keys() else "question"
|
|
228
|
+
if self.max_context_length is None or not isinstance(
|
|
229
|
+
dataset[0][filter_keyword], str
|
|
230
|
+
):
|
|
231
|
+
return dataset
|
|
202
232
|
filtered_dataset = dataset.filter(
|
|
203
233
|
lambda x: len(self.tokenizer.encode(x[filter_keyword]))
|
|
204
234
|
<= self.max_context_length - self.min_completion_length
|
|
@@ -225,10 +255,10 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
225
255
|
:type tokenizer: AutoTokenizer
|
|
226
256
|
:param reward_fn: Reward function for evaluating completions.
|
|
227
257
|
:type reward_fn: Callable[..., float]
|
|
258
|
+
:param conversation_template: A structured conversation that acts as a base pattern for each data point.
|
|
259
|
+
:type conversation_template: list[dict[str, str]]
|
|
228
260
|
:param data_batch_size_per_gpu: DataLoader batch size, defaults to 8
|
|
229
261
|
:type data_batch_size_per_gpu: int, optional
|
|
230
|
-
:param custom_collate_fn: Custom collate fxwunction to be used for creating the batch, defaults to None
|
|
231
|
-
:type custom_collate_fn: Callable, optional
|
|
232
262
|
:param accelerator: Accelerator to be used for training, defaults to None
|
|
233
263
|
:type accelerator: Accelerator, optional
|
|
234
264
|
:param max_context_length: Maximum context length, defaults to None
|
|
@@ -245,9 +275,8 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
245
275
|
test_dataset: Dataset,
|
|
246
276
|
tokenizer: AutoTokenizer,
|
|
247
277
|
reward_fn: Callable[[str, str, str], float],
|
|
248
|
-
|
|
278
|
+
conversation_template: list[dict[str, str]],
|
|
249
279
|
data_batch_size_per_gpu: int = 8,
|
|
250
|
-
custom_collate_fn: Callable | None = None,
|
|
251
280
|
accelerator: Accelerator | None = None,
|
|
252
281
|
return_raw_completions: bool = False,
|
|
253
282
|
max_context_length: int | None = None,
|
|
@@ -264,8 +293,7 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
264
293
|
train_dataset=train_dataset,
|
|
265
294
|
test_dataset=test_dataset,
|
|
266
295
|
tokenizer=tokenizer,
|
|
267
|
-
|
|
268
|
-
apply_chat_template_fn=apply_chat_template_fn,
|
|
296
|
+
conversation_template=conversation_template,
|
|
269
297
|
data_batch_size_per_gpu=data_batch_size_per_gpu,
|
|
270
298
|
max_context_length=max_context_length,
|
|
271
299
|
min_completion_length=0,
|
|
@@ -382,15 +410,12 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
382
410
|
def create_collate_fn(
|
|
383
411
|
self,
|
|
384
412
|
tokenizer: AutoTokenizer,
|
|
385
|
-
apply_chat_template_fn: Callable[[str, str, AutoTokenizer], BatchEncoding],
|
|
386
413
|
) -> Callable[[list[dict[str, Any]]], dict[str, Any]]:
|
|
387
414
|
"""
|
|
388
415
|
Create a collate function that applies the chat template to the batch of questions and answers.
|
|
389
416
|
|
|
390
417
|
:param tokenizer: Tokenizer to be used for encoding and decoding the prompts.
|
|
391
418
|
:type tokenizer: AutoTokenizer
|
|
392
|
-
:param apply_chat_template_fn: Function to apply the chat template to the batch of questions and answers.
|
|
393
|
-
:type apply_chat_template_fn: Callable[[str, str, AutoTokenizer], BatchEncoding]
|
|
394
419
|
:return: Collate function that applies the chat template to the batch of questions and answers.
|
|
395
420
|
:rtype: Callable[[list[dict[str, Any]]], dict[str, Any]]
|
|
396
421
|
"""
|
|
@@ -402,7 +427,7 @@ class ReasoningGym(HuggingFaceGym):
|
|
|
402
427
|
|
|
403
428
|
# Apply chat template to all samples
|
|
404
429
|
tokenized_prompts = [
|
|
405
|
-
|
|
430
|
+
apply_chat_template(self.conversation_template, q, a, tokenizer)
|
|
406
431
|
for q, a in zip(questions, answers)
|
|
407
432
|
]
|
|
408
433
|
|
|
@@ -451,8 +476,7 @@ class PreferenceGym(HuggingFaceGym):
|
|
|
451
476
|
train_dataset=train_dataset,
|
|
452
477
|
test_dataset=test_dataset,
|
|
453
478
|
tokenizer=tokenizer,
|
|
454
|
-
|
|
455
|
-
apply_chat_template_fn=None,
|
|
479
|
+
conversation_template=None,
|
|
456
480
|
data_batch_size_per_gpu=data_batch_size_per_gpu,
|
|
457
481
|
max_context_length=max_context_length,
|
|
458
482
|
min_completion_length=min_completion_length,
|
|
@@ -667,3 +691,27 @@ def get_state_dict(model: nn.Module) -> dict[str, torch.Tensor]:
|
|
|
667
691
|
|
|
668
692
|
with gather_if_zero3(3, list(model.parameters()), modifier_rank=0):
|
|
669
693
|
return model.state_dict()
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def create_model_from_name_or_path(
|
|
697
|
+
model_name_or_path: str, model_config: dict[str, Any] | None = None
|
|
698
|
+
) -> PreTrainedModel:
|
|
699
|
+
"""
|
|
700
|
+
Create a model from a name or path.
|
|
701
|
+
|
|
702
|
+
:param model_name_or_path: The name or path of the model to create.
|
|
703
|
+
:type model_name_or_path: str
|
|
704
|
+
:param model_config: The configuration of the model to create.
|
|
705
|
+
:type model_config: dict[str, Any ] | None
|
|
706
|
+
:return: The created model.
|
|
707
|
+
:rtype: PreTrainedModel
|
|
708
|
+
"""
|
|
709
|
+
if model_config is None:
|
|
710
|
+
model_config = {
|
|
711
|
+
"torch_dtype": torch.bfloat16,
|
|
712
|
+
"attn_implementation": "sdpa",
|
|
713
|
+
}
|
|
714
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
715
|
+
pretrained_model_name_or_path=model_name_or_path, **model_config
|
|
716
|
+
)
|
|
717
|
+
return model
|
|
@@ -36,7 +36,7 @@ from agilerl.hpo.mutation import Mutations
|
|
|
36
36
|
from agilerl.hpo.tournament import TournamentSelection
|
|
37
37
|
from agilerl.modules import EvolvableModule
|
|
38
38
|
from agilerl.typing import BPTTSequenceType, GymSpaceType, PopulationType
|
|
39
|
-
from agilerl.utils.algo_utils import CosineLRScheduleConfig,
|
|
39
|
+
from agilerl.utils.algo_utils import CosineLRScheduleConfig, clone_llm
|
|
40
40
|
from agilerl.utils.llm_utils import DummyOptimizer, get_state_dict
|
|
41
41
|
from agilerl.vector.pz_async_vec_env import AsyncPettingZooVecEnv
|
|
42
42
|
|
|
@@ -213,10 +213,10 @@ def default_progress_bar(
|
|
|
213
213
|
|
|
214
214
|
def create_population(
|
|
215
215
|
algo: str,
|
|
216
|
-
observation_space: GymSpaceType,
|
|
217
|
-
action_space: GymSpaceType,
|
|
218
216
|
net_config: Optional[dict[str, Any]],
|
|
219
217
|
INIT_HP: dict[str, Any],
|
|
218
|
+
observation_space: GymSpaceType | None = None,
|
|
219
|
+
action_space: GymSpaceType | None = None,
|
|
220
220
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
221
221
|
actor_network: Optional[EvolvableModule] = None,
|
|
222
222
|
critic_network: Optional[EvolvableModule] = None,
|
|
@@ -233,14 +233,14 @@ def create_population(
|
|
|
233
233
|
|
|
234
234
|
:param algo: RL algorithm
|
|
235
235
|
:type algo: str
|
|
236
|
-
:param observation_space: Observation space
|
|
237
|
-
:type observation_space: spaces.Space
|
|
238
|
-
:param action_space: Action space
|
|
239
|
-
:type action_space: spaces.Space
|
|
240
236
|
:param net_config: Network configuration
|
|
241
237
|
:type net_config: dict or None
|
|
242
238
|
:param INIT_HP: Initial hyperparameters
|
|
243
239
|
:type INIT_HP: dict
|
|
240
|
+
:param observation_space: Observation space
|
|
241
|
+
:type observation_space: spaces.Space
|
|
242
|
+
:param action_space: Action space
|
|
243
|
+
:type action_space: spaces.Space
|
|
244
244
|
:param hp_config: Choice of algorithm hyperparameters to mutate during training, defaults to None
|
|
245
245
|
:type hp_config: HyperparameterConfig, optional
|
|
246
246
|
:param actor_network: Custom actor network, defaults to None
|
|
@@ -572,23 +572,23 @@ def create_population(
|
|
|
572
572
|
elif algo == "GRPO":
|
|
573
573
|
for idx in range(population_size):
|
|
574
574
|
agent = GRPO(
|
|
575
|
-
observation_space=observation_space,
|
|
576
|
-
action_space=action_space,
|
|
577
575
|
actor_network=(
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
576
|
+
(
|
|
577
|
+
clone_llm(
|
|
578
|
+
actor_network,
|
|
579
|
+
zero_stage=INIT_HP.get("ZERO_STAGE", 0),
|
|
580
|
+
state_dict=(
|
|
581
|
+
actor_network.state_dict()
|
|
582
|
+
if accelerator is None
|
|
583
|
+
else get_state_dict(actor_network)
|
|
584
|
+
),
|
|
585
|
+
)
|
|
586
|
+
if idx != 0
|
|
587
|
+
else actor_network
|
|
586
588
|
)
|
|
587
|
-
if
|
|
588
|
-
else
|
|
589
|
+
if actor_network is not None
|
|
590
|
+
else None
|
|
589
591
|
),
|
|
590
|
-
pad_token_id=INIT_HP.get("PAD_TOKEN_ID"),
|
|
591
|
-
pad_token=INIT_HP.get("PAD_TOKEN"),
|
|
592
592
|
hp_config=hp_config,
|
|
593
593
|
index=idx,
|
|
594
594
|
batch_size=INIT_HP.get("BATCH_SIZE", 2),
|
|
@@ -601,7 +601,7 @@ def create_population(
|
|
|
601
601
|
temperature=INIT_HP.get("TEMPERATURE", 0.9),
|
|
602
602
|
calc_position_embeddings=INIT_HP.get("CALC_POSITION_EMBEDDINGS", True),
|
|
603
603
|
reduce_memory_peak=INIT_HP.get("REDUCE_MEMORY_PEAK", False),
|
|
604
|
-
max_output_tokens=INIT_HP.get("MAX_OUTPUT_TOKENS",
|
|
604
|
+
max_output_tokens=INIT_HP.get("MAX_OUTPUT_TOKENS", None),
|
|
605
605
|
min_output_tokens=INIT_HP.get("MIN_OUTPUT_TOKENS", None),
|
|
606
606
|
cosine_lr_schedule_config=(
|
|
607
607
|
CosineLRScheduleConfig(**INIT_HP.get("COSINE_lR_SCHEDULER", None))
|
|
@@ -610,23 +610,13 @@ def create_population(
|
|
|
610
610
|
),
|
|
611
611
|
accelerator=Accelerator() if accelerator else None,
|
|
612
612
|
device=device,
|
|
613
|
-
use_separate_reference_adapter=False,
|
|
614
613
|
max_model_len=INIT_HP.get("MAX_MODEL_LEN", None),
|
|
615
|
-
use_vllm=INIT_HP.get("USE_VLLM", False),
|
|
616
|
-
vllm_config=(
|
|
617
|
-
VLLMConfig(**INIT_HP.get("VLLM_CONFIG"))
|
|
618
|
-
if INIT_HP.get("VLLM_CONFIG", None) is not None
|
|
619
|
-
and INIT_HP.get("USE_VLLM", False)
|
|
620
|
-
else None
|
|
621
|
-
),
|
|
622
614
|
**algo_kwargs,
|
|
623
615
|
)
|
|
624
616
|
population.append(agent)
|
|
625
617
|
elif algo == "DPO":
|
|
626
618
|
for idx in range(population_size):
|
|
627
619
|
agent = DPO(
|
|
628
|
-
observation_space=observation_space,
|
|
629
|
-
action_space=action_space,
|
|
630
620
|
actor_network=(
|
|
631
621
|
clone_llm(
|
|
632
622
|
actor_network,
|
|
@@ -640,8 +630,6 @@ def create_population(
|
|
|
640
630
|
if idx != 0
|
|
641
631
|
else actor_network
|
|
642
632
|
),
|
|
643
|
-
pad_token_id=INIT_HP.get("PAD_TOKEN_ID"),
|
|
644
|
-
pad_token=INIT_HP.get("PAD_TOKEN"),
|
|
645
633
|
hp_config=hp_config,
|
|
646
634
|
index=idx,
|
|
647
635
|
batch_size=INIT_HP.get("BATCH_SIZE", 2),
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|