agilerl 2.3.4.dev1__tar.gz → 2.3.5.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/PKG-INFO +3 -3
  2. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/README.md +0 -1
  3. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/base.py +129 -23
  4. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/dqn_rainbow.py +25 -24
  5. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/grpo.py +527 -153
  6. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ppo.py +49 -13
  7. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/dummy.py +20 -13
  8. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/base.py +1 -1
  9. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/q_networks.py +2 -0
  10. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_llm.py +43 -31
  11. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/typing.py +20 -2
  12. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/algo_utils.py +109 -25
  13. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/llm_utils.py +37 -12
  14. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/log_utils.py +1 -2
  15. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/utils.py +20 -6
  16. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/pyproject.toml +4 -2
  17. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/LICENSE +0 -0
  18. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/__init__.py +0 -0
  19. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/__init__.py +0 -0
  20. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/bc_lm.py +0 -0
  21. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/__init__.py +0 -0
  22. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
  23. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/core/registry.py +0 -0
  24. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/cqn.py +0 -0
  25. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ddpg.py +0 -0
  26. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/dqn.py +0 -0
  27. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ilql.py +1 -1
  28. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/ippo.py +0 -0
  29. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/maddpg.py +0 -0
  30. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/matd3.py +0 -0
  31. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  32. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  33. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/algorithms/td3.py +0 -0
  34. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/__init__.py +0 -0
  35. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/data.py +0 -0
  36. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  37. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/replay_buffer.py +0 -0
  38. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/rollout_buffer.py +0 -0
  39. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/sampler.py +0 -0
  40. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/components/segment_tree.py +0 -0
  41. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/__init__.py +0 -0
  42. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/language_environment.py +0 -0
  43. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/rl_data.py +0 -0
  44. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/tokenizer.py +0 -0
  45. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/data/torch_datasets.py +0 -0
  46. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/__init__.py +0 -0
  47. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/mutation.py +0 -0
  48. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/hpo/tournament.py +0 -0
  49. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/__init__.py +0 -0
  50. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/base.py +0 -0
  51. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/bert.py +0 -0
  52. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/cnn.py +0 -0
  53. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/configs.py +0 -0
  54. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/custom_components.py +0 -0
  55. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/gpt.py +0 -0
  56. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/lstm.py +0 -0
  57. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/mlp.py +0 -0
  58. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/multi_input.py +0 -0
  59. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/resnet.py +0 -0
  60. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/modules/simba.py +0 -0
  61. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/__init__.py +0 -0
  62. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/actors.py +0 -0
  63. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/custom_modules.py +0 -0
  64. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/distributions.py +0 -0
  65. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/distributions_experimental.py +0 -0
  66. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/networks/value_networks.py +0 -0
  67. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/protocols.py +0 -0
  68. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/rollouts/__init__.py +0 -0
  69. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/rollouts/on_policy.py +0 -0
  70. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/__init__.py +0 -0
  71. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_bandits.py +1 -1
  72. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_multi_agent_off_policy.py +1 -1
  73. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_multi_agent_on_policy.py +1 -1
  74. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_off_policy.py +1 -1
  75. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_offline.py +1 -1
  76. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/training/train_on_policy.py +1 -1
  77. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/__init__.py +0 -0
  78. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/cache.py +0 -0
  79. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/evolvable_networks.py +0 -0
  80. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/ilql_utils.py +0 -0
  81. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/minari_utils.py +0 -0
  82. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/probe_envs.py +0 -0
  83. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
  84. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/sampling_utils.py +0 -0
  85. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/utils/torch_utils.py +0 -0
  86. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/__init__.py +0 -0
  87. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
  88. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/vector/pz_vec_env.py +0 -0
  89. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/__init__.py +0 -0
  90. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/agent.py +0 -0
  91. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/learning.py +0 -0
  92. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
  93. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  94. {agilerl-2.3.4.dev1 → agilerl-2.3.5.dev0}/agilerl/wrappers/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: agilerl
