agilerl 2.4.0.dev0__tar.gz → 2.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/PKG-INFO +23 -10
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/README.md +12 -1
- agilerl-2.4.1/agilerl/__init__.py +18 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/base.py +125 -61
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/optimizer_wrapper.py +11 -3
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/registry.py +1 -1
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dpo.py +60 -10
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/grpo.py +34 -27
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ilql.py +14 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/protocols.py +131 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_llm.py +2 -2
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/algo_utils.py +59 -5
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/llm_utils.py +94 -80
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/utils.py +23 -35
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/pyproject.toml +25 -8
- agilerl-2.4.0.dev0/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/LICENSE +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ippo.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/data.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/sampler.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.4.0.dev0/agilerl → agilerl-2.4.1/agilerl/data}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.4.0.dev0/agilerl/data → agilerl-2.4.1/agilerl/hpo}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/base.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/bert.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/configs.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/simba.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/actors.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/base.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.4.0.dev0/agilerl/hpo → agilerl-2.4.1/agilerl/training}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/typing.py +0 -0
- {agilerl-2.4.0.dev0/agilerl/training → agilerl-2.4.1/agilerl/utils}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/cache.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.4.0.dev0/agilerl/utils → agilerl-2.4.1/agilerl/vector}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.4.0.dev0/agilerl/vector → agilerl-2.4.1/agilerl/wrappers}/__init__.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,22 +1,23 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.
|
|
3
|
+
Version: 2.4.1
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Author: Nick Ustaran-Anderegg
|
|
8
8
|
Author-email: dev@agilerl.com
|
|
9
|
-
Requires-Python: >=3.10,<
|
|
9
|
+
Requires-Python: >=3.10,<3.13
|
|
10
10
|
Classifier: License :: Other/Proprietary License
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
-
|
|
16
|
-
|
|
15
|
+
Provides-Extra: all
|
|
16
|
+
Provides-Extra: llm
|
|
17
17
|
Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
|
|
18
18
|
Requires-Dist: accelerate (>=1.7.0,<2.0.0)
|
|
19
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: datasets (==4.4.1) ; extra == "llm" or extra == "all"
|
|
20
|
+
Requires-Dist: deepspeed (>=0.17.1,<0.18.0) ; extra == "llm" or extra == "all"
|
|
20
21
|
Requires-Dist: dill (>=0.3.7,<0.4.0)
|
|
21
22
|
Requires-Dist: fastrand (>=1.3.0,<2.0.0)
|
|
22
23
|
Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
|
|
@@ -26,11 +27,12 @@ Requires-Dist: h5py (>=3.8.0,<4.0.0)
|
|
|
26
27
|
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
|
27
28
|
Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
|
|
28
29
|
Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
|
|
29
|
-
Requires-Dist: minari (
|
|
30
|
+
Requires-Dist: minari[all] (==0.5.2)
|
|
30
31
|
Requires-Dist: numpy (>=1.26.4,<2.0.0)
|
|
31
32
|
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
33
|
+
Requires-Dist: packaging (>=20.0)
|
|
32
34
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
33
|
-
Requires-Dist: peft (>=0.
|
|
35
|
+
Requires-Dist: peft (>=0.18.0,<0.19.0) ; extra == "llm" or extra == "all"
|
|
34
36
|
Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
|
|
35
37
|
Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
|
|
36
38
|
Requires-Dist: pygame (>=2.6.0,<3.0.0)
|
|
@@ -41,9 +43,9 @@ Requires-Dist: tensordict (>=0.8,<0.9)
|
|
|
41
43
|
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
42
44
|
Requires-Dist: torch (==2.7.1)
|
|
43
45
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
44
|
-
Requires-Dist: transformers (>=4.
|
|
46
|
+
Requires-Dist: transformers (>=4.57.1,<5.0.0) ; extra == "llm" or extra == "all"
|
|
45
47
|
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
46
|
-
Requires-Dist: vllm (==0.10.0)
|
|
48
|
+
Requires-Dist: vllm (==0.10.0) ; extra == "llm" or extra == "all"
|
|
47
49
|
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
48
50
|
Description-Content-Type: text/markdown
|
|
49
51
|
|
|
@@ -97,6 +99,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
97
99
|
pip install -e .
|
|
98
100
|
```
|
|
99
101
|
|
|
102
|
+
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install agilerl[all]
|
|
106
|
+
```
|
|
107
|
+
Or in development mode:
|
|
108
|
+
```bash
|
|
109
|
+
pip install -e ".[all]"
|
|
110
|
+
```
|
|
111
|
+
|
|
100
112
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
101
113
|
|
|
102
114
|
```bash
|
|
@@ -155,11 +167,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
155
167
|
| ---------- | --------- |
|
|
156
168
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
157
169
|
|
|
158
|
-
### LLM
|
|
170
|
+
### LLM Fine-tuning Algorithms
|
|
159
171
|
|
|
160
172
|
| RL | Algorithm |
|
|
161
173
|
| ---------- | --------- |
|
|
162
174
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
175
|
+
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
163
176
|
|
|
164
177
|
|
|
165
178
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -48,6 +48,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
48
48
|
pip install -e .
|
|
49
49
|
```
|
|
50
50
|
|
|
51
|
+
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
pip install agilerl[all]
|
|
55
|
+
```
|
|
56
|
+
Or in development mode:
|
|
57
|
+
```bash
|
|
58
|
+
pip install -e ".[all]"
|
|
59
|
+
```
|
|
60
|
+
|
|
51
61
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
52
62
|
|
|
53
63
|
```bash
|
|
@@ -106,11 +116,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
106
116
|
| ---------- | --------- |
|
|
107
117
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
108
118
|
|
|
109
|
-
### LLM
|
|
119
|
+
### LLM Fine-tuning Algorithms
|
|
110
120
|
|
|
111
121
|
| RL | Algorithm |
|
|
112
122
|
| ---------- | --------- |
|
|
113
123
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
124
|
+
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
114
125
|
|
|
115
126
|
|
|
116
127
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from importlib.metadata import metadata
|
|
2
|
+
from importlib.util import find_spec
|
|
3
|
+
|
|
4
|
+
from packaging.requirements import Requirement
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_extra_dependencies(package: str, extra: str) -> list[str]:
|
|
8
|
+
requires = metadata(package).get_all("Requires-Dist") or []
|
|
9
|
+
deps = []
|
|
10
|
+
for req in requires:
|
|
11
|
+
r = Requirement(req)
|
|
12
|
+
if r.marker and r.marker.evaluate({"extra": extra}):
|
|
13
|
+
deps.append(r.name)
|
|
14
|
+
return deps
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
LLM_PACKAGES = get_extra_dependencies("agilerl", "llm")
|
|
18
|
+
HAS_LLM_DEPENDENCIES = all(find_spec(pkg) is not None for pkg in LLM_PACKAGES)
|
|
@@ -27,19 +27,14 @@ import torch
|
|
|
27
27
|
import torch.nn.functional as F
|
|
28
28
|
from accelerate import Accelerator
|
|
29
29
|
from accelerate.utils import broadcast_object_list, set_seed
|
|
30
|
-
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
31
|
-
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
32
30
|
from gymnasium import spaces
|
|
33
|
-
from peft import LoraConfig, PeftModel, get_peft_model, set_peft_model_state_dict
|
|
34
|
-
from safetensors.torch import load_file
|
|
35
31
|
from tensordict import TensorDict
|
|
36
32
|
from torch._dynamo import OptimizedModule
|
|
37
33
|
from torch.nn.utils import clip_grad_norm_
|
|
38
34
|
from torch.optim import AdamW
|
|
39
35
|
from torch.optim.lr_scheduler import SequentialLR
|
|
40
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
41
|
-
from vllm import LLM, SamplingParams
|
|
42
36
|
|
|
37
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
43
38
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
44
39
|
from agilerl.algorithms.core.registry import (
|
|
45
40
|
HyperparameterConfig,
|
|
@@ -54,7 +49,11 @@ from agilerl.protocols import (
|
|
|
54
49
|
EvolvableAttributeDict,
|
|
55
50
|
EvolvableAttributeType,
|
|
56
51
|
EvolvableModule,
|
|
52
|
+
LoraConfigProtocol,
|
|
57
53
|
ModuleDict,
|
|
54
|
+
PeftModelProtocol,
|
|
55
|
+
PretrainedConfigProtocol,
|
|
56
|
+
PreTrainedModelProtocol,
|
|
58
57
|
)
|
|
59
58
|
from agilerl.typing import (
|
|
60
59
|
ActionType,
|
|
@@ -73,6 +72,7 @@ from agilerl.typing import (
|
|
|
73
72
|
)
|
|
74
73
|
from agilerl.utils.algo_utils import (
|
|
75
74
|
CosineLRScheduleConfig,
|
|
75
|
+
DummyOptimizer,
|
|
76
76
|
VLLMConfig,
|
|
77
77
|
check_supported_space,
|
|
78
78
|
chkpt_attribute_to_device,
|
|
@@ -95,7 +95,18 @@ from agilerl.utils.evolvable_networks import (
|
|
|
95
95
|
is_image_space,
|
|
96
96
|
is_vector_space,
|
|
97
97
|
)
|
|
98
|
-
|
|
98
|
+
|
|
99
|
+
if HAS_LLM_DEPENDENCIES:
|
|
100
|
+
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
101
|
+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
102
|
+
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
|
103
|
+
from safetensors.torch import load_file
|
|
104
|
+
from vllm import LLM, SamplingParams
|
|
105
|
+
|
|
106
|
+
from agilerl.utils.llm_utils import (
|
|
107
|
+
create_model_from_name_or_path,
|
|
108
|
+
gather_if_zero3,
|
|
109
|
+
)
|
|
99
110
|
|
|
100
111
|
__all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
|
|
101
112
|
|
|
@@ -596,14 +607,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
596
607
|
)
|
|
597
608
|
optimizer = opt.optimizer if hasattr(opt, "optimizer") else None
|
|
598
609
|
|
|
599
|
-
if isinstance(
|
|
600
|
-
if
|
|
601
|
-
|
|
610
|
+
if isinstance(self, LLMAlgorithm):
|
|
611
|
+
if hasattr(self.actor, "optimizer"):
|
|
612
|
+
optimizer = getattr(
|
|
602
613
|
getattr(self, "actor"), "optimizer"
|
|
603
614
|
) # If the optimizer is defined in the deepspeed config, we do this
|
|
615
|
+
else:
|
|
616
|
+
optimizer = opt.optimizer
|
|
604
617
|
|
|
605
618
|
self.accelerator, self.lr_scheduler = LLMAlgorithm.update_lr(
|
|
606
|
-
|
|
619
|
+
optimizer,
|
|
607
620
|
lr=getattr(self, config.lr),
|
|
608
621
|
accelerator=self.accelerator,
|
|
609
622
|
scheduler_config=self.cosine_lr_schedule_config,
|
|
@@ -1138,6 +1151,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1138
1151
|
|
|
1139
1152
|
return self
|
|
1140
1153
|
|
|
1154
|
+
def clean_up(self) -> None:
|
|
1155
|
+
"""
|
|
1156
|
+
Clean up the algorithm by deleting the networks and optimizers.
|
|
1157
|
+
|
|
1158
|
+
:return: None
|
|
1159
|
+
:rtype: None
|
|
1160
|
+
"""
|
|
1161
|
+
for evo_attr in self.evolvable_attributes().values():
|
|
1162
|
+
del evo_attr
|
|
1163
|
+
|
|
1141
1164
|
|
|
1142
1165
|
class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
1143
1166
|
"""Base object for all single-agent algorithms in the AgileRL framework.
|
|
@@ -1782,8 +1805,6 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1782
1805
|
class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
1783
1806
|
"""Base object for all LLM algorithms in the AgileRL framework.
|
|
1784
1807
|
|
|
1785
|
-
:param observation_space: The observation space of the environment.
|
|
1786
|
-
:type observation_space: gymnasium.spaces.Space
|
|
1787
1808
|
:param action_space: The action space of the environment.
|
|
1788
1809
|
:type action_space: gymnasium.spaces.Space
|
|
1789
1810
|
:param index: The index of the algorithm.
|
|
@@ -1796,13 +1817,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1796
1817
|
:type accelerator: Optional[Accelerator]
|
|
1797
1818
|
:param name: The name of the algorithm.
|
|
1798
1819
|
:type name: Optional[str]
|
|
1820
|
+
:param model_config: The configuration for the model.
|
|
1821
|
+
:type model_config: dict[str, Any] | PretrainedConfig | None
|
|
1822
|
+
:param gradient_checkpointing: Whether to use gradient checkpointing.
|
|
1823
|
+
:type gradient_checkpointing: bool
|
|
1799
1824
|
"""
|
|
1800
1825
|
|
|
1801
1826
|
def __init__(
|
|
1802
1827
|
self,
|
|
1803
|
-
observation_space: spaces.Space,
|
|
1804
|
-
action_space: spaces.Space,
|
|
1805
|
-
actor_network: PreTrainedModel,
|
|
1806
1828
|
index: int,
|
|
1807
1829
|
batch_size: int,
|
|
1808
1830
|
lr: float,
|
|
@@ -1813,8 +1835,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1813
1835
|
seed: int,
|
|
1814
1836
|
pad_token_id: int,
|
|
1815
1837
|
pad_token: str,
|
|
1816
|
-
lora_config:
|
|
1838
|
+
lora_config: LoraConfigProtocol | None,
|
|
1817
1839
|
use_separate_reference_adapter: bool,
|
|
1840
|
+
model_name: str | None = None,
|
|
1841
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
1818
1842
|
micro_batch_size_per_gpu: int | None = None,
|
|
1819
1843
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
1820
1844
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
@@ -1822,7 +1846,18 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1822
1846
|
device: Union[str, torch.device] = "cpu",
|
|
1823
1847
|
accelerator: Optional[Accelerator] = None,
|
|
1824
1848
|
name: Optional[str] = None,
|
|
1849
|
+
model_config: dict[str, Any] | PretrainedConfigProtocol | None = None,
|
|
1850
|
+
gradient_checkpointing: bool = True,
|
|
1825
1851
|
):
|
|
1852
|
+
if not HAS_LLM_DEPENDENCIES:
|
|
1853
|
+
raise ImportError(
|
|
1854
|
+
"LLM dependencies are not installed. Please install them using `pip install agilerl[llm]`."
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1857
|
+
if model_name is None and actor_network is None:
|
|
1858
|
+
raise ValueError(
|
|
1859
|
+
"At least one of model_name or actor_network must be provided."
|
|
1860
|
+
)
|
|
1826
1861
|
if (
|
|
1827
1862
|
accelerator is not None
|
|
1828
1863
|
and cosine_lr_schedule_config is not None
|
|
@@ -1835,20 +1870,16 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1835
1870
|
cosine_lr_schedule_config = None
|
|
1836
1871
|
|
|
1837
1872
|
super().__init__(index, hp_config, device, accelerator, None, name)
|
|
1838
|
-
|
|
1839
|
-
observation_space, spaces.Space
|
|
1840
|
-
), "Observation space must be an instance of gymnasium.spaces.Space."
|
|
1841
|
-
assert isinstance(
|
|
1842
|
-
action_space, spaces.Space
|
|
1843
|
-
), "Action space must be an instance of gymnasium.spaces.Space."
|
|
1844
|
-
|
|
1845
|
-
self.observation_space = observation_space
|
|
1846
|
-
self.action_space = action_space
|
|
1873
|
+
self.gradient_checkpointing = gradient_checkpointing
|
|
1847
1874
|
self.zero_stage = None
|
|
1848
1875
|
self.reference_update_tracker = 0 # Updated every time the reference policy is updated which is updated each time we pass through the train dataset
|
|
1849
1876
|
self.calc_position_embeddings = calc_position_embeddings
|
|
1850
1877
|
self.pad_token_id = pad_token_id
|
|
1851
1878
|
self.pad_token = pad_token
|
|
1879
|
+
self.pretrained_model_name_or_path = (
|
|
1880
|
+
model_name if model_name is not None else actor_network.name_or_path
|
|
1881
|
+
)
|
|
1882
|
+
self.model_config = model_config
|
|
1852
1883
|
|
|
1853
1884
|
if not clone and reduce_memory_peak and micro_batch_size_per_gpu is not None:
|
|
1854
1885
|
raise ValueError(
|
|
@@ -1858,7 +1889,9 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1858
1889
|
self._configure_batch_size(
|
|
1859
1890
|
batch_size, clone, reduce_memory_peak, micro_batch_size_per_gpu
|
|
1860
1891
|
)
|
|
1861
|
-
|
|
1892
|
+
self.batch_size = self.batch_size_per_process * (
|
|
1893
|
+
self.accelerator.num_processes if self.accelerator is not None else 1
|
|
1894
|
+
)
|
|
1862
1895
|
if self.accelerator is not None:
|
|
1863
1896
|
if (
|
|
1864
1897
|
self.accelerator.state.deepspeed_plugin.deepspeed_config.get(
|
|
@@ -1875,22 +1908,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1875
1908
|
)
|
|
1876
1909
|
lr = optim_lr
|
|
1877
1910
|
|
|
1878
|
-
if lora_config is None and not isinstance(actor_network,
|
|
1911
|
+
if lora_config is None and not isinstance(actor_network, PeftModelProtocol):
|
|
1879
1912
|
warnings.warn(
|
|
1880
|
-
"No LoRA config provided. Using default LoRA configuration for RL finetuning."
|
|
1913
|
+
"No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
|
|
1881
1914
|
)
|
|
1882
1915
|
lora_config = LoraConfig(
|
|
1883
1916
|
r=16,
|
|
1884
|
-
lora_alpha=
|
|
1885
|
-
target_modules=
|
|
1886
|
-
"q_proj",
|
|
1887
|
-
"k_proj",
|
|
1888
|
-
"v_proj",
|
|
1889
|
-
"o_proj",
|
|
1890
|
-
"up_proj",
|
|
1891
|
-
"down_proj",
|
|
1892
|
-
"gate_proj",
|
|
1893
|
-
],
|
|
1917
|
+
lora_alpha=32,
|
|
1918
|
+
target_modules="all-linear",
|
|
1894
1919
|
task_type="CAUSAL_LM",
|
|
1895
1920
|
lora_dropout=0.05,
|
|
1896
1921
|
)
|
|
@@ -1900,15 +1925,20 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1900
1925
|
self.use_separate_reference_adapter = use_separate_reference_adapter
|
|
1901
1926
|
self.cosine_lr_schedule_config = cosine_lr_schedule_config
|
|
1902
1927
|
|
|
1903
|
-
if max_grad_norm and (accelerator is not None)
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1928
|
+
if max_grad_norm and (accelerator is not None):
|
|
1929
|
+
if accelerator.is_main_process:
|
|
1930
|
+
warnings.warn(
|
|
1931
|
+
"Argument 'max_grad_norm' will overwrite the equivalent value set for 'gradient_clipping' in the deepspeed config."
|
|
1932
|
+
)
|
|
1933
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1934
|
+
"gradient_clipping"
|
|
1935
|
+
] = max_grad_norm
|
|
1936
|
+
|
|
1937
|
+
self.max_grad_norm = max_grad_norm
|
|
1910
1938
|
self.reduce_memory_peak = reduce_memory_peak
|
|
1911
|
-
|
|
1939
|
+
|
|
1940
|
+
if self.accelerator is not None:
|
|
1941
|
+
self.register_mutation_hook(self._sync_deepspeed_gradient_clipping)
|
|
1912
1942
|
|
|
1913
1943
|
if self.accelerator is not None:
|
|
1914
1944
|
self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
@@ -2044,7 +2074,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2044
2074
|
device_map="auto"
|
|
2045
2075
|
)
|
|
2046
2076
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
2047
|
-
model =
|
|
2077
|
+
model = PeftModelProtocol.from_pretrained(base_model, path)
|
|
2048
2078
|
"""
|
|
2049
2079
|
)
|
|
2050
2080
|
|
|
@@ -2141,19 +2171,26 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2141
2171
|
if not is_dummy_optimizer
|
|
2142
2172
|
else type(self.actor.optimizer)
|
|
2143
2173
|
)
|
|
2144
|
-
self.
|
|
2145
|
-
|
|
2146
|
-
|
|
2174
|
+
if self.gradient_checkpointing:
|
|
2175
|
+
self.actor.module.gradient_checkpointing_enable(
|
|
2176
|
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
2177
|
+
)
|
|
2147
2178
|
else:
|
|
2148
2179
|
assert (
|
|
2149
2180
|
self.actor is not None
|
|
2150
2181
|
), "Actor is set to None, please check that the actor is defined."
|
|
2151
2182
|
self.actor = self.actor.to(self.device)
|
|
2152
|
-
self.
|
|
2183
|
+
if self.gradient_checkpointing:
|
|
2184
|
+
self.actor.gradient_checkpointing_enable()
|
|
2153
2185
|
|
|
2154
2186
|
def clean_up(self) -> None:
|
|
2155
2187
|
"""Clean up the algorithm."""
|
|
2156
2188
|
if self.accelerator is not None:
|
|
2189
|
+
# Free up GPU memory occupied by parameters
|
|
2190
|
+
if hasattr(self.actor, "empty_partition_cache"):
|
|
2191
|
+
self.actor.empty_partition_cache()
|
|
2192
|
+
if hasattr(self.actor, "destroy"):
|
|
2193
|
+
self.actor.destroy()
|
|
2157
2194
|
(
|
|
2158
2195
|
self.actor,
|
|
2159
2196
|
self.optimizer,
|
|
@@ -2177,10 +2214,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2177
2214
|
if hasattr(self, "llm"):
|
|
2178
2215
|
del self.llm.llm_engine.model_executor
|
|
2179
2216
|
del self.llm
|
|
2180
|
-
|
|
2181
2217
|
gc.collect()
|
|
2182
2218
|
torch.cuda.empty_cache()
|
|
2183
|
-
torch.cuda.reset_peak_memory_stats()
|
|
2184
2219
|
torch.cuda.synchronize()
|
|
2185
2220
|
|
|
2186
2221
|
def clone(self, index: Optional[int] = None, wrap: bool = True):
|
|
@@ -2215,8 +2250,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2215
2250
|
input_args["wrap"] = False
|
|
2216
2251
|
input_args["clone"] = True
|
|
2217
2252
|
|
|
2218
|
-
actor:
|
|
2219
|
-
|
|
2253
|
+
actor: PeftModelProtocol = cast(
|
|
2254
|
+
PeftModelProtocol,
|
|
2220
2255
|
(
|
|
2221
2256
|
self.accelerator.unwrap_model(self.actor)
|
|
2222
2257
|
if self.accelerator is not None
|
|
@@ -2408,17 +2443,22 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2408
2443
|
self.reference_update_tracker += 1
|
|
2409
2444
|
|
|
2410
2445
|
def _initialize_actors(
|
|
2411
|
-
self, base_model:
|
|
2446
|
+
self, base_model: PreTrainedModelProtocol | None, add_adapters: bool = True
|
|
2412
2447
|
):
|
|
2413
2448
|
"""Initialize the actor network.
|
|
2414
2449
|
|
|
2415
2450
|
:param base_model: Base model
|
|
2416
|
-
:type base_model:
|
|
2451
|
+
:type base_model: PreTrainedModelProtocol
|
|
2417
2452
|
:param add_adapters: Flag to indicate if adapters should be added to the model, defaults to True
|
|
2418
2453
|
:type add_adapters: bool, optional
|
|
2419
2454
|
"""
|
|
2420
2455
|
|
|
2421
|
-
if
|
|
2456
|
+
if base_model is None:
|
|
2457
|
+
base_model = create_model_from_name_or_path(
|
|
2458
|
+
self.pretrained_model_name_or_path
|
|
2459
|
+
)
|
|
2460
|
+
|
|
2461
|
+
if isinstance(base_model, PeftModelProtocol) and add_adapters:
|
|
2422
2462
|
# Handles backwards compatibility with user providing a peft model as the actor network
|
|
2423
2463
|
if self.lora_config is None:
|
|
2424
2464
|
adapter_name = list(base_model.peft_config.keys())
|
|
@@ -2428,7 +2468,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2428
2468
|
if "default" in list(base_model.peft_config.keys()):
|
|
2429
2469
|
base_model.peft_config.pop("default")
|
|
2430
2470
|
|
|
2431
|
-
self.actor:
|
|
2471
|
+
self.actor: PeftModelProtocol = (
|
|
2432
2472
|
get_peft_model(base_model, self.lora_config, adapter_name="actor")
|
|
2433
2473
|
if add_adapters
|
|
2434
2474
|
else base_model
|
|
@@ -2577,7 +2617,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2577
2617
|
def _move_model_to_vllm(self) -> None:
|
|
2578
2618
|
"""Move the deepspeed model to vllm."""
|
|
2579
2619
|
|
|
2580
|
-
# TODO: Add support for ZeRO Stage 3
|
|
2581
2620
|
if self.accelerator is not None:
|
|
2582
2621
|
self.accelerator.wait_for_everyone()
|
|
2583
2622
|
model_ref = self.accelerator.unwrap_model(self.actor)
|
|
@@ -2945,3 +2984,28 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2945
2984
|
|
|
2946
2985
|
if self.accelerator is not None:
|
|
2947
2986
|
self.accelerator.wait_for_everyone()
|
|
2987
|
+
|
|
2988
|
+
def _sync_deepspeed_gradient_clipping(self) -> None:
|
|
2989
|
+
"""Synchronizes max_grad_norm with DeepSpeed gradient_clipping config.
|
|
2990
|
+
Registered as a mutation hook to ensure consistency after mutations.
|
|
2991
|
+
"""
|
|
2992
|
+
if self.accelerator is None:
|
|
2993
|
+
return
|
|
2994
|
+
|
|
2995
|
+
if (
|
|
2996
|
+
"gradient_clipping"
|
|
2997
|
+
not in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
2998
|
+
):
|
|
2999
|
+
return
|
|
3000
|
+
|
|
3001
|
+
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
3002
|
+
if ds_config["gradient_clipping"] != self.max_grad_norm:
|
|
3003
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
3004
|
+
"gradient_clipping"
|
|
3005
|
+
] = self.max_grad_norm
|
|
3006
|
+
|
|
3007
|
+
if hasattr(self.actor, "optimizer"):
|
|
3008
|
+
if hasattr(self.actor.optimizer, "grad_clip"):
|
|
3009
|
+
self.actor.optimizer.grad_clip = self.max_grad_norm
|
|
3010
|
+
if hasattr(self.actor.optimizer, "clip_grad"):
|
|
3011
|
+
self.actor.optimizer.clip_grad = self.max_grad_norm
|
|
@@ -2,19 +2,27 @@ import inspect
|
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
from peft import PeftModel
|
|
6
5
|
from torch.optim import Optimizer
|
|
7
6
|
|
|
7
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
8
8
|
from agilerl.modules import EvolvableModule, ModuleDict
|
|
9
9
|
from agilerl.protocols import EvolvableAlgorithm
|
|
10
10
|
from agilerl.typing import OptimizerType, StateDict
|
|
11
|
-
from agilerl.utils.
|
|
11
|
+
from agilerl.utils.algo_utils import DummyOptimizer
|
|
12
|
+
|
|
13
|
+
if HAS_LLM_DEPENDENCIES:
|
|
14
|
+
from peft import PeftModel
|
|
15
|
+
|
|
16
|
+
PeftModelType = PeftModel
|
|
17
|
+
else:
|
|
18
|
+
PeftModelType = "PeftModel"
|
|
19
|
+
|
|
12
20
|
|
|
13
21
|
ModuleList = list[EvolvableModule]
|
|
14
22
|
_Optimizer = Union[
|
|
15
23
|
type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
|
|
16
24
|
]
|
|
17
|
-
_Module = Union[EvolvableModule, ModuleDict, ModuleList,
|
|
25
|
+
_Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModelType]
|
|
18
26
|
|
|
19
27
|
|
|
20
28
|
def init_from_multiple(
|
|
@@ -1,28 +1,76 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
from typing import Any
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
6
7
|
from accelerate import Accelerator
|
|
7
|
-
from gymnasium import spaces
|
|
8
|
-
from peft import LoraConfig
|
|
9
|
-
from transformers import PreTrainedModel
|
|
10
8
|
|
|
11
9
|
from agilerl.algorithms.core.base import LLMAlgorithm
|
|
12
10
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
+
from agilerl.protocols import LoraConfigProtocol, PreTrainedModelProtocol
|
|
13
12
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
14
13
|
from agilerl.utils.algo_utils import get_experiences_samples
|
|
15
14
|
from agilerl.utils.llm_utils import PreferenceGym
|
|
16
15
|
|
|
17
16
|
|
|
18
17
|
class DPO(LLMAlgorithm):
|
|
18
|
+
"""The DPO algorithm class. DPO paper: https://arxiv.org/pdf/2305.18290
|
|
19
|
+
|
|
20
|
+
:param pad_token_id: Pad token id
|
|
21
|
+
:type pad_token_id: int
|
|
22
|
+
:param pad_token: Pad token
|
|
23
|
+
:type pad_token: str
|
|
24
|
+
:param model_name: Model name
|
|
25
|
+
:type model_name: str, optional
|
|
26
|
+
:param actor_network: HuggingFace LLM
|
|
27
|
+
:type actor_network: PreTrainedModelProtocol
|
|
28
|
+
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
29
|
+
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
30
|
+
:type hp_config: HyperparameterConfig, optional
|
|
31
|
+
:param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
|
|
32
|
+
:type index: int, optional
|
|
33
|
+
:param batch_size: Batch size for training, defaults to 16
|
|
34
|
+
:type batch_size: int, optional
|
|
35
|
+
:param lr: Learning rate, defaults to 0.000005
|
|
36
|
+
:type lr: float, optional
|
|
37
|
+
:param beta: Beta parameter for DPO, defaults to 0.001
|
|
38
|
+
:type beta: float, optional
|
|
39
|
+
:param max_grad_norm: Maximum gradient norm, defaults to 0.1
|
|
40
|
+
:type max_grad_norm: float, optional
|
|
41
|
+
:param update_epochs: Number of update epochs, defaults to 1
|
|
42
|
+
:type update_epochs: int, optional
|
|
43
|
+
:param calc_position_embeddings: Flag to indicate if position embeddings should be calculated, defaults to True
|
|
44
|
+
:type calc_position_embeddings: bool, optional
|
|
45
|
+
:param micro_batch_size_per_gpu: Micro batch size per GPU, defaults to None
|
|
46
|
+
:type micro_batch_size_per_gpu: int, optional
|
|
47
|
+
:param reduce_memory_peak: Flag to indicate if memory peak should be reduced, defaults to False
|
|
48
|
+
:type reduce_memory_peak: bool, optional
|
|
49
|
+
:param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
|
|
50
|
+
:type device: str, optional
|
|
51
|
+
:param lora_config: Config for LoRA, defaults to None
|
|
52
|
+
:type lora_config: LoraConfigProtocol, optional
|
|
53
|
+
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
54
|
+
:type accelerator: accelerate.Accelerator(), optional
|
|
55
|
+
:param wrap: Wrap models for distributed training upon creation, defaults to True
|
|
56
|
+
:type wrap: bool, optional
|
|
57
|
+
:param clone: Flag to indicate if the instantiation is a cloning, defaults to False
|
|
58
|
+
:type clone: bool, optional
|
|
59
|
+
:param use_separate_reference_adapter: Flag to indicate if the reference policy should have a separate adapter, defaults to False
|
|
60
|
+
:type use_separate_reference_adapter: bool, optional
|
|
61
|
+
:param seed: Seed for the random number generator, defaults to 42
|
|
62
|
+
:type seed: int, optional
|
|
63
|
+
:param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
|
|
64
|
+
:type gradient_checkpointing: bool, optional
|
|
65
|
+
"""
|
|
66
|
+
|
|
19
67
|
def __init__(
|
|
20
68
|
self,
|
|
21
|
-
observation_space: spaces.Space,
|
|
22
|
-
action_space: spaces.Space,
|
|
23
|
-
actor_network: PreTrainedModel,
|
|
24
69
|
pad_token_id: int,
|
|
25
70
|
pad_token: str,
|
|
71
|
+
model_name: str | None = None,
|
|
72
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
73
|
+
model_config: dict[str, Any] | None = None,
|
|
26
74
|
hp_config: HyperparameterConfig | None = None,
|
|
27
75
|
index: int = 0,
|
|
28
76
|
batch_size: int = 16,
|
|
@@ -34,12 +82,13 @@ class DPO(LLMAlgorithm):
|
|
|
34
82
|
micro_batch_size_per_gpu: int | None = None,
|
|
35
83
|
reduce_memory_peak: bool = False,
|
|
36
84
|
device: str = "cpu",
|
|
37
|
-
lora_config:
|
|
85
|
+
lora_config: LoraConfigProtocol | None = None,
|
|
38
86
|
accelerator: Accelerator | None = None,
|
|
39
87
|
wrap: bool = True,
|
|
40
88
|
clone: bool = False,
|
|
41
89
|
use_separate_reference_adapter: bool = False,
|
|
42
90
|
seed: int = 42,
|
|
91
|
+
gradient_checkpointing: bool = True,
|
|
43
92
|
):
|
|
44
93
|
device = (
|
|
45
94
|
f"cuda:{accelerator.process_index}"
|
|
@@ -47,9 +96,6 @@ class DPO(LLMAlgorithm):
|
|
|
47
96
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
48
97
|
)
|
|
49
98
|
super().__init__(
|
|
50
|
-
observation_space,
|
|
51
|
-
action_space,
|
|
52
|
-
actor_network,
|
|
53
99
|
index=index,
|
|
54
100
|
batch_size=batch_size,
|
|
55
101
|
lr=lr,
|
|
@@ -62,6 +108,9 @@ class DPO(LLMAlgorithm):
|
|
|
62
108
|
pad_token=pad_token,
|
|
63
109
|
lora_config=lora_config,
|
|
64
110
|
use_separate_reference_adapter=use_separate_reference_adapter,
|
|
111
|
+
model_name=model_name,
|
|
112
|
+
actor_network=actor_network,
|
|
113
|
+
model_config=model_config,
|
|
65
114
|
micro_batch_size_per_gpu=micro_batch_size_per_gpu,
|
|
66
115
|
cosine_lr_schedule_config=None,
|
|
67
116
|
hp_config=hp_config,
|
|
@@ -69,6 +118,7 @@ class DPO(LLMAlgorithm):
|
|
|
69
118
|
device=device,
|
|
70
119
|
accelerator=accelerator,
|
|
71
120
|
name="DPO",
|
|
121
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
72
122
|
)
|
|
73
123
|
self.beta = beta
|
|
74
124
|
self.temperature = (
|