agilerl 2.4.1.dev0__py3-none-any.whl → 2.4.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agilerl/__init__.py +18 -0
- agilerl/algorithms/core/base.py +97 -37
- agilerl/algorithms/core/optimizer_wrapper.py +11 -3
- agilerl/algorithms/core/registry.py +1 -1
- agilerl/algorithms/dpo.py +5 -6
- agilerl/algorithms/grpo.py +15 -16
- agilerl/algorithms/ilql.py +14 -0
- agilerl/protocols.py +131 -0
- agilerl/utils/algo_utils.py +51 -4
- agilerl/utils/llm_utils.py +15 -46
- agilerl/utils/utils.py +2 -2
- {agilerl-2.4.1.dev0.dist-info → agilerl-2.4.1.dev2.dist-info}/METADATA +25 -10
- {agilerl-2.4.1.dev0.dist-info → agilerl-2.4.1.dev2.dist-info}/RECORD +15 -15
- {agilerl-2.4.1.dev0.dist-info → agilerl-2.4.1.dev2.dist-info}/WHEEL +1 -1
- {agilerl-2.4.1.dev0.dist-info → agilerl-2.4.1.dev2.dist-info/licenses}/LICENSE +0 -0
agilerl/__init__.py
CHANGED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from importlib.metadata import metadata
|
|
2
|
+
from importlib.util import find_spec
|
|
3
|
+
|
|
4
|
+
from packaging.requirements import Requirement
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_extra_dependencies(package: str, extra: str) -> list[str]:
|
|
8
|
+
requires = metadata(package).get_all("Requires-Dist") or []
|
|
9
|
+
deps = []
|
|
10
|
+
for req in requires:
|
|
11
|
+
r = Requirement(req)
|
|
12
|
+
if r.marker and r.marker.evaluate({"extra": extra}):
|
|
13
|
+
deps.append(r.name)
|
|
14
|
+
return deps
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
LLM_PACKAGES = get_extra_dependencies("agilerl", "llm")
|
|
18
|
+
HAS_LLM_DEPENDENCIES = all(find_spec(pkg) is not None for pkg in LLM_PACKAGES)
|
agilerl/algorithms/core/base.py
CHANGED
|
@@ -27,20 +27,14 @@ import torch
|
|
|
27
27
|
import torch.nn.functional as F
|
|
28
28
|
from accelerate import Accelerator
|
|
29
29
|
from accelerate.utils import broadcast_object_list, set_seed
|
|
30
|
-
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
31
|
-
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
32
30
|
from gymnasium import spaces
|
|
33
|
-
from peft import LoraConfig, PeftModel, get_peft_model, set_peft_model_state_dict
|
|
34
|
-
from safetensors.torch import load_file
|
|
35
31
|
from tensordict import TensorDict
|
|
36
32
|
from torch._dynamo import OptimizedModule
|
|
37
33
|
from torch.nn.utils import clip_grad_norm_
|
|
38
34
|
from torch.optim import AdamW
|
|
39
35
|
from torch.optim.lr_scheduler import SequentialLR
|
|
40
|
-
from transformers import PretrainedConfig
|
|
41
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
42
|
-
from vllm import LLM, SamplingParams
|
|
43
36
|
|
|
37
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
44
38
|
from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
|
|
45
39
|
from agilerl.algorithms.core.registry import (
|
|
46
40
|
HyperparameterConfig,
|
|
@@ -55,7 +49,11 @@ from agilerl.protocols import (
|
|
|
55
49
|
EvolvableAttributeDict,
|
|
56
50
|
EvolvableAttributeType,
|
|
57
51
|
EvolvableModule,
|
|
52
|
+
LoraConfigProtocol,
|
|
58
53
|
ModuleDict,
|
|
54
|
+
PeftModelProtocol,
|
|
55
|
+
PretrainedConfigProtocol,
|
|
56
|
+
PreTrainedModelProtocol,
|
|
59
57
|
)
|
|
60
58
|
from agilerl.typing import (
|
|
61
59
|
ActionType,
|
|
@@ -74,6 +72,7 @@ from agilerl.typing import (
|
|
|
74
72
|
)
|
|
75
73
|
from agilerl.utils.algo_utils import (
|
|
76
74
|
CosineLRScheduleConfig,
|
|
75
|
+
DummyOptimizer,
|
|
77
76
|
VLLMConfig,
|
|
78
77
|
check_supported_space,
|
|
79
78
|
chkpt_attribute_to_device,
|
|
@@ -96,11 +95,18 @@ from agilerl.utils.evolvable_networks import (
|
|
|
96
95
|
is_image_space,
|
|
97
96
|
is_vector_space,
|
|
98
97
|
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
98
|
+
|
|
99
|
+
if HAS_LLM_DEPENDENCIES:
|
|
100
|
+
from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
|
|
101
|
+
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
|
|
102
|
+
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
|
103
|
+
from safetensors.torch import load_file
|
|
104
|
+
from vllm import LLM, SamplingParams
|
|
105
|
+
|
|
106
|
+
from agilerl.utils.llm_utils import (
|
|
107
|
+
create_model_from_name_or_path,
|
|
108
|
+
gather_if_zero3,
|
|
109
|
+
)
|
|
104
110
|
|
|
105
111
|
__all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
|
|
106
112
|
|
|
@@ -601,14 +607,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
601
607
|
)
|
|
602
608
|
optimizer = opt.optimizer if hasattr(opt, "optimizer") else None
|
|
603
609
|
|
|
604
|
-
if isinstance(
|
|
605
|
-
if
|
|
606
|
-
|
|
610
|
+
if isinstance(self, LLMAlgorithm):
|
|
611
|
+
if hasattr(self.actor, "optimizer"):
|
|
612
|
+
optimizer = getattr(
|
|
607
613
|
getattr(self, "actor"), "optimizer"
|
|
608
614
|
) # If the optimizer is defined in the deepspeed config, we do this
|
|
615
|
+
else:
|
|
616
|
+
optimizer = opt.optimizer
|
|
609
617
|
|
|
610
618
|
self.accelerator, self.lr_scheduler = LLMAlgorithm.update_lr(
|
|
611
|
-
|
|
619
|
+
optimizer,
|
|
612
620
|
lr=getattr(self, config.lr),
|
|
613
621
|
accelerator=self.accelerator,
|
|
614
622
|
scheduler_config=self.cosine_lr_schedule_config,
|
|
@@ -1143,6 +1151,16 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
|
|
|
1143
1151
|
|
|
1144
1152
|
return self
|
|
1145
1153
|
|
|
1154
|
+
def clean_up(self) -> None:
|
|
1155
|
+
"""
|
|
1156
|
+
Clean up the algorithm by deleting the networks and optimizers.
|
|
1157
|
+
|
|
1158
|
+
:return: None
|
|
1159
|
+
:rtype: None
|
|
1160
|
+
"""
|
|
1161
|
+
for evo_attr in self.evolvable_attributes().values():
|
|
1162
|
+
del evo_attr
|
|
1163
|
+
|
|
1146
1164
|
|
|
1147
1165
|
class RLAlgorithm(EvolvableAlgorithm, ABC):
|
|
1148
1166
|
"""Base object for all single-agent algorithms in the AgileRL framework.
|
|
@@ -1799,6 +1817,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1799
1817
|
:type accelerator: Optional[Accelerator]
|
|
1800
1818
|
:param name: The name of the algorithm.
|
|
1801
1819
|
:type name: Optional[str]
|
|
1820
|
+
:param model_config: The configuration for the model.
|
|
1821
|
+
:type model_config: dict[str, Any] | PretrainedConfig | None
|
|
1822
|
+
:param gradient_checkpointing: Whether to use gradient checkpointing.
|
|
1823
|
+
:type gradient_checkpointing: bool
|
|
1802
1824
|
"""
|
|
1803
1825
|
|
|
1804
1826
|
def __init__(
|
|
@@ -1813,10 +1835,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1813
1835
|
seed: int,
|
|
1814
1836
|
pad_token_id: int,
|
|
1815
1837
|
pad_token: str,
|
|
1816
|
-
lora_config:
|
|
1838
|
+
lora_config: LoraConfigProtocol | None,
|
|
1817
1839
|
use_separate_reference_adapter: bool,
|
|
1818
1840
|
model_name: str | None = None,
|
|
1819
|
-
actor_network:
|
|
1841
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
1820
1842
|
micro_batch_size_per_gpu: int | None = None,
|
|
1821
1843
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
1822
1844
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
@@ -1824,9 +1846,14 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1824
1846
|
device: Union[str, torch.device] = "cpu",
|
|
1825
1847
|
accelerator: Optional[Accelerator] = None,
|
|
1826
1848
|
name: Optional[str] = None,
|
|
1827
|
-
model_config: dict[str, Any] |
|
|
1849
|
+
model_config: dict[str, Any] | PretrainedConfigProtocol | None = None,
|
|
1828
1850
|
gradient_checkpointing: bool = True,
|
|
1829
1851
|
):
|
|
1852
|
+
if not HAS_LLM_DEPENDENCIES:
|
|
1853
|
+
raise ImportError(
|
|
1854
|
+
"LLM dependencies are not installed. Please install them using `pip install agilerl[llm]`."
|
|
1855
|
+
)
|
|
1856
|
+
|
|
1830
1857
|
if model_name is None and actor_network is None:
|
|
1831
1858
|
raise ValueError(
|
|
1832
1859
|
"At least one of model_name or actor_network must be provided."
|
|
@@ -1881,7 +1908,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1881
1908
|
)
|
|
1882
1909
|
lr = optim_lr
|
|
1883
1910
|
|
|
1884
|
-
if lora_config is None and not isinstance(actor_network,
|
|
1911
|
+
if lora_config is None and not isinstance(actor_network, PeftModelProtocol):
|
|
1885
1912
|
warnings.warn(
|
|
1886
1913
|
"No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
|
|
1887
1914
|
)
|
|
@@ -1898,15 +1925,21 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1898
1925
|
self.use_separate_reference_adapter = use_separate_reference_adapter
|
|
1899
1926
|
self.cosine_lr_schedule_config = cosine_lr_schedule_config
|
|
1900
1927
|
|
|
1901
|
-
if max_grad_norm and (accelerator is not None)
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1928
|
+
if max_grad_norm and (accelerator is not None):
|
|
1929
|
+
if accelerator.is_main_process:
|
|
1930
|
+
warnings.warn(
|
|
1931
|
+
"Argument 'max_grad_norm' will overwrite the equivalent value set for 'gradient_clipping' in the deepspeed config."
|
|
1932
|
+
)
|
|
1933
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1934
|
+
"gradient_clipping"
|
|
1935
|
+
] = max_grad_norm
|
|
1936
|
+
|
|
1937
|
+
self.max_grad_norm = max_grad_norm
|
|
1908
1938
|
self.reduce_memory_peak = reduce_memory_peak
|
|
1909
1939
|
|
|
1940
|
+
if self.accelerator is not None:
|
|
1941
|
+
self.register_mutation_hook(self._sync_deepspeed_gradient_clipping)
|
|
1942
|
+
|
|
1910
1943
|
if self.accelerator is not None:
|
|
1911
1944
|
self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
1912
1945
|
"zero_optimization"
|
|
@@ -2041,7 +2074,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2041
2074
|
device_map="auto"
|
|
2042
2075
|
)
|
|
2043
2076
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
|
2044
|
-
model =
|
|
2077
|
+
model = PeftModelProtocol.from_pretrained(base_model, path)
|
|
2045
2078
|
"""
|
|
2046
2079
|
)
|
|
2047
2080
|
|
|
@@ -2153,6 +2186,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2153
2186
|
def clean_up(self) -> None:
|
|
2154
2187
|
"""Clean up the algorithm."""
|
|
2155
2188
|
if self.accelerator is not None:
|
|
2189
|
+
# Free up GPU memory occupied by parameters
|
|
2190
|
+
if hasattr(self.actor, "empty_partition_cache"):
|
|
2191
|
+
self.actor.empty_partition_cache()
|
|
2192
|
+
if hasattr(self.actor, "destroy"):
|
|
2193
|
+
self.actor.destroy()
|
|
2156
2194
|
(
|
|
2157
2195
|
self.actor,
|
|
2158
2196
|
self.optimizer,
|
|
@@ -2176,10 +2214,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2176
2214
|
if hasattr(self, "llm"):
|
|
2177
2215
|
del self.llm.llm_engine.model_executor
|
|
2178
2216
|
del self.llm
|
|
2179
|
-
|
|
2180
2217
|
gc.collect()
|
|
2181
2218
|
torch.cuda.empty_cache()
|
|
2182
|
-
torch.cuda.reset_peak_memory_stats()
|
|
2183
2219
|
torch.cuda.synchronize()
|
|
2184
2220
|
|
|
2185
2221
|
def clone(self, index: Optional[int] = None, wrap: bool = True):
|
|
@@ -2214,8 +2250,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2214
2250
|
input_args["wrap"] = False
|
|
2215
2251
|
input_args["clone"] = True
|
|
2216
2252
|
|
|
2217
|
-
actor:
|
|
2218
|
-
|
|
2253
|
+
actor: PeftModelProtocol = cast(
|
|
2254
|
+
PeftModelProtocol,
|
|
2219
2255
|
(
|
|
2220
2256
|
self.accelerator.unwrap_model(self.actor)
|
|
2221
2257
|
if self.accelerator is not None
|
|
@@ -2407,12 +2443,12 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2407
2443
|
self.reference_update_tracker += 1
|
|
2408
2444
|
|
|
2409
2445
|
def _initialize_actors(
|
|
2410
|
-
self, base_model:
|
|
2446
|
+
self, base_model: PreTrainedModelProtocol | None, add_adapters: bool = True
|
|
2411
2447
|
):
|
|
2412
2448
|
"""Initialize the actor network.
|
|
2413
2449
|
|
|
2414
2450
|
:param base_model: Base model
|
|
2415
|
-
:type base_model:
|
|
2451
|
+
:type base_model: PreTrainedModelProtocol
|
|
2416
2452
|
:param add_adapters: Flag to indicate if adapters should be added to the model, defaults to True
|
|
2417
2453
|
:type add_adapters: bool, optional
|
|
2418
2454
|
"""
|
|
@@ -2422,7 +2458,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2422
2458
|
self.pretrained_model_name_or_path
|
|
2423
2459
|
)
|
|
2424
2460
|
|
|
2425
|
-
if isinstance(base_model,
|
|
2461
|
+
if isinstance(base_model, PeftModelProtocol) and add_adapters:
|
|
2426
2462
|
# Handles backwards compatibility with user providing a peft model as the actor network
|
|
2427
2463
|
if self.lora_config is None:
|
|
2428
2464
|
adapter_name = list(base_model.peft_config.keys())
|
|
@@ -2432,7 +2468,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2432
2468
|
if "default" in list(base_model.peft_config.keys()):
|
|
2433
2469
|
base_model.peft_config.pop("default")
|
|
2434
2470
|
|
|
2435
|
-
self.actor:
|
|
2471
|
+
self.actor: PeftModelProtocol = (
|
|
2436
2472
|
get_peft_model(base_model, self.lora_config, adapter_name="actor")
|
|
2437
2473
|
if add_adapters
|
|
2438
2474
|
else base_model
|
|
@@ -2581,7 +2617,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2581
2617
|
def _move_model_to_vllm(self) -> None:
|
|
2582
2618
|
"""Move the deepspeed model to vllm."""
|
|
2583
2619
|
|
|
2584
|
-
# TODO: Add support for ZeRO Stage 3
|
|
2585
2620
|
if self.accelerator is not None:
|
|
2586
2621
|
self.accelerator.wait_for_everyone()
|
|
2587
2622
|
model_ref = self.accelerator.unwrap_model(self.actor)
|
|
@@ -2949,3 +2984,28 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2949
2984
|
|
|
2950
2985
|
if self.accelerator is not None:
|
|
2951
2986
|
self.accelerator.wait_for_everyone()
|
|
2987
|
+
|
|
2988
|
+
def _sync_deepspeed_gradient_clipping(self) -> None:
|
|
2989
|
+
"""Synchronizes max_grad_norm with DeepSpeed gradient_clipping config.
|
|
2990
|
+
Registered as a mutation hook to ensure consistency after mutations.
|
|
2991
|
+
"""
|
|
2992
|
+
if self.accelerator is None:
|
|
2993
|
+
return
|
|
2994
|
+
|
|
2995
|
+
if (
|
|
2996
|
+
"gradient_clipping"
|
|
2997
|
+
not in self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
2998
|
+
):
|
|
2999
|
+
return
|
|
3000
|
+
|
|
3001
|
+
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
|
3002
|
+
if ds_config["gradient_clipping"] != self.max_grad_norm:
|
|
3003
|
+
self.accelerator.state.deepspeed_plugin.deepspeed_config[
|
|
3004
|
+
"gradient_clipping"
|
|
3005
|
+
] = self.max_grad_norm
|
|
3006
|
+
|
|
3007
|
+
if hasattr(self.actor, "optimizer"):
|
|
3008
|
+
if hasattr(self.actor.optimizer, "grad_clip"):
|
|
3009
|
+
self.actor.optimizer.grad_clip = self.max_grad_norm
|
|
3010
|
+
if hasattr(self.actor.optimizer, "clip_grad"):
|
|
3011
|
+
self.actor.optimizer.clip_grad = self.max_grad_norm
|
|
@@ -2,19 +2,27 @@ import inspect
|
|
|
2
2
|
from typing import Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import torch.nn as nn
|
|
5
|
-
from peft import PeftModel
|
|
6
5
|
from torch.optim import Optimizer
|
|
7
6
|
|
|
7
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
8
8
|
from agilerl.modules import EvolvableModule, ModuleDict
|
|
9
9
|
from agilerl.protocols import EvolvableAlgorithm
|
|
10
10
|
from agilerl.typing import OptimizerType, StateDict
|
|
11
|
-
from agilerl.utils.
|
|
11
|
+
from agilerl.utils.algo_utils import DummyOptimizer
|
|
12
|
+
|
|
13
|
+
if HAS_LLM_DEPENDENCIES:
|
|
14
|
+
from peft import PeftModel
|
|
15
|
+
|
|
16
|
+
PeftModelType = PeftModel
|
|
17
|
+
else:
|
|
18
|
+
PeftModelType = "PeftModel"
|
|
19
|
+
|
|
12
20
|
|
|
13
21
|
ModuleList = list[EvolvableModule]
|
|
14
22
|
_Optimizer = Union[
|
|
15
23
|
type[OptimizerType], dict[str, type[OptimizerType]], type[DummyOptimizer]
|
|
16
24
|
]
|
|
17
|
-
_Module = Union[EvolvableModule, ModuleDict, ModuleList,
|
|
25
|
+
_Module = Union[EvolvableModule, ModuleDict, ModuleList, PeftModelType]
|
|
18
26
|
|
|
19
27
|
|
|
20
28
|
def init_from_multiple(
|
agilerl/algorithms/dpo.py
CHANGED
|
@@ -5,11 +5,10 @@ import numpy as np
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from accelerate import Accelerator
|
|
8
|
-
from peft import LoraConfig
|
|
9
|
-
from transformers import PreTrainedModel
|
|
10
8
|
|
|
11
9
|
from agilerl.algorithms.core.base import LLMAlgorithm
|
|
12
10
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
+
from agilerl.protocols import LoraConfigProtocol, PreTrainedModelProtocol
|
|
13
12
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
14
13
|
from agilerl.utils.algo_utils import get_experiences_samples
|
|
15
14
|
from agilerl.utils.llm_utils import PreferenceGym
|
|
@@ -25,7 +24,7 @@ class DPO(LLMAlgorithm):
|
|
|
25
24
|
:param model_name: Model name
|
|
26
25
|
:type model_name: str, optional
|
|
27
26
|
:param actor_network: HuggingFace LLM
|
|
28
|
-
:type actor_network:
|
|
27
|
+
:type actor_network: PreTrainedModelProtocol
|
|
29
28
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
30
29
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
31
30
|
:type hp_config: HyperparameterConfig, optional
|
|
@@ -50,7 +49,7 @@ class DPO(LLMAlgorithm):
|
|
|
50
49
|
:param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
|
|
51
50
|
:type device: str, optional
|
|
52
51
|
:param lora_config: Config for LoRA, defaults to None
|
|
53
|
-
:type lora_config:
|
|
52
|
+
:type lora_config: LoraConfigProtocol, optional
|
|
54
53
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
55
54
|
:type accelerator: accelerate.Accelerator(), optional
|
|
56
55
|
:param wrap: Wrap models for distributed training upon creation, defaults to True
|
|
@@ -70,7 +69,7 @@ class DPO(LLMAlgorithm):
|
|
|
70
69
|
pad_token_id: int,
|
|
71
70
|
pad_token: str,
|
|
72
71
|
model_name: str | None = None,
|
|
73
|
-
actor_network:
|
|
72
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
74
73
|
model_config: dict[str, Any] | None = None,
|
|
75
74
|
hp_config: HyperparameterConfig | None = None,
|
|
76
75
|
index: int = 0,
|
|
@@ -83,7 +82,7 @@ class DPO(LLMAlgorithm):
|
|
|
83
82
|
micro_batch_size_per_gpu: int | None = None,
|
|
84
83
|
reduce_memory_peak: bool = False,
|
|
85
84
|
device: str = "cpu",
|
|
86
|
-
lora_config:
|
|
85
|
+
lora_config: LoraConfigProtocol | None = None,
|
|
87
86
|
accelerator: Accelerator | None = None,
|
|
88
87
|
wrap: bool = True,
|
|
89
88
|
clone: bool = False,
|
agilerl/algorithms/grpo.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from typing import Any, Optional
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
6
|
from accelerate import Accelerator
|
|
7
|
-
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
|
|
8
|
-
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
9
|
-
from peft import LoraConfig, PeftModel
|
|
10
|
-
from transformers import GenerationConfig
|
|
11
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
12
7
|
|
|
8
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
13
9
|
from agilerl.algorithms.core import LLMAlgorithm
|
|
14
10
|
from agilerl.algorithms.core.registry import HyperparameterConfig, NetworkGroup
|
|
11
|
+
from agilerl.protocols import (
|
|
12
|
+
LoraConfigProtocol,
|
|
13
|
+
PeftModelProtocol,
|
|
14
|
+
PreTrainedModelProtocol,
|
|
15
|
+
)
|
|
15
16
|
from agilerl.typing import ExperiencesType, LLMObsType
|
|
16
17
|
from agilerl.utils.algo_utils import (
|
|
17
18
|
CosineLRScheduleConfig,
|
|
@@ -23,10 +24,8 @@ from agilerl.utils.llm_utils import (
|
|
|
23
24
|
ReasoningGym,
|
|
24
25
|
)
|
|
25
26
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
DeepSpeedZeroOptimizer_Stage3, # ZeRO Stage 3 optimizer
|
|
29
|
-
]
|
|
27
|
+
if HAS_LLM_DEPENDENCIES:
|
|
28
|
+
from transformers import GenerationConfig
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
class GRPO(LLMAlgorithm):
|
|
@@ -39,7 +38,7 @@ class GRPO(LLMAlgorithm):
|
|
|
39
38
|
:param model_name: Model name
|
|
40
39
|
:type model_name: str, optional
|
|
41
40
|
:param actor_network: HuggingFace LLM
|
|
42
|
-
:type actor_network:
|
|
41
|
+
:type actor_network: PreTrainedModelProtocol
|
|
43
42
|
:param model_config: Model configuration, to be used when creating the model from a name or path
|
|
44
43
|
:type model_config: dict[str, Any], optional
|
|
45
44
|
:param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
|
|
@@ -77,7 +76,7 @@ class GRPO(LLMAlgorithm):
|
|
|
77
76
|
:param max_model_len: Maximum context window length, defaults to None
|
|
78
77
|
:type max_model_len: int, optional
|
|
79
78
|
:param lora_config: Config for LoRA, defaults to None
|
|
80
|
-
:type lora_config:
|
|
79
|
+
:type lora_config: LoraConfigProtocol, optional
|
|
81
80
|
:param cosine_lr_schedule_config: Config for cosine lr scheduling, defaults to None
|
|
82
81
|
:type cosine_lr_schedule_config: CosineLRScheduleConfig, optional
|
|
83
82
|
:param accelerator: Accelerator for distributed computing, defaults to None
|
|
@@ -105,7 +104,7 @@ class GRPO(LLMAlgorithm):
|
|
|
105
104
|
pad_token_id: int,
|
|
106
105
|
pad_token: str,
|
|
107
106
|
model_name: str | None = None,
|
|
108
|
-
actor_network:
|
|
107
|
+
actor_network: PreTrainedModelProtocol | None = None,
|
|
109
108
|
model_config: dict[str, Any] | None = None,
|
|
110
109
|
hp_config: Optional[HyperparameterConfig] = None,
|
|
111
110
|
index: int = 0,
|
|
@@ -127,7 +126,7 @@ class GRPO(LLMAlgorithm):
|
|
|
127
126
|
max_output_tokens: int | None = 1024,
|
|
128
127
|
min_output_tokens: Optional[int] = None,
|
|
129
128
|
max_model_len: Optional[int] = None,
|
|
130
|
-
lora_config: Optional[
|
|
129
|
+
lora_config: Optional[LoraConfigProtocol] = None,
|
|
131
130
|
cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
|
|
132
131
|
accelerator: Optional[Accelerator] = None,
|
|
133
132
|
device: str = "cpu",
|
|
@@ -188,8 +187,8 @@ class GRPO(LLMAlgorithm):
|
|
|
188
187
|
), "Policy update epochs must be greater than or equal to one."
|
|
189
188
|
if actor_network is not None:
|
|
190
189
|
assert isinstance(
|
|
191
|
-
actor_network, (
|
|
192
|
-
), "Actor network must be a
|
|
190
|
+
actor_network, (PeftModelProtocol, PreTrainedModelProtocol)
|
|
191
|
+
), "Actor network must be a PeftModelProtocol or PreTrainedModelProtocol"
|
|
193
192
|
|
|
194
193
|
self.clip_coef = clip_coef
|
|
195
194
|
self.update_epochs = update_epochs
|
agilerl/algorithms/ilql.py
CHANGED
|
@@ -1223,6 +1223,20 @@ class ILQL(nn.Module):
|
|
|
1223
1223
|
self.fitness = checkpoint["fitness"]
|
|
1224
1224
|
self.steps = checkpoint["steps"]
|
|
1225
1225
|
|
|
1226
|
+
def clean_up(self) -> None:
|
|
1227
|
+
"""Clean up the networks"""
|
|
1228
|
+
del self.model
|
|
1229
|
+
del self.actor
|
|
1230
|
+
del self.actor_target
|
|
1231
|
+
del self.v
|
|
1232
|
+
del self.q
|
|
1233
|
+
del self.target_q
|
|
1234
|
+
del self.pi
|
|
1235
|
+
del self.optimizer
|
|
1236
|
+
if self.double_q:
|
|
1237
|
+
del self.q2
|
|
1238
|
+
del self.target_q2
|
|
1239
|
+
|
|
1226
1240
|
|
|
1227
1241
|
class ILQL_Policy:
|
|
1228
1242
|
def __init__(self, iql_model: ILQL, kind: str, **generation_kwargs) -> None:
|
agilerl/protocols.py
CHANGED
|
@@ -299,3 +299,134 @@ class AgentWrapper(Protocol, Generic[T_EvolvableAlgorithm]):
|
|
|
299
299
|
def learn(
|
|
300
300
|
self, experiences: tuple[Iterable[ObservationType], ...], **kwargs
|
|
301
301
|
) -> None: ...
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@runtime_checkable
|
|
305
|
+
class LoraConfigProtocol(Protocol):
|
|
306
|
+
"""
|
|
307
|
+
"Protocol for LoRA configuration.
|
|
308
|
+
|
|
309
|
+
LoRA configuration is used to configure the LoRA module.
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
r: int
|
|
313
|
+
lora_alpha: int
|
|
314
|
+
target_modules: str
|
|
315
|
+
task_type: str
|
|
316
|
+
lora_dropout: float
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@runtime_checkable
|
|
320
|
+
class PretrainedConfigProtocol(Protocol):
|
|
321
|
+
"""Protocol for HuggingFace pre-trained model configuration.
|
|
322
|
+
|
|
323
|
+
Defines the interface for model configuration objects from HuggingFace transformers.
|
|
324
|
+
These configs store model architecture parameters and can be converted to/from dictionaries.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
# Common model architecture attributes (these are examples - actual configs may have more)
|
|
328
|
+
vocab_size: int
|
|
329
|
+
hidden_size: int
|
|
330
|
+
num_attention_heads: int
|
|
331
|
+
num_hidden_layers: int
|
|
332
|
+
|
|
333
|
+
def to_dict(self) -> dict[str, Any]: ...
|
|
334
|
+
def to_json_string(self) -> str: ...
|
|
335
|
+
def save_pretrained(self, save_directory: str, **kwargs: Any) -> None: ...
|
|
336
|
+
|
|
337
|
+
@classmethod
|
|
338
|
+
def from_pretrained(
|
|
339
|
+
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
|
340
|
+
) -> "PretrainedConfigProtocol": ...
|
|
341
|
+
|
|
342
|
+
@classmethod
|
|
343
|
+
def from_dict(
|
|
344
|
+
cls, config_dict: dict[str, Any], **kwargs: Any
|
|
345
|
+
) -> "PretrainedConfigProtocol": ...
|
|
346
|
+
|
|
347
|
+
@classmethod
|
|
348
|
+
def from_json_file(cls, json_file: str) -> "PretrainedConfigProtocol": ...
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@runtime_checkable
|
|
352
|
+
class GenerationConfigProtocol(Protocol):
|
|
353
|
+
"""Protocol for text generation configuration.
|
|
354
|
+
|
|
355
|
+
Used to configure parameters for text generation in language models.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
do_sample: bool
|
|
359
|
+
temperature: float
|
|
360
|
+
max_length: Optional[int]
|
|
361
|
+
max_new_tokens: Optional[int]
|
|
362
|
+
min_new_tokens: Optional[int]
|
|
363
|
+
pad_token_id: int
|
|
364
|
+
repetition_penalty: float
|
|
365
|
+
top_p: float
|
|
366
|
+
top_k: int
|
|
367
|
+
min_p: float
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@runtime_checkable
|
|
371
|
+
class PreTrainedModelProtocol(Protocol):
|
|
372
|
+
"""Protocol for HuggingFace pre-trained models.
|
|
373
|
+
|
|
374
|
+
Defines the interface for pre-trained transformer models from HuggingFace.
|
|
375
|
+
These models support text generation, state management, and device operations.
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
device: DeviceType
|
|
379
|
+
config: Any
|
|
380
|
+
|
|
381
|
+
def eval(self) -> "PreTrainedModelProtocol": ...
|
|
382
|
+
def train(self, mode: bool = True) -> "PreTrainedModelProtocol": ...
|
|
383
|
+
def generate(
|
|
384
|
+
self,
|
|
385
|
+
input_ids: torch.Tensor,
|
|
386
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
387
|
+
generation_config: Optional["GenerationConfigProtocol"] = None,
|
|
388
|
+
**kwargs: Any
|
|
389
|
+
) -> torch.Tensor: ...
|
|
390
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
391
|
+
def parameters(self) -> Generator: ...
|
|
392
|
+
def state_dict(self) -> dict[str, Any]: ...
|
|
393
|
+
def load_state_dict(
|
|
394
|
+
self, state_dict: dict[str, Any], strict: bool = True
|
|
395
|
+
) -> None: ...
|
|
396
|
+
def to(self, device: DeviceType) -> "PreTrainedModelProtocol": ...
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@runtime_checkable
|
|
400
|
+
class PeftModelProtocol(Protocol):
|
|
401
|
+
"""Protocol for PEFT (Parameter-Efficient Fine-Tuning) models.
|
|
402
|
+
|
|
403
|
+
PEFT models wrap pre-trained models with adapters for efficient fine-tuning.
|
|
404
|
+
They extend PreTrainedModel functionality with adapter-specific operations.
|
|
405
|
+
"""
|
|
406
|
+
|
|
407
|
+
device: DeviceType
|
|
408
|
+
config: Any
|
|
409
|
+
peft_config: dict[str, Any]
|
|
410
|
+
base_model: PreTrainedModelProtocol
|
|
411
|
+
|
|
412
|
+
def eval(self) -> "PeftModelProtocol": ...
|
|
413
|
+
def train(self, mode: bool = True) -> "PeftModelProtocol": ...
|
|
414
|
+
def generate(
|
|
415
|
+
self,
|
|
416
|
+
input_ids: torch.Tensor,
|
|
417
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
418
|
+
generation_config: Optional["GenerationConfigProtocol"] = None,
|
|
419
|
+
**kwargs: Any
|
|
420
|
+
) -> torch.Tensor: ...
|
|
421
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
422
|
+
def parameters(self) -> Generator: ...
|
|
423
|
+
def state_dict(self) -> dict[str, Any]: ...
|
|
424
|
+
def load_state_dict(
|
|
425
|
+
self, state_dict: dict[str, Any], strict: bool = True
|
|
426
|
+
) -> None: ...
|
|
427
|
+
def to(self, device: DeviceType) -> "PeftModelProtocol": ...
|
|
428
|
+
|
|
429
|
+
@classmethod
|
|
430
|
+
def from_pretrained(
|
|
431
|
+
cls, base_model: PreTrainedModelProtocol, adapter_path: str, **kwargs: Any
|
|
432
|
+
) -> "PeftModelProtocol": ...
|
agilerl/utils/algo_utils.py
CHANGED
|
@@ -13,14 +13,13 @@ import torch
|
|
|
13
13
|
import torch.nn as nn
|
|
14
14
|
import torch.nn.functional as F
|
|
15
15
|
from gymnasium import spaces
|
|
16
|
-
from peft import PeftModel, get_peft_model
|
|
17
16
|
from tensordict import TensorDict, from_module
|
|
18
17
|
from tensordict.nn import CudaGraphModule
|
|
19
18
|
from torch._dynamo import OptimizedModule
|
|
20
19
|
from torch.optim import Optimizer
|
|
21
20
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
|
22
|
-
from transformers import PreTrainedModel
|
|
23
21
|
|
|
22
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
24
23
|
from agilerl.modules.dummy import DummyEvolvable
|
|
25
24
|
from agilerl.protocols import (
|
|
26
25
|
EvolvableAttributeType,
|
|
@@ -42,9 +41,16 @@ from agilerl.typing import (
|
|
|
42
41
|
SupportedObsSpaces,
|
|
43
42
|
TorchObsType,
|
|
44
43
|
)
|
|
45
|
-
from agilerl.utils.llm_utils import gather_if_zero3
|
|
46
44
|
|
|
47
|
-
|
|
45
|
+
if HAS_LLM_DEPENDENCIES:
|
|
46
|
+
from peft import PeftModel, get_peft_model
|
|
47
|
+
from transformers import PreTrainedModel
|
|
48
|
+
|
|
49
|
+
from agilerl.utils.llm_utils import gather_if_zero3
|
|
50
|
+
|
|
51
|
+
PreTrainedModelType = Union[PeftModel, PreTrainedModel]
|
|
52
|
+
else:
|
|
53
|
+
PreTrainedModelType = Union["PeftModel", "PreTrainedModel"]
|
|
48
54
|
|
|
49
55
|
|
|
50
56
|
def check_supported_space(observation_space: GymSpaceType) -> None:
|
|
@@ -1629,3 +1635,44 @@ def clone_llm(
|
|
|
1629
1635
|
if state_dict is not None:
|
|
1630
1636
|
model.load_state_dict(state_dict, strict=False)
|
|
1631
1637
|
return model
|
|
1638
|
+
|
|
1639
|
+
|
|
1640
|
+
class DummyOptimizer:
|
|
1641
|
+
"""
|
|
1642
|
+
Placeholder optimizer class to pass to the OptimizerWrapper when the optimizer is defined in the deepspeed config.
|
|
1643
|
+
"""
|
|
1644
|
+
|
|
1645
|
+
def __init__(self, params: list[torch.Tensor], lr: float, **kwargs) -> None:
|
|
1646
|
+
"""
|
|
1647
|
+
Sentinel class to use for the optimizer when the optimizer is defined in the deepspeed config.
|
|
1648
|
+
|
|
1649
|
+
:param params: Parameters to optimize.
|
|
1650
|
+
:type params: list[torch.Tensor]
|
|
1651
|
+
:param lr: Learning rate.
|
|
1652
|
+
:type lr: float
|
|
1653
|
+
"""
|
|
1654
|
+
pass
|
|
1655
|
+
|
|
1656
|
+
def step(self, closure=None):
|
|
1657
|
+
raise RuntimeError(
|
|
1658
|
+
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
1659
|
+
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
1660
|
+
)
|
|
1661
|
+
|
|
1662
|
+
def zero_grad(self):
|
|
1663
|
+
raise RuntimeError(
|
|
1664
|
+
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
1665
|
+
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
1666
|
+
)
|
|
1667
|
+
|
|
1668
|
+
def state_dict(self):
|
|
1669
|
+
raise RuntimeError(
|
|
1670
|
+
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
1671
|
+
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
def load_state_dict(self, state_dict):
|
|
1675
|
+
raise RuntimeError(
|
|
1676
|
+
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
1677
|
+
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
1678
|
+
)
|
agilerl/utils/llm_utils.py
CHANGED
|
@@ -4,19 +4,29 @@ from abc import ABC, abstractmethod
|
|
|
4
4
|
from contextlib import contextmanager
|
|
5
5
|
from typing import Any, Callable, Generator
|
|
6
6
|
|
|
7
|
-
import deepspeed
|
|
8
7
|
import gymnasium as gym
|
|
9
8
|
import torch
|
|
10
9
|
import torch.nn as nn
|
|
11
10
|
from accelerate import Accelerator
|
|
12
|
-
from datasets import Dataset
|
|
13
11
|
from torch.utils.data import DataLoader
|
|
14
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
15
|
-
from transformers.modeling_utils import PreTrainedModel
|
|
16
|
-
from transformers.tokenization_utils_base import BatchEncoding
|
|
17
12
|
|
|
13
|
+
from agilerl import HAS_LLM_DEPENDENCIES
|
|
18
14
|
from agilerl.typing import PreferencePrompts, ReasoningPrompts
|
|
19
15
|
|
|
16
|
+
if HAS_LLM_DEPENDENCIES:
|
|
17
|
+
import deepspeed
|
|
18
|
+
from datasets import Dataset
|
|
19
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
20
|
+
from transformers.modeling_utils import PreTrainedModel
|
|
21
|
+
from transformers.tokenization_utils_base import BatchEncoding
|
|
22
|
+
|
|
23
|
+
AutoTokenizer = AutoTokenizer
|
|
24
|
+
else:
|
|
25
|
+
AutoTokenizer = Any
|
|
26
|
+
PreTrainedModel = Any
|
|
27
|
+
BatchEncoding = Any
|
|
28
|
+
Dataset = Any
|
|
29
|
+
|
|
20
30
|
|
|
21
31
|
def apply_chat_template(
|
|
22
32
|
conversation_template: list[dict[str, str]],
|
|
@@ -614,47 +624,6 @@ class PreferenceGym(HuggingFaceGym):
|
|
|
614
624
|
return collate_fn
|
|
615
625
|
|
|
616
626
|
|
|
617
|
-
class DummyOptimizer:
|
|
618
|
-
"""
|
|
619
|
-
Placeholder optimizer class to pass to the OptimizerWrapper when the optimizer is defined in the deepspeed config.
|
|
620
|
-
"""
|
|
621
|
-
|
|
622
|
-
def __init__(self, params: list[torch.Tensor], lr: float, **kwargs) -> None:
|
|
623
|
-
"""
|
|
624
|
-
Sentinel class to use for the optimizer when the optimizer is defined in the deepspeed config.
|
|
625
|
-
|
|
626
|
-
:param params: Parameters to optimize.
|
|
627
|
-
:type params: list[torch.Tensor]
|
|
628
|
-
:param lr: Learning rate.
|
|
629
|
-
:type lr: float
|
|
630
|
-
"""
|
|
631
|
-
pass
|
|
632
|
-
|
|
633
|
-
def step(self, closure=None):
|
|
634
|
-
raise RuntimeError(
|
|
635
|
-
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
636
|
-
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
637
|
-
)
|
|
638
|
-
|
|
639
|
-
def zero_grad(self):
|
|
640
|
-
raise RuntimeError(
|
|
641
|
-
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
642
|
-
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
643
|
-
)
|
|
644
|
-
|
|
645
|
-
def state_dict(self):
|
|
646
|
-
raise RuntimeError(
|
|
647
|
-
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
648
|
-
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
def load_state_dict(self, state_dict):
|
|
652
|
-
raise RuntimeError(
|
|
653
|
-
"DummyOptimizer is a placeholder optimizer and should not be used."
|
|
654
|
-
"Please ensure you are calling accelerator.prepare() on the optimizer."
|
|
655
|
-
)
|
|
656
|
-
|
|
657
|
-
|
|
658
627
|
@contextmanager
|
|
659
628
|
def gather_if_zero3(
|
|
660
629
|
zero_stage: int, params: list[torch.Tensor], modifier_rank: int | None = None
|
agilerl/utils/utils.py
CHANGED
|
@@ -36,8 +36,8 @@ from agilerl.hpo.mutation import Mutations
|
|
|
36
36
|
from agilerl.hpo.tournament import TournamentSelection
|
|
37
37
|
from agilerl.modules import EvolvableModule
|
|
38
38
|
from agilerl.typing import BPTTSequenceType, GymSpaceType, PopulationType
|
|
39
|
-
from agilerl.utils.algo_utils import CosineLRScheduleConfig, clone_llm
|
|
40
|
-
from agilerl.utils.llm_utils import
|
|
39
|
+
from agilerl.utils.algo_utils import CosineLRScheduleConfig, DummyOptimizer, clone_llm
|
|
40
|
+
from agilerl.utils.llm_utils import get_state_dict
|
|
41
41
|
from agilerl.vector.pz_async_vec_env import AsyncPettingZooVecEnv
|
|
42
42
|
|
|
43
43
|
SupportedObservationSpace = Union[
|
|
@@ -1,20 +1,23 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.1.
|
|
3
|
+
Version: 2.4.1.dev2
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
License: Apache 2.0
|
|
6
|
+
License-File: LICENSE
|
|
6
7
|
Author: Nick Ustaran-Anderegg
|
|
7
8
|
Author-email: dev@agilerl.com
|
|
8
|
-
Requires-Python: >=3.10,<
|
|
9
|
+
Requires-Python: >=3.10,<3.13
|
|
9
10
|
Classifier: License :: Other/Proprietary License
|
|
10
11
|
Classifier: Programming Language :: Python :: 3
|
|
11
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
|
|
15
|
+
Provides-Extra: all
|
|
16
|
+
Provides-Extra: llm
|
|
15
17
|
Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
|
|
16
18
|
Requires-Dist: accelerate (>=1.7.0,<2.0.0)
|
|
17
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: datasets (==4.4.1) ; extra == "llm" or extra == "all"
|
|
20
|
+
Requires-Dist: deepspeed (>=0.17.1,<0.18.0) ; extra == "llm" or extra == "all"
|
|
18
21
|
Requires-Dist: dill (>=0.3.7,<0.4.0)
|
|
19
22
|
Requires-Dist: fastrand (>=1.3.0,<2.0.0)
|
|
20
23
|
Requires-Dist: flatten_dict (>=0.4.2,<0.5.0)
|
|
@@ -24,11 +27,12 @@ Requires-Dist: h5py (>=3.8.0,<4.0.0)
|
|
|
24
27
|
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
|
25
28
|
Requires-Dist: jax[cpu] (>=0.4.31,<0.5.0)
|
|
26
29
|
Requires-Dist: matplotlib (>=3.9.4,<3.10.0)
|
|
27
|
-
Requires-Dist: minari (
|
|
30
|
+
Requires-Dist: minari[all] (==0.5.2)
|
|
28
31
|
Requires-Dist: numpy (>=1.26.4,<2.0.0)
|
|
29
32
|
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
33
|
+
Requires-Dist: packaging (==25.0)
|
|
30
34
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
31
|
-
Requires-Dist: peft (>=0.
|
|
35
|
+
Requires-Dist: peft (>=0.18.0,<0.19.0) ; extra == "llm" or extra == "all"
|
|
32
36
|
Requires-Dist: pettingzoo (>=1.23.1,<2.0.0)
|
|
33
37
|
Requires-Dist: pre-commit (>=3.4.0,<4.0.0)
|
|
34
38
|
Requires-Dist: pygame (>=2.6.0,<3.0.0)
|
|
@@ -39,9 +43,9 @@ Requires-Dist: tensordict (>=0.8,<0.9)
|
|
|
39
43
|
Requires-Dist: termcolor (>=1.1.0,<2.0.0)
|
|
40
44
|
Requires-Dist: torch (==2.7.1)
|
|
41
45
|
Requires-Dist: tqdm (>=4.66.4,<5.0.0)
|
|
42
|
-
Requires-Dist: transformers (>=4.
|
|
46
|
+
Requires-Dist: transformers (>=4.57.1,<5.0.0) ; extra == "llm" or extra == "all"
|
|
43
47
|
Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
|
|
44
|
-
Requires-Dist: vllm (==0.10.0)
|
|
48
|
+
Requires-Dist: vllm (==0.10.0) ; extra == "llm" or extra == "all"
|
|
45
49
|
Requires-Dist: wandb (>=0.17.6,<0.18.0)
|
|
46
50
|
Description-Content-Type: text/markdown
|
|
47
51
|
|
|
@@ -95,6 +99,16 @@ git clone https://github.com/AgileRL/AgileRL.git && cd AgileRL
|
|
|
95
99
|
pip install -e .
|
|
96
100
|
```
|
|
97
101
|
|
|
102
|
+
If you wish to install all additional dependencies please specify `[all]` or if you want to install a specific family of dependencies specify that family directly. At present, we have just one family, `[llm]`, which contains the dependencies related to our LLM RFT algorithms (datasets, deepspeed, peft, transformers, vllm).
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
pip install agilerl[all]
|
|
106
|
+
```
|
|
107
|
+
Or in development mode:
|
|
108
|
+
```bash
|
|
109
|
+
pip install -e ".[all]"
|
|
110
|
+
```
|
|
111
|
+
|
|
98
112
|
To install the ``nightly`` version of AgileRL with the latest features, use:
|
|
99
113
|
|
|
100
114
|
```bash
|
|
@@ -153,11 +167,12 @@ We are constantly updating our tutorials to showcase the latest features of Agil
|
|
|
153
167
|
| ---------- | --------- |
|
|
154
168
|
| [Bandits](https://docs.agilerl.com/en/latest/bandits/index.html) | [Neural Contextual Bandits with UCB-based Exploration (NeuralUCB)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ucb.html) <br> [Neural Contextual Bandits with Thompson Sampling (NeuralTS)](https://docs.agilerl.com/en/latest/api/algorithms/neural_ts.html) |
|
|
155
169
|
|
|
156
|
-
### LLM
|
|
170
|
+
### LLM Fine-tuning Algorithms
|
|
157
171
|
|
|
158
172
|
| RL | Algorithm |
|
|
159
173
|
| ---------- | --------- |
|
|
160
174
|
| [On-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Group Relative Policy Optimization (GRPO)](https://docs.agilerl.com/en/latest/api/algorithms/grpo.html)
|
|
175
|
+
| [Off-Policy](https://docs.agilerl.com/en/latest/llm_finetuning/index.html) | [Direct Preference Optimization (DPO)](https://docs.agilerl.com/en/latest/api/algorithms/dpo.html)
|
|
161
176
|
|
|
162
177
|
|
|
163
178
|
## Train an Agent to Beat a Gym Environment
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
agilerl/__init__.py,sha256=
|
|
1
|
+
agilerl/__init__.py,sha256=0hZjnAULURFWpshG_mhNdaHhf8nlc7h2sR7CLEqup54,572
|
|
2
2
|
agilerl/algorithms/__init__.py,sha256=5N4DqCEETuFBlhnzf7XEQzIClRXX9e-FxQqQHgLh3Es,661
|
|
3
3
|
agilerl/algorithms/bc_lm.py,sha256=dDCN--Y49wJA_msVB_r8XYgLYXSYeJItYyhSD41bFFk,22946
|
|
4
4
|
agilerl/algorithms/core/__init__.py,sha256=kKGnzj4TGRZKk2J6jcaKkK3s1LjCYu979o8u8OJUZjI,268
|
|
5
|
-
agilerl/algorithms/core/base.py,sha256=
|
|
6
|
-
agilerl/algorithms/core/optimizer_wrapper.py,sha256=
|
|
7
|
-
agilerl/algorithms/core/registry.py,sha256=
|
|
5
|
+
agilerl/algorithms/core/base.py,sha256=LeFN0l17oCUxp23zFayq8tr9RFbSw--68TPa1FwobuA,121970
|
|
6
|
+
agilerl/algorithms/core/optimizer_wrapper.py,sha256=UQTlnv-mbNGlQ3RX9ocHtczXhTZq1MBKO6OdoQ879uM,13086
|
|
7
|
+
agilerl/algorithms/core/registry.py,sha256=ndaw9U814tHrPBhEPO9kLIDNKmLStTwLXPsnu-nnj8c,19991
|
|
8
8
|
agilerl/algorithms/cqn.py,sha256=3zE6LPWPV8ut5hLPllw3yhY_amonbiSmbBXJU0-7Zo4,12583
|
|
9
9
|
agilerl/algorithms/ddpg.py,sha256=uau1E37D9SARlf_bTswfZQGQRobh9tOcB6hoRpszx_g,21365
|
|
10
|
-
agilerl/algorithms/dpo.py,sha256=
|
|
10
|
+
agilerl/algorithms/dpo.py,sha256=kN2wp2Ms_2sFiJcmqpVPxG4XHoJis6l6BQlSCsj07pk,15777
|
|
11
11
|
agilerl/algorithms/dqn.py,sha256=P05AspMruXghyqWGzXj4t0x6m6Pl9MKt8EKh3RP2yBU,17105
|
|
12
12
|
agilerl/algorithms/dqn_rainbow.py,sha256=HyP-jkiVOkBUJmvpUlrB6VHo8m-AO2Z84M3Zb_ZP6fQ,20483
|
|
13
|
-
agilerl/algorithms/grpo.py,sha256=
|
|
14
|
-
agilerl/algorithms/ilql.py,sha256=
|
|
13
|
+
agilerl/algorithms/grpo.py,sha256=9VvRf4jQNDOfUlkKDZBNiiBACUybgeOxSQgnszjm2BM,19237
|
|
14
|
+
agilerl/algorithms/ilql.py,sha256=vX070xfPFxNKWh6oEc_LERUJx80JQq8oMzZ8ESBOUgE,79844
|
|
15
15
|
agilerl/algorithms/ippo.py,sha256=W9FDLf5bznG-RvfJs8Gqpa2ARGReqmPB9xW9mu2Mj-c,39085
|
|
16
16
|
agilerl/algorithms/maddpg.py,sha256=qVXDyb_W51lZtvst4K3yiosSy58BEBYbck8wF8CViBA,33908
|
|
17
17
|
agilerl/algorithms/matd3.py,sha256=n17y6PvM51r290Def_QeFT4p7TMo54MIDLN30XqlMk8,37926
|
|
@@ -55,7 +55,7 @@ agilerl/networks/distributions.py,sha256=mzntWgwoEdZKAspInbmvfc6_0rGuPdquqQyQkVS
|
|
|
55
55
|
agilerl/networks/distributions_experimental.py,sha256=K6_EYflAlR6qRouRr6SJXnT19w7QhOA1bwN7kCl3DJ8,18890
|
|
56
56
|
agilerl/networks/q_networks.py,sha256=a1Arze6GypKprxUQObbpJQbikmY5LtrvAAnEyoTrcLM,17284
|
|
57
57
|
agilerl/networks/value_networks.py,sha256=ZLX5vQIxeV65uxOzv2r5QMxF_-fzFT8N1et3lHdQP7E,4630
|
|
58
|
-
agilerl/protocols.py,sha256=
|
|
58
|
+
agilerl/protocols.py,sha256=SQ8T79jmZAqlm2fJ1Qo0kefU5w2c4Mh_wUk9RtiPego,14052
|
|
59
59
|
agilerl/rollouts/__init__.py,sha256=dGR9BnXliQI6yvXPwecV7g5TCtCEPbyIB-W1a5evBBY,130
|
|
60
60
|
agilerl/rollouts/on_policy.py,sha256=VOxUjwzyYngzrTEW9asXsgz1O6lRTUn_PijmjqtzGwQ,8036
|
|
61
61
|
agilerl/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -68,18 +68,18 @@ agilerl/training/train_offline.py,sha256=qAlr3lGQf7EfSSmTtmohi80rUN4HMha955q3pae
|
|
|
68
68
|
agilerl/training/train_on_policy.py,sha256=iQEIHq_JgBIBH2GPJeLN6QmPRho-_beUdro1H9DPkUA,19360
|
|
69
69
|
agilerl/typing.py,sha256=JtLhZMNyFzrnSeos6ltWyD_8yWFkc8Zx-OIC3d1CPQc,5442
|
|
70
70
|
agilerl/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
71
|
-
agilerl/utils/algo_utils.py,sha256=
|
|
71
|
+
agilerl/utils/algo_utils.py,sha256=Ue9uR5R_QywZbO7jvnQPTVAn6STLT9f-_nwrygs4Iz4,60376
|
|
72
72
|
agilerl/utils/cache.py,sha256=8Q1SYbTxQYzIn40UMy32EWMvtgaduY1k5jqwPihxJ_Q,3418
|
|
73
73
|
agilerl/utils/evolvable_networks.py,sha256=cIJHzadFOaK0aAqwn96HvnuH4atLBxrQ3cwpR1nxvUo,23265
|
|
74
74
|
agilerl/utils/ilql_utils.py,sha256=dU_vbwOB6VsODGGu_hOyDN_xRtFKVhZbxMISFlAUM5s,2293
|
|
75
|
-
agilerl/utils/llm_utils.py,sha256=
|
|
75
|
+
agilerl/utils/llm_utils.py,sha256=rc4fnqw3z1RvKdDUisX4THbRTkAWeg84SPt7VTd_hJY,26594
|
|
76
76
|
agilerl/utils/log_utils.py,sha256=OIhj86V97-ijlUENic2WKIWipB5ITJyBIGM_ZPZg5Vo,4401
|
|
77
77
|
agilerl/utils/minari_utils.py,sha256=WNFzt9ZQuvWy3w84MFhhGkA0e9MAgc4KSI_cmPgFTBo,5109
|
|
78
78
|
agilerl/utils/probe_envs.py,sha256=q2uyPQW7mbo9x4c_Yq9vi2Yu1X9qyLm43adET9SFf9Y,39796
|
|
79
79
|
agilerl/utils/probe_envs_ma.py,sha256=vvUY6lUBJfKGOVZtiFBKQ7Nwmsoj8aFnXD2W8-7rw8A,75686
|
|
80
80
|
agilerl/utils/sampling_utils.py,sha256=Sc2G178eB5_hQEPiMnrMUDt8WdmRI7CVbRZPVg0NDTE,2336
|
|
81
81
|
agilerl/utils/torch_utils.py,sha256=V3W9q3Y8x_eTYk83JORutOalAcZryKrlzq1_-7VxxdU,3424
|
|
82
|
-
agilerl/utils/utils.py,sha256=
|
|
82
|
+
agilerl/utils/utils.py,sha256=bLCBDIEv4xBAC49yqWWoeiTFgYrFBAtcca6F6sFoD7c,39846
|
|
83
83
|
agilerl/vector/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
84
84
|
agilerl/vector/pz_async_vec_env.py,sha256=uj9TyCn0SWksTUOW84RGspMkXqdGG-wjr86w08uCMb0,36742
|
|
85
85
|
agilerl/vector/pz_vec_env.py,sha256=sFVqm8eecxVHahTpFZEE3fvyZrmp2vMu0GECik8el6M,5978
|
|
@@ -89,7 +89,7 @@ agilerl/wrappers/learning.py,sha256=nSVMg6eUBWn13NNdIFgCEHj31CaN_dGryQa13SmMvBw,
|
|
|
89
89
|
agilerl/wrappers/make_evolvable.py,sha256=sb9oAorGAayrD_6lNbyvHhefA_RKO4bSSNjqS6u9UhI,51079
|
|
90
90
|
agilerl/wrappers/pettingzoo_wrappers.py,sha256=Pw8VzabxfYCw5ad15y5J3rAH1teA6nVVo0RHCTTdOPQ,2063
|
|
91
91
|
agilerl/wrappers/utils.py,sha256=pENFH2AxsXd22s8HGUeM-jRowC0tmjHLWjqDwIq12l8,2194
|
|
92
|
-
agilerl-2.4.1.
|
|
93
|
-
agilerl-2.4.1.
|
|
94
|
-
agilerl-2.4.1.
|
|
95
|
-
agilerl-2.4.1.
|
|
92
|
+
agilerl-2.4.1.dev2.dist-info/METADATA,sha256=Qcy1RTLnsmVvEfUMYFBIQcKTvEgH8n9Zv9vfBanHYXM,20565
|
|
93
|
+
agilerl-2.4.1.dev2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
94
|
+
agilerl-2.4.1.dev2.dist-info/licenses/LICENSE,sha256=vPX_VnIseflXJ30mQvwbXZoe208EtIr9ZVrl6cfdQNs,11720
|
|
95
|
+
agilerl-2.4.1.dev2.dist-info/RECORD,,
|
|
File without changes
|