agilerl 2.4.1__tar.gz → 2.4.1.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/PKG-INFO +10 -25
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/README.md +1 -12
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/base.py +37 -97
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +3 -11
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/registry.py +1 -1
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/dpo.py +6 -5
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/grpo.py +16 -15
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/ilql.py +0 -14
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/protocols.py +0 -131
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/algo_utils.py +5 -52
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/llm_utils.py +49 -15
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/utils.py +2 -2
- agilerl-2.4.1.dev0/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/pyproject.toml +8 -25
- agilerl-2.4.1/agilerl/__init__.py +0 -18
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/LICENSE +0 -0
- {agilerl-2.4.1/agilerl/data → agilerl-2.4.1.dev0/agilerl}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/ippo.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/data.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/sampler.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.4.1/agilerl/hpo → agilerl-2.4.1.dev0/agilerl/data}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.4.1/agilerl/training → agilerl-2.4.1.dev0/agilerl/hpo}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/base.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/bert.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/configs.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/modules/simba.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/actors.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/base.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.4.1/agilerl/utils → agilerl-2.4.1.dev0/agilerl/training}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_llm.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/typing.py +0 -0
- {agilerl-2.4.1/agilerl/vector → agilerl-2.4.1.dev0/agilerl/utils}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/cache.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.4.1/agilerl/wrappers → agilerl-2.4.1.dev0/agilerl/vector}/__init__.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.4.1 → agilerl-2.4.1.dev0}/agilerl/wrappers/utils.py +0 -0
|
@@ -1,23 +1,20 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.1
|
|
3
|
+
Version: 2.4.1.dev0
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
|
-
License-File: LICENSE
|
|
7
6
|
Author: Nick Ustaran-Anderegg
|
|
8
7
|
Author-email: dev@agilerl.com
|
|
9
|
-
Requires-Python: >=3.10,<
|
|
8
|
+
Requires-Python: >=3.10,<4.0
|
|
10
9
|
Classifier: License :: Other/Proprietary License
|
|
11
10
|
Classifier: Programming Language :: Python :: 3
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
-
|
|
16
|
-
Provides-Extra: llm
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
15
|
Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
|
|
18
16
|
Requires-Dist: accelerate (>=1.7.0,<2.0.0)
|
|
19
|
-
Requires-Dist:
|
|
20
|
-
Requires-Dist: deepspeed (>=0.17.1,<0.18.0) ; extra == "llm" or extra == "all"
|
|
17
|
+
Requires-Dist: deepspeed (>=0.17.1,<0.18.0)
|
|
21
18
|
Requires-Dist: dill (>=0.3.7,<0.4.0)
|
|
22
19
|
Requires-Dist: fastrand (>=1.3.0,<2.0.0)
|
|
23
20
|
Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
|
|
@@ -27,12 +24,11 @@ Requires-Dist: h5py (>=3.8.0,<4.0.0)
|
|
|
27
24
|
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
|
28
25
|
Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
|
|
29
26
|
Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
|
|
30
|
-
Requires-Dist: minari
|
|
27
|
+
Requires-Dist: minari (>=0.5.2,<0.6.0)
|
|
31
28
|
Requires-Dist: numpy (>=1.26.4,<2.0.0)
|
|
32
29
|
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
33
|
-
Requires-Dist: packaging (>=20.0)
|
|
34
30
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
35
|
-
Requires-Dist: peft (>=0.
|
|
31
|
+
Requires-Dist: peft (>=0.15.2,<0.16.0)
|
|
36
32
|
Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
|
|
37
33
|
Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
|
|
38
34
|
Requires-Dist: pygame (>=2.6.0,<3.0.0)
|
|
@@ -43,9 +39,9 @@ Requires-Dist: tensordict (>=0.8,<0.9)
|
|
|
43
39
|
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
44
40
|
Requires-Dist: torch (==2.7.1)
|
|
45
41
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
46
|
-
Requires-Dist: transformers (>=4.
|
|
42
|
+
Requires-Dist: transformers (>=4.48.1,<5.0.0)
|
|
47
43
|
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
48
|
-
Requires-Dist: vllm (==0.10.0)
|
|
44
|
+
Requires-Dist: vllm (==0.10.0)
|
|
49
45
|
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
50
46
|
Description-Content-Type: text/markdown
|
|
51
47
|
|
|
@@ -99,16 +95,6 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
99
95
|
pip install -e .
|
|
100
96
|
```
|
|
101
97
|
|
|
102
|
-
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
103
|
-
|
|
104
|
-
```bash
|
|
105
|
-
pip install agilerl[all]
|
|
106
|
-
```
|
|
107
|
-
Or in development mode:
|
|
108
|
-
```bash
|
|
109
|
-
pip install -e ".[all]"
|
|
110
|
-
```
|
|
111
|
-
|
|
112
98
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
113
99
|
|
|
114
100
|
```bash
|
|
@@ -167,12 +153,11 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
167
153
|
| ---------- | --------- |
|
|
168
154
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
169
155
|
|
|
170
|
-
### LLM
|
|
156
|
+
### LLM Reasoning Algorithms
|
|
171
157
|
|
|
172
158
|
| RL | Algorithm |
|
|
173
159
|
| ---------- | --------- |
|
|
174
160
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
175
|
-
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
176
161
|
|
|
177
162
|
|
|
178
163
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -48,16 +48,6 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
48
48
|
pip install -e .
|
|
49
49
|
```
|
|
50
50
|
|
|
51
|
-
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
52
|
-
|
|
53
|
-
```bash
|
|
54
|
-
pip install agilerl[all]
|
|
55
|
-
```
|
|
56
|
-
Or in development mode:
|
|
57
|
-
```bash
|
|
58
|
-
pip install -e ".[all]"
|
|
59
|
-
```
|
|
60
|
-
|
|
61
51
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
62
52
|
|
|
63
53
|
```bash
|
|
@@ -116,12 +106,11 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
116
106
|
| ---------- | --------- |
|
|
117
107
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
118
108
|
|
|
119
|
-
### LLM
|
|
109
|
+
### LLM Reasoning Algorithms
|
|
120
110
|
|
|
121
111
|
| RL | Algorithm |
|
|
122
112
|
| ---------- | --------- |
|
|
123
113
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
124
|
-
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
125
114
|
|
|
126
115
|
|
|
127
116
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -27,14 +27,20 @@ import torch
|
|
|
27
27
|
import torch.nn.functional as F
|
|
28
28
|
from accelerate import Accelerator
|
|
29
29
|
from accelerate.utils import broadcast_object_list, set_seed
|
|
30
|
+
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
31
|
+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
30
32
|
from gymnasium import spaces
|
|
33
|
+
from peft import LoraConfig, PeftModel, get_peft_model, set_peft_model_state_dict
|
|
34
|
+
from safetensors.torch import load_file
|
|
31
35
|
from tensordict import TensorDict
|
|
32
36
|
from torch._dynamo import OptimizedModule
|
|
33
37
|
from torch.nn.utils import clip_grad_norm_
|
|
34
38
|
from torch.optim import AdamW
|
|
35
39
|
from torch.optim.lr_scheduler import SequentialLR
|
|
40
|
+
from transformers import PretrainedConfig
|
|
41
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
42
|
+
from vllm import LLM, SamplingParams
|
|
36
43
|
|
|
37
|
-
from agilerl import HAS_LLM_DEPENDENCIES
|
|
38
44
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
39
45
|
from agilerl.algorithms.core.registry import (
|
|
40
46
|
HyperparameterConfig,
|
|
@@ -49,11 +55,7 @@ from agilerl.protocols import (
|
|
|
49
55
|
EvolvableAttributeDict,
|
|
50
56
|
EvolvableAttributeType,
|
|
51
57
|
EvolvableModule,
|
|
52
|
-
LoraConfigProtocol,
|
|
53
58
|
ModuleDict,
|
|
54
|
-
PeftModelProtocol,
|
|
55
|
-
PretrainedConfigProtocol,
|
|
56
|
-
PreTrainedModelProtocol,
|
|
57
59
|
)
|
|
58
60
|
from agilerl.typing import (
|
|
59
61
|
ActionType,
|
|
@@ -72,7 +74,6 @@ from agilerl.typing import (
|
|
|
72
74
|
)
|
|
73
75
|
from agilerl.utils.algo_utils import (
|
|
74
76
|
CosineLRScheduleConfig,
|
|
75
|
-
DummyOptimizer,
|
|
76
77
|
VLLMConfig,
|
|
77
78
|
check_supported_space,
|
|
78
79
|
chkpt_attribute_to_device,
|
|
@@ -95,18 +96,11 @@ from agilerl.utils.evolvable_networks import (
|
|
|
95
96
|
is_image_space,
|
|
96
97
|
is_vector_space,
|
|
97
98
|
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
from safetensors.torch import load_file
|
|
104
|
-
from vllm import LLM, SamplingParams
|
|
105
|
-
|
|
106
|
-
from agilerl.utils.llm_utils import (
|
|
107
|
-
create_model_from_name_or_path,
|
|
108
|
-
gather_if_zero3,
|
|
109
|
-
)
|
|
99
|
+
from agilerl.utils.llm_utils import (
|
|
100
|
+
DummyOptimizer,
|
|
101
|
+
create_model_from_name_or_path,
|
|
102
|
+
gather_if_zero3,
|
|
103
|
+
)
|
|
110
104
|
|
|
111
105
|
__all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
|
|
112
106
|
|
|
@@ -607,16 +601,14 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
607
601
|
)
|
|
608
602
|
optimizer = opt.optimizer if hasattr(opt, "optimizer") else None
|
|
609
603
|
|
|
610
|
-
if isinstance(
|
|
611
|
-
if
|
|
612
|
-
|
|
604
|
+
if isinstance(opt, DeepSpeedOptimizerWrapper):
|
|
605
|
+
if isinstance(opt.optimizer, DummyOptimizer):
|
|
606
|
+
opt = getattr(
|
|
613
607
|
getattr(self, "actor"), "optimizer"
|
|
614
608
|
) # If the optimizer is defined in the deepspeed config, we do this
|
|
615
|
-
else:
|
|
616
|
-
optimizer = opt.optimizer
|
|
617
609
|
|
|
618
610
|
self.accelerator, self.lr_scheduler = LLMAlgorithm.update_lr(
|
|
619
|
-
|
|
611
|
+
opt,
|
|
620
612
|
lr=getattr(self, config.lr),
|
|
621
613
|
accelerator=self.accelerator,
|
|
622
614
|
scheduler_config=self.cosine_lr_schedule_config,
|
|
@@ -1151,16 +1143,6 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1151
1143
|
|
|
1152
1144
|
return self
|
|
1153
1145
|
|
|
1154
|
-
def clean_up(self) -> None:
|
|
1155
|
-
"""
|
|
1156
|
-
Clean up the algorithm by deleting the networks and optimizers.
|
|
1157
|
-
|
|
1158
|
-
:return: None
|
|
1159
|
-
:rtype: None
|
|
1160
|
-
"""
|
|
1161
|
-
for evo_attr in self.evolvable_attributes().values():
|
|
1162
|
-
del evo_attr
|
|
1163
|
-
|
|
1164
1146
|
|
|
1165
1147
|
class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
1166
1148
|
"""Base object for all single-agent algorithms in the AgileRL framework.
|
|
@@ -1817,10 +1799,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1817
1799
|
:type accelerator: Optional[Accelerator]
|
|
1818
1800
|
:param name: The name of the algorithm.
|
|
1819
1801
|
:type name: Optional[str]
|
|
1820
|
-
:param model_config: The configuration for the model.
|
|
1821
|
-
:type model_config: dict[str, Any] | PretrainedConfig | None
|
|
1822
|
-
:param gradient_checkpointing: Whether to use gradient checkpointing.
|
|
1823
|
-
:type gradient_checkpointing: bool
|
|
1824
1802
|
"""
|
|
1825
1803
|
|
|
1826
1804
|
def __init__(
|
|
@@ -1835,10 +1813,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1835
1813
|
seed: int,
|
|
1836
1814
|
pad_token_id: int,
|
|
1837
1815
|
pad_token: str,
|
|
1838
|
-
lora_config:
|
|
1816
|
+
lora_config: LoraConfig | None,
|
|
1839
1817
|
use_separate_reference_adapter: bool,
|
|
1840
1818
|
model_name: str | None = None,
|
|
1841
|
-
actor_network:
|
|
1819
|
+
actor_network: PreTrainedModel | None = None,
|
|
1842
1820
|
micro_batch_size_per_gpu: int | None = None,
|
|
1843
1821
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
1844
1822
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
@@ -1846,14 +1824,9 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1846
1824
|
device: Union[str, torch.device] = "cpu",
|
|
1847
1825
|
accelerator: Optional[Accelerator] = None,
|
|
1848
1826
|
name: Optional[str] = None,
|
|
1849
|
-
model_config: dict[str, Any] |
|
|
1827
|
+
model_config: dict[str, Any] | PretrainedConfig | None = None,
|
|
1850
1828
|
gradient_checkpointing: bool = True,
|
|
1851
1829
|
):
|
|
1852
|
-
if not HAS_LLM_DEPENDENCIES:
|
|
1853
|
-
raise ImportError(
|
|
1854
|
-
"LLM dependencies are not installed. Please install them using `pip install agilerl[llm]`."
|
|
1855
|
-
)
|
|
1856
|
-
|
|
1857
1830
|
if model_name is None and actor_network is None:
|
|
1858
1831
|
raise ValueError(
|
|
1859
1832
|
"At least one of model_name or actor_network must be provided."
|
|
@@ -1908,7 +1881,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1908
1881
|
)
|
|
1909
1882
|
lr = optim_lr
|
|
1910
1883
|
|
|
1911
|
-
if lora_config is None and not isinstance(actor_network,
|
|
1884
|
+
if lora_config is None and not isinstance(actor_network, PeftModel):
|
|
1912
1885
|
warnings.warn(
|
|
1913
1886
|
"No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
|
|
1914
1887
|
)
|
|
@@ -1925,21 +1898,15 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1925
1898
|
self.use_separate_reference_adapter = use_separate_reference_adapter
|
|
1926
1899
|
self.cosine_lr_schedule_config = cosine_lr_schedule_config
|
|
1927
1900
|
|
|
1928
|
-
if max_grad_norm and (accelerator is not None):
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
|
|
1933
|
-
|
|
1934
|
-
|
|
1935
|
-
] = max_grad_norm
|
|
1936
|
-
|
|
1937
|
-
self.max_grad_norm = max_grad_norm
|
|
1901
|
+
if max_grad_norm and (accelerator is not None) and accelerator.is_main_process:
|
|
1902
|
+
warnings.warn(
|
|
1903
|
+
"Argument 'max_grad_norm' will be overwritten by the 'gradient_clipping' value set in the deepspeed config."
|
|
1904
|
+
)
|
|
1905
|
+
self.max_grad_norm = None
|
|
1906
|
+
else:
|
|
1907
|
+
self.max_grad_norm = max_grad_norm
|
|
1938
1908
|
self.reduce_memory_peak = reduce_memory_peak
|
|
1939
1909
|
|
|
1940
|
-
if self.accelerator is not None:
|
|
1941
|
-
self.register_mutation_hook(self._sync_deepspeed_gradient_clipping)
|
|
1942
|
-
|
|
1943
1910
|
if self.accelerator is not None:
|
|
1944
1911
|
self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1945
1912
|
"zero_optimization"
|
|
@@ -2074,7 +2041,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2074
2041
|
device_map="auto"
|
|
2075
2042
|
)
|
|
2076
2043
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
2077
|
-
model =
|
|
2044
|
+
model = PeftModel.from_pretrained(base_model, path)
|
|
2078
2045
|
"""
|
|
2079
2046
|
)
|
|
2080
2047
|
|
|
@@ -2186,11 +2153,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2186
2153
|
def clean_up(self) -> None:
|
|
2187
2154
|
"""Clean up the algorithm."""
|
|
2188
2155
|
if self.accelerator is not None:
|
|
2189
|
-
# Free up GPU memory occupied by parameters
|
|
2190
|
-
if hasattr(self.actor, "empty_partition_cache"):
|
|
2191
|
-
self.actor.empty_partition_cache()
|
|
2192
|
-
if hasattr(self.actor, "destroy"):
|
|
2193
|
-
self.actor.destroy()
|
|
2194
2156
|
(
|
|
2195
2157
|
self.actor,
|
|
2196
2158
|
self.optimizer,
|
|
@@ -2214,8 +2176,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2214
2176
|
if hasattr(self, "llm"):
|
|
2215
2177
|
del self.llm.llm_engine.model_executor
|
|
2216
2178
|
del self.llm
|
|
2179
|
+
|
|
2217
2180
|
gc.collect()
|
|
2218
2181
|
torch.cuda.empty_cache()
|
|
2182
|
+
torch.cuda.reset_peak_memory_stats()
|
|
2219
2183
|
torch.cuda.synchronize()
|
|
2220
2184
|
|
|
2221
2185
|
def clone(self, index: Optional[int] = None, wrap: bool = True):
|
|
@@ -2250,8 +2214,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2250
2214
|
input_args["wrap"] = False
|
|
2251
2215
|
input_args["clone"] = True
|
|
2252
2216
|
|
|
2253
|
-
actor:
|
|
2254
|
-
|
|
2217
|
+
actor: PeftModel = cast(
|
|
2218
|
+
PeftModel,
|
|
2255
2219
|
(
|
|
2256
2220
|
self.accelerator.unwrap_model(self.actor)
|
|
2257
2221
|
if self.accelerator is not None
|
|
@@ -2443,12 +2407,12 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2443
2407
|
self.reference_update_tracker += 1
|
|
2444
2408
|
|
|
2445
2409
|
def _initialize_actors(
|
|
2446
|
-
self, base_model:
|
|
2410
|
+
self, base_model: PreTrainedModel | None, add_adapters: bool = True
|
|
2447
2411
|
):
|
|
2448
2412
|
"""Initialize the actor network.
|
|
2449
2413
|
|
|
2450
2414
|
:param base_model: Base model
|
|
2451
|
-
:type base_model:
|
|
2415
|
+
:type base_model: PreTrainedModel
|
|
2452
2416
|
:param add_adapters: Flag to indicate if adapters should be added to the model, defaults to True
|
|
2453
2417
|
:type add_adapters: bool, optional
|
|
2454
2418
|
"""
|
|
@@ -2458,7 +2422,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2458
2422
|
self.pretrained_model_name_or_path
|
|
2459
2423
|
)
|
|
2460
2424
|
|
|
2461
|
-
if isinstance(base_model,
|
|
2425
|
+
if isinstance(base_model, PeftModel) and add_adapters:
|
|
2462
2426
|
# Handles backwards compatibility with user providing a peft model as the actor network
|
|
2463
2427
|
if self.lora_config is None:
|
|
2464
2428
|
adapter_name = list(base_model.peft_config.keys())
|
|
@@ -2468,7 +2432,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2468
2432
|
if "default" in list(base_model.peft_config.keys()):
|
|
2469
2433
|
base_model.peft_config.pop("default")
|
|
2470
2434
|
|
|
2471
|
-
self.actor:
|
|
2435
|
+
self.actor: PeftModel = (
|
|
2472
2436
|
get_peft_model(base_model, self.lora_config, adapter_name="actor")
|
|
2473
2437
|
if add_adapters
|
|
2474
2438
|
else base_model
|
|
@@ -2617,6 +2581,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2617
2581
|
def _move_model_to_vllm(self) -> None:
|
|
2618
2582
|
"""Move the deepspeed model to vllm."""
|
|
2619
2583
|
|
|
2584
|
+
# TODO: Add support for ZeRO Stage 3
|
|
2620
2585
|
if self.accelerator is not None:
|
|
2621
2586
|
self.accelerator.wait_for_everyone()
|
|
2622
2587
|
model_ref = self.accelerator.unwrap_model(self.actor)
|
|
@@ -2984,28 +2949,3 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2984
2949
|
|
|
2985
2950
|
if self.accelerator is not None:
|
|
2986
2951
|
self.accelerator.wait_for_everyone()
|
|
2987
|
-
|
|
2988
|
-
def _sync_deepspeed_gradient_clipping(self) -> None:
|
|
2989
|
-
"""Synchronizes max_grad_norm with DeepSpeed gradient_clipping config.
|
|
2990
|
-
Registered as a mutation hook to ensure consistency after mutations.
|
|
2991
|
-
"""
|
|
2992
|
-
if self.accelerator is None:
|
|
2993
|
-
return
|
|
2994
|
-
|
|
2995
|
-
if (
|
|
2996
|
-
"gradient_clipping"
|
|
2997
|
-
not in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
2998
|
-
):
|
|
2999
|
-
return
|
|
3000
|
-
|
|
3001
|
-
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
3002
|
-
if ds_config["gradient_clipping"] != self.max_grad_norm:
|
|
3003
|
-
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
3004
|
-
"gradient_clipping"
|
|
3005
|
-
] = self.max_grad_norm
|
|
3006
|
-
|
|
3007
|
-
if hasattr(self.actor, "optimizer"):
|
|
3008
|
-
if hasattr(self.actor.optimizer, "grad_clip"):
|
|
3009
|
-
self.actor.optimizer.grad_clip = self.max_grad_norm
|
|
3010
|
-
if hasattr(self.actor.optimizer, "clip_grad"):
|
|
3011
|
-
self.actor.optimizer.clip_grad = self.max_grad_norm
|
|
@@ -2,27 +2,19 @@ import inspect
|
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
5
|
+
from peft import PeftModel
|
|
5
6
|
from torch.optim import Optimizer
|
|
6
7
|
|
|
7
|
-
from agilerl import HAS_LLM_DEPENDENCIES
|
|
8
8
|
from agilerl.modules import EvolvableModule, ModuleDict
|
|
9
9
|
from agilerl.protocols import EvolvableAlgorithm
|
|
10
10
|
from agilerl.typing import OptimizerType, StateDict
|
|
11
|
-
from agilerl.utils.
|
|
12
|
-
|
|
13
|
-
if HAS_LLM_DEPENDENCIES:
|
|
14
|
-
from peft import PeftModel
|
|
15
|
-
|
|
16
|
-
PeftModelType = PeftModel
|
|
17
|
-
else:
|
|
18
|
-
PeftModelType = "PeftModel"
|
|
19
|
-
|
|
11
|
+
from agilerl.utils.llm_utils import DummyOptimizer
|
|
20
12
|
|
|
21
13
|
ModuleList = list[EvolvableModule]
|
|
22
14
|
_Optimizer = Union[
|
|
23
15
|
type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
|
|
24
16
|
]
|
|
25
|
-
_Module = Union[EvolvableModule, ModuleDict, ModuleList,
|
|
17
|
+
_Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModel]
|
|
26
18
|
|
|
27
19
|
|
|
28
20
|
def init_from_multiple(
|
|
@@ -5,10 +5,11 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from accelerate import Accelerator
|
|
8
|
+
from peft import LoraConfig
|
|
9
|
+
from transformers import PreTrainedModel
|
|
8
10
|
|
|
9
11
|
from agilerl.algorithms.core.base import LLMAlgorithm
|
|
10
12
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
-
from agilerl.protocols import LoraConfigProtocol, PreTrainedModelProtocol
|
|
12
13
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
13
14
|
from agilerl.utils.algo_utils import get_experiences_samples
|
|
14
15
|
from agilerl.utils.llm_utils import PreferenceGym
|
|
@@ -24,7 +25,7 @@ class DPO(LLMAlgorithm):
|
|
|
24
25
|
:param model_name: Model name
|
|
25
26
|
:type model_name: str, optional
|
|
26
27
|
:param actor_network: HuggingFace LLM
|
|
27
|
-
:type actor_network:
|
|
28
|
+
:type actor_network: PreTrainedModel
|
|
28
29
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
29
30
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
30
31
|
:type hp_config: HyperparameterConfig, optional
|
|
@@ -49,7 +50,7 @@ class DPO(LLMAlgorithm):
|
|
|
49
50
|
:param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
|
|
50
51
|
:type device: str, optional
|
|
51
52
|
:param lora_config: Config for LoRA, defaults to None
|
|
52
|
-
:type lora_config:
|
|
53
|
+
:type lora_config: LoraConfig, optional
|
|
53
54
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
54
55
|
:type accelerator: accelerate.Accelerator(), optional
|
|
55
56
|
:param wrap: Wrap models for distributed training upon creation, defaults to True
|
|
@@ -69,7 +70,7 @@ class DPO(LLMAlgorithm):
|
|
|
69
70
|
pad_token_id: int,
|
|
70
71
|
pad_token: str,
|
|
71
72
|
model_name: str | None = None,
|
|
72
|
-
actor_network:
|
|
73
|
+
actor_network: PreTrainedModel | None = None,
|
|
73
74
|
model_config: dict[str, Any] | None = None,
|
|
74
75
|
hp_config: HyperparameterConfig | None = None,
|
|
75
76
|
index: int = 0,
|
|
@@ -82,7 +83,7 @@ class DPO(LLMAlgorithm):
|
|
|
82
83
|
micro_batch_size_per_gpu: int | None = None,
|
|
83
84
|
reduce_memory_peak: bool = False,
|
|
84
85
|
device: str = "cpu",
|
|
85
|
-
lora_config:
|
|
86
|
+
lora_config: LoraConfig | None = None,
|
|
86
87
|
accelerator: Accelerator | None = None,
|
|
87
88
|
wrap: bool = True,
|
|
88
89
|
clone: bool = False,
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from typing import Any, Optional
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from accelerate import Accelerator
|
|
7
|
+
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
|
8
|
+
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
9
|
+
from peft import LoraConfig, PeftModel
|
|
10
|
+
from transformers import GenerationConfig
|
|
11
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
7
12
|
|
|
8
|
-
from agilerl import HAS_LLM_DEPENDENCIES
|
|
9
13
|
from agilerl.algorithms.core import LLMAlgorithm
|
|
10
14
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
-
from agilerl.protocols import (
|
|
12
|
-
LoraConfigProtocol,
|
|
13
|
-
PeftModelProtocol,
|
|
14
|
-
PreTrainedModelProtocol,
|
|
15
|
-
)
|
|
16
15
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
17
16
|
from agilerl.utils.algo_utils import (
|
|
18
17
|
CosineLRScheduleConfig,
|
|
@@ -24,8 +23,10 @@ from agilerl.utils.llm_utils import (
|
|
|
24
23
|
ReasoningGym,
|
|
25
24
|
)
|
|
26
25
|
|
|
27
|
-
|
|
28
|
-
|
|
26
|
+
DeepSpeedOptimizerType = Union[
|
|
27
|
+
DeepSpeedZeroOptimizer, # ZeRO Stage 1 & 2 optimizer
|
|
28
|
+
DeepSpeedZeroOptimizer_Stage3, # ZeRO Stage 3 optimizer
|
|
29
|
+
]
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class GRPO(LLMAlgorithm):
|
|
@@ -38,7 +39,7 @@ class GRPO(LLMAlgorithm):
|
|
|
38
39
|
:param model_name: Model name
|
|
39
40
|
:type model_name: str, optional
|
|
40
41
|
:param actor_network: HuggingFace LLM
|
|
41
|
-
:type actor_network:
|
|
42
|
+
:type actor_network: PreTrainedModel
|
|
42
43
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
43
44
|
:type model_config: dict[str, Any], optional
|
|
44
45
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
@@ -76,7 +77,7 @@ class GRPO(LLMAlgorithm):
|
|
|
76
77
|
:param max_model_len: Maximum context window length, defaults to None
|
|
77
78
|
:type max_model_len: int, optional
|
|
78
79
|
:param lora_config: Config for LoRA, defaults to None
|
|
79
|
-
:type lora_config:
|
|
80
|
+
:type lora_config: LoraConfig, optional
|
|
80
81
|
:param cosine_lr_schedule_config: Config for cosine lr scheduling, defaults to None
|
|
81
82
|
:type cosine_lr_schedule_config: CosineLRScheduleConfig, optional
|
|
82
83
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
@@ -104,7 +105,7 @@ class GRPO(LLMAlgorithm):
|
|
|
104
105
|
pad_token_id: int,
|
|
105
106
|
pad_token: str,
|
|
106
107
|
model_name: str | None = None,
|
|
107
|
-
actor_network:
|
|
108
|
+
actor_network: PreTrainedModel | None = None,
|
|
108
109
|
model_config: dict[str, Any] | None = None,
|
|
109
110
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
110
111
|
index: int = 0,
|
|
@@ -126,7 +127,7 @@ class GRPO(LLMAlgorithm):
|
|
|
126
127
|
max_output_tokens: int | None = 1024,
|
|
127
128
|
min_output_tokens: Optional[int] = None,
|
|
128
129
|
max_model_len: Optional[int] = None,
|
|
129
|
-
lora_config: Optional[
|
|
130
|
+
lora_config: Optional[LoraConfig] = None,
|
|
130
131
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
131
132
|
accelerator: Optional[Accelerator] = None,
|
|
132
133
|
device: str = "cpu",
|
|
@@ -187,8 +188,8 @@ class GRPO(LLMAlgorithm):
|
|
|
187
188
|
), "Policy update epochs must be greater than or equal to one."
|
|
188
189
|
if actor_network is not None:
|
|
189
190
|
assert isinstance(
|
|
190
|
-
actor_network, (
|
|
191
|
-
), "Actor network must be a
|
|
191
|
+
actor_network, (PeftModel, PreTrainedModel)
|
|
192
|
+
), "Actor network must be a PeftModel or PreTrainedModel"
|
|
192
193
|
|
|
193
194
|
self.clip_coef = clip_coef
|
|
194
195
|
self.update_epochs = update_epochs
|
|
@@ -1223,20 +1223,6 @@ class ILQL(nn.Module):
|
|
|
1223
1223
|
self.fitness = checkpoint["fitness"]
|
|
1224
1224
|
self.steps = checkpoint["steps"]
|
|
1225
1225
|
|
|
1226
|
-
def clean_up(self) -> None:
|
|
1227
|
-
"""Clean up the networks"""
|
|
1228
|
-
del self.model
|
|
1229
|
-
del self.actor
|
|
1230
|
-
del self.actor_target
|
|
1231
|
-
del self.v
|
|
1232
|
-
del self.q
|
|
1233
|
-
del self.target_q
|
|
1234
|
-
del self.pi
|
|
1235
|
-
del self.optimizer
|
|
1236
|
-
if self.double_q:
|
|
1237
|
-
del self.q2
|
|
1238
|
-
del self.target_q2
|
|
1239
|
-
|
|
1240
1226
|
|
|
1241
1227
|
class ILQL_Policy:
|
|
1242
1228
|
def __init__(self, iql_model: ILQL, kind: str, **generation_kwargs) -> None:
|