agilerl 2.3.2.dev0__tar.gz → 2.3.3.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/PKG-INFO +1 -1
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/base.py +5 -2
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/cqn.py +0 -1
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ippo.py +3 -1
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ppo.py +45 -61
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/rollout_buffer.py +52 -125
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/actors.py +8 -4
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/distributions.py +2 -1
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/rollouts/on_policy.py +3 -4
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_multi_agent_on_policy.py +7 -1
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/algo_utils.py +234 -149
- agilerl-2.3.3.dev1/agilerl/utils/cache.py +129 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/evolvable_networks.py +27 -105
- agilerl-2.3.3.dev1/agilerl/utils/ilql_utils.py +83 -0
- agilerl-2.3.3.dev1/agilerl/utils/log_utils.py +138 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/pyproject.toml +1 -1
- agilerl-2.3.2.dev0/agilerl/utils/cache.py +0 -56
- agilerl-2.3.2.dev0/agilerl/utils/ilql_utils.py +0 -34
- agilerl-2.3.2.dev0/agilerl/utils/log_utils.py +0 -70
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/LICENSE +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/README.md +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/registry.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/grpo.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ilql.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/data.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/sampler.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/base.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/bert.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/configs.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/simba.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/base.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/protocols.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_llm.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/typing.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/llm_utils.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/utils.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/utils.py +0 -0
|
@@ -73,6 +73,8 @@ from agilerl.utils.algo_utils import (
|
|
|
73
73
|
chkpt_attribute_to_device,
|
|
74
74
|
clone_llm,
|
|
75
75
|
create_warmup_cosine_scheduler,
|
|
76
|
+
get_input_size_from_space,
|
|
77
|
+
get_output_size_from_space,
|
|
76
78
|
isroutine,
|
|
77
79
|
key_in_nested_dict,
|
|
78
80
|
module_checkpoint_dict,
|
|
@@ -84,8 +86,6 @@ from agilerl.utils.evolvable_networks import (
|
|
|
84
86
|
compile_model,
|
|
85
87
|
config_from_dict,
|
|
86
88
|
get_default_encoder_config,
|
|
87
|
-
get_input_size_from_space,
|
|
88
|
-
get_output_size_from_space,
|
|
89
89
|
is_image_space,
|
|
90
90
|
is_vector_space,
|
|
91
91
|
)
|
|
@@ -144,6 +144,9 @@ def get_checkpoint_dict(
|
|
|
144
144
|
attribute_dict.pop("actor", None)
|
|
145
145
|
return attribute_dict
|
|
146
146
|
|
|
147
|
+
if "rollout_buffer" in attribute_dict:
|
|
148
|
+
attribute_dict.pop("rollout_buffer")
|
|
149
|
+
|
|
147
150
|
# Get checkpoint dictionaries for evolvable modules and optimizers
|
|
148
151
|
network_info: Dict[str, Dict[str, Any]] = {"modules": {}, "optimizers": {}}
|
|
149
152
|
for attr in agent.evolvable_attributes():
|
|
@@ -722,7 +722,9 @@ class IPPO(MultiAgentRLAlgorithm):
|
|
|
722
722
|
returns = advantages + values
|
|
723
723
|
|
|
724
724
|
states = concatenate_experiences_into_batches(states, obs_space)
|
|
725
|
-
actions = concatenate_experiences_into_batches(
|
|
725
|
+
actions = concatenate_experiences_into_batches(
|
|
726
|
+
actions, action_space, actions=True
|
|
727
|
+
)
|
|
726
728
|
log_probs = log_probs.reshape((-1,))
|
|
727
729
|
experiences = (states, actions, log_probs, advantages, returns, values)
|
|
728
730
|
|
|
@@ -464,20 +464,27 @@ class PPO(RLAlgorithm):
|
|
|
464
464
|
self,
|
|
465
465
|
obs: ArrayOrTensor,
|
|
466
466
|
actions: ArrayOrTensor,
|
|
467
|
-
|
|
467
|
+
hidden_state: Optional[Dict[str, ArrayOrTensor]] = None,
|
|
468
|
+
) -> Tuple[
|
|
469
|
+
torch.Tensor, torch.Tensor, torch.Tensor, Optional[Dict[str, ArrayOrTensor]]
|
|
470
|
+
]:
|
|
468
471
|
"""Evaluates the actions.
|
|
469
472
|
|
|
470
473
|
:param obs: Environment observation, or multiple observations in a batch
|
|
471
474
|
:type obs: ArrayOrTensor
|
|
472
475
|
:param actions: Actions to evaluate
|
|
473
476
|
:type actions: ArrayOrTensor
|
|
474
|
-
:
|
|
475
|
-
:
|
|
477
|
+
:param hidden_state: Hidden state for recurrent policies, defaults to None. Expected shape: dict with tensors of shape (batch_size, 1, hidden_size).
|
|
478
|
+
:type hidden_state: Optional[Dict[str, ArrayOrTensor]]
|
|
479
|
+
:return: Log probability, entropy, state values, and next hidden state
|
|
480
|
+
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[Dict[str, ArrayOrTensor]]]
|
|
476
481
|
"""
|
|
477
482
|
obs = self.preprocess_observation(obs)
|
|
478
483
|
|
|
479
484
|
# Get values from actor-critic
|
|
480
|
-
_, _, entropy, values,
|
|
485
|
+
_, _, entropy, values, next_hidden_state = self._get_action_and_values(
|
|
486
|
+
obs, hidden_state=hidden_state, sample=False
|
|
487
|
+
)
|
|
481
488
|
|
|
482
489
|
log_prob = self.actor.action_log_prob(actions)
|
|
483
490
|
|
|
@@ -485,7 +492,7 @@ class PPO(RLAlgorithm):
|
|
|
485
492
|
if entropy is None:
|
|
486
493
|
entropy = -log_prob.mean()
|
|
487
494
|
|
|
488
|
-
return log_prob, entropy, values
|
|
495
|
+
return log_prob, entropy, values, next_hidden_state
|
|
489
496
|
|
|
490
497
|
def get_action(
|
|
491
498
|
self,
|
|
@@ -659,7 +666,7 @@ class PPO(RLAlgorithm):
|
|
|
659
666
|
num_samples = experiences[4].size(0)
|
|
660
667
|
batch_idxs = np.arange(num_samples)
|
|
661
668
|
mean_loss = 0
|
|
662
|
-
for
|
|
669
|
+
for _ in range(self.update_epochs):
|
|
663
670
|
np.random.shuffle(batch_idxs)
|
|
664
671
|
for start in range(0, num_samples, self.batch_size):
|
|
665
672
|
minibatch_idxs = batch_idxs[start : start + self.batch_size]
|
|
@@ -679,8 +686,8 @@ class PPO(RLAlgorithm):
|
|
|
679
686
|
batch_values = batch_values.squeeze()
|
|
680
687
|
|
|
681
688
|
if len(minibatch_idxs) > 1:
|
|
682
|
-
log_prob, entropy, value = self.evaluate_actions(
|
|
683
|
-
obs=batch_observations, actions=batch_actions
|
|
689
|
+
log_prob, entropy, value, _ = self.evaluate_actions(
|
|
690
|
+
obs=batch_observations, actions=batch_actions, hidden_state=None
|
|
684
691
|
)
|
|
685
692
|
|
|
686
693
|
logratio = log_prob - batch_log_probs
|
|
@@ -754,12 +761,8 @@ class PPO(RLAlgorithm):
|
|
|
754
761
|
warnings.warn("Buffer data is empty. Skipping learning step.")
|
|
755
762
|
return 0.0
|
|
756
763
|
|
|
757
|
-
observations = buffer_td["observations"]
|
|
758
|
-
advantages = buffer_td["advantages"]
|
|
759
|
-
|
|
760
764
|
batch_size = self.batch_size
|
|
761
|
-
num_samples =
|
|
762
|
-
|
|
765
|
+
num_samples = self.rollout_buffer.size()
|
|
763
766
|
indices = np.arange(num_samples)
|
|
764
767
|
mean_loss = 0.0
|
|
765
768
|
approx_kl_divs = []
|
|
@@ -775,7 +778,7 @@ class PPO(RLAlgorithm):
|
|
|
775
778
|
mb_obs = minibatch_td["observations"]
|
|
776
779
|
mb_actions = minibatch_td["actions"]
|
|
777
780
|
mb_log_probs = minibatch_td["log_probs"]
|
|
778
|
-
mb_advantages = advantages
|
|
781
|
+
mb_advantages = minibatch_td["advantages"]
|
|
779
782
|
mb_returns = minibatch_td["returns"]
|
|
780
783
|
mb_old_values = minibatch_td["values"]
|
|
781
784
|
|
|
@@ -799,23 +802,19 @@ class PPO(RLAlgorithm):
|
|
|
799
802
|
"Recurrent policy, but no hidden_states found in minibatch_td for flat learning."
|
|
800
803
|
)
|
|
801
804
|
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
hidden_state=eval_hidden_state,
|
|
805
|
-
sample=False, # No sampling during evaluation for loss calculation
|
|
806
|
-
)
|
|
807
|
-
|
|
808
|
-
log_probs = self.actor.action_log_prob(mb_actions)
|
|
805
|
+
if isinstance(self.action_space, spaces.Discrete):
|
|
806
|
+
mb_actions = mb_actions.squeeze(-1)
|
|
809
807
|
|
|
810
|
-
|
|
811
|
-
|
|
808
|
+
log_probs, entropy, values, _ = self.evaluate_actions(
|
|
809
|
+
obs=mb_obs, actions=mb_actions, hidden_state=eval_hidden_state
|
|
810
|
+
)
|
|
812
811
|
|
|
813
812
|
# Normalize advantages
|
|
814
813
|
mb_advantages = (mb_advantages - mb_advantages.mean()) / (
|
|
815
814
|
mb_advantages.std() + 1e-8
|
|
816
815
|
)
|
|
817
816
|
|
|
818
|
-
# Policy
|
|
817
|
+
# Policy los
|
|
819
818
|
ratio = torch.exp(log_probs - mb_log_probs)
|
|
820
819
|
policy_loss1 = -mb_advantages * ratio
|
|
821
820
|
policy_loss2 = -mb_advantages * torch.clamp(
|
|
@@ -943,12 +942,10 @@ class PPO(RLAlgorithm):
|
|
|
943
942
|
warnings.warn("No BPTT sequences to sample. Skipping learning.")
|
|
944
943
|
return 0.0
|
|
945
944
|
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
) # Here, batch_size means number of sequences per minibatch
|
|
945
|
+
# Here, batch_size means number of sequences per minibatch
|
|
946
|
+
sequences_per_minibatch = self.batch_size
|
|
949
947
|
mean_loss = 0.0
|
|
950
948
|
total_minibatch_updates_total = 0
|
|
951
|
-
|
|
952
949
|
for epoch in range(self.update_epochs):
|
|
953
950
|
approx_kl_divs_epoch = [] # KL divergences for this epoch's minibatches
|
|
954
951
|
np.random.shuffle(all_start_coords)
|
|
@@ -982,24 +979,18 @@ class PPO(RLAlgorithm):
|
|
|
982
979
|
warnings.warn("Skipping empty or invalid minibatch of sequences.")
|
|
983
980
|
continue
|
|
984
981
|
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
]
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
mb_returns_seq = current_minibatch_td[
|
|
998
|
-
"returns"
|
|
999
|
-
] # Shape: (batch_seq, seq_len)
|
|
1000
|
-
|
|
1001
|
-
mb_initial_hidden_states_dict = current_minibatch_td.get_non_tensor(
|
|
1002
|
-
"initial_hidden_states", default=None
|
|
982
|
+
# Obs shape: (batch_seq, seq_len, *obs_dims) or nested TD
|
|
983
|
+
# Actions shape: (batch_seq, seq_len, *act_dims)
|
|
984
|
+
# Other tensors shape: (batch_seq, seq_len)
|
|
985
|
+
mb_obs_seq = current_minibatch_td["observations"]
|
|
986
|
+
mb_actions_seq = current_minibatch_td["actions"]
|
|
987
|
+
mb_old_log_probs_seq = current_minibatch_td["log_probs"]
|
|
988
|
+
mb_advantages_seq = current_minibatch_td["advantages"]
|
|
989
|
+
mb_returns_seq = current_minibatch_td["returns"]
|
|
990
|
+
mb_initial_hidden_states_dict: Optional[TensorDict] = (
|
|
991
|
+
current_minibatch_td.get_non_tensor(
|
|
992
|
+
"initial_hidden_states", default=None
|
|
993
|
+
)
|
|
1003
994
|
)
|
|
1004
995
|
|
|
1005
996
|
policy_loss_total, value_loss_total, entropy_loss_total = 0.0, 0.0, 0.0
|
|
@@ -1027,26 +1018,19 @@ class PPO(RLAlgorithm):
|
|
|
1027
1018
|
)
|
|
1028
1019
|
adv_t, return_t = mb_advantages_seq[:, t], mb_returns_seq[:, t]
|
|
1029
1020
|
|
|
1021
|
+
# new_value_t: (batch_seq,), entropy_t: (batch_seq,) or scalar, log_prob_t: (batch_seq,)
|
|
1030
1022
|
(
|
|
1031
|
-
|
|
1032
|
-
_,
|
|
1023
|
+
new_log_prob_t,
|
|
1033
1024
|
entropy_t,
|
|
1034
1025
|
new_value_t,
|
|
1035
1026
|
next_hidden_state_for_actor_step,
|
|
1036
|
-
) = self.
|
|
1037
|
-
obs_t,
|
|
1027
|
+
) = self.evaluate_actions(
|
|
1028
|
+
obs=obs_t,
|
|
1029
|
+
actions=actions_t,
|
|
1038
1030
|
hidden_state=current_step_hidden_state_actor,
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
new_log_prob_t = self.actor.action_log_prob(
|
|
1043
|
-
actions_t
|
|
1044
|
-
) # Shape: (batch_seq,)
|
|
1045
|
-
entropy_t = (
|
|
1046
|
-
(-new_log_prob_t.mean())
|
|
1047
|
-
if entropy_t is None
|
|
1048
|
-
else entropy_t.mean()
|
|
1049
|
-
) # Ensure scalar
|
|
1031
|
+
)
|
|
1032
|
+
if isinstance(entropy_t, torch.Tensor):
|
|
1033
|
+
entropy_t = entropy_t.mean()
|
|
1050
1034
|
|
|
1051
1035
|
ratio = torch.exp(new_log_prob_t - old_log_prob_t)
|
|
1052
1036
|
policy_loss1 = -adv_t * ratio
|
|
@@ -1,34 +1,15 @@
|
|
|
1
1
|
import random # Added to support random sequence sampling for BPTT
|
|
2
2
|
import warnings
|
|
3
|
+
from collections import OrderedDict
|
|
3
4
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import torch
|
|
7
8
|
from gymnasium import spaces
|
|
8
|
-
from tensordict import TensorDict
|
|
9
|
-
|
|
10
|
-
from agilerl.typing import ArrayOrTensor, ObservationType
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
# Define the utility function locally to avoid circular import
|
|
14
|
-
def convert_np_to_torch_dtype(np_dtype):
|
|
15
|
-
"""Converts a numpy dtype to a torch dtype."""
|
|
16
|
-
if np_dtype == np.float32:
|
|
17
|
-
return torch.float32
|
|
18
|
-
elif np_dtype == np.float64:
|
|
19
|
-
return torch.float64
|
|
20
|
-
elif np_dtype == np.int32:
|
|
21
|
-
return torch.int32
|
|
22
|
-
elif np_dtype == np.int64:
|
|
23
|
-
return torch.int64
|
|
24
|
-
elif np_dtype == np.uint8:
|
|
25
|
-
return torch.uint8
|
|
26
|
-
elif np_dtype == np.bool_:
|
|
27
|
-
return torch.bool
|
|
28
|
-
else:
|
|
29
|
-
# Fallback or raise error for unhandled dtypes
|
|
30
|
-
warnings.warn(f"Unhandled numpy dtype {np_dtype}, defaulting to torch.float32")
|
|
31
|
-
return torch.float32
|
|
9
|
+
from tensordict import TensorDict
|
|
10
|
+
|
|
11
|
+
from agilerl.typing import ArrayOrTensor, ObservationType, TorchObsType
|
|
12
|
+
from agilerl.utils.algo_utils import get_num_actions, get_obs_shape, maybe_add_batch_dim
|
|
32
13
|
|
|
33
14
|
|
|
34
15
|
class RolloutBuffer:
|
|
@@ -97,69 +78,39 @@ class RolloutBuffer:
|
|
|
97
78
|
self.full = False
|
|
98
79
|
self._initialize_buffers()
|
|
99
80
|
|
|
81
|
+
def _maybe_reshape_obs(
|
|
82
|
+
self, obs: TorchObsType, space: spaces.Space
|
|
83
|
+
) -> TorchObsType:
|
|
84
|
+
"""Reshape observation to the correct shape.
|
|
85
|
+
|
|
86
|
+
:param obs: Observation to reshape.
|
|
87
|
+
:type obs: TorchObsType
|
|
88
|
+
:param space: Observation space.
|
|
89
|
+
:type space: spaces.Space
|
|
90
|
+
:return: Reshaped observation.
|
|
91
|
+
:rtype: TorchObsType
|
|
92
|
+
"""
|
|
93
|
+
if isinstance(space, spaces.Discrete) and obs.ndim < 2:
|
|
94
|
+
obs = obs.unsqueeze(-1)
|
|
95
|
+
|
|
96
|
+
return maybe_add_batch_dim(obs, space)
|
|
97
|
+
|
|
100
98
|
def _initialize_buffers(self) -> None:
|
|
101
99
|
"""Initialize buffer arrays with correct shapes for vectorized environments."""
|
|
102
100
|
# Determine shapes and dtypes for all expected fields
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
elif isinstance(self.observation_space, spaces.MultiDiscrete):
|
|
106
|
-
obs_shape = (len(self.observation_space.nvec),)
|
|
107
|
-
elif isinstance(self.observation_space, spaces.Box):
|
|
108
|
-
obs_shape = self.observation_space.shape
|
|
109
|
-
elif isinstance(self.observation_space, spaces.Dict):
|
|
110
|
-
# For Dict observation spaces, we'll create a nested structure
|
|
111
|
-
# The observations will be stored as nested TensorDicts
|
|
112
|
-
obs_shape = None # Will be handled as nested TensorDict
|
|
113
|
-
elif isinstance(self.observation_space, spaces.Tuple):
|
|
114
|
-
# For Tuple, we'll flatten or handle as multiple entries
|
|
115
|
-
# For now, let's assume we'll pre-allocate based on flattened structure
|
|
116
|
-
obs_shape = () # Placeholder, will be determined by actual data
|
|
117
|
-
else:
|
|
118
|
-
obs_shape = self.observation_space.shape
|
|
119
|
-
|
|
120
|
-
if isinstance(self.action_space, spaces.Discrete):
|
|
121
|
-
action_shape = ()
|
|
122
|
-
action_dtype = torch.int64
|
|
123
|
-
elif isinstance(self.action_space, spaces.Box):
|
|
124
|
-
action_shape = self.action_space.shape
|
|
125
|
-
action_dtype = convert_np_to_torch_dtype(
|
|
126
|
-
self.action_space.dtype
|
|
127
|
-
) # Convert numpy dtype to torch dtype
|
|
128
|
-
elif isinstance(self.action_space, spaces.MultiDiscrete):
|
|
129
|
-
action_shape = (len(self.action_space.nvec),)
|
|
130
|
-
action_dtype = torch.int64
|
|
131
|
-
elif isinstance(self.action_space, spaces.MultiBinary):
|
|
132
|
-
action_shape = (self.action_space.n,)
|
|
133
|
-
action_dtype = torch.int64
|
|
134
|
-
else:
|
|
135
|
-
try:
|
|
136
|
-
action_shape = self.action_space.shape
|
|
137
|
-
action_dtype = convert_np_to_torch_dtype(
|
|
138
|
-
getattr(self.action_space, "dtype", np.float32)
|
|
139
|
-
) # Convert numpy dtype to torch dtype
|
|
140
|
-
except AttributeError:
|
|
141
|
-
raise TypeError(
|
|
142
|
-
f"Unsupported action space type without shape: {type(self.action_space)}"
|
|
143
|
-
)
|
|
101
|
+
obs_shape = get_obs_shape(self.observation_space)
|
|
102
|
+
num_actions = get_num_actions(self.action_space)
|
|
144
103
|
|
|
145
104
|
# Create a source TensorDict with appropriately sized tensors
|
|
146
105
|
# The tensors will be on the CPU by default, can be moved to device later if needed.
|
|
147
|
-
source_dict =
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
for key, subspace in self.observation_space.spaces.items():
|
|
154
|
-
if isinstance(subspace, spaces.Discrete):
|
|
155
|
-
sub_shape = (1,)
|
|
156
|
-
elif isinstance(subspace, spaces.Box):
|
|
157
|
-
sub_shape = subspace.shape
|
|
158
|
-
else:
|
|
159
|
-
sub_shape = subspace.shape if hasattr(subspace, "shape") else ()
|
|
160
|
-
|
|
106
|
+
source_dict = OrderedDict()
|
|
107
|
+
if isinstance(
|
|
108
|
+
self.observation_space, spaces.Dict
|
|
109
|
+
): # Nested structure for Dict spaces
|
|
110
|
+
obs_dict = OrderedDict()
|
|
111
|
+
for key, shape in obs_shape.items():
|
|
161
112
|
obs_dict[key] = torch.zeros(
|
|
162
|
-
(self.capacity, self.num_envs, *
|
|
113
|
+
(self.capacity, self.num_envs, *shape), dtype=torch.float32
|
|
163
114
|
)
|
|
164
115
|
|
|
165
116
|
source_dict["observations"] = obs_dict
|
|
@@ -179,7 +130,7 @@ class RolloutBuffer:
|
|
|
179
130
|
source_dict.update(
|
|
180
131
|
{
|
|
181
132
|
"actions": torch.zeros(
|
|
182
|
-
(self.capacity, self.num_envs,
|
|
133
|
+
(self.capacity, self.num_envs, num_actions), dtype=torch.float32
|
|
183
134
|
),
|
|
184
135
|
"rewards": torch.zeros(
|
|
185
136
|
(self.capacity, self.num_envs), dtype=torch.float32
|
|
@@ -283,44 +234,27 @@ class RolloutBuffer:
|
|
|
283
234
|
)
|
|
284
235
|
|
|
285
236
|
# Prepare data as a dictionary of tensors for the current time step
|
|
286
|
-
current_step_data =
|
|
237
|
+
current_step_data = OrderedDict()
|
|
287
238
|
|
|
288
239
|
# Convert inputs to tensors and ensure correct device (CPU for buffer storage)
|
|
289
240
|
# Also ensure they have the (num_envs, ...) shape
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
if isinstance(obs, dict): # Dict observation space
|
|
293
|
-
obs_dict = {}
|
|
241
|
+
if isinstance(self.observation_space, spaces.Dict):
|
|
242
|
+
obs_dict = OrderedDict()
|
|
294
243
|
for key, item in obs.items():
|
|
244
|
+
sub_space = self.observation_space.spaces[key]
|
|
295
245
|
obs_tensor = torch.as_tensor(item, device="cpu")
|
|
296
|
-
|
|
297
|
-
obs_tensor = obs_tensor.unsqueeze(0)
|
|
298
|
-
elif (
|
|
299
|
-
self.num_envs == 1
|
|
300
|
-
and len(obs_tensor.shape)
|
|
301
|
-
< len(self.observation_space.spaces[key].shape) + 1
|
|
302
|
-
):
|
|
303
|
-
obs_tensor = obs_tensor.unsqueeze(0)
|
|
304
|
-
|
|
305
|
-
obs_dict[key] = obs_tensor
|
|
246
|
+
obs_dict[key] = self._maybe_reshape_obs(obs_tensor, sub_space)
|
|
306
247
|
|
|
307
248
|
current_step_data["observations"] = obs_dict
|
|
308
249
|
else:
|
|
309
250
|
obs_tensor = torch.as_tensor(obs, device="cpu")
|
|
310
|
-
|
|
311
|
-
self.
|
|
312
|
-
|
|
313
|
-
): # Add batch dim for single env
|
|
314
|
-
obs_tensor = obs_tensor.unsqueeze(0)
|
|
315
|
-
|
|
316
|
-
current_step_data["observations"] = obs_tensor
|
|
251
|
+
current_step_data["observations"] = self._maybe_reshape_obs(
|
|
252
|
+
obs_tensor, self.observation_space
|
|
253
|
+
)
|
|
317
254
|
|
|
318
255
|
# Actions
|
|
319
256
|
action_tensor = torch.as_tensor(action, device="cpu")
|
|
320
|
-
|
|
321
|
-
action_tensor = action_tensor.unsqueeze(0)
|
|
322
|
-
|
|
323
|
-
current_step_data["actions"] = action_tensor
|
|
257
|
+
current_step_data["actions"] = action_tensor.reshape(self.num_envs, -1)
|
|
324
258
|
|
|
325
259
|
# Rewards
|
|
326
260
|
reward_tensor = torch.as_tensor(reward, dtype=torch.float32, device="cpu")
|
|
@@ -340,28 +274,21 @@ class RolloutBuffer:
|
|
|
340
274
|
|
|
341
275
|
# Next Observations
|
|
342
276
|
if next_obs is not None:
|
|
343
|
-
if isinstance(
|
|
344
|
-
next_obs_dict =
|
|
277
|
+
if isinstance(self.observation_space, spaces.Dict):
|
|
278
|
+
next_obs_dict = OrderedDict()
|
|
345
279
|
for key, item in next_obs.items():
|
|
280
|
+
sub_space = self.observation_space.spaces[key]
|
|
346
281
|
next_obs_tensor = torch.as_tensor(item, device="cpu")
|
|
347
|
-
|
|
348
|
-
next_obs_tensor
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
and len(next_obs_tensor.shape)
|
|
352
|
-
< len(self.observation_space.spaces[key].shape) + 1
|
|
353
|
-
):
|
|
354
|
-
next_obs_tensor = next_obs_tensor.unsqueeze(0)
|
|
355
|
-
next_obs_dict[key] = next_obs_tensor
|
|
282
|
+
next_obs_dict[key] = self._maybe_reshape_obs(
|
|
283
|
+
next_obs_tensor, sub_space
|
|
284
|
+
)
|
|
285
|
+
|
|
356
286
|
current_step_data["next_observations"] = next_obs_dict
|
|
357
287
|
else:
|
|
358
288
|
next_obs_tensor = torch.as_tensor(next_obs, device="cpu")
|
|
359
|
-
|
|
360
|
-
self.
|
|
361
|
-
|
|
362
|
-
): # Add batch dim
|
|
363
|
-
next_obs_tensor = next_obs_tensor.unsqueeze(0)
|
|
364
|
-
current_step_data["next_observations"] = next_obs_tensor
|
|
289
|
+
current_step_data["next_observations"] = self._maybe_reshape_obs(
|
|
290
|
+
next_obs_tensor, self.observation_space
|
|
291
|
+
)
|
|
365
292
|
|
|
366
293
|
# Episode Starts
|
|
367
294
|
if episode_start is not None:
|
|
@@ -493,7 +420,7 @@ class RolloutBuffer:
|
|
|
493
420
|
|
|
494
421
|
# Get a view of the buffer up to the current position and for all envs
|
|
495
422
|
# This slice will have batch_size [buffer_size, num_envs]
|
|
496
|
-
valid_buffer_data = self.buffer[:buffer_size]
|
|
423
|
+
valid_buffer_data: TensorDict = self.buffer[:buffer_size]
|
|
497
424
|
|
|
498
425
|
# Reshape to flatten the num_envs dimension into the first batch dimension
|
|
499
426
|
# New batch_size will be [buffer_size * num_envs]
|
|
@@ -8,6 +8,7 @@ from agilerl.modules.configs import MlpNetConfig
|
|
|
8
8
|
from agilerl.networks.base import EvolvableNetwork
|
|
9
9
|
from agilerl.networks.distributions import EvolvableDistribution
|
|
10
10
|
from agilerl.typing import ArrayOrTensor, NetConfigType, TorchObsType
|
|
11
|
+
from agilerl.utils.algo_utils import get_output_size_from_space
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class DeterministicActor(EvolvableNetwork):
|
|
@@ -105,6 +106,8 @@ class DeterministicActor(EvolvableNetwork):
|
|
|
105
106
|
else:
|
|
106
107
|
head_config["output_activation"] = output_activation
|
|
107
108
|
|
|
109
|
+
self.output_size = get_output_size_from_space(self.action_space)
|
|
110
|
+
|
|
108
111
|
self.build_network_head(head_config)
|
|
109
112
|
self.output_activation = head_config.get("output_activation", output_activation)
|
|
110
113
|
|
|
@@ -155,7 +158,7 @@ class DeterministicActor(EvolvableNetwork):
|
|
|
155
158
|
"""
|
|
156
159
|
self.head_net = self.create_mlp(
|
|
157
160
|
num_inputs=self.latent_dim,
|
|
158
|
-
num_outputs=
|
|
161
|
+
num_outputs=self.output_size,
|
|
159
162
|
name="actor",
|
|
160
163
|
net_config=net_config,
|
|
161
164
|
)
|
|
@@ -188,7 +191,7 @@ class DeterministicActor(EvolvableNetwork):
|
|
|
188
191
|
|
|
189
192
|
head_net = self.create_mlp(
|
|
190
193
|
num_inputs=self.latent_dim,
|
|
191
|
-
num_outputs=
|
|
194
|
+
num_outputs=self.output_size,
|
|
192
195
|
name="actor",
|
|
193
196
|
net_config=self.head_net.net_config,
|
|
194
197
|
)
|
|
@@ -290,6 +293,7 @@ class StochasticActor(EvolvableNetwork):
|
|
|
290
293
|
self.squash_output = squash_output
|
|
291
294
|
self.action_space = action_space
|
|
292
295
|
self.use_experimental_distribution = use_experimental_distribution
|
|
296
|
+
self.output_size = get_output_size_from_space(self.action_space)
|
|
293
297
|
|
|
294
298
|
self.build_network_head(head_config)
|
|
295
299
|
self.output_activation = None
|
|
@@ -327,7 +331,7 @@ class StochasticActor(EvolvableNetwork):
|
|
|
327
331
|
"""
|
|
328
332
|
self.head_net = self.create_mlp(
|
|
329
333
|
num_inputs=self.latent_dim,
|
|
330
|
-
num_outputs=
|
|
334
|
+
num_outputs=self.output_size,
|
|
331
335
|
name="actor",
|
|
332
336
|
net_config=net_config,
|
|
333
337
|
)
|
|
@@ -389,7 +393,7 @@ class StochasticActor(EvolvableNetwork):
|
|
|
389
393
|
|
|
390
394
|
head_net = self.create_mlp(
|
|
391
395
|
num_inputs=self.latent_dim,
|
|
392
|
-
num_outputs=
|
|
396
|
+
num_outputs=self.output_size,
|
|
393
397
|
name="actor",
|
|
394
398
|
net_config=self.head_net.net_config,
|
|
395
399
|
)
|
|
@@ -7,6 +7,7 @@ from torch.distributions import Bernoulli, Categorical, Distribution, Normal
|
|
|
7
7
|
|
|
8
8
|
from agilerl.modules.base import EvolvableModule, EvolvableWrapper
|
|
9
9
|
from agilerl.typing import ArrayOrTensor, DeviceType, NetConfigType
|
|
10
|
+
from agilerl.utils.algo_utils import get_output_size_from_space
|
|
10
11
|
|
|
11
12
|
DistributionType = Union[Distribution, List[Distribution]]
|
|
12
13
|
|
|
@@ -328,7 +329,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
328
329
|
super().__init__(network)
|
|
329
330
|
|
|
330
331
|
self.action_space = action_space
|
|
331
|
-
self.action_dim =
|
|
332
|
+
self.action_dim = get_output_size_from_space(action_space)
|
|
332
333
|
self.action_std_init = action_std_init
|
|
333
334
|
self.device = device
|
|
334
335
|
self.squash_output = squash_output and isinstance(action_space, spaces.Box)
|
|
@@ -55,9 +55,9 @@ def _collect_rollouts(
|
|
|
55
55
|
|
|
56
56
|
if (
|
|
57
57
|
last_obs is None
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
58
|
+
and last_done is None
|
|
59
|
+
and last_scores is None
|
|
60
|
+
and last_info is None
|
|
61
61
|
):
|
|
62
62
|
obs, info = env.reset()
|
|
63
63
|
scores = np.zeros(agent.num_envs)
|
|
@@ -169,7 +169,6 @@ def _collect_rollouts(
|
|
|
169
169
|
scores[idx] = 0
|
|
170
170
|
|
|
171
171
|
# Calculate last value to compute returns and advantages properly
|
|
172
|
-
# TODO: We shouldn't access a hidden method here...
|
|
173
172
|
with torch.no_grad():
|
|
174
173
|
if recurrent:
|
|
175
174
|
_, _, _, last_value, _ = agent._get_action_and_values(
|
|
@@ -14,6 +14,7 @@ from agilerl.algorithms import IPPO
|
|
|
14
14
|
from agilerl.hpo.mutation import Mutations
|
|
15
15
|
from agilerl.hpo.tournament import TournamentSelection
|
|
16
16
|
from agilerl.networks import StochasticActor
|
|
17
|
+
from agilerl.typing import SingleAgentModule
|
|
17
18
|
from agilerl.utils.algo_utils import obs_channels_to_first
|
|
18
19
|
from agilerl.utils.utils import (
|
|
19
20
|
default_progress_bar,
|
|
@@ -192,6 +193,7 @@ def train_multi_agent_on_policy(
|
|
|
192
193
|
pop_episode_scores = []
|
|
193
194
|
pop_fps = []
|
|
194
195
|
for agent_idx, agent in enumerate(pop): # Loop through population
|
|
196
|
+
compiled_agent = agent.torch_compiler is not None
|
|
195
197
|
agent.set_training_mode(True)
|
|
196
198
|
|
|
197
199
|
obs, info = env.reset() # Reset environment at start of episode
|
|
@@ -244,7 +246,11 @@ def train_multi_agent_on_policy(
|
|
|
244
246
|
)
|
|
245
247
|
agent_space = agent.possible_action_spaces[agent_id]
|
|
246
248
|
policy = getattr(agent, agent.registry.policy())
|
|
247
|
-
agent_policy = policy[network_id]
|
|
249
|
+
agent_policy: SingleAgentModule = policy[network_id]
|
|
250
|
+
|
|
251
|
+
if compiled_agent:
|
|
252
|
+
agent_policy = agent_policy._orig_mod
|
|
253
|
+
|
|
248
254
|
if isinstance(agent_policy, StochasticActor) and isinstance(
|
|
249
255
|
agent_space, spaces.Box
|
|
250
256
|
):
|