agilerl 2.4.0.dev0__tar.gz → 2.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/PKG-INFO +23 -10
  2. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/README.md +12 -1
  3. agilerl-2.4.1/agilerl/__init__.py +18 -0
  4. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/base.py +125 -61
  5. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/optimizer_wrapper.py +11 -3
  6. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/registry.py +1 -1
  7. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dpo.py +60 -10
  8. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/grpo.py +34 -27
  9. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ilql.py +14 -0
  10. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/protocols.py +131 -0
  11. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_llm.py +2 -2
  12. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/algo_utils.py +59 -5
  13. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/llm_utils.py +94 -80
  14. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/utils.py +23 -35
  15. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/pyproject.toml +25 -8
  16. agilerl-2.4.0.dev0/agilerl/wrappers/__init__.py +0 -0
  17. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/LICENSE +0 -0
  18. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/__init__.py +0 -0
  19. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/bc_lm.py +0 -0
  20. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/core/__init__.py +0 -0
  21. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/cqn.py +0 -0
  22. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ddpg.py +0 -0
  23. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dqn.py +0 -0
  24. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/dqn_rainbow.py +0 -0
  25. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ippo.py +0 -0
  26. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/maddpg.py +0 -0
  27. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/matd3.py +0 -0
  28. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  29. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  30. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/ppo.py +0 -0
  31. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/algorithms/td3.py +0 -0
  32. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/__init__.py +0 -0
  33. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/data.py +0 -0
  34. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  35. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/replay_buffer.py +0 -0
  36. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/rollout_buffer.py +0 -0
  37. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/sampler.py +0 -0
  38. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/components/segment_tree.py +0 -0
  39. {agilerl-2.4.0.dev0/agilerl → agilerl-2.4.1/agilerl/data}/__init__.py +0 -0
  40. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/language_environment.py +0 -0
  41. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/rl_data.py +0 -0
  42. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/tokenizer.py +0 -0
  43. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/data/torch_datasets.py +0 -0
  44. {agilerl-2.4.0.dev0/agilerl/data → agilerl-2.4.1/agilerl/hpo}/__init__.py +0 -0
  45. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/hpo/mutation.py +0 -0
  46. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/hpo/tournament.py +0 -0
  47. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/__init__.py +0 -0
  48. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/base.py +0 -0
  49. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/bert.py +0 -0
  50. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/cnn.py +0 -0
  51. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/configs.py +0 -0
  52. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/custom_components.py +0 -0
  53. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/dummy.py +0 -0
  54. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/gpt.py +0 -0
  55. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/lstm.py +0 -0
  56. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/mlp.py +0 -0
  57. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/multi_input.py +0 -0
  58. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/resnet.py +0 -0
  59. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/modules/simba.py +0 -0
  60. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/__init__.py +0 -0
  61. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/actors.py +0 -0
  62. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/base.py +0 -0
  63. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/custom_modules.py +0 -0
  64. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/distributions.py +0 -0
  65. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/distributions_experimental.py +0 -0
  66. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/q_networks.py +0 -0
  67. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/networks/value_networks.py +0 -0
  68. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/rollouts/__init__.py +0 -0
  69. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/rollouts/on_policy.py +0 -0
  70. {agilerl-2.4.0.dev0/agilerl/hpo → agilerl-2.4.1/agilerl/training}/__init__.py +0 -0
  71. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_bandits.py +0 -0
  72. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_multi_agent_off_policy.py +0 -0
  73. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_multi_agent_on_policy.py +0 -0
  74. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_off_policy.py +0 -0
  75. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_offline.py +0 -0
  76. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/training/train_on_policy.py +0 -0
  77. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/typing.py +0 -0
  78. {agilerl-2.4.0.dev0/agilerl/training → agilerl-2.4.1/agilerl/utils}/__init__.py +0 -0
  79. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/cache.py +0 -0
  80. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/evolvable_networks.py +0 -0
  81. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/ilql_utils.py +0 -0
  82. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/log_utils.py +0 -0
  83. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/minari_utils.py +0 -0
  84. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/probe_envs.py +0 -0
  85. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/probe_envs_ma.py +0 -0
  86. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/sampling_utils.py +0 -0
  87. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/utils/torch_utils.py +0 -0
  88. {agilerl-2.4.0.dev0/agilerl/utils → agilerl-2.4.1/agilerl/vector}/__init__.py +0 -0
  89. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/vector/pz_async_vec_env.py +0 -0
  90. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/vector/pz_vec_env.py +0 -0
  91. {agilerl-2.4.0.dev0/agilerl/vector → agilerl-2.4.1/agilerl/wrappers}/__init__.py +0 -0
  92. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/agent.py +0 -0
  93. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/learning.py +0 -0
  94. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/make_evolvable.py +0 -0
  95. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  96. {agilerl-2.4.0.dev0 → agilerl-2.4.1}/agilerl/wrappers/utils.py +0 -0
