agilerl 2.3.2.dev0__tar.gz → 2.3.3.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/PKG-INFO +1 -1
  2. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/base.py +5 -2
  3. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/cqn.py +0 -1
  4. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ippo.py +3 -1
  5. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ppo.py +45 -61
  6. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/rollout_buffer.py +52 -125
  7. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/actors.py +8 -4
  8. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/distributions.py +2 -1
  9. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/rollouts/on_policy.py +3 -4
  10. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_multi_agent_on_policy.py +7 -1
  11. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/algo_utils.py +234 -149
  12. agilerl-2.3.3.dev1/agilerl/utils/cache.py +129 -0
  13. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/evolvable_networks.py +27 -105
  14. agilerl-2.3.3.dev1/agilerl/utils/ilql_utils.py +83 -0
  15. agilerl-2.3.3.dev1/agilerl/utils/log_utils.py +138 -0
  16. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/pyproject.toml +1 -1
  17. agilerl-2.3.2.dev0/agilerl/utils/cache.py +0 -56
  18. agilerl-2.3.2.dev0/agilerl/utils/ilql_utils.py +0 -34
  19. agilerl-2.3.2.dev0/agilerl/utils/log_utils.py +0 -70
  20. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/LICENSE +0 -0
  21. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/README.md +0 -0
  22. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/__init__.py +0 -0
  23. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/__init__.py +0 -0
  24. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/bc_lm.py +0 -0
  25. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/__init__.py +0 -0
  26. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
  27. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/core/registry.py +0 -0
  28. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ddpg.py +0 -0
  29. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/dqn.py +0 -0
  30. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/dqn_rainbow.py +0 -0
  31. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/grpo.py +0 -0
  32. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/ilql.py +0 -0
  33. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/maddpg.py +0 -0
  34. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/matd3.py +0 -0
  35. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  36. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  37. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/algorithms/td3.py +0 -0
  38. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/__init__.py +0 -0
  39. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/data.py +0 -0
  40. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  41. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/replay_buffer.py +0 -0
  42. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/sampler.py +0 -0
  43. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/components/segment_tree.py +0 -0
  44. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/__init__.py +0 -0
  45. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/language_environment.py +0 -0
  46. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/rl_data.py +0 -0
  47. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/tokenizer.py +0 -0
  48. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/data/torch_datasets.py +0 -0
  49. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/__init__.py +0 -0
  50. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/mutation.py +0 -0
  51. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/hpo/tournament.py +0 -0
  52. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/__init__.py +0 -0
  53. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/base.py +0 -0
  54. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/bert.py +0 -0
  55. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/cnn.py +0 -0
  56. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/configs.py +0 -0
  57. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/custom_components.py +0 -0
  58. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/dummy.py +0 -0
  59. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/gpt.py +0 -0
  60. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/lstm.py +0 -0
  61. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/mlp.py +0 -0
  62. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/multi_input.py +0 -0
  63. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/resnet.py +0 -0
  64. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/modules/simba.py +0 -0
  65. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/__init__.py +0 -0
  66. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/base.py +0 -0
  67. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/custom_modules.py +0 -0
  68. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/distributions_experimental.py +0 -0
  69. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/q_networks.py +0 -0
  70. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/networks/value_networks.py +0 -0
  71. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/protocols.py +0 -0
  72. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/rollouts/__init__.py +0 -0
  73. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/__init__.py +0 -0
  74. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_bandits.py +0 -0
  75. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_llm.py +0 -0
  76. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_multi_agent_off_policy.py +0 -0
  77. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_off_policy.py +0 -0
  78. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_offline.py +0 -0
  79. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/training/train_on_policy.py +0 -0
  80. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/typing.py +0 -0
  81. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/__init__.py +0 -0
  82. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/llm_utils.py +0 -0
  83. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/minari_utils.py +0 -0
  84. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/probe_envs.py +0 -0
  85. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/probe_envs_ma.py +0 -0
  86. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/sampling_utils.py +0 -0
  87. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/torch_utils.py +0 -0
  88. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/utils/utils.py +0 -0
  89. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/__init__.py +0 -0
  90. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/pz_async_vec_env.py +0 -0
  91. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/vector/pz_vec_env.py +0 -0
  92. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/__init__.py +0 -0
  93. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/agent.py +0 -0
  94. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/learning.py +0 -0
  95. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/make_evolvable.py +0 -0
  96. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  97. {agilerl-2.3.2.dev0 → agilerl-2.3.3.dev1}/agilerl/wrappers/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: agilerl
