agilerl 2.3.5.dev0__py3-none-any.whl → 2.4.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agilerl/algorithms/__init__.py +2 -0
- agilerl/algorithms/bc_lm.py +3 -3
- agilerl/algorithms/core/base.py +763 -125
- agilerl/algorithms/core/optimizer_wrapper.py +16 -16
- agilerl/algorithms/core/registry.py +77 -45
- agilerl/algorithms/cqn.py +5 -6
- agilerl/algorithms/ddpg.py +14 -14
- agilerl/algorithms/dpo.py +311 -0
- agilerl/algorithms/dqn.py +2 -2
- agilerl/algorithms/dqn_rainbow.py +5 -6
- agilerl/algorithms/grpo.py +45 -624
- agilerl/algorithms/ilql.py +2 -2
- agilerl/algorithms/ippo.py +25 -25
- agilerl/algorithms/maddpg.py +22 -22
- agilerl/algorithms/matd3.py +28 -28
- agilerl/algorithms/neural_ts_bandit.py +4 -4
- agilerl/algorithms/neural_ucb_bandit.py +4 -4
- agilerl/algorithms/ppo.py +29 -29
- agilerl/algorithms/td3.py +6 -4
- agilerl/components/multi_agent_replay_buffer.py +23 -24
- agilerl/components/replay_buffer.py +3 -3
- agilerl/components/rollout_buffer.py +30 -30
- agilerl/components/sampler.py +5 -5
- agilerl/data/rl_data.py +8 -8
- agilerl/data/tokenizer.py +5 -5
- agilerl/hpo/mutation.py +20 -20
- agilerl/hpo/tournament.py +6 -8
- agilerl/modules/base.py +33 -38
- agilerl/modules/bert.py +11 -11
- agilerl/modules/cnn.py +43 -43
- agilerl/modules/configs.py +11 -11
- agilerl/modules/dummy.py +4 -4
- agilerl/modules/gpt.py +14 -14
- agilerl/modules/lstm.py +11 -11
- agilerl/modules/mlp.py +13 -13
- agilerl/modules/multi_input.py +18 -18
- agilerl/modules/resnet.py +12 -12
- agilerl/modules/simba.py +4 -4
- agilerl/networks/actors.py +7 -7
- agilerl/networks/base.py +27 -27
- agilerl/networks/custom_modules.py +4 -4
- agilerl/networks/distributions.py +12 -12
- agilerl/networks/distributions_experimental.py +3 -3
- agilerl/networks/q_networks.py +10 -10
- agilerl/networks/value_networks.py +4 -4
- agilerl/protocols.py +41 -45
- agilerl/rollouts/on_policy.py +10 -10
- agilerl/training/train_bandits.py +4 -4
- agilerl/training/train_llm.py +296 -14
- agilerl/training/train_multi_agent_off_policy.py +4 -4
- agilerl/training/train_multi_agent_on_policy.py +4 -4
- agilerl/training/train_off_policy.py +5 -5
- agilerl/training/train_offline.py +4 -4
- agilerl/training/train_on_policy.py +5 -5
- agilerl/typing.py +38 -29
- agilerl/utils/algo_utils.py +98 -93
- agilerl/utils/evolvable_networks.py +26 -26
- agilerl/utils/ilql_utils.py +6 -6
- agilerl/utils/llm_utils.py +439 -79
- agilerl/utils/torch_utils.py +4 -4
- agilerl/utils/utils.py +79 -11
- agilerl/vector/pz_async_vec_env.py +48 -48
- agilerl/vector/pz_vec_env.py +10 -10
- agilerl/wrappers/agent.py +14 -14
- agilerl/wrappers/make_evolvable.py +17 -17
- {agilerl-2.3.5.dev0.dist-info → agilerl-2.4.0.dev0.dist-info}/METADATA +5 -3
- agilerl-2.4.0.dev0.dist-info/RECORD +95 -0
- {agilerl-2.3.5.dev0.dist-info → agilerl-2.4.0.dev0.dist-info}/WHEEL +1 -1
- {agilerl-2.3.5.dev0.dist-info → agilerl-2.4.0.dev0.dist-info/licenses}/LICENSE +13 -0
- agilerl-2.3.5.dev0.dist-info/RECORD +0 -94
agilerl/algorithms/__init__.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from .bc_lm import BC_LM, BC_Evaluator, BC_Policy
|
|
2
2
|
from .cqn import CQN
|
|
3
3
|
from .ddpg import DDPG
|
|
4
|
+
from .dpo import DPO
|
|
4
5
|
from .dqn import DQN
|
|
5
6
|
from .dqn_rainbow import RainbowDQN
|
|
6
7
|
from .grpo import GRPO
|
|
@@ -30,4 +31,5 @@ __all__ = [
|
|
|
30
31
|
"PPO",
|
|
31
32
|
"TD3",
|
|
32
33
|
"GRPO",
|
|
34
|
+
"DPO",
|
|
33
35
|
]
|
agilerl/algorithms/bc_lm.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable, Optional,
|
|
1
|
+
from typing import Any, Callable, Optional, Union
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -167,7 +167,7 @@ class BC_LM(nn.Module):
|
|
|
167
167
|
temp: float = 1.0,
|
|
168
168
|
top_k: Optional[int] = None,
|
|
169
169
|
top_p: Optional[float] = None,
|
|
170
|
-
) ->
|
|
170
|
+
) -> tuple[torch.Tensor, Any]:
|
|
171
171
|
prepared_inputs = self.prepare_inputs(items)
|
|
172
172
|
tokens = prepared_inputs["tokens"]
|
|
173
173
|
scores, model_outputs = self.score(
|
|
@@ -189,7 +189,7 @@ class BC_LM(nn.Module):
|
|
|
189
189
|
temp: float = 1.0,
|
|
190
190
|
top_k: Optional[int] = None,
|
|
191
191
|
top_p: Optional[float] = None,
|
|
192
|
-
) ->
|
|
192
|
+
) -> tuple[torch.Tensor, Any]:
|
|
193
193
|
scores, model_outputs = self.score(
|
|
194
194
|
(
|
|
195
195
|
tokens.unsqueeze(1),
|