agilerl 2.4.0.dev0__tar.gz → 2.4.1.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/PKG-INFO +2 -4
  2. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/base.py +37 -33
  3. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dpo.py +58 -7
  4. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/grpo.py +23 -15
  5. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_llm.py +2 -2
  6. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/algo_utils.py +7 -0
  7. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/llm_utils.py +81 -33
  8. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/utils.py +22 -34
  9. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/pyproject.toml +1 -1
  10. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/LICENSE +0 -0
  11. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/README.md +0 -0
  12. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/__init__.py +0 -0
  13. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/__init__.py +0 -0
  14. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/bc_lm.py +0 -0
  15. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/__init__.py +0 -0
  16. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
  17. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/core/registry.py +0 -0
  18. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/cqn.py +0 -0
  19. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ddpg.py +0 -0
  20. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn.py +0 -0
  21. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/dqn_rainbow.py +0 -0
  22. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ilql.py +0 -0
  23. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ippo.py +0 -0
  24. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/maddpg.py +0 -0
  25. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/matd3.py +0 -0
  26. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  27. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  28. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/ppo.py +0 -0
  29. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/algorithms/td3.py +0 -0
  30. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/__init__.py +0 -0
  31. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/data.py +0 -0
  32. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  33. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/replay_buffer.py +0 -0
  34. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/rollout_buffer.py +0 -0
  35. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/sampler.py +0 -0
  36. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/components/segment_tree.py +0 -0
  37. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/__init__.py +0 -0
  38. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/language_environment.py +0 -0
  39. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/rl_data.py +0 -0
  40. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/tokenizer.py +0 -0
  41. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/data/torch_datasets.py +0 -0
  42. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/__init__.py +0 -0
  43. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/mutation.py +0 -0
  44. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/hpo/tournament.py +0 -0
  45. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/__init__.py +0 -0
  46. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/base.py +0 -0
  47. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/bert.py +0 -0
  48. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/cnn.py +0 -0
  49. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/configs.py +0 -0
  50. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/custom_components.py +0 -0
  51. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/dummy.py +0 -0
  52. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/gpt.py +0 -0
  53. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/lstm.py +0 -0
  54. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/mlp.py +0 -0
  55. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/multi_input.py +0 -0
  56. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/resnet.py +0 -0
  57. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/modules/simba.py +0 -0
  58. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/__init__.py +0 -0
  59. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/actors.py +0 -0
  60. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/base.py +0 -0
  61. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/custom_modules.py +0 -0
  62. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/distributions.py +0 -0
  63. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/distributions_experimental.py +0 -0
  64. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/q_networks.py +0 -0
  65. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/networks/value_networks.py +0 -0
  66. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/protocols.py +0 -0
  67. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/rollouts/__init__.py +0 -0
  68. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/rollouts/on_policy.py +0 -0
  69. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/__init__.py +0 -0
  70. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_bandits.py +0 -0
  71. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_off_policy.py +0 -0
  72. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_multi_agent_on_policy.py +0 -0
  73. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_off_policy.py +0 -0
  74. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_offline.py +0 -0
  75. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/training/train_on_policy.py +0 -0
  76. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/typing.py +0 -0
  77. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/__init__.py +0 -0
  78. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/cache.py +0 -0
  79. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/evolvable_networks.py +0 -0
  80. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/ilql_utils.py +0 -0
  81. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/log_utils.py +0 -0
  82. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/minari_utils.py +0 -0
  83. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs.py +0 -0
  84. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
  85. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/sampling_utils.py +0 -0
  86. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/utils/torch_utils.py +0 -0
  87. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/__init__.py +0 -0
  88. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
  89. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/vector/pz_vec_env.py +0 -0
  90. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/__init__.py +0 -0
  91. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/agent.py +0 -0
  92. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/learning.py +0 -0
  93. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
  94. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  95. {agilerl-2.4.0.dev0 → agilerl-2.4.1.dev0}/agilerl/wrappers/utils.py +0 -0
@@ -1,9 +1,8 @@
1
- Metadata-Version: 2.4
1
+ Metadata-Version: 2.3
2
2
  Name: agilerl
3
- Version: 2.4.0.dev0
3
+ Version: 2.4.1.dev0
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  License: Apache 2.0
6
- License-File: LICENSE
7
6
  Author: Nick Ustaran-Anderegg
8
7
  Author-email: dev@agilerl.com
9
8
  Requires-Python: >=3.10,<4.0
@@ -13,7 +12,6 @@ Classifier: Programming Language :: Python :: 3.10
13
12
  Classifier: Programming Language :: Python :: 3.11
14
13
  Classifier: Programming Language :: Python :: 3.12
15
14
  Classifier: Programming Language :: Python :: 3.13
16
- Classifier: Programming Language :: Python :: 3.14
17
15
  Requires-Dist: SuperSuit (>=3.9.0,<4.0.0)
18
16
  Requires-Dist: accelerate (>=1.7.0,<2.0.0)
19
17
  Requires-Dist: deepspeed (>=0.17.1,<0.18.0)
@@ -37,6 +37,7 @@ from torch._dynamo import OptimizedModule
37
37
  from torch.nn.utils import clip_grad_norm_
38
38
  from torch.optim import AdamW
39
39
  from torch.optim.lr_scheduler import SequentialLR