3
- Version: 2.3.2.dev0
3
+ Version: 2.3.3.dev1
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  License: Apache 2.0
6
6
  Author: Nick Ustaran-Anderegg
@@ -73,6 +73,8 @@ from agilerl.utils.algo_utils import (
73
73
  chkpt_attribute_to_device,
74
74
  clone_llm,
75
75
  create_warmup_cosine_scheduler,
76
+ get_input_size_from_space,
77
+ get_output_size_from_space,
76
78
  isroutine,
77
79
  key_in_nested_dict,
78
80
  module_checkpoint_dict,
@@ -84,8 +86,6 @@ from agilerl.utils.evolvable_networks import (
84
86
  compile_model,
85
87
  config_from_dict,
86
88
  get_default_encoder_config,
87
- get_input_size_from_space,
88
- get_output_size_from_space,
89
89
  is_image_space,
90
90
  is_vector_space,
91
91
  )
@@ -144,6 +144,9 @@ def get_checkpoint_dict(
144
144
  attribute_dict.pop("actor", None)
145
145
  return attribute_dict
146
146
 
147
+ if "rollout_buffer" in attribute_dict:
148
+ attribute_dict.pop("rollout_buffer")
149
+
147
150
  # Get checkpoint dictionaries for evolvable modules and optimizers
148
151
  network_info: Dict[str, Dict[str, Any]] = {"modules": {}, "optimizers": {}}
149
152
  for attr in agent.evolvable_attributes():
@@ -196,7 +196,6 @@ class CQN(RLAlgorithm):
196
196
  ),
197
197
  axis=1,
198
198
  )
199
-
200
199
  else:
201
200
  self.actor.eval()
202
201
  with torch.no_grad():
@@ -722,7 +722,9 @@ class IPPO(MultiAgentRLAlgorithm):
722
722
  returns = advantages + values
723
723
 
724
724
  states = concatenate_experiences_into_batches(states, obs_space)
725
- actions = concatenate_experiences_into_batches(actions, action_space)
725
+ actions = concatenate_experiences_into_batches(
726
+ actions, action_space, actions=True
727
+ )
726
728
  log_probs = log_probs.reshape((-1,))
727
729
  experiences = (states, actions, log_probs, advantages, returns, values)
728
730
 
@@ -464,20 +464,27 @@ class PPO(RLAlgorithm):
464
464
  self,
465
465
  obs: ArrayOrTensor,
466
466
  actions: ArrayOrTensor,
467
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
467
+ hidden_state: Optional[Dict[str, ArrayOrTensor]] = None,
468
+ ) -> Tuple[
469
+ torch.Tensor, torch.Tensor, torch.Tensor, Optional[Dict[str, ArrayOrTensor]]
470
+ ]:
468
471
  """Evaluates the actions.
469
472
 
470
473
  :param obs: Environment observation, or multiple observations in a batch
471
474
  :type obs: ArrayOrTensor
472
475
  :param actions: Actions to evaluate
473
476
  :type actions: ArrayOrTensor
474
- :return: Log probability, entropy, and state values
475
- :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
477
+ :param hidden_state: Hidden state for recurrent policies, defaults to None. Expected shape: dict with tensors of shape (batch_size, 1, hidden_size).
478
+ :type hidden_state: Optional[Dict[str, ArrayOrTensor]]
479
+ :return: Log probability, entropy, state values, and next hidden state
480
+ :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[Dict[str, ArrayOrTensor]]]
476
481
  """
477
482
  obs = self.preprocess_observation(obs)
478
483
 
479
484
  # Get values from actor-critic
480
- _, _, entropy, values, _ = self._get_action_and_values(obs, sample=False)
485
+ _, _, entropy, values, next_hidden_state = self._get_action_and_values(
486
+ obs, hidden_state=hidden_state, sample=False
487
+ )
481
488
 
482
489
  log_prob = self.actor.action_log_prob(actions)
483
490
 
@@ -485,7 +492,7 @@ class PPO(RLAlgorithm):
485
492
  if entropy is None:
486
493
  entropy = -log_prob.mean()
