agilerl 2.4.1.dev0__tar.gz → 2.4.1.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/PKG-INFO +25 -10
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/README.md +12 -1
- agilerl-2.4.1.dev2/agilerl/__init__.py +18 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/core/base.py +97 -37
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/core/optimizer_wrapper.py +11 -3
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/core/registry.py +1 -1
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/dpo.py +5 -6
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/grpo.py +15 -16
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/ilql.py +14 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/protocols.py +131 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/algo_utils.py +51 -4
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/llm_utils.py +15 -46
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/utils.py +2 -2
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/pyproject.toml +15 -8
- agilerl-2.4.1.dev0/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/LICENSE +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/ippo.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/data.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/sampler.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.4.1.dev0/agilerl → agilerl-2.4.1.dev2/agilerl/data}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.4.1.dev0/agilerl/data → agilerl-2.4.1.dev2/agilerl/hpo}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/base.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/bert.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/configs.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/modules/simba.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/actors.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/base.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.4.1.dev0/agilerl/hpo → agilerl-2.4.1.dev2/agilerl/training}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_llm.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/typing.py +0 -0
- {agilerl-2.4.1.dev0/agilerl/training → agilerl-2.4.1.dev2/agilerl/utils}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/cache.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.4.1.dev0/agilerl/utils → agilerl-2.4.1.dev2/agilerl/vector}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.4.1.dev0/agilerl/vector → agilerl-2.4.1.dev2/agilerl/wrappers}/__init__.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.4.1.dev0 → agilerl-2.4.1.dev2}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,20 +1,23 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.1.
|
|
3
|
+
Version: 2.4.1.dev2
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
|
+
License-File: LICENSE
|
|
6
7
|
Author: Nick Ustaran-Anderegg
|
|
7
8
|
Author-email: dev@agilerl.com
|
|
8
|
-
Requires-Python: >=3.10,<
|
|
9
|
+
Requires-Python: >=3.10,<3.13
|
|
9
10
|
Classifier: License :: Other/Proprietary License
|
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
|
11
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
|
|
15
|
+
Provides-Extra: all
|
|
16
|
+
Provides-Extra: llm
|
|
15
17
|
Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
|
|
16
18
|
Requires-Dist: accelerate (>=1.7.0,<2.0.0)
|
|
17
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: datasets (==4.4.1) ; extra == "llm" or extra == "all"
|
|
20
|
+
Requires-Dist: deepspeed (>=0.17.1,<0.18.0) ; extra == "llm" or extra == "all"
|
|
18
21
|
Requires-Dist: dill (>=0.3.7,<0.4.0)
|
|
19
22
|
Requires-Dist: fastrand (>=1.3.0,<2.0.0)
|
|
20
23
|
Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
|
|
@@ -24,11 +27,12 @@ Requires-Dist: h5py (>=3.8.0,<4.0.0)
|
|
|
24
27
|
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
|
25
28
|
Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
|
|
26
29
|
Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
|
|
27
|
-
Requires-Dist: minari (
|
|
30
|
+
Requires-Dist: minari[all] (==0.5.2)
|
|
28
31
|
Requires-Dist: numpy (>=1.26.4,<2.0.0)
|
|
29
32
|
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
33
|
+
Requires-Dist: packaging (==25.0)
|
|
30
34
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
31
|
-
Requires-Dist: peft (>=0.
|
|
35
|
+
Requires-Dist: peft (>=0.18.0,<0.19.0) ; extra == "llm" or extra == "all"
|
|
32
36
|
Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
|
|
33
37
|
Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
|
|
34
38
|
Requires-Dist: pygame (>=2.6.0,<3.0.0)
|
|
@@ -39,9 +43,9 @@ Requires-Dist: tensordict (>=0.8,<0.9)
|
|
|
39
43
|
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
40
44
|
Requires-Dist: torch (==2.7.1)
|
|
41
45
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
42
|
-
Requires-Dist: transformers (>=4.
|
|
46
|
+
Requires-Dist: transformers (>=4.57.1,<5.0.0) ; extra == "llm" or extra == "all"
|
|
43
47
|
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
44
|
-
Requires-Dist: vllm (==0.10.0)
|
|
48
|
+
Requires-Dist: vllm (==0.10.0) ; extra == "llm" or extra == "all"
|
|
45
49
|
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
46
50
|
Description-Content-Type: text/markdown
|
|
47
51
|
|
|
@@ -95,6 +99,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
95
99
|
pip install -e .
|
|
96
100
|
```
|
|
97
101
|
|
|
102
|
+
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install agilerl[all]
|
|
106
|
+
```
|
|
107
|
+
Or in development mode:
|
|
108
|
+
```bash
|
|
109
|
+
pip install -e ".[all]"
|
|
110
|
+
```
|
|
111
|
+
|
|
98
112
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
99
113
|
|
|
100
114
|
```bash
|
|
@@ -153,11 +167,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
153
167
|
| ---------- | --------- |
|
|
154
168
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
155
169
|
|
|
156
|
-
### LLM
|
|
170
|
+
### LLM Fine-tuning Algorithms
|
|
157
171
|
|
|
158
172
|
| RL | Algorithm |
|
|
159
173
|
| ---------- | --------- |
|
|
160
174
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
175
|
+
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
161
176
|
|
|
162
177
|
|
|
163
178
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -48,6 +48,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
48
48
|
pip install -e .
|
|
49
49
|
```
|
|
50
50
|
|
|
51
|
+
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
pip install agilerl[all]
|
|
55
|
+
```
|
|
56
|
+
Or in development mode:
|
|
57
|
+
```bash
|
|
58
|
+
pip install -e ".[all]"
|
|
59
|
+
```
|
|
60
|
+
|
|
51
61
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
52
62
|
|
|
53
63
|
```bash
|
|
@@ -106,11 +116,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
106
116
|
| ---------- | --------- |
|
|
107
117
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
108
118
|
|
|
109
|
-
### LLM
|
|
119
|
+
### LLM Fine-tuning Algorithms
|
|
110
120
|
|
|
111
121
|
| RL | Algorithm |
|
|
112
122
|
| ---------- | --------- |
|
|
113
123
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
124
|
+
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
114
125
|
|
|
115
126
|
|
|
116
127
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from importlib.metadata import metadata
|
|
2
|
+
from importlib.util import find_spec
|
|
3
|
+
|
|
4
|
+
from packaging.requirements import Requirement
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_extra_dependencies(package: str, extra: str) -> list[str]:
|
|
8
|
+
requires = metadata(package).get_all("Requires-Dist") or []
|
|
9
|
+
deps = []
|
|
10
|
+
for req in requires:
|
|
11
|
+
r = Requirement(req)
|
|
12
|
+
if r.marker and r.marker.evaluate({"extra": extra}):
|
|
13
|
+
deps.append(r.name)
|
|
14
|
+
return deps
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
LLM_PACKAGES = get_extra_dependencies("agilerl", "llm")
|
|
18
|
+
HAS_LLM_DEPENDENCIES = all(find_spec(pkg) is not None for pkg in LLM_PACKAGES)
|
|
@@ -27,20 +27,14 @@ import torch
|
|
|
27
27
|
import torch.nn.functional as F
|
|
28
28
|
from accelerate import Accelerator
|
|
29
29
|
from accelerate.utils import broadcast_object_list, set_seed
|
|
30
|
-
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
31
|
-
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
32
30
|
from gymnasium import spaces
|
|
33
|
-
from peft import LoraConfig, PeftModel, get_peft_model, set_peft_model_state_dict
|
|
34
|
-
from safetensors.torch import load_file
|
|
35
31
|
from tensordict import TensorDict
|
|
36
32
|
from torch._dynamo import OptimizedModule
|
|
37
33
|
from torch.nn.utils import clip_grad_norm_
|
|
38
34
|
from torch.optim import AdamW
|
|
39
35
|
from torch.optim.lr_scheduler import SequentialLR
|
|
40
|
-
from transformers import PretrainedConfig
|
|
41
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
42
|
-
from vllm import LLM, SamplingParams
|
|
43
36
|
|
|
37
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
44
38
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
45
39
|
from agilerl.algorithms.core.registry import (
|
|
46
40
|
HyperparameterConfig,
|
|
@@ -55,7 +49,11 @@ from agilerl.protocols import (
|
|
|
55
49
|
EvolvableAttributeDict,
|
|
56
50
|
EvolvableAttributeType,
|
|
57
51
|
EvolvableModule,
|
|
52
|
+
LoraConfigProtocol,
|
|
58
53
|
ModuleDict,
|
|
54
|
+
PeftModelProtocol,
|
|
55
|
+
PretrainedConfigProtocol,
|
|
56
|
+
PreTrainedModelProtocol,
|
|
59
57
|
)
|
|
60
58
|
from agilerl.typing import (
|
|
61
59
|
ActionType,
|
|
@@ -74,6 +72,7 @@ from agilerl.typing import (
|
|
|
74
72
|
)
|
|
75
73
|
from agilerl.utils.algo_utils import (
|
|
76
74
|
CosineLRScheduleConfig,
|
|
75
|
+
DummyOptimizer,
|
|
77
76
|
VLLMConfig,
|
|
78
77
|
check_supported_space,
|
|
79
78
|
chkpt_attribute_to_device,
|
|
@@ -96,11 +95,18 @@ from agilerl.utils.evolvable_networks import (
|
|
|
96
95
|
is_image_space,
|
|
97
96
|
is_vector_space,
|
|
98
97
|
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
98
|
+
|
|
99
|
+
if HAS_LLM_DEPENDENCIES:
|
|
100
|
+
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
101
|
+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
102
|
+
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
|
103
|
+
from safetensors.torch import load_file
|
|
104
|
+
from vllm import LLM, SamplingParams
|
|
105
|
+
|
|
106
|
+
from agilerl.utils.llm_utils import (
|
|
107
|
+
create_model_from_name_or_path,
|
|
108
|
+
gather_if_zero3,
|
|
109
|
+
)
|
|
104
110
|
|
|
105
111
|
__all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
|
|
106
112
|
|
|
@@ -601,14 +607,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
601
607
|
)
|
|
602
608
|
optimizer = opt.optimizer if hasattr(opt, "optimizer") else None
|
|
603
609
|
|
|
604
|
-
if isinstance(
|
|
605
|
-
if
|
|
606
|
-
|
|
610
|
+
if isinstance(self, LLMAlgorithm):
|
|
611
|
+
if hasattr(self.actor, "optimizer"):
|
|
612
|
+
optimizer = getattr(
|
|
607
613
|
getattr(self, "actor"), "optimizer"
|
|
608
614
|
) # If the optimizer is defined in the deepspeed config, we do this
|
|
615
|
+
else:
|
|
616
|
+
optimizer = opt.optimizer
|
|
609
617
|
|
|
610
618
|
self.accelerator, self.lr_scheduler = LLMAlgorithm.update_lr(
|
|
611
|
-
|
|
619
|
+
optimizer,
|
|
612
620
|
lr=getattr(self, config.lr),
|
|
613
621
|
accelerator=self.accelerator,
|
|
614
622
|
scheduler_config=self.cosine_lr_schedule_config,
|
|
@@ -1143,6 +1151,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1143
1151
|
|
|
1144
1152
|
return self
|
|
1145
1153
|
|
|
1154
|
+
def clean_up(self) -> None:
|
|
1155
|
+
"""
|
|
1156
|
+
Clean up the algorithm by deleting the networks and optimizers.
|
|
1157
|
+
|
|
1158
|
+
:return: None
|
|
1159
|
+
:rtype: None
|
|
1160
|
+
"""
|
|
1161
|
+
for evo_attr in self.evolvable_attributes().values():
|
|
1162
|
+
del evo_attr
|
|
1163
|
+
|
|
1146
1164
|
|
|
1147
1165
|
class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
1148
1166
|
"""Base object for all single-agent algorithms in the AgileRL framework.
|
|
@@ -1799,6 +1817,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1799
1817
|
:type accelerator: Optional[Accelerator]
|
|
1800
1818
|
:param name: The name of the algorithm.
|
|
1801
1819
|
:type name: Optional[str]
|
|
1820
|
+
:param model_config: The configuration for the model.
|
|
1821
|
+
:type model_config: dict[str, Any] | PretrainedConfig | None
|
|
1822
|
+
:param gradient_checkpointing: Whether to use gradient checkpointing.
|
|
1823
|
+
:type gradient_checkpointing: bool
|
|
1802
1824
|
"""
|
|
1803
1825
|
|
|
1804
1826
|
def __init__(
|
|
@@ -1813,10 +1835,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1813
1835
|
seed: int,
|
|
1814
1836
|
pad_token_id: int,
|
|
1815
1837
|
pad_token: str,
|
|
1816
|
-
lora_config:
|
|
1838
|
+
lora_config: LoraConfigProtocol | None,
|
|
1817
1839
|
use_separate_reference_adapter: bool,
|
|
1818
1840
|
model_name: str | None = None,
|
|
1819
|
-
actor_network:
|
|
1841
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
1820
1842
|
micro_batch_size_per_gpu: int | None = None,
|
|
1821
1843
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
1822
1844
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
@@ -1824,9 +1846,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1824
1846
|
device: Union[str, torch.device] = "cpu",
|
|
1825
1847
|
accelerator: Optional[Accelerator] = None,
|
|
1826
1848
|
name: Optional[str] = None,
|
|
1827
|
-
model_config: dict[str, Any] |
|
|
1849
|
+
model_config: dict[str, Any] | PretrainedConfigProtocol | None = None,
|
|
1828
1850
|
gradient_checkpointing: bool = True,
|
|
1829
1851
|
):
|
|
1852
|
+
if not HAS_LLM_DEPENDENCIES:
|
|
1853
|
+
raise ImportError(
|
|
1854
|
+
"LLM dependencies are not installed. Please install them using `pip install agilerl[llm]`."
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1830
1857
|
if model_name is None and actor_network is None:
|
|
1831
1858
|
raise ValueError(
|
|
1832
1859
|
"At least one of model_name or actor_network must be provided."
|
|
@@ -1881,7 +1908,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1881
1908
|
)
|
|
1882
1909
|
lr = optim_lr
|
|
1883
1910
|
|
|
1884
|
-
if lora_config is None and not isinstance(actor_network,
|
|
1911
|
+
if lora_config is None and not isinstance(actor_network, PeftModelProtocol):
|
|
1885
1912
|
warnings.warn(
|
|
1886
1913
|
"No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
|
|
1887
1914
|
)
|
|
@@ -1898,15 +1925,21 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1898
1925
|
self.use_separate_reference_adapter = use_separate_reference_adapter
|
|
1899
1926
|
self.cosine_lr_schedule_config = cosine_lr_schedule_config
|
|
1900
1927
|
|
|
1901
|
-
if max_grad_norm and (accelerator is not None)
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1928
|
+
if max_grad_norm and (accelerator is not None):
|
|
1929
|
+
if accelerator.is_main_process:
|
|
1930
|
+
warnings.warn(
|
|
1931
|
+
"Argument 'max_grad_norm' will overwrite the equivalent value set for 'gradient_clipping' in the deepspeed config."
|
|
1932
|
+
)
|
|
1933
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1934
|
+
"gradient_clipping"
|
|
1935
|
+
] = max_grad_norm
|
|
1936
|
+
|
|
1937
|
+
self.max_grad_norm = max_grad_norm
|
|
1908
1938
|
self.reduce_memory_peak = reduce_memory_peak
|
|
1909
1939
|
|
|
1940
|
+
if self.accelerator is not None:
|
|
1941
|
+
self.register_mutation_hook(self._sync_deepspeed_gradient_clipping)
|
|
1942
|
+
|
|
1910
1943
|
if self.accelerator is not None:
|
|
1911
1944
|
self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1912
1945
|
"zero_optimization"
|
|
@@ -2041,7 +2074,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2041
2074
|
device_map="auto"
|
|
2042
2075
|
)
|
|
2043
2076
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
2044
|
-
model =
|
|
2077
|
+
model = PeftModelProtocol.from_pretrained(base_model, path)
|
|
2045
2078
|
"""
|
|
2046
2079
|
)
|
|
2047
2080
|
|
|
@@ -2153,6 +2186,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2153
2186
|
def clean_up(self) -> None:
|
|
2154
2187
|
"""Clean up the algorithm."""
|
|
2155
2188
|
if self.accelerator is not None:
|
|
2189
|
+
# Free up GPU memory occupied by parameters
|
|
2190
|
+
if hasattr(self.actor, "empty_partition_cache"):
|
|
2191
|
+
self.actor.empty_partition_cache()
|
|
2192
|
+
if hasattr(self.actor, "destroy"):
|
|
2193
|
+
self.actor.destroy()
|
|
2156
2194
|
(
|
|
2157
2195
|
self.actor,
|
|
2158
2196
|
self.optimizer,
|
|
@@ -2176,10 +2214,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2176
2214
|
if hasattr(self, "llm"):
|
|
2177
2215
|
del self.llm.llm_engine.model_executor
|
|
2178
2216
|
del self.llm
|
|
2179
|
-
|
|
2180
2217
|
gc.collect()
|
|
2181
2218
|
torch.cuda.empty_cache()
|
|
2182
|
-
torch.cuda.reset_peak_memory_stats()
|
|
2183
2219
|
torch.cuda.synchronize()
|
|
2184
2220
|
|
|
2185
2221
|
def clone(self, index: Optional[int] = None, wrap: bool = True):
|
|
@@ -2214,8 +2250,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2214
2250
|
input_args["wrap"] = False
|
|
2215
2251
|
input_args["clone"] = True
|
|
2216
2252
|
|
|
2217
|
-
actor:
|
|
2218
|
-
|
|
2253
|
+
actor: PeftModelProtocol = cast(
|
|
2254
|
+
PeftModelProtocol,
|
|
2219
2255
|
(
|
|
2220
2256
|
self.accelerator.unwrap_model(self.actor)
|
|
2221
2257
|
if self.accelerator is not None
|
|
@@ -2407,12 +2443,12 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2407
2443
|
self.reference_update_tracker += 1
|
|
2408
2444
|
|
|
2409
2445
|
def _initialize_actors(
|
|
2410
|
-
self, base_model:
|
|
2446
|
+
self, base_model: PreTrainedModelProtocol | None, add_adapters: bool = True
|
|
2411
2447
|
):
|
|
2412
2448
|
"""Initialize the actor network.
|
|
2413
2449
|
|
|
2414
2450
|
:param base_model: Base model
|
|
2415
|
-
:type base_model:
|
|
2451
|
+
:type base_model: PreTrainedModelProtocol
|
|
2416
2452
|
:param add_adapters: Flag to indicate if adapters should be added to the model, defaults to True
|
|
2417
2453
|
:type add_adapters: bool, optional
|
|
2418
2454
|
"""
|
|
@@ -2422,7 +2458,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2422
2458
|
self.pretrained_model_name_or_path
|
|
2423
2459
|
)
|
|
2424
2460
|
|
|
2425
|
-
if isinstance(base_model,
|
|
2461
|
+
if isinstance(base_model, PeftModelProtocol) and add_adapters:
|
|
2426
2462
|
# Handles backwards compatibility with user providing a peft model as the actor network
|
|
2427
2463
|
if self.lora_config is None:
|
|
2428
2464
|
adapter_name = list(base_model.peft_config.keys())
|
|
@@ -2432,7 +2468,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2432
2468
|
if "default" in list(base_model.peft_config.keys()):
|
|
2433
2469
|
base_model.peft_config.pop("default")
|
|
2434
2470
|
|
|
2435
|
-
self.actor:
|
|
2471
|
+
self.actor: PeftModelProtocol = (
|
|
2436
2472
|
get_peft_model(base_model, self.lora_config, adapter_name="actor")
|
|
2437
2473
|
if add_adapters
|
|
2438
2474
|
else base_model
|
|
@@ -2581,7 +2617,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2581
2617
|
def _move_model_to_vllm(self) -> None:
|
|
2582
2618
|
"""Move the deepspeed model to vllm."""
|
|
2583
2619
|
|
|
2584
|
-
# TODO: Add support for ZeRO Stage 3
|
|
2585
2620
|
if self.accelerator is not None:
|
|
2586
2621
|
self.accelerator.wait_for_everyone()
|
|
2587
2622
|
model_ref = self.accelerator.unwrap_model(self.actor)
|
|
@@ -2949,3 +2984,28 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2949
2984
|
|
|
2950
2985
|
if self.accelerator is not None:
|
|
2951
2986
|
self.accelerator.wait_for_everyone()
|
|
2987
|
+
|
|
2988
|
+
def _sync_deepspeed_gradient_clipping(self) -> None:
|
|
2989
|
+
"""Synchronizes max_grad_norm with DeepSpeed gradient_clipping config.
|
|
2990
|
+
Registered as a mutation hook to ensure consistency after mutations.
|
|
2991
|
+
"""
|
|
2992
|
+
if self.accelerator is None:
|
|
2993
|
+
return
|
|
2994
|
+
|
|
2995
|
+
if (
|
|
2996
|
+
"gradient_clipping"
|
|
2997
|
+
not in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
2998
|
+
):
|
|
2999
|
+
return
|
|
3000
|
+
|
|
3001
|
+
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
3002
|
+
if ds_config["gradient_clipping"] != self.max_grad_norm:
|
|
3003
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
3004
|
+
"gradient_clipping"
|
|
3005
|
+
] = self.max_grad_norm
|
|
3006
|
+
|
|
3007
|
+
if hasattr(self.actor, "optimizer"):
|
|
3008
|
+
if hasattr(self.actor.optimizer, "grad_clip"):
|
|
3009
|
+
self.actor.optimizer.grad_clip = self.max_grad_norm
|
|
3010
|
+
if hasattr(self.actor.optimizer, "clip_grad"):
|
|
3011
|
+
self.actor.optimizer.clip_grad = self.max_grad_norm
|
|
@@ -2,19 +2,27 @@ import inspect
|
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
from peft import PeftModel
|
|
6
5
|
from torch.optim import Optimizer
|
|
7
6
|
|
|
7
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
8
8
|
from agilerl.modules import EvolvableModule, ModuleDict
|
|
9
9
|
from agilerl.protocols import EvolvableAlgorithm
|
|
10
10
|
from agilerl.typing import OptimizerType, StateDict
|
|
11
|
-
from agilerl.utils.
|
|
11
|
+
from agilerl.utils.algo_utils import DummyOptimizer
|
|
12
|
+
|
|
13
|
+
if HAS_LLM_DEPENDENCIES:
|
|
14
|
+
from peft import PeftModel
|
|
15
|
+
|
|
16
|
+
PeftModelType = PeftModel
|
|
17
|
+
else:
|
|
18
|
+
PeftModelType = "PeftModel"
|
|
19
|
+
|
|
12
20
|
|
|
13
21
|
ModuleList = list[EvolvableModule]
|
|
14
22
|
_Optimizer = Union[
|
|
15
23
|
type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
|
|
16
24
|
]
|
|
17
|
-
_Module = Union[EvolvableModule, ModuleDict, ModuleList,
|
|
25
|
+
_Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModelType]
|
|
18
26
|
|
|
19
27
|
|
|
20
28
|
def init_from_multiple(
|
|
@@ -5,11 +5,10 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from accelerate import Accelerator
|
|
8
|
-
from peft import LoraConfig
|
|
9
|
-
from transformers import PreTrainedModel
|
|
10
8
|
|
|
11
9
|
from agilerl.algorithms.core.base import LLMAlgorithm
|
|
12
10
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
+
from agilerl.protocols import LoraConfigProtocol, PreTrainedModelProtocol
|
|
13
12
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
14
13
|
from agilerl.utils.algo_utils import get_experiences_samples
|
|
15
14
|
from agilerl.utils.llm_utils import PreferenceGym
|
|
@@ -25,7 +24,7 @@ class DPO(LLMAlgorithm):
|
|
|
25
24
|
:param model_name: Model name
|
|
26
25
|
:type model_name: str, optional
|
|
27
26
|
:param actor_network: HuggingFace LLM
|
|
28
|
-
:type actor_network:
|
|
27
|
+
:type actor_network: PreTrainedModelProtocol
|
|
29
28
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
30
29
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
31
30
|
:type hp_config: HyperparameterConfig, optional
|
|
@@ -50,7 +49,7 @@ class DPO(LLMAlgorithm):
|
|
|
50
49
|
:param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
|
|
51
50
|
:type device: str, optional
|
|
52
51
|
:param lora_config: Config for LoRA, defaults to None
|
|
53
|
-
:type lora_config:
|
|
52
|
+
:type lora_config: LoraConfigProtocol, optional
|
|
54
53
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
55
54
|
:type accelerator: accelerate.Accelerator(), optional
|
|
56
55
|
:param wrap: Wrap models for distributed training upon creation, defaults to True
|
|
@@ -70,7 +69,7 @@ class DPO(LLMAlgorithm):
|
|
|
70
69
|
pad_token_id: int,
|
|
71
70
|
pad_token: str,
|
|
72
71
|
model_name: str | None = None,
|
|
73
|
-
actor_network:
|
|
72
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
74
73
|
model_config: dict[str, Any] | None = None,
|
|
75
74
|
hp_config: HyperparameterConfig | None = None,
|
|
76
75
|
index: int = 0,
|
|
@@ -83,7 +82,7 @@ class DPO(LLMAlgorithm):
|
|
|
83
82
|
micro_batch_size_per_gpu: int | None = None,
|
|
84
83
|
reduce_memory_peak: bool = False,
|
|
85
84
|
device: str = "cpu",
|
|
86
|
-
lora_config:
|
|
85
|
+
lora_config: LoraConfigProtocol | None = None,
|
|
87
86
|
accelerator: Accelerator | None = None,
|
|
88
87
|
wrap: bool = True,
|
|
89
88
|
clone: bool = False,
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from typing import Any, Optional
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from accelerate import Accelerator
|
|
7
|
-
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
|
8
|
-
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
9
|
-
from peft import LoraConfig, PeftModel
|
|
10
|
-
from transformers import GenerationConfig
|
|
11
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
12
7
|
|
|
8
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
13
9
|
from agilerl.algorithms.core import LLMAlgorithm
|
|
14
10
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
+
from agilerl.protocols import (
|
|
12
|
+
LoraConfigProtocol,
|
|
13
|
+
PeftModelProtocol,
|
|
14
|
+
PreTrainedModelProtocol,
|
|
15
|
+
)
|
|
15
16
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
16
17
|
from agilerl.utils.algo_utils import (
|
|
17
18
|
CosineLRScheduleConfig,
|
|
@@ -23,10 +24,8 @@ from agilerl.utils.llm_utils import (
|
|
|
23
24
|
ReasoningGym,
|
|
24
25
|
)
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
DeepSpeedZeroOptimizer_Stage3, # ZeRO Stage 3 optimizer
|
|
29
|
-
]
|
|
27
|
+
if HAS_LLM_DEPENDENCIES:
|
|
28
|
+
from transformers import GenerationConfig
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
class GRPO(LLMAlgorithm):
|
|
@@ -39,7 +38,7 @@ class GRPO(LLMAlgorithm):
|
|
|
39
38
|
:param model_name: Model name
|
|
40
39
|
:type model_name: str, optional
|
|
41
40
|
:param actor_network: HuggingFace LLM
|
|
42
|
-
:type actor_network:
|
|
41
|
+
:type actor_network: PreTrainedModelProtocol
|
|
43
42
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
44
43
|
:type model_config: dict[str, Any], optional
|
|
45
44
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
@@ -77,7 +76,7 @@ class GRPO(LLMAlgorithm):
|
|
|
77
76
|
:param max_model_len: Maximum context window length, defaults to None
|
|
78
77
|
:type max_model_len: int, optional
|
|
79
78
|
:param lora_config: Config for LoRA, defaults to None
|
|
80
|
-
:type lora_config:
|
|
79
|
+
:type lora_config: LoraConfigProtocol, optional
|
|
81
80
|
:param cosine_lr_schedule_config: Config for cosine lr scheduling, defaults to None
|
|
82
81
|
:type cosine_lr_schedule_config: CosineLRScheduleConfig, optional
|
|
83
82
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
@@ -105,7 +104,7 @@ class GRPO(LLMAlgorithm):
|
|
|
105
104
|
pad_token_id: int,
|
|
106
105
|
pad_token: str,
|
|
107
106
|
model_name: str | None = None,
|
|
108
|
-
actor_network:
|
|
107
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
109
108
|
model_config: dict[str, Any] | None = None,
|
|
110
109
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
111
110
|
index: int = 0,
|
|
@@ -127,7 +126,7 @@ class GRPO(LLMAlgorithm):
|
|
|
127
126
|
max_output_tokens: int | None = 1024,
|
|
128
127
|
min_output_tokens: Optional[int] = None,
|
|
129
128
|
max_model_len: Optional[int] = None,
|
|
130
|
-
lora_config: Optional[
|
|
129
|
+
lora_config: Optional[LoraConfigProtocol] = None,
|
|
131
130
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
132
131
|
accelerator: Optional[Accelerator] = None,
|
|
133
132
|
device: str = "cpu",
|
|
@@ -188,8 +187,8 @@ class GRPO(LLMAlgorithm):
|
|
|
188
187
|
), "Policy update epochs must be greater than or equal to one."
|
|
189
188
|
if actor_network is not None:
|
|
190
189
|
assert isinstance(
|
|
191
|
-
actor_network, (
|
|
192
|
-
), "Actor network must be a
|
|
190
|
+
actor_network, (PeftModelProtocol, PreTrainedModelProtocol)
|
|
191
|
+
), "Actor network must be a PeftModelProtocol or PreTrainedModelProtocol"
|
|
193
192
|
|
|
194
193
|
self.clip_coef = clip_coef
|
|
195
194
|
self.update_epochs = update_epochs
|
|
@@ -1223,6 +1223,20 @@ class ILQL(nn.Module):
|
|
|
1223
1223
|
self.fitness = checkpoint["fitness"]
|
|
1224
1224
|
self.steps = checkpoint["steps"]
|
|
1225
1225
|
|
|
1226
|
+
def clean_up(self) -> None:
|
|
1227
|
+
"""Clean up the networks"""
|
|
1228
|
+
del self.model
|
|
1229
|
+
del self.actor
|
|
1230
|
+
del self.actor_target
|
|
1231
|
+
del self.v
|
|
1232
|
+
del self.q
|
|
1233
|
+
del self.target_q
|
|
1234
|
+
del self.pi
|
|
1235
|
+
del self.optimizer
|
|
1236
|
+
if self.double_q:
|
|
1237
|
+
del self.q2
|
|
1238
|
+
del self.target_q2
|
|
1239
|
+
|
|
1226
1240
|
|
|
1227
1241
|
class ILQL_Policy:
|
|
1228
1242
|
def __init__(self, iql_model: ILQL, kind: str, **generation_kwargs) -> None:
|