40
+ from transformers import PretrainedConfig
40
41
  from transformers.modeling_utils import PreTrainedModel
41
42
  from vllm import LLM, SamplingParams
42
43
 
@@ -95,7 +96,11 @@ from agilerl.utils.evolvable_networks import (
95
96
  is_image_space,
96
97
  is_vector_space,
97
98
  )
98
- from agilerl.utils.llm_utils import DummyOptimizer, gather_if_zero3
99
+ from agilerl.utils.llm_utils import (
100
+ DummyOptimizer,
101
+ create_model_from_name_or_path,
102
+ gather_if_zero3,
103
+ )
99
104
 
100
105
  __all__ = ["EvolvableAlgorithm", "RLAlgorithm", "MultiAgentRLAlgorithm"]
101
106
 
@@ -1782,8 +1787,6 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1782
1787
  class LLMAlgorithm(EvolvableAlgorithm, ABC):
1783
1788
  """Base object for all LLM algorithms in the AgileRL framework.
1784
1789
 
1785
- :param observation_space: The observation space of the environment.
1786
- :type observation_space: gymnasium.spaces.Space
1787
1790
  :param action_space: The action space of the environment.
1788
1791
  :type action_space: gymnasium.spaces.Space
1789
1792
  :param index: The index of the algorithm.
@@ -1800,9 +1803,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1800
1803
 
1801
1804
  def __init__(
1802
1805
  self,
1803
- observation_space: spaces.Space,
1804
- action_space: spaces.Space,
1805
- actor_network: PreTrainedModel,
1806
1806
  index: int,
1807
1807
  batch_size: int,
1808
1808
  lr: float,
@@ -1815,6 +1815,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1815
1815
  pad_token: str,
1816
1816
  lora_config: LoraConfig | None,
1817
1817
  use_separate_reference_adapter: bool,
1818
+ model_name: str | None = None,
1819
+ actor_network: PreTrainedModel | None = None,
1818
1820
  micro_batch_size_per_gpu: int | None = None,
1819
1821
  cosine_lr_schedule_config: Optional[CosineLRScheduleConfig] = None,
1820
1822
  hp_config: Optional[HyperparameterConfig] = None,
@@ -1822,7 +1824,13 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1822
1824
  device: Union[str, torch.device] = "cpu",
1823
1825
  accelerator: Optional[Accelerator] = None,
1824
1826
  name: Optional[str] = None,
1827
+ model_config: dict[str, Any] | PretrainedConfig | None = None,
1828
+ gradient_checkpointing: bool = True,
1825
1829
  ):
1830
+ if model_name is None and actor_network is None:
1831
+ raise ValueError(
1832
+ "At least one of model_name or actor_network must be provided."
1833
+ )
1826
1834
  if (
1827
1835
  accelerator is not None
1828
1836
  and cosine_lr_schedule_config is not None
@@ -1835,20 +1843,16 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1835
1843
  cosine_lr_schedule_config = None
1836
1844
 
1837
1845
  super().__init__(index, hp_config, device, accelerator, None, name)
1838
- assert isinstance(
1839
- observation_space, spaces.Space
1840
- ), "Observation space must be an instance of gymnasium.spaces.Space."
1841
- assert isinstance(
1842
- action_space, spaces.Space
1843
- ), "Action space must be an instance of gymnasium.spaces.Space."
1844
-
1845
- self.observation_space = observation_space
1846
- self.action_space = action_space
1846
+ self.gradient_checkpointing = gradient_checkpointing
1847
1847
  self.zero_stage = None
1848
1848
  self.reference_update_tracker = 0 # Updated every time the reference policy is updated which is updated each time we pass through the train dataset
1849
1849
  self.calc_position_embeddings = calc_position_embeddings
1850
1850
  self.pad_token_id = pad_token_id
1851
1851
  self.pad_token = pad_token
1852
+ self.pretrained_model_name_or_path = (
1853
+ model_name if model_name is not None else actor_network.name_or_path
1854
+ )
1855
+ self.model_config = model_config
1852
1856
 
1853
1857
  if not clone and reduce_memory_peak and micro_batch_size_per_gpu is not None:
1854
1858
  raise ValueError(
@@ -1858,7 +1862,9 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1858
1862
  self._configure_batch_size(
1859
1863
  batch_size, clone, reduce_memory_peak, micro_batch_size_per_gpu
1860
1864
  )
1861
-
1865
+ self.batch_size = self.batch_size_per_process * (
1866
+ self.accelerator.num_processes if self.accelerator is not None else 1
1867
+ )
1862
1868
  if self.accelerator is not None:
1863
1869
  if (
1864
1870
  self.accelerator.state.deepspeed_plugin.deepspeed_config.get(
@@ -1877,20 +1883,12 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1877
1883
 
1878
1884
  if lora_config is None and not isinstance(actor_network, PeftModel):
1879
1885
  warnings.warn(
1880
- "No LoRA config provided. Using default LoRA configuration for RL finetuning."
1886
+ "No LoRA config provided. AgileRL can only be used to finetune adapters at present. Using default LoRA configuration for RL finetuning."
1881
1887
  )
1882
1888
  lora_config = LoraConfig(
1883
1889
  r=16,
1884
- lora_alpha=64,
1885
- target_modules=[
1886
- "q_proj",
1887
- "k_proj",
1888
- "v_proj",
1889
- "o_proj",
1890
- "up_proj",
1891
- "down_proj",
1892
- "gate_proj",
1893
- ],
1890
+ lora_alpha=32,
1891
+ target_modules="all-linear",
1894
1892
  task_type="CAUSAL_LM",
1895
1893
  lora_dropout=0.05,
1896
1894
  )
@@ -1908,7 +1906,6 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1908
1906
  else:
1909
1907
  self.max_grad_norm = max_grad_norm
1910
1908
  self.reduce_memory_peak = reduce_memory_peak
1911
- self.pretrained_model_name_or_path = actor_network.name_or_path
1912
1909
 
1913
1910
  if self.accelerator is not None:
1914
1911
  self.zero_stage = self.accelerator.state.deepspeed_plugin.deepspeed_config[
@@ -2141,15 +2138,17 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2141
2138
  if not is_dummy_optimizer
2142
2139
  else type(self.actor.optimizer)
2143
2140
  )
2144
- self.actor.module.gradient_checkpointing_enable(
2145
- gradient_checkpointing_kwargs={"use_reentrant": False}
2146
- )
2141
+ if self.gradient_checkpointing:
2142
+ self.actor.module.gradient_checkpointing_enable(
2143
+ gradient_checkpointing_kwargs={"use_reentrant": False}
2144
+ )
2147
2145
  else:
2148
2146
  assert (
2149
2147
  self.actor is not None
2150
2148
  ), "Actor is set to None, please check that the actor is defined."
2151
2149
  self.actor = self.actor.to(self.device)
2152
- self.actor.gradient_checkpointing_enable()
2150
+ if self.gradient_checkpointing:
2151
+ self.actor.gradient_checkpointing_enable()
2153
2152
 
2154
2153
  def clean_up(self) -> None:
2155
2154
  """Clean up the algorithm."""
@@ -2408,7 +2407,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2408
2407
  self.reference_update_tracker += 1
2409
2408
 
2410
2409
  def _initialize_actors(
2411
- self, base_model: PreTrainedModel, add_adapters: bool = True
2410
+ self, base_model: PreTrainedModel | None, add_adapters: bool = True
2412
2411
  ):
2413
2412
  """Initialize the actor network.