3
- Version: 2.3.4.dev1
3
+ Version: 2.3.5.dev0
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  License: Apache 2.0
6
6
  Author: Nick Ustaran-Anderegg
@@ -37,14 +37,14 @@ Requires-Dist: redis (>=4.4.4,<5.0.0)
37
37
  Requires-Dist: scipy (>=1.12.0,<2.0.0)
38
38
  Requires-Dist: tensordict (>=0.8,<0.9)
39
39
  Requires-Dist: termcolor (>=1.1.0,<2.0.0)
40
- Requires-Dist: torch (==2.5.1)
40
+ Requires-Dist: torch (==2.7.1)
41
41
  Requires-Dist: tqdm (>=4.66.4,<5.0.0)
42
42
  Requires-Dist: transformers (>=4.48.1,<5.0.0)
43
43
  Requires-Dist: ucimlrepo (>=0.0.3,<0.0.4)
44
+ Requires-Dist: vllm (==0.10.0)
44
45
  Requires-Dist: wandb (>=0.17.6,<0.18.0)
45
46
  Description-Content-Type: text/markdown
46
47
 
47
- # AgileRL
48
48
  <p align="center">
49
49
  <img src=https://user-images.githubusercontent.com/47857277/222710068-e09a4e3c-368c-458a-9e01-b68674806887.png height="120">
50
50
  </p>
@@ -1,4 +1,3 @@
1
- # AgileRL
2
1
  <p align="center">
3
2
  <img src=https://user-images.githubusercontent.com/47857277/222710068-e09a4e3c-368c-458a-9e01-b68674806887.png height="120">
4
3
  </p>
@@ -30,13 +30,16 @@ from accelerate import Accelerator
30
30
  from accelerate.utils import broadcast_object_list
31
31
  from accelerate.utils.deepspeed import DeepSpeedOptimizerWrapper
32
32
  from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
33
+ from deepspeed.runtime.engine import DeepSpeedEngine
33
34
  from gymnasium import spaces
34
35
  from numpy.typing import ArrayLike
35
- from peft import PeftModel
36
+ from peft import PeftModel, set_peft_model_state_dict
37
+ from safetensors.torch import load_file
36
38
  from tensordict import TensorDict
37
39
  from torch._dynamo import OptimizedModule
38
40
  from torch.optim import AdamW
39
41
  from torch.optim.lr_scheduler import SequentialLR
42
+ from vllm.distributed.parallel_state import destroy_model_parallel
40
43
 
41
44
  from agilerl.algorithms.core.optimizer_wrapper import OptimizerWrapper