487
494
 
488
- return log_prob, entropy, values
495
+ return log_prob, entropy, values, next_hidden_state
489
496
 
490
497
  def get_action(
491
498
  self,
@@ -659,7 +666,7 @@ class PPO(RLAlgorithm):
659
666
  num_samples = experiences[4].size(0)
660
667
  batch_idxs = np.arange(num_samples)
661
668
  mean_loss = 0
662
- for epoch in range(self.update_epochs):
669
+ for _ in range(self.update_epochs):
663
670
  np.random.shuffle(batch_idxs)
664
671
  for start in range(0, num_samples, self.batch_size):
665
672
  minibatch_idxs = batch_idxs[start : start + self.batch_size]
@@ -679,8 +686,8 @@ class PPO(RLAlgorithm):
679
686
  batch_values = batch_values.squeeze()
680
687
 
681
688
  if len(minibatch_idxs) > 1:
682
- log_prob, entropy, value = self.evaluate_actions(
683
- obs=batch_observations, actions=batch_actions
689
+ log_prob, entropy, value, _ = self.evaluate_actions(
690
+ obs=batch_observations, actions=batch_actions, hidden_state=None
684
691
  )
685
692
 
686
693
  logratio = log_prob - batch_log_probs
@@ -754,12 +761,8 @@ class PPO(RLAlgorithm):
754
761
  warnings.warn("Buffer data is empty. Skipping learning step.")
755
762
  return 0.0
756
763
 
757
- observations = buffer_td["observations"]
758
- advantages = buffer_td["advantages"]
759
-
760
764
  batch_size = self.batch_size
761
- num_samples = observations.size(0) # Total number of samples in the buffer
762
-
765
+ num_samples = self.rollout_buffer.size()
763
766
  indices = np.arange(num_samples)
764
767
  mean_loss = 0.0
765
768
  approx_kl_divs = []
@@ -775,7 +778,7 @@ class PPO(RLAlgorithm):
775
778
  mb_obs = minibatch_td["observations"]
776
779
  mb_actions = minibatch_td["actions"]
777
780
  mb_log_probs = minibatch_td["log_probs"]
778
- mb_advantages = advantages[minibatch_indices]
781
+ mb_advantages = minibatch_td["advantages"]
779
782
  mb_returns = minibatch_td["returns"]
780
783
  mb_old_values = minibatch_td["values"]
781
784
 
@@ -799,23 +802,19 @@ class PPO(RLAlgorithm):
799
802
  "Recurrent policy, but no hidden_states found in minibatch_td for flat learning."
800
803
  )
801
804
 
802
- _, _, entropy, values, _ = self._get_action_and_values(
803
- mb_obs,
804
- hidden_state=eval_hidden_state,
805
- sample=False, # No sampling during evaluation for loss calculation
806
- )
807
-
808
- log_probs = self.actor.action_log_prob(mb_actions)
805
+ if isinstance(self.action_space, spaces.Discrete):
806
+ mb_actions = mb_actions.squeeze(-1)
809
807
 
810
- if entropy is None: # For continuous squashed actions
811
- entropy = -log_probs
808
+ log_probs, entropy, values, _ = self.evaluate_actions(
809
+ obs=mb_obs, actions=mb_actions, hidden_state=eval_hidden_state
810
+ )
812
811
 
813
812
  # Normalize advantages
814
813
  mb_advantages = (mb_advantages - mb_advantages.mean()) / (
815
814
  mb_advantages.std() + 1e-8
816
815
  )
817
816
 
818
- # Policy loss
817
+ # Policy los
819
818
  ratio = torch.exp(log_probs - mb_log_probs)
820
819
  policy_loss1 = -mb_advantages * ratio
821
820
  policy_loss2 = -mb_advantages * torch.clamp(
@@ -943,12 +942,10 @@ class PPO(RLAlgorithm):
943
942
  warnings.warn("No BPTT sequences to sample. Skipping learning.")
944
943
  return 0.0
945
944
 
946
- sequences_per_minibatch = (
947
- self.batch_size
948
- ) # Here, batch_size means number of sequences per minibatch
945
+ # Here, batch_size means number of sequences per minibatch
946
+ sequences_per_minibatch = self.batch_size
949
947
  mean_loss = 0.0
950
948
  total_minibatch_updates_total = 0
951
-
952
949
  for epoch in range(self.update_epochs):
953
950
  approx_kl_divs_epoch = [] # KL divergences for this epoch's minibatches
954
951
  np.random.shuffle(all_start_coords)
@@ -982,24 +979,18 @@ class PPO(RLAlgorithm):
982
979
  warnings.warn("Skipping empty or invalid minibatch of sequences.")
983
980
  continue
984
981
 
985
- mb_obs_seq = current_minibatch_td[
986
- "observations"
987
- ] # Shape: (batch_seq, seq_len, *obs_dims) or nested TD
988
- mb_actions_seq = current_minibatch_td[
989
- "actions"
990
- ] # Shape: (batch_seq, seq_len, *act_dims)
991
- mb_old_log_probs_seq = current_minibatch_td[
992
- "log_probs"
993
- ] # Shape: (batch_seq, seq_len)
994
- mb_advantages_seq = current_minibatch_td[
995
- "advantages"
996
- ] # Shape: (batch_seq, seq_len) (already normalized)
997
- mb_returns_seq = current_minibatch_td[
998
- "returns"
999
- ] # Shape: (batch_seq, seq_len)
1000
-
1001
- mb_initial_hidden_states_dict = current_minibatch_td.get_non_tensor(
1002
- "initial_hidden_states", default=None
982
+ # Obs shape: (batch_seq, seq_len, *obs_dims) or nested TD
983
+ # Actions shape: (batch_seq, seq_len, *act_dims)
984
+ # Other tensors shape: (batch_seq, seq_len)
985
+ mb_obs_seq = current_minibatch_td["observations"]
986
+ mb_actions_seq = current_minibatch_td["actions"]
987
+ mb_old_log_probs_seq = current_minibatch_td["log_probs"]
988
+ mb_advantages_seq = current_minibatch_td["advantages"]
989
+ mb_returns_seq = current_minibatch_td["returns"]
990
+ mb_initial_hidden_states_dict: Optional[TensorDict] = (
991
+ current_minibatch_td.get_non_tensor(
992
+ "initial_hidden_states", default=None
993
+ )
1003
994
  )
1004
995
 
1005
996
  policy_loss_total, value_loss_total, entropy_loss_total = 0.0, 0.0, 0.0
@@ -1027,26 +1018,19 @@ class PPO(RLAlgorithm):
1027
1018
  )
1028
1019
  adv_t, return_t = mb_advantages_seq[:, t], mb_returns_seq[:, t]
1029
1020
 
1021
+ # new_value_t: (batch_seq,), entropy_t: (batch_seq,) or scalar, log_prob_t: (batch_seq,)
1030
1022
  (
1031
- _,
1032
- _,
1023
+ new_log_prob_t,
1033
1024
  entropy_t,
1034
1025
  new_value_t,
1035
1026
  next_hidden_state_for_actor_step,
1036
- ) = self._get_action_and_values(
1037
- obs_t,
1027
+ ) = self.evaluate_actions(
1028
+ obs=obs_t,
1029
+ actions=actions_t,
1038
1030
  hidden_state=current_step_hidden_state_actor,
1039
- sample=False,
1040
- ) # new_value_t: (batch_seq,), entropy_t: (batch_seq,) or scalar
1041
-
1042
- new_log_prob_t = self.actor.action_log_prob(
1043
- actions_t
1044
- ) # Shape: (batch_seq,)
1045
- entropy_t = (
1046
- (-new_log_prob_t.mean())
1047
- if entropy_t is None
1048
- else entropy_t.mean()
1049
- ) # Ensure scalar
1031
+ )
1032
+ if isinstance(entropy_t, torch.Tensor):
1033
+ entropy_t = entropy_t.mean()
1050
1034
 
1051
1035
  ratio = torch.exp(new_log_prob_t - old_log_prob_t)
1052
1036
  policy_loss1 = -adv_t * ratio
@@ -1,34 +1,15 @@
1
1
  import random # Added to support random sequence sampling for BPTT
2
2
  import warnings
3
+ from collections import OrderedDict
3
4
  from typing import Any, Dict, List, Optional, Tuple, Union
4
5
 
5
6
  import numpy as np
6
7
  import torch
7
8
  from gymnasium import spaces
8
- from tensordict import TensorDict # Add import
9
-
10
- from agilerl.typing import ArrayOrTensor, ObservationType
11
-
12
-
13
- # Define the utility function locally to avoid circular import
14
- def convert_np_to_torch_dtype(np_dtype):
15
- """Converts a numpy dtype to a torch dtype."""
16
- if np_dtype == np.float32:
17
- return torch.float32
18
- elif np_dtype == np.float64:
19
- return torch.float64
20
- elif np_dtype == np.int32:
21
- return torch.int32
22
- elif np_dtype == np.int64:
23
- return torch.int64
24
- elif np_dtype == np.uint8:
25
- return torch.uint8
26
- elif np_dtype == np.bool_:
27
- return torch.bool
28
- else:
29
- # Fallback or raise error for unhandled dtypes
30
- warnings.warn(f"Unhandled numpy dtype {np_dtype}, defaulting to torch.float32")
31
- return torch.float32
9
+ from tensordict import TensorDict
10
+
11
+ from agilerl.typing import ArrayOrTensor, ObservationType, TorchObsType
12
+ from agilerl.utils.algo_utils import get_num_actions, get_obs_shape, maybe_add_batch_dim
32
13
 
33
14
 
34
15
  class RolloutBuffer:
@@ -97,69 +78,39 @@ class RolloutBuffer:
97
78
  self.full = False
98
79
  self._initialize_buffers()
99
80
 
81
+ def _maybe_reshape_obs(
82
+ self, obs: TorchObsType, space: spaces.Space
83
+ ) -> TorchObsType:
84
+ """Reshape observation to the correct shape.
85
+
86
+ :param obs: Observation to reshape.
87
+ :type obs: TorchObsType
88
+ :param space: Observation space.
89
+ :type space: spaces.Space
90
+ :return: Reshaped observation.
91
+ :rtype: TorchObsType
92
+ """
93
+ if isinstance(space, spaces.Discrete) and obs.ndim < 2:
94
+ obs = obs.unsqueeze(-1)
95
+
96
+ return maybe_add_batch_dim(obs, space)
97
+
100
98
  def _initialize_buffers(self) -> None:
101
99
  """Initialize buffer arrays with correct shapes for vectorized environments."""
102
100
  # Determine shapes and dtypes for all expected fields
103
- if isinstance(self.observation_space, spaces.Discrete):
104
- obs_shape = (1,)
105
- elif isinstance(self.observation_space, spaces.MultiDiscrete):
106
- obs_shape = (len(self.observation_space.nvec),)
107
- elif isinstance(self.observation_space, spaces.Box):
108
- obs_shape = self.observation_space.shape
109
- elif isinstance(self.observation_space, spaces.Dict):
110
- # For Dict observation spaces, we'll create a nested structure
111
- # The observations will be stored as nested TensorDicts
112
- obs_shape = None # Will be handled as nested TensorDict
113
- elif isinstance(self.observation_space, spaces.Tuple):
114
- # For Tuple, we'll flatten or handle as multiple entries
115
- # For now, let's assume we'll pre-allocate based on flattened structure
116
- obs_shape = () # Placeholder, will be determined by actual data
117
- else:
118
- obs_shape = self.observation_space.shape
119
-
120
- if isinstance(self.action_space, spaces.Discrete):
121
- action_shape = ()
122
- action_dtype = torch.int64
123
- elif isinstance(self.action_space, spaces.Box):
124
- action_shape = self.action_space.shape
125
- action_dtype = convert_np_to_torch_dtype(
126
- self.action_space.dtype
127
- ) # Convert numpy dtype to torch dtype
128
- elif isinstance(self.action_space, spaces.MultiDiscrete):
129
- action_shape = (len(self.action_space.nvec),)
130
- action_dtype = torch.int64
131
- elif isinstance(self.action_space, spaces.MultiBinary):
132
- action_shape = (self.action_space.n,)
133
- action_dtype = torch.int64
134
- else:
135
- try:
136
- action_shape = self.action_space.shape
137
- action_dtype = convert_np_to_torch_dtype(
138
- getattr(self.action_space, "dtype", np.float32)
139
- ) # Convert numpy dtype to torch dtype
140
- except AttributeError:
141
- raise TypeError(
142
- f"Unsupported action space type without shape: {type(self.action_space)}"
143
- )
101
+ obs_shape = get_obs_shape(self.observation_space)
102
+ num_actions = get_num_actions(self.action_space)
144
103
 
145
104
  # Create a source TensorDict with appropriately sized tensors
146
105
  # The tensors will be on the CPU by default, can be moved to device later if needed.
147
- source_dict = {}
148
-
149
- # Handle observations based on space type
150
- if isinstance(self.observation_space, spaces.Dict):
151
- # For Dict spaces, create nested structure
152
- obs_dict = {}
153
- for key, subspace in self.observation_space.spaces.items():
154
- if isinstance(subspace, spaces.Discrete):
155
- sub_shape = (1,)
156
- elif isinstance(subspace, spaces.Box):
157
- sub_shape = subspace.shape
158
- else:
159
- sub_shape = subspace.shape if hasattr(subspace, "shape") else ()
160
-
106
+ source_dict = OrderedDict()
107
+ if isinstance(
108
+ self.observation_space, spaces.Dict
109
+ ): # Nested structure for Dict spaces
110
+ obs_dict = OrderedDict()
111
+ for key, shape in obs_shape.items():
161
112
  obs_dict[key] = torch.zeros(
162
- (self.capacity, self.num_envs, *sub_shape), dtype=torch.float32
113
+ (self.capacity, self.num_envs, *shape), dtype=torch.float32
163
114
  )
164
115
 
165
116
  source_dict["observations"] = obs_dict
@@ -179,7 +130,7 @@ class RolloutBuffer:
179
130
  source_dict.update(
180
131
  {
181
132
  "actions": torch.zeros(
182
- (self.capacity, self.num_envs, *action_shape), dtype=action_dtype
133
+ (self.capacity, self.num_envs, num_actions), dtype=torch.float32
183
134
  ),
184
135
  "rewards": torch.zeros(
185
136
  (self.capacity, self.num_envs), dtype=torch.float32
@@ -283,44 +234,27 @@ class RolloutBuffer:
283
234
  )
284
235
 
285
236
  # Prepare data as a dictionary of tensors for the current time step
286
- current_step_data = {}
237
+ current_step_data = OrderedDict()
287
238
 
288
239
  # Convert inputs to tensors and ensure correct device (CPU for buffer storage)
289
240
  # Also ensure they have the (num_envs, ...) shape
290
-
291
- # Observations
292
- if isinstance(obs, dict): # Dict observation space
293
- obs_dict = {}
241
+ if isinstance(self.observation_space, spaces.Dict):
242
+ obs_dict = OrderedDict()
294
243
  for key, item in obs.items():
244
+ sub_space = self.observation_space.spaces[key]
295
245
  obs_tensor = torch.as_tensor(item, device="cpu")
296
- if self.num_envs == 1 and obs_tensor.ndim == 0:
297
- obs_tensor = obs_tensor.unsqueeze(0)
298
- elif (
299
- self.num_envs == 1
300
- and len(obs_tensor.shape)
301
- < len(self.observation_space.spaces[key].shape) + 1
302
- ):
303
- obs_tensor = obs_tensor.unsqueeze(0)
304
-
305
- obs_dict[key] = obs_tensor
246
+ obs_dict[key] = self._maybe_reshape_obs(obs_tensor, sub_space)
306
247
 
307
248
  current_step_data["observations"] = obs_dict
308
249
  else:
309
250
  obs_tensor = torch.as_tensor(obs, device="cpu")
310
- if (
311
- self.num_envs == 1
312
- and obs_tensor.ndim < len(self.observation_space.shape) + 1
313
- ): # Add batch dim for single env
314
- obs_tensor = obs_tensor.unsqueeze(0)
315
-
316
- current_step_data["observations"] = obs_tensor
251
+ current_step_data["observations"] = self._maybe_reshape_obs(
252
+ obs_tensor, self.observation_space
253
+ )
317
254
 
318
255
  # Actions
319
256
  action_tensor = torch.as_tensor(action, device="cpu")
320
- if self.num_envs == 1 and action_tensor.ndim < len(self.action_space.shape) + 1:
321
- action_tensor = action_tensor.unsqueeze(0)
322
-
323
- current_step_data["actions"] = action_tensor
257
+ current_step_data["actions"] = action_tensor.reshape(self.num_envs, -1)
324
258
 
325
259
  # Rewards
326
260
  reward_tensor = torch.as_tensor(reward, dtype=torch.float32, device="cpu")
@@ -340,28 +274,21 @@ class RolloutBuffer:
340
274
 
341
275
  # Next Observations
342
276
  if next_obs is not None:
343
- if isinstance(next_obs, dict): # Dict observation space
344
- next_obs_dict = {}
277
+ if isinstance(self.observation_space, spaces.Dict):
278
+ next_obs_dict = OrderedDict()
345
279
  for key, item in next_obs.items():
280
+ sub_space = self.observation_space.spaces[key]
346
281
  next_obs_tensor = torch.as_tensor(item, device="cpu")
347
- if self.num_envs == 1 and next_obs_tensor.ndim == 0:
348
- next_obs_tensor = next_obs_tensor.unsqueeze(0)
349
- elif (
350
- self.num_envs == 1
351
- and len(next_obs_tensor.shape)
352
- < len(self.observation_space.spaces[key].shape) + 1
353
- ):
354
- next_obs_tensor = next_obs_tensor.unsqueeze(0)
355
- next_obs_dict[key] = next_obs_tensor
282
+ next_obs_dict[key] = self._maybe_reshape_obs(
283
+ next_obs_tensor, sub_space
284
+ )
285
+
356
286
  current_step_data["next_observations"] = next_obs_dict
357
287
  else:
358
288
  next_obs_tensor = torch.as_tensor(next_obs, device="cpu")
359
- if (
360
- self.num_envs == 1
361
- and next_obs_tensor.ndim < len(self.observation_space.shape) + 1
362
- ): # Add batch dim
363
- next_obs_tensor = next_obs_tensor.unsqueeze(0)
364
- current_step_data["next_observations"] = next_obs_tensor
289
+ current_step_data["next_observations"] = self._maybe_reshape_obs(
290
+ next_obs_tensor, self.observation_space
291
+ )
365
292
 
366
293
  # Episode Starts
367
294
  if episode_start is not None:
@@ -493,7 +420,7 @@ class RolloutBuffer:
493
420
 
494
421
  # Get a view of the buffer up to the current position and for all envs
495
422
  # This slice will have batch_size [buffer_size, num_envs]
496
- valid_buffer_data = self.buffer[:buffer_size]
423
+ valid_buffer_data: TensorDict = self.buffer[:buffer_size]
497
424
 
498
425
  # Reshape to flatten the num_envs dimension into the first batch dimension
499
426
  # New batch_size will be [buffer_size * num_envs]
@@ -8,6 +8,7 @@ from agilerl.modules.configs import MlpNetConfig
8
8
  from agilerl.networks.base import EvolvableNetwork
9
9
  from agilerl.networks.distributions import EvolvableDistribution
10
10
  from agilerl.typing import ArrayOrTensor, NetConfigType, TorchObsType
11
+ from agilerl.utils.algo_utils import get_output_size_from_space
11
12
 
12
13
 
13
14
  class DeterministicActor(EvolvableNetwork):
@@ -105,6 +106,8 @@ class DeterministicActor(EvolvableNetwork):
105
106
  else:
106
107
  head_config["output_activation"] = output_activation
107
108
 
109
+ self.output_size = get_output_size_from_space(self.action_space)
110
+
108
111
  self.build_network_head(head_config)
109
112
  self.output_activation = head_config.get("output_activation", output_activation)
110
113
 
@@ -155,7 +158,7 @@ class DeterministicActor(EvolvableNetwork):
155
158
  """
156
159
  self.head_net = self.create_mlp(
157
160
  num_inputs=self.latent_dim,
158
- num_outputs=spaces.flatdim(self.action_space),
161
+ num_outputs=self.output_size,
159
162
  name="actor",
160
163
  net_config=net_config,
161
164
  )
@@ -188,7 +191,7 @@ class DeterministicActor(EvolvableNetwork):
188
191
 
189
192
  head_net = self.create_mlp(
190
193
  num_inputs=self.latent_dim,
191
- num_outputs=spaces.flatdim(self.action_space),
194
+ num_outputs=self.output_size,
192
195
  name="actor",
193
196
  net_config=self.head_net.net_config,
194
197
  )
@@ -290,6 +293,7 @@ class StochasticActor(EvolvableNetwork):
290
293
  self.squash_output = squash_output
291
294
  self.action_space = action_space
292
295
  self.use_experimental_distribution = use_experimental_distribution
296
+ self.output_size = get_output_size_from_space(self.action_space)
293
297
 
294
298
  self.build_network_head(head_config)
295
299
  self.output_activation = None
@@ -327,7 +331,7 @@ class StochasticActor(EvolvableNetwork):
327
331
  """
328
332
  self.head_net = self.create_mlp(
329
333
  num_inputs=self.latent_dim,
330
- num_outputs=spaces.flatdim(self.action_space),
334
+ num_outputs=self.output_size,
331
335
  name="actor",
332
336
  net_config=net_config,
333
337
  )
@@ -389,7 +393,7 @@ class StochasticActor(EvolvableNetwork):
389
393
 
390
394
  head_net = self.create_mlp(
391
395
  num_inputs=self.latent_dim,
392
- num_outputs=spaces.flatdim(self.action_space),
396
+ num_outputs=self.output_size,
393
397
  name="actor",
394
398
  net_config=self.head_net.net_config,
395
399
  )
@@ -7,6 +7,7 @@ from torch.distributions import Bernoulli, Categorical, Distribution, Normal
7
7
 
8
8
  from agilerl.modules.base import EvolvableModule, EvolvableWrapper
9
9
  from agilerl.typing import ArrayOrTensor, DeviceType, NetConfigType
10
+ from agilerl.utils.algo_utils import get_output_size_from_space
10
11
 
11
12
  DistributionType = Union[Distribution, List[Distribution]]
12
13
 
@@ -328,7 +329,7 @@ class EvolvableDistribution(EvolvableWrapper):
328
329
  super().__init__(network)
329
330
 
330
331
  self.action_space = action_space
331
- self.action_dim = spaces.flatdim(action_space)
332
+ self.action_dim = get_output_size_from_space(action_space)
332
333
  self.action_std_init = action_std_init
333
334
  self.device = device
334
335
  self.squash_output = squash_output and isinstance(action_space, spaces.Box)
@@ -55,9 +55,9 @@ def _collect_rollouts(
55
55
 
56
56
  if (
57
57
  last_obs is None
58
- or last_done is None
59
- or last_scores is None
60
- or last_info is None
58
+ and last_done is None
59
+ and last_scores is None
60
+ and last_info is None
61
61
  ):
62
62
  obs, info = env.reset()
63
63
  scores = np.zeros(agent.num_envs)
@@ -169,7 +169,6 @@ def _collect_rollouts(
169
169
  scores[idx] = 0
170
170
 
171
171
  # Calculate last value to compute returns and advantages properly
172
- # TODO: We shouldn't access a hidden method here...
173
172
  with torch.no_grad():
174
173
  if recurrent:
175
174
  _, _, _, last_value, _ = agent._get_action_and_values(
@@ -14,6 +14,7 @@ from agilerl.algorithms import IPPO
14
14
  from agilerl.hpo.mutation import Mutations
15
15
  from agilerl.hpo.tournament import TournamentSelection
16
16
  from agilerl.networks import StochasticActor
17
+ from agilerl.typing import SingleAgentModule
17
18
  from agilerl.utils.algo_utils import obs_channels_to_first
18
19
  from agilerl.utils.utils import (
19
20
  default_progress_bar,
@@ -192,6 +193,7 @@ def train_multi_agent_on_policy(
192
193
  pop_episode_scores = []
193
194
  pop_fps = []
194
195
  for agent_idx, agent in enumerate(pop): # Loop through population
196
+ compiled_agent = agent.torch_compiler is not None
195
197
  agent.set_training_mode(True)
196
198
 
197
199
  obs, info = env.reset() # Reset environment at start of episode
@@ -244,7 +246,11 @@ def train_multi_agent_on_policy(
244
246
  )
245
247
  agent_space = agent.possible_action_spaces[agent_id]
246
248
  policy = getattr(agent, agent.registry.policy())
247
- agent_policy = policy[network_id]
249
+ agent_policy: SingleAgentModule = policy[network_id]
250
+
251
+ if compiled_agent:
252
+ agent_policy = agent_policy._orig_mod
253
+
248
254
  if isinstance(agent_policy, StochasticActor) and isinstance(
249
255
  agent_space, spaces.Box
250
256
  ):