2414
2413
 
@@ -2418,6 +2417,11 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2418
2417
  :type add_adapters: bool, optional
2419
2418
  """
2420
2419
 
2420
+ if base_model is None:
2421
+ base_model = create_model_from_name_or_path(
2422
+ self.pretrained_model_name_or_path
2423
+ )
2424
+
2421
2425
  if isinstance(base_model, PeftModel) and add_adapters:
2422
2426
  # Handles backwards compatibility with user providing a peft model as the actor network
2423
2427
  if self.lora_config is None:
@@ -1,10 +1,10 @@
1
1
  import gc
2
+ from typing import Any
2
3
 
3
4
  import numpy as np
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from accelerate import Accelerator
7
- from gymnasium import spaces
8
8
  from peft import LoraConfig
9
9
  from transformers import PreTrainedModel
10
10
 
@@ -16,13 +16,62 @@ from agilerl.utils.llm_utils import PreferenceGym
16
16
 
17
17
 
18
18
  class DPO(LLMAlgorithm):
19
+ """The DPO algorithm class. DPO paper: https://arxiv.org/pdf/2305.18290
20
+
21
+ :param pad_token_id: Pad token id
22
+ :type pad_token_id: int
23
+ :param pad_token: Pad token
24
+ :type pad_token: str
25
+ :param model_name: Model name
26
+ :type model_name: str, optional
27
+ :param actor_network: HuggingFace LLM
28
+ :type actor_network: PreTrainedModel
29
+ :param model_config: Model configuration, to be used when creating the model from a name or path
30
+ :param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
31
+ :type hp_config: HyperparameterConfig, optional
32
+ :param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
33
+ :type index: int, optional
34
+ :param batch_size: Batch size for training, defaults to 16
35
+ :type batch_size: int, optional
36
+ :param lr: Learning rate, defaults to 0.000005
37
+ :type lr: float, optional
38
+ :param beta: Beta parameter for DPO, defaults to 0.001
39
+ :type beta: float, optional
40
+ :param max_grad_norm: Maximum gradient norm, defaults to 0.1
41
+ :type max_grad_norm: float, optional
42
+ :param update_epochs: Number of update epochs, defaults to 1
43
+ :type update_epochs: int, optional
44
+ :param calc_position_embeddings: Flag to indicate if position embeddings should be calculated, defaults to True
45
+ :type calc_position_embeddings: bool, optional
46
+ :param micro_batch_size_per_gpu: Micro batch size per GPU, defaults to None
47
+ :type micro_batch_size_per_gpu: int, optional
48
+ :param reduce_memory_peak: Flag to indicate if memory peak should be reduced, defaults to False
49
+ :type reduce_memory_peak: bool, optional
50
+ :param device: Device for accelerated computing, 'cpu' or 'cuda', defaults to 'cpu'
51
+ :type device: str, optional
52
+ :param lora_config: Config for LoRA, defaults to None
53
+ :type lora_config: LoraConfig, optional
54
+ :param accelerator: Accelerator for distributed computing, defaults to None
55
+ :type accelerator: accelerate.Accelerator(), optional
56
+ :param wrap: Wrap models for distributed training upon creation, defaults to True
57
+ :type wrap: bool, optional
58
+ :param clone: Flag to indicate if the instantiation is a cloning, defaults to False
59
+ :type clone: bool, optional
60
+ :param use_separate_reference_adapter: Flag to indicate if the reference policy should have a separate adapter, defaults to False
61
+ :type use_separate_reference_adapter: bool, optional
62
+ :param seed: Seed for the random number generator, defaults to 42
63
+ :type seed: int, optional
64
+ :param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
65
+ :type gradient_checkpointing: bool, optional
66
+ """
67
+
19
68
  def __init__(
20
69
  self,
21
- observation_space: spaces.Space,
22
- action_space: spaces.Space,
23
- actor_network: PreTrainedModel,
24
70
  pad_token_id: int,
25
71
  pad_token: str,
72
+ model_name: str | None = None,
73
+ actor_network: PreTrainedModel | None = None,
74
+ model_config: dict[str, Any] | None = None,
26
75
  hp_config: HyperparameterConfig | None = None,
27
76
  index: int = 0,
28
77
  batch_size: int = 16,
@@ -40,6 +89,7 @@ class DPO(LLMAlgorithm):
40
89
  clone: bool = False,
41
90
  use_separate_reference_adapter: bool = False,
42
91
  seed: int = 42,
92
+ gradient_checkpointing: bool = True,
43
93
  ):
44
94
  device = (
45
95
  f"cuda:{accelerator.process_index}"
@@ -47,9 +97,6 @@ class DPO(LLMAlgorithm):
47
97
  else ("cuda" if torch.cuda.is_available() else "cpu")
48
98
  )
49
99
  super().__init__(
50
- observation_space,
51
- action_space,
52
- actor_network,
53
100
  index=index,
54
101
  batch_size=batch_size,
55
102
  lr=lr,
@@ -62,6 +109,9 @@ class DPO(LLMAlgorithm):
62
109
  pad_token=pad_token,
63
110
  lora_config=lora_config,
64
111
  use_separate_reference_adapter=use_separate_reference_adapter,
112
+ model_name=model_name,
113
+ actor_network=actor_network,
114
+ model_config=model_config,
65
115
  micro_batch_size_per_gpu=micro_batch_size_per_gpu,
66
116
  cosine_lr_schedule_config=None,
67
117
  hp_config=hp_config,
@@ -69,6 +119,7 @@ class DPO(LLMAlgorithm):
69
119
  device=device,
70
120
  accelerator=accelerator,
71
121
  name="DPO",
122
+ gradient_checkpointing=gradient_checkpointing,
72
123
  )
73
124
  self.beta = beta
74
125
  self.temperature = (
@@ -1,12 +1,11 @@
1
1
  import gc
2
- from typing import Optional, Union
2
+ from typing import Any, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
6
  from accelerate import Accelerator
7
7
  from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
8
8
  from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
9
- from gymnasium import spaces
10
9
  from peft import LoraConfig, PeftModel
11
10
  from transformers import GenerationConfig
12
11
  from transformers.modeling_utils import PreTrainedModel
@@ -33,12 +32,16 @@ DeepSpeedOptimizerType = Union[
33
32
  class GRPO(LLMAlgorithm):
34
33
  """The GRPO algorithm class. GRPO paper: https://arxiv.org/pdf/2402.03300