@@ -1,22 +1,23 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agilerl
3
- Version: 2.4.0.dev0
3
+ Version: 2.4.1
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  License: Apache 2.0
6
6
  License-File: LICENSE
7
7
  Author: Nick Ustaran-Anderegg
8
8
  Author-email: dev@agilerl.com
9
- Requires-Python: >=3.10,<4.0
9
+ Requires-Python: >=3.10,<3.13
10
10
  Classifier: License :: Other/Proprietary License
11
11
  Classifier: Programming Language :: Python :: 3
12
12
  Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
- Classifier: Programming Language :: Python :: 3.14
15
+ Provides-Extra: all
16
+ Provides-Extra: llm
17
17
  Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
18
18
  Requires-Dist: accelerate (>=1.7.0,<2.0.0)
19
- Requires-Dist: deepspeed (>=0.17.1,<0.18.0)
19
+ Requires-Dist: datasets (==4.4.1) ; extra == "llm" or extra == "all"
20
+ Requires-Dist: deepspeed (>=0.17.1,<0.18.0) ; extra == "llm" or extra == "all"
20
21
  Requires-Dist: dill (>=0.3.7,<0.4.0)
21
22
  Requires-Dist: fastrand (>=1.3.0,<2.0.0)
22
23
  Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
@@ -26,11 +27,12 @@ Requires-Dist: h5py (>=3.8.0,<4.0.0)
26
27
  Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
27
28
  Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
28
29
  Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
29
- Requires-Dist: minari (>=0.5.2,<0.6.0)
30
+ Requires-Dist: minari[all] (==0.5.2)
30
31
  Requires-Dist: numpy (>=1.26.4,<2.0.0)
31
32
  Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
33
+ Requires-Dist: packaging (>=20.0)
32
34
  Requires-Dist: pandas (>=2.2.3,<3.0.0)
33
- Requires-Dist: peft (>=0.15.2,<0.16.0)
35
+ Requires-Dist: peft (>=0.18.0,<0.19.0) ; extra == "llm" or extra == "all"
34
36
  Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
35
37
  Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
36
38
  Requires-Dist: pygame (>=2.6.0,<3.0.0)
@@ -41,9 +43,9 @@ Requires-Dist: tensordict (>=0.8,<0.9)
41
43
  Requires-Dist: termcolor (>=1.1.0,<2.0.0)
42
44
  Requires-Dist: torch (==2.7.1)
43
45
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
44
- Requires-Dist: transformers (>=4.48.1,<5.0.0)
46
+ Requires-Dist: transformers (>=4.57.1,<5.0.0) ; extra == "llm" or extra == "all"
45
47
  Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
46
- Requires-Dist: vllm (==0.10.0)
48
+ Requires-Dist: vllm (==0.10.0) ; extra == "llm" or extra == "all"
47
49
  Requires-Dist: wandb (>=0.17.6,<0.18.0)
48
50
  Description-Content-Type: text/markdown
49
51
 
@@ -97,6 +99,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
97
99
  pip install -e .
98
100
  ```
99
101
 
102
+ If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
103
+
104
+ ```bash
105
+ pip install agilerl[all]
106
+ ```
107
+ Or in development mode:
108
+ ```bash
109
+ pip install -e ".[all]"
110
+ ```
111
+
100
112
  To install the ``nightly`` version of AgileRL with the latest features, use:
101
113
 
102
114
  ```bash
@@ -155,11 +167,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
155
167
  | ---------- | --------- |