42
45
  from agilerl.algorithms.core.registry import (
@@ -70,6 +73,7 @@ from agilerl.typing import (
70
73
  )
71
74
  from agilerl.utils.algo_utils import (
72
75
  CosineLRScheduleConfig,
76
+ check_supported_space,
73
77
  chkpt_attribute_to_device,
74
78
  clone_llm,
75
79
  create_warmup_cosine_scheduler,
@@ -869,7 +873,11 @@ class EvolvableAlgorithm(ABC, metaclass=RegistryMeta):
869
873
  k: v for k, v in network_info["modules"].items() if k.startswith(name)
870
874
  }
871
875
 
872
- module_cls = net_dict[f"{name}_cls"]
876
+ module_cls = net_dict.get(f"{name}_cls", None)
877
+ if module_cls is None:
878
+ # This allows us to super this method in the LLMAlgorithm class
879
+ # as we don't want to reinstantiate the network in this class
880
+ break
873
881
  init_dict = net_dict[f"{name}_init_dict"]
874
882
 
875
883
  module_dict_cls = net_dict.get(f"{name}_module_dict_cls", None)
@@ -1164,12 +1172,8 @@ class RLAlgorithm(EvolvableAlgorithm, ABC):
1164
1172
 
1165
1173
  super().__init__(index, hp_config, device, accelerator, torch_compiler, name)
1166
1174
 
1167
- assert isinstance(
1168
- observation_space, spaces.Space
1169
- ), "Observation space must be an instance of gymnasium.spaces.Space."
1170
- assert isinstance(
1171
- action_space, spaces.Space
1172
- ), "Action space must be an instance of gymnasium.spaces.Space."
1175
+ check_supported_space(observation_space)
1176
+ check_supported_space(action_space)
1173
1177
 
1174
1178
  self.observation_space = observation_space
1175
1179
  self.action_space = action_space
@@ -1257,12 +1261,7 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1257
1261
  assert len(agent_ids) == len(
1258
1262
  observation_spaces
1259
1263
  ), "Number of agent IDs must match number of observation spaces."
1260
- assert all(
1261
- isinstance(_space, spaces.Space) for _space in observation_spaces
1262
- ), "Observation spaces must be instances of gymnasium.spaces.Space."
1263
- assert all(
1264
- isinstance(_space, spaces.Space) for _space in action_spaces
1265
- ), "Action spaces must be instances of gymnasium.spaces.Space."
1264
+
1266
1265
  self.possible_observation_spaces = spaces.Dict(
1267
1266
  {
1268
1267
  agent_id: space
@@ -1284,6 +1283,11 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1284
1283
  f"Observation spaces must be a list or dictionary of spaces.Space objects. Got {type(observation_spaces)}."
1285
1284
  )
1286
1285
 
1286
+ for obs_space in self.possible_observation_spaces.values():
1287
+ check_supported_space(obs_space)
1288
+ for action_space in self.possible_action_spaces.values():
1289
+ check_supported_space(action_space)
1290
+
1287
1291
  self.agent_ids = list(self.possible_observation_spaces.keys())
1288
1292
  self.n_agents = len(self.agent_ids)
1289
1293
  self.placeholder_value = placeholder_value
@@ -1823,8 +1827,8 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1823
1827
  and self.zero_stage > 2
1824
1828
  and self.accelerator.is_main_process
1825
1829
  ):
1826
- warnings.warn(
1827
- "Zero stage 3 support is nascent and has not been thoroughly tested. It may be unstable or subject to change. We recommend caution in production environments."
1830
+ raise NotImplementedError(
1831
+ "DeepSpeed ZeRO Stage 3 is not yet supported in AgileRL. This feature is in development and will be available in a future release."
1828
1832
  )
1829
1833
 
1830
1834
  seed = 42
@@ -1848,19 +1852,44 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1848
1852
 
1849
1853
  # TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
1850
1854
  # handle _save_distributed_actor if deepspeed is used.
1851
- def save_checkpoint(self, path: str) -> None:
1855
+ def save_checkpoint(self, path: str, weights_only: bool = False) -> None:
1852
1856
  """
1853
1857
  Override the save_checkpoint method to provide guidance on the correct method to use.
1854
1858
  :param path: Location to save checkpoint at
1855
1859
  :type path: string
1860
+ :param weights_only: If True, only save the weights of the model, defaults to False
1861
+ :type weights_only: bool, optional
1856
1862
  """
1863
+
1864
+ warnings.warn("weights_only default will be changed to True in the future.")
1865
+
1857
1866
  if self.accelerator is not None:
1858
- self._save_distributed_actor(path, tag="save_checkpoint")
1859
- torch.save(
1860
- get_checkpoint_dict(self, using_deepspeed=self.accelerator is not None),
1861
- path + "/attributes.pt",
1862
- pickle_module=dill,
1867
+ if not weights_only:
1868
+ self._save_distributed_actor(path, tag="save_checkpoint")
1869
+ else:
1870
+ selected_adapters = (
1871
+ ["actor", "reference"]
1872
+ if self.use_separate_reference_adapter
1873
+ else ["actor"]
1874
+ )
1875
+ self.actor.save_pretrained(
1876
+ save_directory=path,
1877
+ selected_adapters=selected_adapters,
1878
+ is_main_process=self.accelerator.is_main_process,
1879
+ )
1880
+ checkpoint_dict = get_checkpoint_dict(
1881
+ self, using_deepspeed=self.accelerator is not None
1863
1882
  )
1883
+ checkpoint_dict["_weights_only"] = weights_only
1884
+ checkpoint_dict.pop("llm", None)
1885
+ checkpoint_dict.pop("tp_group", None)
1886
+
1887
+ if self.accelerator is None or self.accelerator.is_main_process:
1888
+ torch.save(
1889
+ checkpoint_dict,
1890
+ path + "/attributes.pt",
1891
+ pickle_module=dill,
1892
+ )
1864
1893
 
1865
1894
  # TODO: This could hopefully be abstracted into EvolvableAlgorithm with a decorator to
1866
1895
  # handle _load_distributed_actor if deepspeed is used.
@@ -1872,8 +1901,27 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
1872
1901
  :type path: string
1873
1902
  """
1874
1903
  if self.accelerator is not None:
1875
- self._load_distributed_actor(path, tag="save_checkpoint")
1876
1904
  checkpoint = torch.load(path + "/attributes.pt", weights_only=False)
1905
+ weights_only = checkpoint.get("_weights_only", False)
1906
+
1907
+ if weights_only:
1908
+ if self.use_separate_reference_adapter:
1909
+ self._update_existing_adapter(
1910
+ self.accelerator,
1911
+ self.actor,
1912
+ path,
1913
+ "reference",
1914
+ )
1915
+
1916
+ self._update_existing_adapter(
1917
+ self.accelerator,
1918
+ self.actor,
1919
+ path,
1920
+ "actor",
1921
+ )
1922
+ else:
1923
+ self._load_distributed_actor(path, tag="save_checkpoint")
1924
+
1877
1925
  checkpoint["accelerator"] = (
1878
1926
  Accelerator() if self.accelerator is not None else None
1879
1927
  )
@@ -2040,6 +2088,10 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2040
2088
  None,
2041
2089
  None,
2042
2090
  )
2091
+ if self.use_vllm:
2092
+ destroy_model_parallel()
2093
+ del self.llm.llm_engine.model_executor.driver_worker
2094
+ self.llm = None
2043
2095
  gc.collect()
2044
2096
  torch.cuda.empty_cache()
2045
2097
 
@@ -2106,11 +2158,19 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2106
2158
  original_lr_scheduler = self.lr_scheduler
2107
2159
  clone.lr_scheduler = None
2108
2160
  self.lr_scheduler = None
2161
+ if self.use_vllm:
2162
+ original_llm = self.llm
2163
+ cloned_llm = clone.llm
2164
+ clone.llm = None
2165
+ self.llm = None
2109
2166
  clone = EvolvableAlgorithm.copy_attributes(self, clone)
2110
2167
  clone.accelerator = accelerator
2111
2168
  clone.lr_scheduler = lr_scheduler
2112
2169
  clone.lr_scheduler = cloned_lr_scheduler
2113
2170
  self.lr_scheduler = original_lr_scheduler
2171
+ if self.use_vllm:
2172
+ clone.llm = cloned_llm
2173
+ self.llm = original_llm
2114
2174
 
2115
2175
  if self.accelerator is None:
2116
2176
  clone.optimizer.optimizer.load_state_dict(
@@ -2201,3 +2261,49 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2201
2261
  ] = lr
2202
2262
 
2203
2263
  return accelerator, None
2264
+
2265
+ def recompile(self) -> None:
2266
+ """Recompiles the algorithm."""
2267
+ raise NotImplementedError(
2268
+ "Recompile method is not available for LLM finetuning algorithms."
2269
+ )
2270
+
2271
+ @staticmethod
2272
+ def _update_existing_adapter(
2273
+ accelerator: Accelerator,
2274
+ wrapped_model: DeepSpeedEngine,
2275
+ checkpoint_dir: str,
2276
+ adapter_name: str,
2277
+ ) -> None:
2278
+ """
2279
+ Overwrite weights of an existing adapter in-place without creating new parameters.
2280
+
2281
+ :param accelerator: Accelerator
2282
+ :type accelerator: Accelerator
2283
+ :param wrapped_model: Wrapped model
2284
+ :type wrapped_model: DeepSpeedEngine
2285
+ :param checkpoint_dir: Checkpoint directory
2286
+ :type checkpoint_dir: str
2287
+ :param adapter_name: Adapter name
2288
+ :type adapter_name: str
2289
+
2290
+ :return: None
2291
+ :rtype: None
2292
+ """
2293
+ base_model = accelerator.unwrap_model(wrapped_model)
2294
+ if hasattr(base_model, "module"):
2295
+ base_model = base_model.module
2296
+
2297
+ adapter_path = f"{checkpoint_dir}/{adapter_name}/adapter_model.safetensors"
2298
+ adapter_state = load_file(adapter_path, device="cpu")
2299
+
2300
+ with torch.no_grad():
2301
+ set_peft_model_state_dict(
2302
+ base_model, adapter_state, adapter_name=adapter_name
2303
+ )
2304
+ base_model.set_adapter(adapter_name)
2305
+
2306
+ # Make reference weights not trainable
2307
+ for name, param in base_model.named_parameters():
2308
+ if "reference" in name:
2309
+ param.requires_grad = False
@@ -275,23 +275,23 @@ class RainbowDQN(RLAlgorithm):
275
275
 
276
276
  def _dqn_loss(
277
277
  self,
278
- states: TorchObsType,
278
+ obs: TorchObsType,
279
279
  actions: torch.Tensor,
280
280
  rewards: torch.Tensor,
281
- next_states: torch.Tensor,
281
+ next_obs: torch.Tensor,
282
282
  dones: torch.Tensor,
283
283
  gamma: float,
284
284
  ) -> torch.Tensor:
285
285
  """Calculates the DQN loss.
286
286
 
287
- :param states: Batch of current states
288
- :type states: torch.Tensor
287
+ :param obs: Batch of current states
288
+ :type obs: torch.Tensor
289
289
  :param actions: Batch of actions taken
290
290
  :type actions: torch.Tensor
291
291
  :param rewards: Batch of rewards received
292
292
  :type rewards: torch.Tensor
293
- :param next_states: Batch of next states
294
- :type next_states: torch.Tensor
293
+ :param next_obs: Batch of next states
294
+ :type next_obs: torch.Tensor
295
295
  :param dones: Batch of done flags indicating episode termination
296
296
  :type dones: torch.Tensor
297
297
  :param gamma: Discount factor
@@ -299,16 +299,15 @@ class RainbowDQN(RLAlgorithm):
299
299
  :return: Element-wise loss
300
300
  :rtype: torch.Tensor
301
301
  """
302
- states = self.preprocess_observation(states)
303
- next_states = self.preprocess_observation(next_states)
302
+ obs = self.preprocess_observation(obs)
303
+ next_obs = self.preprocess_observation(next_obs)
304
304
 
305
305
  with torch.no_grad():
306
+ # Predict next actions from next_obs
307
+ next_actions = self.actor(next_obs).argmax(1)
306
308
 
307
- # Predict next actions from next_states
308
- next_actions = self.actor(next_states).argmax(1)
309
-
310
- # Predict the target q distribution for the same next states
311
- target_q_dist = self.actor_target(next_states, q=False)
309
+ # Predict the target q distribution for the same next obs
310
+ target_q_dist = self.actor_target(next_obs, q=False)
312
311
 
313
312
  # Index the target q_dist to select the distributions corresponding to next_actions
314
313
  target_q_dist = target_q_dist[range(self.batch_size), next_actions]
@@ -349,7 +348,7 @@ class RainbowDQN(RLAlgorithm):
349
348
  )
350
349
 
351
350
  # Calculate the current obs
352
- log_q_dist = self.actor(states, q=False, log=True)
351
+ log_q_dist = self.actor(obs, q=False, log=True)
353
352
  log_p = log_q_dist[range(self.batch_size), actions.squeeze().long()]
354
353
 
355
354
  # loss
@@ -375,29 +374,29 @@ class RainbowDQN(RLAlgorithm):
375
374
  :rtype: Tuple[float, numpy.ndarray, numpy.ndarray]
376
375
  """
377
376
  n_step = n_experiences is not None
378
- states = experiences["obs"]
377
+ obs = experiences["obs"]
379
378
  actions = experiences["action"]
380
379
  rewards = experiences["reward"]
381
- next_states = experiences["next_obs"]
380
+ next_obs = experiences["next_obs"]
382
381
  dones = experiences["done"]
383
382
  if per:
384
383
  weights = experiences["weights"]
385
384
  idxs = experiences["idxs"]
386
385
  if n_step:
387
- n_states = n_experiences["obs"]
386
+ n_obs = n_experiences["obs"]
388
387
  n_actions = n_experiences["action"]
389
388
  n_rewards = n_experiences["reward"]
390
- n_next_states = n_experiences["next_obs"]
389
+ n_next_obs = n_experiences["next_obs"]
391
390
  n_dones = n_experiences["done"]
392
391
 
393
392
  if self.combined_reward or not n_step:
394
393
  elementwise_loss = self._dqn_loss(
395
- states, actions, rewards, next_states, dones, self.gamma
394
+ obs, actions, rewards, next_obs, dones, self.gamma
396
395
  )
397
396
  if n_step:
398
397
  n_gamma = self.gamma**self.n_step
399
398
  n_step_elementwise_loss = self._dqn_loss(
400
- n_states, n_actions, n_rewards, n_next_states, n_dones, n_gamma
399
+ n_obs, n_actions, n_rewards, n_next_obs, n_dones, n_gamma
401
400
  )
402
401
  if self.combined_reward:
403
402
  elementwise_loss += n_step_elementwise_loss
@@ -409,10 +408,10 @@ class RainbowDQN(RLAlgorithm):
409
408
  else:
410
409
  if n_step:
411
410
  idxs = experiences["idxs"]
412
- n_states = n_experiences["obs"]
411
+ n_obs = n_experiences["obs"]
413
412
  n_actions = n_experiences["action"]
414
413
  n_rewards = n_experiences["reward"]
415
- n_next_states = n_experiences["next_obs"]
414
+ n_next_obs = n_experiences["next_obs"]
416
415
  n_dones = n_experiences["done"]
417
416
  else:
418
417
  idxs = None
@@ -420,13 +419,13 @@ class RainbowDQN(RLAlgorithm):
420
419
  new_priorities = None
421
420
  if self.combined_reward or not n_step:
422
421
  elementwise_loss = self._dqn_loss(
423
- states, actions, rewards, next_states, dones, self.gamma
422
+ obs, actions, rewards, next_obs, dones, self.gamma
424
423
  )
425
424
 
426
425
  if n_step:
427
426
  n_gamma = self.gamma**self.n_step
428
427
  n_step_elementwise_loss = self._dqn_loss(
429
- n_states, n_actions, n_rewards, n_next_states, n_dones, n_gamma
428
+ n_obs, n_actions, n_rewards, n_next_obs, n_dones, n_gamma
430
429
  )
431
430
  if self.combined_reward:
432
431
  elementwise_loss += n_step_elementwise_loss
@@ -508,7 +507,9 @@ class RainbowDQN(RLAlgorithm):
508
507
  ) and not finished[idx]:
509
508
  completed_episode_scores[idx] = scores[idx]
510
509
  finished[idx] = 1
510
+
511
511
  rewards.append(np.mean(completed_episode_scores))
512
+
512
513
  mean_fit = np.mean(rewards)
513
514
  self.fitness.append(mean_fit)
514
515
  return mean_fit