35
34
 
36
- :param observation_space: Observation space of the environment
37
- :type observation_space: gym.spaces.Space
38
- :param action_space: Action space of the environment
39
- :type action_space: gym.spaces.Space
35
+ :param pad_token_id: Pad token id
36
+ :type pad_token_id: int
37
+ :param pad_token: Pad token
38
+ :type pad_token: str
39
+ :param model_name: Model name
40
+ :type model_name: str, optional
40
41
  :param actor_network: HuggingFace LLM
41
42
  :type actor_network: PreTrainedModel
43
+ :param model_config: Model configuration, to be used when creating the model from a name or path
44
+ :type model_config: dict[str, Any], optional
42
45
  :param hp_config: RL hyperparameter mutation configuration, defaults to None, whereby algorithm mutations are disabled.
43
46
  :type hp_config: HyperparameterConfig, optional
44
47
  :param index: Index to keep track of object instance during tournament selection and mutation, defaults to 0
@@ -93,15 +96,17 @@ class GRPO(LLMAlgorithm):
93
96
  :type vllm_config: VLLMConfig, optional
94
97
  :param seed: Seed for the random number generator, defaults to 42
95
98
  :type seed: int, optional
99
+ :param gradient_checkpointing: Flag to indicate if gradient checkpointing should be used, defaults to True
100
+ :type gradient_checkpointing: bool, optional
96
101
  """
97
102
 
98
103
  def __init__(
99
104
  self,
100
- observation_space: spaces.Space,
101
- action_space: spaces.Space,
102
- actor_network: PreTrainedModel,
103
105
  pad_token_id: int,
104
106
  pad_token: str,
107
+ model_name: str | None = None,
108
+ actor_network: PreTrainedModel | None = None,
109
+ model_config: dict[str, Any] | None = None,
105
110
  hp_config: Optional[HyperparameterConfig] = None,
106
111
  index: int = 0,
107
112
  batch_size: int = 16,
@@ -132,6 +137,7 @@ class GRPO(LLMAlgorithm):
132
137
  use_vllm: bool = False,
133
138
  vllm_config: Optional[VLLMConfig] = None,
134
139
  seed: int = 42,
140
+ gradient_checkpointing: bool = True,
135
141
  ) -> None:
136
142
 
137
143
  device = (
@@ -140,9 +146,6 @@ class GRPO(LLMAlgorithm):
140
146
  else ("cuda" if torch.cuda.is_available() else "cpu")
141
147
  )
142
148
  super().__init__(
143
- observation_space,
144
- action_space,
145
- actor_network,
146
149
  index=index,
147
150
  batch_size=batch_size,
148
151
  lr=lr,
@@ -155,6 +158,9 @@ class GRPO(LLMAlgorithm):
155
158
  pad_token=pad_token,
156
159
  lora_config=lora_config,
157
160
  use_separate_reference_adapter=use_separate_reference_adapter,
161
+ model_name=model_name,
162
+ actor_network=actor_network,
163
+ model_config=model_config,
158
164
  micro_batch_size_per_gpu=micro_batch_size_per_gpu,
159
165
  cosine_lr_schedule_config=cosine_lr_schedule_config,
160
166
  wrap=wrap,
@@ -162,6 +168,7 @@ class GRPO(LLMAlgorithm):
162
168
  device=device,
163
169
  accelerator=accelerator,
164
170
  name="GRPO",
171
+ gradient_checkpointing=gradient_checkpointing,
165
172
  )
166
173
  assert isinstance(batch_size, int), "Batch size must be an integer."
167
174
  assert batch_size >= 1, "Batch size must be greater than or equal to one."
@@ -179,9 +186,10 @@ class GRPO(LLMAlgorithm):
179
186
  assert (
180
187
  update_epochs >= 1
181
188
  ), "Policy update epochs must be greater than or equal to one."
182
- assert isinstance(
183
- actor_network, (PeftModel, PreTrainedModel)
184
- ), "Actor network must be a PeftModel or PreTrainedModel"
189
+ if actor_network is not None:
190
+ assert isinstance(
191
+ actor_network, (PeftModel, PreTrainedModel)
192
+ ), "Actor network must be a PeftModel or PreTrainedModel"
185
193
 
186
194
  self.clip_coef = clip_coef
187
195
  self.update_epochs = update_epochs
@@ -115,7 +115,7 @@ def finetune_llm_reasoning(
115
115
 
116
116
  if init_hp is None:
117
117
  init_hp = {}
118
- init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size
118
+ init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size_per_process
119
119
  init_hp["ALGO"] = pop[0].algo
120
120
  data_increment = (
121
121
  getattr(dist, "get_world_size", lambda: 1)() if dist.is_initialized() else 1
@@ -463,7 +463,7 @@ def finetune_llm_preference(
463
463
 
464
464
  if init_hp is None:
465
465
  init_hp = {}
466
- init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size
466
+ init_hp["BATCH_SIZE_PER_GPU"] = pop[0].batch_size_per_process
467
467
  init_hp["ALGO"] = pop[0].algo
468
468
 
469
469
  data_increment = accelerator.num_processes if accelerator is not None else 1
@@ -1328,6 +1328,13 @@ class VLLMConfig:
1328
1328
  max_num_seqs: int = 8
1329
1329
  sleep_mode: bool = False
1330
1330
 
1331
+ def __post_init__(self):
1332
+ if self.sleep_mode:
1333
+ warnings.warn(
1334
+ """VLLM sleep mode cannot be used with populations of agents on a single device. To use sleep mode, ensure,
1335
+ you are training a single agent or, alternatively, use a different device for each agent."""
1336
+ )
1337
+
1331
1338
 
1332
1339
  def create_warmup_cosine_scheduler(
1333
1340
  optimizer: torch.optim.Optimizer,
@@ -11,12 +11,53 @@ import torch.nn as nn
11
11
  from accelerate import Accelerator
12
12
  from datasets import Dataset
13
13
  from torch.utils.data import DataLoader
14
- from transformers import AutoTokenizer
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from transformers.modeling_utils import PreTrainedModel
15
16
  from transformers.tokenization_utils_base import BatchEncoding
16
17
 
17
18
  from agilerl.typing import PreferencePrompts, ReasoningPrompts
18
19
 
19
20
 
21
+ def apply_chat_template(
22
+ conversation_template: list[dict[str, str]],
23
+ question: str,
24
+ answer: str,
25
+ tokenizer: AutoTokenizer,
26
+ ) -> BatchEncoding:
27
+ """
28
+ Create and tokenize a chat template for a reaosning task.
29
+
30
+ :param conversation_template: The conversation template to be tokenized.
31
+ :type conversation_template: list[dict[str, str]]
32
+ :param question: The question to be tokenized.
33
+ :type question: str
34
+ :param answer: The answer to be tokenized.
35
+ :type answer: str
36
+ :param tokenizer: The tokenizer to be used.
37
+ :type tokenizer: AutoTokenizer
38
+ :return: The tokenized prompt.
39
+ :rtype: BatchEncoding
40
+ """
41
+ formatted_conversation = [
42
+ {
43
+ "role": msg["role"],
44
+ "content": msg["content"].format(question=question, answer=answer),
45
+ }
46
+ for msg in conversation_template
47
+ ]
48
+ updated_prompt = tokenizer.apply_chat_template(
49
+ formatted_conversation, tokenize=False, continue_final_message=True
50
+ )
51
+ tokenized_prompt = tokenizer(
52
+ [updated_prompt],
53
+ return_tensors="pt",
54
+ padding=True,
55
+ padding_side="left",
56
+ return_attention_mask=True,
57
+ )
58
+ return tokenized_prompt
59
+
60
+
20
61
  class HuggingFaceGym(gym.Env, ABC):
21
62
  """Abstract base class for HuggingFace Gymnasium environments.