156
168
  | [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
157
169
 
158
- ### LLM Reasoning Algorithms
170
+ ### LLM Fine-tuning Algorithms
159
171
 
160
172
  | RL | Algorithm |
161
173
  | ---------- | --------- |
162
174
  | [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
175
+ | [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
163
176
 
164
177
 
165
178
  ## Train an Agent to Beat a Gym Environment
@@ -48,6 +48,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
48
48
  pip install -e .
49
49
  ```
50
50
 
51
+ If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
52
+
53
+ ```bash
54
+ pip install agilerl[all]
55
+ ```
56
+ Or in development mode:
57
+ ```bash
58
+ pip install -e ".[all]"
59
+ ```
60
+
51
61
  To install the ``nightly`` version of AgileRL with the latest features, use:
52
62
 
53
63
  ```bash
@@ -106,11 +116,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
106
116
  | ---------- | --------- |
107
117
  | [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
108
118
 
109
- ### LLM Reasoning Algorithms
119
+ ### LLM Fine-tuning Algorithms
110
120
 
111
121
  | RL | Algorithm |
112
122
  | ---------- | --------- |
113
123
  | [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
124
+ | [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
114
125
 
115
126
 
116
127
  ## Train an Agent to Beat a Gym Environment
@@ -0,0 +1,18 @@
1
+ from importlib.metadata import metadata
2
+ from importlib.util import find_spec
3
+
4
+ from packaging.requirements import Requirement
5
+
6
+
7
+ def get_extra_dependencies(package: str, extra: str) -> list[str]:
8
+ requires = metadata(package).get_all("Requires-Dist") or []
9
+ deps = []
10
+ for req in requires:
11
+ r = Requirement(req)
12
+ if r.marker and r.marker.evaluate({"extra": extra}):
13
+ deps.append(r.name)
14
+ return deps
15
+
16
+
17
+ LLM_PACKAGES = get_extra_dependencies("agilerl", "llm")
18
+ HAS_LLM_DEPENDENCIES = all(find_spec(pkg) is not None for pkg in LLM_PACKAGES)
@@ -27,19 +27,14 @@ import torch
27
27
  import torch.nn.functional as F
28
28
  from accelerate import Accelerator
29
29
  from accelerate.utils import broadcast_object_list, set_seed
30
- from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
31
- from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
32
30
  from gymnasium import spaces
33
- from peft import LoraConfig, PeftModel, get_peft_model, set_peft_model_state_dict
34
- from safetensors.torch import load_file
35
31
  from tensordict import TensorDict
36
32
  from torch._dynamo import OptimizedModule
37
33
  from torch.nn.utils import clip_grad_norm_
38
34
  from torch.optim import AdamW
39
35
  from torch.optim.lr_scheduler import SequentialLR
40
- from transformers.modeling_utils import PreTrainedModel
41
- from vllm import LLM, SamplingParams
42
36
 
37
+ from agilerl import HAS_LLM_DEPENDENCIES
43
38
  from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
44
39
  from agilerl.algorithms.core.registry import (
45
40
  HyperparameterConfig,
@@ -54,7 +49,11 @@ from agilerl.protocols import (
54
49
  EvolvableAttributeDict,
55
50
  EvolvableAttributeType,
56
51
  EvolvableModule,
52
+ LoraConfigProtocol,
57
53
  ModuleDict,
54
+ PeftModelProtocol,
55
+ PretrainedConfigProtocol,
56
+ PreTrainedModelProtocol,
58
57
  )
59
58
  from agilerl.typing import (
60
59
  ActionType,
@@ -73,6 +72,7 @@ from agilerl.typing import (
73
72
  )
74
73
  from agilerl.utils.algo_utils import (
75
74
  CosineLRScheduleConfig,
75
+ DummyOptimizer,
76
76
  VLLMConfig,
77
77
  check_supported_space,
78
78
  chkpt_attribute_to_device,
@@ -95,7 +95,18 @@ from agilerl.utils.evolvable_networks import (
95
95
  is_image_space,
96
96
  is_vector_space,
97
97
  )
98
- from agilerl.utils.llm_utils import DummyOptimizer, gather_if_zero3
98
+
99
+ if HAS_LLM_DEPENDENCIES:
100
+ from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
101
+ from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
102
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
103
+ from safetensors.torch import load_file
104
+ from vllm import LLM, SamplingParams
105
+
106
+ from agilerl.utils.llm_utils import (
107
+ create_model_from_name_or_path,
108
+ gather_if_zero3,
109
+ )
99
110
 
100
111
  __all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
101
112
 
@@ -596,14 +607,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
596
607
  )
597
608
  optimizer = opt.optimizer if hasattr(opt, "optimizer") else None
598
609
 
599
- if isinstance(opt, DeepSpeedOptimizerWrapper):
600
- if isinstance(opt.optimizer, DummyOptimizer):
601
- opt = getattr(
610
+ if isinstance(self, LLMAlgorithm):
611
+ if hasattr(self.actor, "optimizer"):
612
+ optimizer = getattr(
602
613
  getattr(self, "actor"), "optimizer"
603
614
  ) # If the optimizer is defined in the deepspeed config, we do this
615
+ else:
616
+ optimizer = opt.optimizer
604
617
 
605
618
  self.accelerator, self.lr_scheduler = LLMAlgorithm.update_lr(
606
- opt,
619
+ optimizer,
607
620
  lr=getattr(self, config.lr),
608
621
  accelerator=self.accelerator,
609
622
  scheduler_config=self.cosine_lr_schedule_config,
@@ -1138,6 +1151,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
1138
1151
 
1139
1152
  return self
1140
1153
 
1154
+ def clean_up(self) -> None:
1155
+ """
1156
+ Clean up the algorithm by deleting the networks and optimizers.
1157
+
1158
+ :return: None
1159
+ :rtype: None
1160
+ """
1161
+ for evo_attr in self.evolvable_attributes().values():
1162
+ del evo_attr
1163
+
1141
1164
 
1142
1165
  class RLAlgorithm(EvolvableAlgorithm, ABC):
1143
1166
  """Base object for all single-agent algorithms in the AgileRL framework.
@@ -1782,8 +1805,6 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1782
1805
  class LLMAlgorithm(EvolvableAlgorithm, ABC):
1783
1806
  """Base object for all LLM algorithms in the AgileRL framework.
1784
1807
 
1785
- :param observation_space: The observation space of the environment.
1786
- :type observation_space: gymnasium.spaces.Space
1787
1808
  :param action_space: The action space of the environment.
1788
1809
  :type action_space: gymnasium.spaces.Space
1789
1810
  :param index: The index of the algorithm.
@@ -1796,13 +1817,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1796
1817
  :type accelerator: Optional[Accelerator]
1797
1818
  :param name: The name of the algorithm.
1798
1819
  :type name: Optional[str]
1820
+ :param model_config: The configuration for the model.
1821
+ :type model_config: dict[str, Any] | PretrainedConfig | None
1822
+ :param gradient_checkpointing: Whether to use gradient checkpointing.
1823
+ :type gradient_checkpointing: bool
1799
1824
  """
1800
1825
 
1801
1826
  def __init__(
1802
1827
  self,
1803
- observation_space: spaces.Space,
1804
- action_space: spaces.Space,
1805
- actor_network: PreTrainedModel,
1806
1828
  index: int,
1807
1829
  batch_size: int,
1808
1830
  lr: float,
@@ -1813,8 +1835,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1813
1835
  seed: int,
1814
1836
  pad_token_id: int,
1815
1837
  pad_token: str,
1816
- lora_config: LoraConfig | None,
1838
+ lora_config: LoraConfigProtocol | None,
1817
1839
  use_separate_reference_adapter: bool,
1840
+ model_name: str | None = None,
1841
+ actor_network: PreTrainedModelProtocol | None = None,
1818
1842
  micro_batch_size_per_gpu: int | None = None,
1819
1843
  cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
1820
1844
  hp_config: Optional[HyperparameterConfig] = None,
@@ -1822,7 +1846,18 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1822
1846
  device: Union[str, torch.device] = "cpu",
1823
1847
  accelerator: Optional[Accelerator] = None,
1824
1848
  name: Optional[str] = None,
1849
+ model_config: dict[str, Any] | PretrainedConfigProtocol | None = None,
1850
+ gradient_checkpointing: bool = True,
1825
1851
  ):
1852
+ if not HAS_LLM_DEPENDENCIES:
1853
+ raise ImportError(
1854
+ "LLM dependencies are not installed. Please install them using `pip install agilerl[llm]`."
1855
+ )
1856
+
1857
+ if model_name is None and actor_network is None:
1858
+ raise ValueError(
1859
+ "At least one of model_name or actor_network must be provided."
1860
+ )
1826
1861
  if (
1827
1862
  accelerator is not None
1828
1863
  and cosine_lr_schedule_config is not None
@@ -1835,20 +1870,16 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1835
1870
  cosine_lr_schedule_config = None
1836
1871
 
1837
1872
  super().__init__(index, hp_config, device, accelerator, None, name)
1838
- assert isinstance(
1839
- observation_space, spaces.Space
1840
- ), "Observation space must be an instance of gymnasium.spaces.Space."
1841
- assert isinstance(
1842
- action_space, spaces.Space
1843
- ), "Action space must be an instance of gymnasium.spaces.Space."
1844
-
1845
- self.observation_space = observation_space
1846
- self.action_space = action_space
1873
+ self.gradient_checkpointing = gradient_checkpointing
1847
1874
  self.zero_stage = None
1848
1875
  self.reference_update_tracker = 0 # Updated every time the reference policy is updated which is updated each time we pass through the train dataset
1849
1876
  self.calc_position_embeddings = calc_position_embeddings
1850
1877
  self.pad_token_id = pad_token_id
1851
1878
  self.pad_token = pad_token
1879
+ self.pretrained_model_name_or_path = (
1880
+ model_name if model_name is not None else actor_network.name_or_path
1881
+ )
1882
+ self.model_config = model_config
1852
1883
 
1853
1884
  if not clone and reduce_memory_peak and micro_batch_size_per_gpu is not None:
1854
1885
  raise ValueError(
@@ -1858,7 +1889,9 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1858
1889
  self._configure_batch_size(
1859
1890
  batch_size, clone, reduce_memory_peak, micro_batch_size_per_gpu
1860
1891
  )
1861
-
1892
+ self.batch_size = self.batch_size_per_process * (
1893
+ self.accelerator.num_processes if self.accelerator is not None else 1
1894
+ )
1862
1895
  if self.accelerator is not None:
1863
1896
  if (
1864
1897
  self.accelerator.state.deepspeed_plugin.deepspeed_config.get(
@@ -1875,22 +1908,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1875
1908
  )
1876
1909
  lr = optim_lr
1877
1910
 
1878
- if lora_config is None and not isinstance(actor_network, PeftModel):
1911
+ if lora_config is None and not isinstance(actor_network, PeftModelProtocol):
1879
1912
  warnings.warn(
1880
- "No LoRA config provided. Using default LoRA configuration for RL finetuning."
1913
+ "No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
1881
1914
  )
1882
1915
  lora_config = LoraConfig(
1883
1916
  r=16,
1884
- lora_alpha=64,
1885
- target_modules=[
1886
- "q_proj",
1887
- "k_proj",
1888
- "v_proj",
1889
- "o_proj",
1890
- "up_proj",
1891
- "down_proj",
1892
- "gate_proj",
1893
- ],
1917
+ lora_alpha=32,
1918
+ target_modules="all-linear",
1894
1919
  task_type="CAUSAL_LM",
1895
1920
  lora_dropout=0.05,
1896
1921
  )
@@ -1900,15 +1925,20 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1900
1925
  self.use_separate_reference_adapter = use_separate_reference_adapter
1901
1926
  self.cosine_lr_schedule_config = cosine_lr_schedule_config
1902
1927
 
1903
- if max_grad_norm and (accelerator is not None) and accelerator.is_main_process:
1904
- warnings.warn(
1905
- "Argument 'max_grad_norm' will be overwritten by the 'gradient_clipping' value set in the deepspeed config."
1906
- )
1907
- self.max_grad_norm = None
1908
- else:
1909
- self.max_grad_norm = max_grad_norm
1928
+ if max_grad_norm and (accelerator is not None):
1929
+ if accelerator.is_main_process:
1930
+ warnings.warn(
1931
+ "Argument 'max_grad_norm' will overwrite the equivalent value set for 'gradient_clipping' in the deepspeed config."
1932
+ )
1933
+ self.accelerator.state.deepspeed_plugin.deepspeed_config[
1934
+ "gradient_clipping"
1935
+ ] = max_grad_norm
1936
+
1937
+ self.max_grad_norm = max_grad_norm
1910
1938
  self.reduce_memory_peak = reduce_memory_peak
1911
- self.pretrained_model_name_or_path = actor_network.name_or_path
1939
+
1940
+ if self.accelerator is not None:
1941
+ self.register_mutation_hook(self._sync_deepspeed_gradient_clipping)
1912
1942
 
1913
1943
  if self.accelerator is not None:
1914
1944
  self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
@@ -2044,7 +2074,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2044
2074
  device_map="auto"
2045
2075
  )
2046
2076
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
2047
- model = PeftModel.from_pretrained(base_model, path)
2077
+ model = PeftModelProtocol.from_pretrained(base_model, path)
2048
2078
  """
2049
2079
  )
2050
2080
 
@@ -2141,19 +2171,26 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2141
2171
  if not is_dummy_optimizer
2142
2172
  else type(self.actor.optimizer)
2143
2173
  )
2144
- self.actor.module.gradient_checkpointing_enable(
2145
- gradient_checkpointing_kwargs={"use_reentrant": False}
2146
- )
2174
+ if self.gradient_checkpointing:
2175
+ self.actor.module.gradient_checkpointing_enable(
2176
+ gradient_checkpointing_kwargs={"use_reentrant": False}
2177
+ )
2147
2178
  else:
2148
2179
  assert (
2149
2180
  self.actor is not None
2150
2181
  ), "Actor is set to None, please check that the actor is defined."
2151
2182
  self.actor = self.actor.to(self.device)
2152
- self.actor.gradient_checkpointing_enable()
2183
+ if self.gradient_checkpointing:
2184
+ self.actor.gradient_checkpointing_enable()
2153
2185
 
2154
2186
  def clean_up(self) -> None:
2155
2187
  """Clean up the algorithm."""
2156
2188
  if self.accelerator is not None:
2189
+ # Free up GPU memory occupied by parameters
2190
+ if hasattr(self.actor, "empty_partition_cache"):
2191
+ self.actor.empty_partition_cache()
2192
+ if hasattr(self.actor, "destroy"):
2193
+ self.actor.destroy()
2157
2194
  (
2158
2195
  self.actor,
2159
2196
  self.optimizer,
@@ -2177,10 +2214,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2177
2214
  if hasattr(self, "llm"):
2178
2215
  del self.llm.llm_engine.model_executor
2179
2216
  del self.llm
2180
-
2181
2217
  gc.collect()
2182
2218
  torch.cuda.empty_cache()
2183
- torch.cuda.reset_peak_memory_stats()
2184
2219
  torch.cuda.synchronize()
2185
2220
 
2186
2221
  def clone(self, index: Optional[int] = None, wrap: bool = True):
@@ -2215,8 +2250,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2215
2250
  input_args["wrap"] = False
2216
2251
  input_args["clone"] = True
2217
2252
 
2218
- actor: PeftModel = cast(
2219
- PeftModel,
2253
+ actor: PeftModelProtocol = cast(
2254
+ PeftModelProtocol,
2220
2255
  (
2221
2256
  self.accelerator.unwrap_model(self.actor)
2222
2257
  if self.accelerator is not None
@@ -2408,17 +2443,22 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2408
2443
  self.reference_update_tracker += 1
2409
2444
 
2410
2445
  def _initialize_actors(
2411
- self, base_model: PreTrainedModel, add_adapters: bool = True
2446
+ self, base_model: PreTrainedModelProtocol | None, add_adapters: bool = True
2412
2447
  ):
2413
2448
  """Initialize the actor network.
2414
2449
 
2415
2450
  :param base_model: Base model
2416
- :type base_model: PreTrainedModel
2451
+ :type base_model: PreTrainedModelProtocol
2417
2452
  :param add_adapters: Flag to indicate if adapters should be added to the model, defaults to True
2418
2453
  :type add_adapters: bool, optional
2419
2454
  """
2420
2455
 
2421
- if isinstance(base_model, PeftModel) and add_adapters:
2456
+ if base_model is None:
2457
+ base_model = create_model_from_name_or_path(
2458
+ self.pretrained_model_name_or_path
2459
+ )
2460
+
2461
+ if isinstance(base_model, PeftModelProtocol) and add_adapters:
2422
2462
  # Handles backwards compatibility with user providing a peft model as the actor network
2423
2463
  if self.lora_config is None:
2424
2464
  adapter_name = list(base_model.peft_config.keys())
@@ -2428,7 +2468,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2428
2468
  if "default" in list(base_model.peft_config.keys()):
2429
2469
  base_model.peft_config.pop("default")
2430
2470
 
2431
- self.actor: PeftModel = (
2471
+ self.actor: PeftModelProtocol = (
2432
2472
  get_peft_model(base_model, self.lora_config, adapter_name="actor")
2433
2473
  if add_adapters
2434
2474
  else base_model
@@ -2577,7 +2617,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2577
2617
  def _move_model_to_vllm(self) -> None:
2578
2618
  """Move the deepspeed model to vllm."""
2579
2619
 
2580
- # TODO: Add support for ZeRO Stage 3
2581
2620
  if self.accelerator is not None:
2582
2621
  self.accelerator.wait_for_everyone()
2583
2622
  model_ref = self.accelerator.unwrap_model(self.actor)
@@ -2945,3 +2984,28 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2945
2984
 
2946
2985
  if self.accelerator is not None:
2947
2986
  self.accelerator.wait_for_everyone()
2987
+
2988
+ def _sync_deepspeed_gradient_clipping(self) -> None:
2989
+ """Synchronizes max_grad_norm with DeepSpeed gradient_clipping config.
2990
+ Registered as a mutation hook to ensure consistency after mutations.
2991
+ """
2992
+ if self.accelerator is None:
2993
+ return
2994
+
2995
+ if (
2996
+ "gradient_clipping"
2997
+ not in self.accelerator.state.deepspeed_plugin.deepspeed_config
2998
+ ):
2999
+ return
3000
+
3001
+ ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
3002
+ if ds_config["gradient_clipping"] != self.max_grad_norm:
3003
+ self.accelerator.state.deepspeed_plugin.deepspeed_config[
3004
+ "gradient_clipping"
3005
+ ] = self.max_grad_norm
3006
+
3007
+ if hasattr(self.actor, "optimizer"):
3008
+ if hasattr(self.actor.optimizer, "grad_clip"):
3009
+ self.actor.optimizer.grad_clip = self.max_grad_norm
3010
+ if hasattr(self.actor.optimizer, "clip_grad"):
3011
+ self.actor.optimizer.clip_grad = self.max_grad_norm
@@ -2,19 +2,27 @@ import inspect
2
2
  from typing import Any, Optional, Union
3
3
 
4
4
  import torch.nn as nn
5
- from peft import PeftModel
6
5
  from torch.optim import Optimizer
7
6
 
7
+ from agilerl import HAS_LLM_DEPENDENCIES
8
8
  from agilerl.modules import EvolvableModule, ModuleDict
9
9
  from agilerl.protocols import EvolvableAlgorithm
10
10
  from agilerl.typing import OptimizerType, StateDict
11
- from agilerl.utils.llm_utils import DummyOptimizer
11
+ from agilerl.utils.algo_utils import DummyOptimizer
12
+
13
+ if HAS_LLM_DEPENDENCIES:
14
+ from peft import PeftModel
15
+
16
+ PeftModelType = PeftModel
17
+ else:
18
+ PeftModelType = "PeftModel"
19
+
12
20
 
13
21
  ModuleList = list[EvolvableModule]
14
22
  _Optimizer = Union[
15
23
  type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
16
24
  ]
17
- _Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModel]
25
+ _Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModelType]
18
26
 
19
27
 
20
28
  def init_from_multiple(
@@ -9,7 +9,7 @@ from torch.optim import Optimizer
9
9
 
10
10
  from agilerl.protocols import EvolvableAlgorithm
11
11
  from agilerl.typing import NetworkType
12
- from agilerl.utils.llm_utils import DummyOptimizer
12
+ from agilerl.utils.algo_utils import DummyOptimizer
13
13
 
14
14
 
15
15
  @dataclass
@@ -1,28 +1,76 @@
1
1
  import gc
2
+ from typing import Any
2
3
 
3
4
  import numpy as np
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from accelerate import Accelerator
7
- from gymnasium import spaces
8
- from peft import LoraConfig
9
- from transformers import PreTrainedModel
10
8
 
11
9
  from agilerl.algorithms.core.base import LLMAlgorithm
12
10
  from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
11
+ from agilerl.protocols import LoraConfigProtocol, PreTrainedModelProtocol
13
12
  from agilerl.typing import ExperiencesType, LLMObsType
14
13
  from agilerl.utils.algo_utils import get_experiences_samples
15
14
  from agilerl.utils.llm_utils import PreferenceGym
16
15
 
17
16
 
18
17
  class DPO(LLMAlgorithm):
18
+ """The DPO algorithm class. DPO paper: https://arxiv.org/pdf/2305.18290
19
+
20
+ :param pad_token_id: Pad token id
21
+ :type pad_token_id: int
22
+ :param pad_token: Pad token
23
+ :type pad_token: str
24
+ :param model_name: Model name
25
+ :type model_name: str, optional
26
+ :param actor_network: HuggingFace LLM
27
+ :type actor_network: PreTrainedModelProtocol
28
+ :param model_config: Model configuration, to be used when creating the model from a name or path
29
+ :param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
30
+ :type hp_config: HyperparameterConfig, optional
31
+ :param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
32
+ :type index: int, optional
33
+ :param batch_size: Batch size for training, defaults to 16
34
+ :type batch_size: int, optional
35
+ :param lr: Learning rate, defaults to 0.000005
36
+ :type lr: float, optional
37
+ :param beta: Beta parameter for DPO, defaults to 0.001
38
+ :type beta: float, optional
39
+ :param max_grad_norm: Maximum gradient norm, defaults to 0.1
40
+ :type max_grad_norm: float, optional
41
+ :param update_epochs: Number of update epochs, defaults to 1
42
+ :type update_epochs: int, optional
43
+ :param calc_position_embeddings: Flag to indicate if position embeddings should be calculated, defaults to True
44
+ :type calc_position_embeddings: bool, optional
45
+ :param micro_batch_size_per_gpu: Micro batch size per GPU, defaults to None
46
+ :type micro_batch_size_per_gpu: int, optional
47
+ :param reduce_memory_peak: Flag to indicate if memory peak should be reduced, defaults to False
48
+ :type reduce_memory_peak: bool, optional
49
+ :param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
50
+ :type device: str, optional
51
+ :param lora_config: Config for LoRA, defaults to None
52
+ :type lora_config: LoraConfigProtocol, optional
53
+ :param accelerator: Accelerator for distributed computing, defaults to None
54
+ :type accelerator: accelerate.Accelerator(), optional
55
+ :param wrap: Wrap models for distributed training upon creation, defaults to True
56
+ :type wrap: bool, optional
57
+ :param clone: Flag to indicate if the instantiation is a cloning, defaults to False
58
+ :type clone: bool, optional
59
+ :param use_separate_reference_adapter: Flag to indicate if the reference policy should have a separate adapter, defaults to False
60
+ :type use_separate_reference_adapter: bool, optional
61
+ :param seed: Seed for the random number generator, defaults to 42
62
+ :type seed: int, optional
63
+ :param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
64
+ :type gradient_checkpointing: bool, optional
65
+ """
66
+
19
67
  def __init__(
20
68
  self,
21
- observation_space: spaces.Space,
22
- action_space: spaces.Space,
23
- actor_network: PreTrainedModel,
24
69
  pad_token_id: int,
25
70
  pad_token: str,
71
+ model_name: str | None = None,
72
+ actor_network: PreTrainedModelProtocol | None = None,
73
+ model_config: dict[str, Any] | None = None,
26
74
  hp_config: HyperparameterConfig | None = None,
27
75
  index: int = 0,
28
76
  batch_size: int = 16,
@@ -34,12 +82,13 @@ class DPO(LLMAlgorithm):
34
82
  micro_batch_size_per_gpu: int | None = None,
35
83
  reduce_memory_peak: bool = False,
36
84
  device: str = "cpu",
37
- lora_config: LoraConfig | None = None,
85
+ lora_config: LoraConfigProtocol | None = None,
38
86
  accelerator: Accelerator | None = None,
39
87
  wrap: bool = True,
40
88
  clone: bool = False,
41
89
  use_separate_reference_adapter: bool = False,
42
90
  seed: int = 42,
91
+ gradient_checkpointing: bool = True,
43
92
  ):
44
93
  device = (
45
94
  f"cuda:{accelerator.process_index}"
@@ -47,9 +96,6 @@ class DPO(LLMAlgorithm):
47
96
  else ("cuda" if torch.cuda.is_available() else "cpu")
48
97
  )
49
98
  super().__init__(
50
- observation_space,
51
- action_space,
52
- actor_network,
53
99
  index=index,
54
100
  batch_size=batch_size,
55
101
  lr=lr,
@@ -62,6 +108,9 @@ class DPO(LLMAlgorithm):
62
108
  pad_token=pad_token,
63
109
  lora_config=lora_config,
64
110
  use_separate_reference_adapter=use_separate_reference_adapter,
111
+ model_name=model_name,
112
+ actor_network=actor_network,
113
+ model_config=model_config,
65
114
  micro_batch_size_per_gpu=micro_batch_size_per_gpu,
66
115
  cosine_lr_schedule_config=None,
67
116
  hp_config=hp_config,
@@ -69,6 +118,7 @@ class DPO(LLMAlgorithm):
69
118
  device=device,
70
119
  accelerator=accelerator,
71
120
  name="DPO",
121
+ gradient_checkpointing=gradient_checkpointing,
72
122
  )
73
123
  self.beta = beta
74
124
  self.temperature = (