22
63
 
@@ -28,8 +69,8 @@ class HuggingFaceGym(gym.Env, ABC):
28
69
  :type tokenizer: AutoTokenizer
29
70
  :param custom_collate_fn: Custom collate function to be used for creating the batch, defaults to None
30
71
  :type custom_collate_fn: Callable, optional
31
- :param apply_chat_template_fn: Function to apply the chat template to the batch of questions and answers, defaults to None
32
- :type apply_chat_template_fn: Callable, optional
72
+ :param conversation_template: A structured conversation that acts as a base pattern for each data point.
73
+ :type conversation_template: list[dict[str, str]]
33
74
  :param data_batch_size_per_gpu: DataLoader batch size, defaults to 8
34
75
  :type data_batch_size_per_gpu: int, optional
35
76
  :param max_context_length: Maximum context length, defaults to None
@@ -47,12 +88,7 @@ class HuggingFaceGym(gym.Env, ABC):
47
88
  train_dataset: Dataset,
48
89
  test_dataset: Dataset,
49
90
  tokenizer: AutoTokenizer,
50
- custom_collate_fn: (
51
- Callable[[list[dict[str, Any]]], dict[str, Any]] | None
52
- ) = None,
53
- apply_chat_template_fn: (
54
- Callable[[str, str, AutoTokenizer], BatchEncoding] | None
55
- ) = None,
91
+ conversation_template: list[dict[str, str]],
56
92
  data_batch_size_per_gpu: int = 8,
57
93
  max_context_length: int | None = None,
58
94
  min_completion_length: int = None,
@@ -70,11 +106,8 @@ class HuggingFaceGym(gym.Env, ABC):
70
106
  self.max_context_length = max_context_length
71
107
  self.seed = seed
72
108
  generator = torch.Generator().manual_seed(seed)
73
- if custom_collate_fn is None:
74
- collate_kwargs = {"tokenizer": tokenizer}
75
- if apply_chat_template_fn is not None:
76
- collate_kwargs["apply_chat_template_fn"] = apply_chat_template_fn
77
- custom_collate_fn = self.create_collate_fn(**collate_kwargs)
109
+ self.conversation_template = conversation_template
110
+ custom_collate_fn = self.create_collate_fn(tokenizer)
78
111
  dataloader_kwargs = {"collate_fn": custom_collate_fn}
79
112
  train_dataset = self._filter_dataset_by_max_context_length(
80
113
  train_dataset, "train dataset"
@@ -107,11 +140,6 @@ class HuggingFaceGym(gym.Env, ABC):
107
140
  self.test_dataloader_iter = iter(self.test_dataloader)
108
141
  self.dataloader = self.train_dataloader_iter
109
142
  self.reset_called = False
110
- self.observation_space = gym.spaces.Box(low=0, high=tokenizer.vocab_size - 1)
111
- self.action_space = gym.spaces.Box(
112
- low=0,
113
- high=tokenizer.vocab_size - 1,
114
- )
115
143
  self.evaluation_mode = False
116
144
  self.num_epochs = 0
117
145
 
@@ -196,9 +224,11 @@ class HuggingFaceGym(gym.Env, ABC):
196
224
  :rtype: tuple[Dataset, Dataset]
197
225
  """
198
226
  dataset_type = "dataset" if dataset_type is None else dataset_type
199
- if self.max_context_length is None:
200
- return dataset
201
227
  filter_keyword = "prompt" if "prompt" in dataset.features.keys() else "question"
228
+ if self.max_context_length is None or not isinstance(
229
+ dataset[0][filter_keyword], str
230
+ ):
231
+ return dataset
202
232
  filtered_dataset = dataset.filter(
203
233
  lambda x: len(self.tokenizer.encode(x[filter_keyword]))
204
234
  <= self.max_context_length - self.min_completion_length
@@ -225,10 +255,10 @@ class ReasoningGym(HuggingFaceGym):
225
255
  :type tokenizer: AutoTokenizer
226
256
  :param reward_fn: Reward function for evaluating completions.
227
257
  :type reward_fn: Callable[..., float]
258
+ :param conversation_template: A structured conversation that acts as a base pattern for each data point.
259
+ :type conversation_template: list[dict[str, str]]
228
260
  :param data_batch_size_per_gpu: DataLoader batch size, defaults to 8
229
261
  :type data_batch_size_per_gpu: int, optional
230
- :param custom_collate_fn: Custom collate fxwunction to be used for creating the batch, defaults to None
231
- :type custom_collate_fn: Callable, optional
232
262
  :param accelerator: Accelerator to be used for training, defaults to None
233
263
  :type accelerator: Accelerator, optional
234
264
  :param max_context_length: Maximum context length, defaults to None
@@ -245,9 +275,8 @@ class ReasoningGym(HuggingFaceGym):
245
275
  test_dataset: Dataset,
246
276
  tokenizer: AutoTokenizer,
247
277
  reward_fn: Callable[[str, str, str], float],
248
- apply_chat_template_fn: Callable[[str, str, AutoTokenizer], BatchEncoding],
278
+ conversation_template: list[dict[str, str]],
249
279
  data_batch_size_per_gpu: int = 8,
250
- custom_collate_fn: Callable | None = None,
251
280
  accelerator: Accelerator | None = None,
252
281
  return_raw_completions: bool = False,
253
282
  max_context_length: int | None = None,
@@ -264,8 +293,7 @@ class ReasoningGym(HuggingFaceGym):
264
293
  train_dataset=train_dataset,
265
294
  test_dataset=test_dataset,
266
295
  tokenizer=tokenizer,
267
- custom_collate_fn=custom_collate_fn,
268
- apply_chat_template_fn=apply_chat_template_fn,
296
+ conversation_template=conversation_template,
269
297
  data_batch_size_per_gpu=data_batch_size_per_gpu,
270
298
  max_context_length=max_context_length,
271
299
  min_completion_length=0,
@@ -382,15 +410,12 @@ class ReasoningGym(HuggingFaceGym):
382
410
  def create_collate_fn(
383
411
  self,
384
412
  tokenizer: AutoTokenizer,
385
- apply_chat_template_fn: Callable[[str, str, AutoTokenizer], BatchEncoding],
386
413
  ) -> Callable[[list[dict[str, Any]]], dict[str, Any]]:
387
414
  """
388
415
  Create a collate function that applies the chat template to the batch of questions and answers.
389
416
 
390
417
  :param tokenizer: Tokenizer to be used for encoding and decoding the prompts.
391
418
  :type tokenizer: AutoTokenizer
392
- :param apply_chat_template_fn: Function to apply the chat template to the batch of questions and answers.
393
- :type apply_chat_template_fn: Callable[[str, str, AutoTokenizer], BatchEncoding]
394
419
  :return: Collate function that applies the chat template to the batch of questions and answers.
395
420
  :rtype: Callable[[list[dict[str, Any]]], dict[str, Any]]
396
421
  """
@@ -402,7 +427,7 @@ class ReasoningGym(HuggingFaceGym):
402
427
 
403
428
  # Apply chat template to all samples
404
429
  tokenized_prompts = [
405
- apply_chat_template_fn(q, a, tokenizer)
430
+ apply_chat_template(self.conversation_template, q, a, tokenizer)
406
431
  for q, a in zip(questions, answers)
407
432
  ]
408
433
 
@@ -451,8 +476,7 @@ class PreferenceGym(HuggingFaceGym):
451
476
  train_dataset=train_dataset,
452
477
  test_dataset=test_dataset,
453
478
  tokenizer=tokenizer,
454
- custom_collate_fn=None,
455
- apply_chat_template_fn=None,
479
+ conversation_template=None,
456
480
  data_batch_size_per_gpu=data_batch_size_per_gpu,
457
481
  max_context_length=max_context_length,
458
482
  min_completion_length=min_completion_length,
@@ -667,3 +691,27 @@ def get_state_dict(model: nn.Module) -> dict[str, torch.Tensor]:
667
691
 
668
692
  with gather_if_zero3(3, list(model.parameters()), modifier_rank=0):
669
693
  return model.state_dict()
694
+
695
+
696
+ def create_model_from_name_or_path(
697
+ model_name_or_path: str, model_config: dict[str, Any] | None = None
698
+ ) -> PreTrainedModel:
699
+ """
700
+ Create a model from a name or path.
701
+
702
+ :param model_name_or_path: The name or path of the model to create.
703
+ :type model_name_or_path: str
704
+ :param model_config: The configuration of the model to create.
705
+ :type model_config: dict[str, Any ] | None
706
+ :return: The created model.
707
+ :rtype: PreTrainedModel
708
+ """
709
+ if model_config is None:
710
+ model_config = {
711
+ "torch_dtype": torch.bfloat16,
712
+ "attn_implementation": "sdpa",
713
+ }
714
+ model = AutoModelForCausalLM.from_pretrained(
715
+ pretrained_model_name_or_path=model_name_or_path, **model_config
716
+ )
717
+ return model
@@ -36,7 +36,7 @@ from agilerl.hpo.mutation import Mutations
36
36
  from agilerl.hpo.tournament import TournamentSelection
37
37
  from agilerl.modules import EvolvableModule
38
38
  from agilerl.typing import BPTTSequenceType, GymSpaceType, PopulationType
39
- from agilerl.utils.algo_utils import CosineLRScheduleConfig, VLLMConfig, clone_llm
39
+ from agilerl.utils.algo_utils import CosineLRScheduleConfig, clone_llm
40
40
  from agilerl.utils.llm_utils import DummyOptimizer, get_state_dict
41
41
  from agilerl.vector.pz_async_vec_env import AsyncPettingZooVecEnv
42
42
 
@@ -213,10 +213,10 @@ def default_progress_bar(
213
213
 
214
214
  def create_population(
215
215
  algo: str,
216
- observation_space: GymSpaceType,
217
- action_space: GymSpaceType,
218
216
  net_config: Optional[dict[str, Any]],
219
217
  INIT_HP: dict[str, Any],
218
+ observation_space: GymSpaceType | None = None,
219
+ action_space: GymSpaceType | None = None,
220
220
  hp_config: Optional[HyperparameterConfig] = None,
221
221
  actor_network: Optional[EvolvableModule] = None,
222
222
  critic_network: Optional[EvolvableModule] = None,
@@ -233,14 +233,14 @@ def create_population(
233
233
 
234
234
  :param algo: RL algorithm
235
235
  :type algo: str
236
- :param observation_space: Observation space
237
- :type observation_space: spaces.Space
238
- :param action_space: Action space
239
- :type action_space: spaces.Space
240
236
  :param net_config: Network configuration
241
237
  :type net_config: dict or None
242
238
  :param INIT_HP: Initial hyperparameters
243
239
  :type INIT_HP: dict
240
+ :param observation_space: Observation space
241
+ :type observation_space: spaces.Space
242
+ :param action_space: Action space
243
+ :type action_space: spaces.Space
244
244
  :param hp_config: Choice of algorithm hyperparameters to mutate during training, defaults to None
245
245
  :type hp_config: HyperparameterConfig, optional
246
246
  :param actor_network: Custom actor network, defaults to None
@@ -572,23 +572,23 @@ def create_population(
572
572
  elif algo == "GRPO":
573
573
  for idx in range(population_size):
574
574
  agent = GRPO(
575
- observation_space=observation_space,
576
- action_space=action_space,
577
575
  actor_network=(
578
- clone_llm(
579
- actor_network,
580
- zero_stage=INIT_HP.get("ZERO_STAGE", 0),
581
- state_dict=(
582
- actor_network.state_dict()
583
- if accelerator is None
584
- else get_state_dict(actor_network)
585
- ),
576
+ (
577
+ clone_llm(
578
+ actor_network,
579
+ zero_stage=INIT_HP.get("ZERO_STAGE", 0),
580
+ state_dict=(
581
+ actor_network.state_dict()
582
+ if accelerator is None
583
+ else get_state_dict(actor_network)
584
+ ),
585
+ )
586
+ if idx != 0
587
+ else actor_network
586
588
  )
587
- if idx != 0
588
- else actor_network
589
+ if actor_network is not None
590
+ else None
589
591
  ),
590
- pad_token_id=INIT_HP.get("PAD_TOKEN_ID"),
591
- pad_token=INIT_HP.get("PAD_TOKEN"),
592
592
  hp_config=hp_config,
593
593
  index=idx,
594
594
  batch_size=INIT_HP.get("BATCH_SIZE", 2),
@@ -601,7 +601,7 @@ def create_population(
601
601
  temperature=INIT_HP.get("TEMPERATURE", 0.9),
602
602
  calc_position_embeddings=INIT_HP.get("CALC_POSITION_EMBEDDINGS", True),
603
603
  reduce_memory_peak=INIT_HP.get("REDUCE_MEMORY_PEAK", False),
604
- max_output_tokens=INIT_HP.get("MAX_OUTPUT_TOKENS", 1024),
604
+ max_output_tokens=INIT_HP.get("MAX_OUTPUT_TOKENS", None),
605
605
  min_output_tokens=INIT_HP.get("MIN_OUTPUT_TOKENS", None),
606
606
  cosine_lr_schedule_config=(
607
607
  CosineLRScheduleConfig(**INIT_HP.get("COSINE_lR_SCHEDULER", None))
@@ -610,23 +610,13 @@ def create_population(
610
610
  ),
611
611
  accelerator=Accelerator() if accelerator else None,
612
612
  device=device,
613
- use_separate_reference_adapter=False,
614
613
  max_model_len=INIT_HP.get("MAX_MODEL_LEN", None),
615
- use_vllm=INIT_HP.get("USE_VLLM", False),
616
- vllm_config=(
617
- VLLMConfig(**INIT_HP.get("VLLM_CONFIG"))
618
- if INIT_HP.get("VLLM_CONFIG", None) is not None
619
- and INIT_HP.get("USE_VLLM", False)
620
- else None
621
- ),
622
614
  **algo_kwargs,
623
615
  )
624
616
  population.append(agent)
625
617
  elif algo == "DPO":
626
618
  for idx in range(population_size):
627
619
  agent = DPO(
628
- observation_space=observation_space,
629
- action_space=action_space,
630
620
  actor_network=(
631
621
  clone_llm(
632
622
  actor_network,
@@ -640,8 +630,6 @@ def create_population(
640
630
  if idx != 0
641
631
  else actor_network
642
632
  ),
643
- pad_token_id=INIT_HP.get("PAD_TOKEN_ID"),
644
- pad_token=INIT_HP.get("PAD_TOKEN"),
645
633
  hp_config=hp_config,
646
634
  index=idx,
647
635
  batch_size=INIT_HP.get("BATCH_SIZE", 2),
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "agilerl"
3
- version = "2.4.0.dev0"
3
+ version = "2.4.1.dev0"
4
4
 
5
5
  description = "AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps."
6
6
  authors = ["Nick Ustaran-Anderegg <dev@agilerl.com>"]
File without changes
File without changes