agilerl 2.5.0.dev0__tar.gz → 2.5.0.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (355) hide show
  1. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.pre-commit-config.yaml +2 -1
  2. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/PKG-INFO +4 -4
  3. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/base.py +9 -14
  4. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ippo.py +10 -4
  5. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/maddpg.py +10 -4
  6. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/matd3.py +10 -4
  7. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/actors.py +0 -13
  8. agilerl-2.5.0.dev0/agilerl/networks/distributions_experimental.py → agilerl-2.5.0.dev2/agilerl/networks/distributions.py +91 -202
  9. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/algo_utils.py +65 -0
  10. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/probe_envs.py +7 -1
  11. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/probe_envs_ma.py +11 -2
  12. agilerl-2.5.0.dev2/agilerl/utils/torch_utils.py +617 -0
  13. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo.yaml +2 -3
  14. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/pyproject.toml +21 -4
  15. agilerl-2.5.0.dev2/sitecustomize.py +12 -0
  16. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/conftest.py +1 -1
  17. agilerl-2.5.0.dev2/tests/subprocess_runner.py +210 -0
  18. agilerl-2.5.0.dev2/tests/test_algorithms/test_llms/conftest.py +117 -0
  19. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/test_dpo.py +90 -86
  20. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/test_grpo.py +104 -88
  21. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_ippo.py +63 -1
  22. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_maddpg.py +61 -0
  23. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_matd3.py +63 -0
  24. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ppo.py +26 -3
  25. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_actors.py +206 -337
  26. agilerl-2.5.0.dev2/tests/test_networks/test_distributions.py +184 -0
  27. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_train/test_train.py +1 -1
  28. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_train/test_train_llm.py +4 -4
  29. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_algo_utils.py +77 -0
  30. agilerl-2.5.0.dev2/tests/test_utils/test_torch_utils.py +329 -0
  31. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_utils.py +1 -1
  32. agilerl-2.5.0.dev2/tests/utils.py +310 -0
  33. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/uv.lock +613 -258
  34. agilerl-2.5.0.dev0/.coveragerc +0 -10
  35. agilerl-2.5.0.dev0/agilerl/networks/distributions.py +0 -530
  36. agilerl-2.5.0.dev0/agilerl/utils/torch_utils.py +0 -114
  37. agilerl-2.5.0.dev0/tests/test_algorithms/test_llms/conftest.py +0 -95
  38. agilerl-2.5.0.dev0/tests/test_networks/test_distributions.py +0 -209
  39. agilerl-2.5.0.dev0/tests/test_utils/test_torch_utils.py +0 -106
  40. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  41. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  42. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +0 -0
  43. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/badges/arena-github-badge.svg +0 -0
  44. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/workflows/codeql.yml +0 -0
  45. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/workflows/python-app.yml +0 -0
  46. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.gitignore +0 -0
  47. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.readthedocs.yaml +0 -0
  48. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CITATION.cff +0 -0
  49. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CODE_OF_CONDUCT.md +0 -0
  50. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CONTRIBUTING.md +0 -0
  51. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/LICENSE +0 -0
  52. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/README.md +0 -0
  53. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/__init__.py +0 -0
  54. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/__init__.py +0 -0
  55. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/bc_lm.py +0 -0
  56. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/__init__.py +0 -0
  57. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
  58. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/registry.py +0 -0
  59. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/cqn.py +0 -0
  60. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ddpg.py +0 -0
  61. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dpo.py +0 -0
  62. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dqn.py +0 -0
  63. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dqn_rainbow.py +0 -0
  64. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/grpo.py +0 -0
  65. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ilql.py +0 -0
  66. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  67. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  68. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ppo.py +0 -0
  69. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/td3.py +0 -0
  70. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/__init__.py +0 -0
  71. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/data.py +0 -0
  72. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  73. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/replay_buffer.py +0 -0
  74. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/rollout_buffer.py +0 -0
  75. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/sampler.py +0 -0
  76. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/segment_tree.py +0 -0
  77. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/__init__.py +0 -0
  78. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/language_environment.py +0 -0
  79. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/rl_data.py +0 -0
  80. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/tokenizer.py +0 -0
  81. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/torch_datasets.py +0 -0
  82. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/__init__.py +0 -0
  83. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/mutation.py +0 -0
  84. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/tournament.py +0 -0
  85. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/__init__.py +0 -0
  86. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/base.py +0 -0
  87. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/bert.py +0 -0
  88. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/cnn.py +0 -0
  89. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/configs.py +0 -0
  90. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/custom_components.py +0 -0
  91. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/dummy.py +0 -0
  92. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/gpt.py +0 -0
  93. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/lstm.py +0 -0
  94. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/mlp.py +0 -0
  95. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/multi_input.py +0 -0
  96. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/resnet.py +0 -0
  97. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/simba.py +0 -0
  98. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/__init__.py +0 -0
  99. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/base.py +0 -0
  100. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/custom_modules.py +0 -0
  101. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/q_networks.py +0 -0
  102. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/value_networks.py +0 -0
  103. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/protocols.py +0 -0
  104. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/rollouts/__init__.py +0 -0
  105. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/rollouts/on_policy.py +0 -0
  106. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/__init__.py +0 -0
  107. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_bandits.py +0 -0
  108. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_llm.py +0 -0
  109. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_multi_agent_off_policy.py +0 -0
  110. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_multi_agent_on_policy.py +0 -0
  111. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_off_policy.py +0 -0
  112. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_offline.py +0 -0
  113. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_on_policy.py +0 -0
  114. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/typing.py +0 -0
  115. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/__init__.py +0 -0
  116. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/cache.py +0 -0
  117. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/evolvable_networks.py +0 -0
  118. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/ilql_utils.py +0 -0
  119. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/llm_utils.py +0 -0
  120. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/log_utils.py +0 -0
  121. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/minari_utils.py +0 -0
  122. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/sampling_utils.py +0 -0
  123. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/utils.py +0 -0
  124. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/__init__.py +0 -0
  125. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/pz_async_vec_env.py +0 -0
  126. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/pz_vec_env.py +0 -0
  127. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/__init__.py +0 -0
  128. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/agent.py +0 -0
  129. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/learning.py +0 -0
  130. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/make_evolvable.py +0 -0
  131. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  132. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/utils.py +0 -0
  133. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_bandits.py +0 -0
  134. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_dpo.py +0 -0
  135. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_grpo.py +0 -0
  136. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_multi_agent_off_policy.py +0 -0
  137. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_multi_agent_on_policy.py +0 -0
  138. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_off_policy.py +0 -0
  139. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_off_policy_distributed.py +0 -0
  140. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_offline.py +0 -0
  141. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_offline_distributed.py +0 -0
  142. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_on_policy.py +0 -0
  143. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_rainbow.py +0 -0
  144. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_recurrent.py +0 -0
  145. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_resnet.py +0 -0
  146. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_simba.py +0 -0
  147. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/configs/ds_config.json +0 -0
  148. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/make_evolvable_benchmarking.py +0 -0
  149. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/networks.py +0 -0
  150. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/accelerate/accelerate.yaml +0 -0
  151. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/accelerate/grpo_accelerate_config.yaml +0 -0
  152. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/bandit/neural_ts.yaml +0 -0
  153. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/bandit/neural_ucb.yaml +0 -0
  154. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/cqn.yaml +0 -0
  155. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg.yaml +0 -0
  156. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg_lstm.yaml +0 -0
  157. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg_simba.yaml +0 -0
  158. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dpo.yaml +0 -0
  159. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn.yaml +0 -0
  160. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn_lstm.yaml +0 -0
  161. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn_rainbow.yaml +0 -0
  162. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/grpo.yaml +0 -0
  163. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/ippo.yaml +0 -0
  164. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/ippo_pong.yaml +0 -0
  165. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/maddpg.yaml +0 -0
  166. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/matd3.yaml +0 -0
  167. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_input.yaml +0 -0
  168. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo_image.yaml +0 -0
  169. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo_recurrent.yaml +0 -0
  170. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/td3.yaml +0 -0
  171. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/cartpole/cartpole_random_v1.1.0.h5 +0 -0
  172. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/cartpole/cartpole_v1.1.0.h5 +0 -0
  173. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/pendulum/pendulum_random_v1.1.0.h5 +0 -0
  174. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/pendulum/pendulum_v1.1.0.h5 +0 -0
  175. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_bandit.py +0 -0
  176. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_custom_network.py +0 -0
  177. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_multi_agent.py +0 -0
  178. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_off_policy.py +0 -0
  179. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_off_policy_distributed.py +0 -0
  180. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_offline.py +0 -0
  181. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_offline_distributed.py +0 -0
  182. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy.py +0 -0
  183. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_cartpole.py +0 -0
  184. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_memory.py +0 -0
  185. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_minigrid.py +0 -0
  186. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_cartpole.py +0 -0
  187. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_lunar_lander.py +0 -0
  188. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_lunar_lander_rnn.py +0 -0
  189. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_rnn_memory.py +0 -0
  190. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/dependabot.yml +0 -0
  191. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/Makefile +0 -0
  192. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/__init__.py +0 -0
  193. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/arena-github-badge.svg +0 -0
  194. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/css/custom.css +0 -0
  195. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/favicon.ico +0 -0
  196. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/js/expand_sidebar.js +0 -0
  197. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/logo_teal.png +0 -0
  198. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/logo_white.png +0 -0
  199. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/module.png +0 -0
  200. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/network.png +0 -0
  201. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/iris-thumbnail.png +0 -0
  202. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/pendigits-thumbnail.png +0 -0
  203. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/rainbow_performance.png +0 -0
  204. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/simba_thumbnail.png +0 -0
  205. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/base.rst +0 -0
  206. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/cql.rst +0 -0
  207. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ddpg.rst +0 -0
  208. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dpo.rst +0 -0
  209. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dqn.rst +0 -0
  210. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dqn_rainbow.rst +0 -0
  211. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/grpo.rst +0 -0
  212. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ilql.rst +0 -0
  213. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/index.rst +0 -0
  214. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ippo.rst +0 -0
  215. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/maddpg.rst +0 -0
  216. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/matd3.rst +0 -0
  217. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/neural_ts.rst +0 -0
  218. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/neural_ucb.rst +0 -0
  219. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ppo.rst +0 -0
  220. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/registry.rst +0 -0
  221. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/td3.rst +0 -0
  222. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/wrappers.rst +0 -0
  223. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/data.rst +0 -0
  224. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/index.rst +0 -0
  225. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/multi_agent_replay_buffer.rst +0 -0
  226. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/replay_buffer.rst +0 -0
  227. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/rollout_buffer.rst +0 -0
  228. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/sampler.rst +0 -0
  229. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/segment_tree.rst +0 -0
  230. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/index.rst +0 -0
  231. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/mutation.rst +0 -0
  232. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/tournament.rst +0 -0
  233. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/base.rst +0 -0
  234. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/bert.rst +0 -0
  235. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/cnn.rst +0 -0
  236. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/custom_activation.rst +0 -0
  237. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/dummy.rst +0 -0
  238. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/gpt.rst +0 -0
  239. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/index.rst +0 -0
  240. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/lstm.rst +0 -0
  241. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/mlp.rst +0 -0
  242. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/multi_input.rst +0 -0
  243. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/resnet.rst +0 -0
  244. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/simba.rst +0 -0
  245. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/actors.rst +0 -0
  246. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/base.rst +0 -0
  247. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/index.rst +0 -0
  248. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/q_networks.rst +0 -0
  249. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/value_networks.rst +0 -0
  250. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/rollouts/index.rst +0 -0
  251. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/rollouts/on_policy.rst +0 -0
  252. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/train.rst +0 -0
  253. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/algo_utils.rst +0 -0
  254. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/cache.rst +0 -0
  255. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/evolvable_networks.rst +0 -0
  256. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/ilql_utils.rst +0 -0
  257. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/index.rst +0 -0
  258. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/llm_utils.rst +0 -0
  259. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/log_utils.rst +0 -0
  260. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/minari_utils.rst +0 -0
  261. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/probe_envs.rst +0 -0
  262. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/torch_utils.rst +0 -0
  263. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/utils.rst +0 -0
  264. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/index.rst +0 -0
  265. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/petting_zoo_async_vector_env.rst +0 -0
  266. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/petting_zoo_vector_env.rst +0 -0
  267. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/agent.rst +0 -0
  268. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/index.rst +0 -0
  269. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/learning.rst +0 -0
  270. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/make_evolvable.rst +0 -0
  271. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/pettingzoo.rst +0 -0
  272. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/bandits/index.rst +0 -0
  273. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/conf.py +0 -0
  274. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/custom_algorithms/index.rst +0 -0
  275. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/debugging_rl/index.rst +0 -0
  276. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/distributed_training/index.rst +0 -0
  277. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/evo_hyperparam_opt/index.rst +0 -0
  278. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/evolvable_networks/index.rst +0 -0
  279. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/get_started/agilerl2changes.rst +0 -0
  280. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/get_started/index.rst +0 -0
  281. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/index.rst +0 -0
  282. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/llm_finetuning/index.rst +0 -0
  283. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/make.bat +0 -0
  284. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/multi_agent_training/index.rst +0 -0
  285. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/off_policy/index.rst +0 -0
  286. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/offline_training/index.rst +0 -0
  287. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/on_policy/index.rst +0 -0
  288. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/pomdp/index.rst +0 -0
  289. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/releases/index.rst +0 -0
  290. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/requirements.txt +0 -0
  291. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/__init__.py +0 -0
  292. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/helper_functions.py +0 -0
  293. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/pz_vector_test_utils.py +0 -0
  294. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/__init__.py +0 -0
  295. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/__init__.py +0 -0
  296. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/test_neural_ts.py +0 -0
  297. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/test_neural_ucb.py +0 -0
  298. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_base.py +0 -0
  299. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bc_lm.py +0 -0
  300. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/__init__.py +0 -0
  301. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/__init__.py +0 -0
  302. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_optimizer_wrapper.py +0 -0
  303. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_registry.py +0 -0
  304. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/__init__.py +0 -0
  305. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_cqn.py +0 -0
  306. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ddpg.py +0 -0
  307. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_dqn.py +0 -0
  308. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_dqn_rainbow.py +0 -0
  309. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ilql.py +0 -0
  310. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_td3.py +0 -0
  311. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/__init__.py +0 -0
  312. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_multi_agent_replay_buffer.py +0 -0
  313. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_replay_buffer.py +0 -0
  314. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_replay_data.py +0 -0
  315. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_rollout_buffer.py +0 -0
  316. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_sampler.py +0 -0
  317. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_segment_tree.py +0 -0
  318. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_data.py +0 -0
  319. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/__init__.py +0 -0
  320. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/test_mutation.py +0 -0
  321. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/test_tournament.py +0 -0
  322. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/__init__.py +0 -0
  323. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_base.py +0 -0
  324. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_bert.py +0 -0
  325. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_cnn.py +0 -0
  326. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_custom_activation.py +0 -0
  327. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_dummy.py +0 -0
  328. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_gpt.py +0 -0
  329. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_lstm.py +0 -0
  330. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_make_evolvable.py +0 -0
  331. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_mlp.py +0 -0
  332. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_multi_input.py +0 -0
  333. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_resnet.py +0 -0
  334. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_simba.py +0 -0
  335. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/__init__.py +0 -0
  336. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_base.py +0 -0
  337. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_q_networks.py +0 -0
  338. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_value_functions.py +0 -0
  339. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_protocols.py +0 -0
  340. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/__init__.py +0 -0
  341. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_cache.py +0 -0
  342. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_ilql_utils.py +0 -0
  343. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_llm_utils.py +0 -0
  344. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_log_utils.py +0 -0
  345. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_minari_utils.py +0 -0
  346. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_probe_envs.py +0 -0
  347. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_probe_envs_ma.py +0 -0
  348. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_sampling_utils.py +0 -0
  349. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_utils_evolvable.py +0 -0
  350. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_vector/test_vector.py +0 -0
  351. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/__init__.py +0 -0
  352. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_agent.py +0 -0
  353. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_autoreset.py +0 -0
  354. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_bandit_env.py +0 -0
  355. {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_skills.py +0 -0
@@ -43,8 +43,9 @@ repos:
43
43
  rev: 0.2.3
44
44
  hooks:
45
45
  - id: yamlfmt
46
+
46
47
  - repo: https://github.com/astral-sh/uv-pre-commit
47
48
  # uv version.
48
- rev: 0.10.2
49
+ rev: 0.10.3
49
50
  hooks:
50
51
  - id: uv-lock
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agilerl
3
- Version: 2.5.0.dev0
3
+ Version: 2.5.0.dev2
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  Author-email: Nick Ustaran-Anderegg <dev@agilerl.com>
6
6
  License-Expression: Apache-2.0
@@ -29,7 +29,7 @@ Requires-Dist: redis~=4.4.4
29
29
  Requires-Dist: supersuit~=3.9.0
30
30
  Requires-Dist: tensordict~=0.8
31
31
  Requires-Dist: termcolor~=1.1.0
32
- Requires-Dist: torch==2.7.1
32
+ Requires-Dist: torch==2.9.0
33
33
  Requires-Dist: tqdm>=4.66.4
34
34
  Requires-Dist: wandb~=0.18.0
35
35
  Provides-Extra: all
@@ -37,13 +37,13 @@ Requires-Dist: datasets==4.4.1; extra == 'all'
37
37
  Requires-Dist: deepspeed~=0.17.1; extra == 'all'
38
38
  Requires-Dist: peft~=0.18.0; extra == 'all'
39
39
  Requires-Dist: transformers~=4.57.1; extra == 'all'
40
- Requires-Dist: vllm~=0.10.0; extra == 'all'
40
+ Requires-Dist: vllm==0.13.0; extra == 'all'
41
41
  Provides-Extra: llm
42
42
  Requires-Dist: datasets==4.4.1; extra == 'llm'
43
43
  Requires-Dist: deepspeed~=0.17.1; extra == 'llm'
44
44
  Requires-Dist: peft~=0.18.0; extra == 'llm'
45
45
  Requires-Dist: transformers~=4.57.1; extra == 'llm'
46
- Requires-Dist: vllm~=0.10.0; extra == 'llm'
46
+ Requires-Dist: vllm==0.13.0; extra == 'llm'
47
47
  Description-Content-Type: text/markdown
48
48
 
49
49
  <p align="center">
@@ -1561,27 +1561,16 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1561
1561
  nan_arr = np.empty(self.action_dims[agent_id])
1562
1562
  nan_arr[:] = np.nan
1563
1563
  else:
1564
- nan_arr = np.array([[np.nan]])
1564
+ nan_arr = np.array([np.nan])
1565
1565
 
1566
1566
  env_defined_actions[agent_id] = nan_arr
1567
1567
  val = nan_arr
1568
1568
 
1569
1569
  # Handle discrete actions + env not vectorized
1570
1570
  if isinstance(val, (int, float)):
1571
- val = np.array([[val]])
1571
+ val = np.array([val])
1572
1572
  env_defined_actions[agent_id] = val
1573
1573
 
1574
- # Ensure additional dimension is added in so shapes align for masking
1575
- if isinstance(val, np.ndarray) and len(val.shape) == 1:
1576
- val = (
1577
- val[:, np.newaxis]
1578
- if isinstance(
1579
- self.possible_action_spaces[agent_id],
1580
- spaces.Discrete,
1581
- )
1582
- else val[np.newaxis, :]
1583
- )
1584
- env_defined_actions[agent_id] = val
1585
1574
  agent_masks[agent_id] = np.where(
1586
1575
  np.isnan(env_defined_actions[agent_id]),
1587
1576
  0,
@@ -1814,6 +1803,12 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
1814
1803
  for i, agent_id in enumerate(agent_ids):
1815
1804
  output_dict[agent_id] = group_outputs[group_id][i]
1816
1805
 
1806
+ if (
1807
+ isinstance(self.possible_action_spaces[agent_id], spaces.Discrete)
1808
+ and output_dict[agent_id].shape[-1] == 1
1809
+ ):
1810
+ output_dict[agent_id] = output_dict[agent_id].squeeze(-1)
1811
+
1817
1812
  return output_dict
1818
1813
 
1819
1814
  def sum_shared_rewards(self, rewards: ArrayDict) -> ArrayDict:
@@ -2302,7 +2297,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2302
2297
  None,
2303
2298
  None,
2304
2299
  )
2305
- if hasattr(self, "llm"):
2300
+ if hasattr(self, "llm") and self.llm is not None:
2306
2301
  del self.llm.llm_engine.model_executor
2307
2302
  del self.llm
2308
2303
  gc.collect()
@@ -27,6 +27,7 @@ from agilerl.typing import (
27
27
  TorchObsType,
28
28
  )
29
29
  from agilerl.utils.algo_utils import (
30
+ apply_env_defined_actions,
30
31
  concatenate_experiences_into_batches,
31
32
  concatenate_tensors,
32
33
  get_experiences_samples,
@@ -601,10 +602,15 @@ class IPPO(MultiAgentRLAlgorithm):
601
602
 
602
603
  # If using env_defined_actions replace actions
603
604
  if env_defined_actions is not None:
604
- for agent_id in unique_agents_ids:
605
- action_dict[agent_id][agent_masks[agent_id]] = env_defined_actions[
606
- agent_id
607
- ][agent_masks[agent_id]]
605
+ action_dict = apply_env_defined_actions(
606
+ unique_agents_ids,
607
+ action_dict,
608
+ env_defined_actions,
609
+ agent_masks,
610
+ discrete_actions=isinstance(
611
+ next(iter(self.action_space.values())), spaces.Discrete
612
+ ),
613
+ )
608
614
 
609
615
  return (
610
616
  action_dict,
@@ -25,6 +25,7 @@ from agilerl.typing import (
25
25
  SupportedObsSpaces,
26
26
  )
27
27
  from agilerl.utils.algo_utils import (
28
+ apply_env_defined_actions,
28
29
  concatenate_spaces,
29
30
  format_shared_critic_encoder,
30
31
  get_deepest_head_config,
@@ -515,10 +516,15 @@ class MADDPG(MultiAgentRLAlgorithm):
515
516
 
516
517
  # If using env_defined_actions replace actions
517
518
  if env_defined_actions is not None:
518
- for agent in self.agent_ids:
519
- processed_action_dict[agent][agent_masks[agent]] = env_defined_actions[
520
- agent
521
- ][agent_masks[agent]]
519
+ action_dict = apply_env_defined_actions(
520
+ self.agent_ids,
521
+ processed_action_dict,
522
+ env_defined_actions,
523
+ agent_masks,
524
+ discrete_actions=isinstance(
525
+ next(iter(self.action_space.values())), spaces.Discrete
526
+ ),
527
+ )
522
528
 
523
529
  return processed_action_dict, action_dict
524
530
 
@@ -24,6 +24,7 @@ from agilerl.typing import (
24
24
  StandardTensorDict,
25
25
  )
26
26
  from agilerl.utils.algo_utils import (
27
+ apply_env_defined_actions,
27
28
  concatenate_spaces,
28
29
  format_shared_critic_encoder,
29
30
  get_deepest_head_config,
@@ -575,10 +576,15 @@ class MATD3(MultiAgentRLAlgorithm):
575
576
 
576
577
  # If using env_defined_actions replace actions
577
578
  if env_defined_actions is not None:
578
- for agent in self.agent_ids:
579
- processed_action_dict[agent][agent_masks[agent]] = env_defined_actions[
580
- agent
581
- ][agent_masks[agent]]
579
+ action_dict = apply_env_defined_actions(
580
+ self.agent_ids,
581
+ processed_action_dict,
582
+ env_defined_actions,
583
+ agent_masks,
584
+ discrete_actions=isinstance(
585
+ next(iter(self.action_space.values())), spaces.Discrete
586
+ ),
587
+ )
582
588
 
583
589
  return processed_action_dict, action_dict
584
590
 
@@ -252,9 +252,6 @@ class StochasticActor(EvolvableNetwork):
252
252
  :type recurrent: bool
253
253
  :param device: Device to use for the network.
254
254
  :type device: str
255
- :param use_experimental_distribution: Whether to use the experimental distribution implementation, which
256
- includes several optimizations related to using torch primitives for statistics calculations. Defaults to False.
257
- :type use_experimental_distribution: bool
258
255
  :param random_seed: Random seed to use for the network. Defaults to None.
259
256
  :type random_seed: int | None
260
257
  :param encoder_name: Name of the encoder network.
@@ -284,7 +281,6 @@ class StochasticActor(EvolvableNetwork):
284
281
  simba: bool = False,
285
282
  recurrent: bool = False,
286
283
  device: str = "cpu",
287
- use_experimental_distribution: bool = False,
288
284
  random_seed: int | None = None,
289
285
  encoder_name: str = "encoder",
290
286
  ) -> None:
@@ -312,7 +308,6 @@ class StochasticActor(EvolvableNetwork):
312
308
  self.action_std_init = action_std_init
313
309
  self.squash_output = squash_output
314
310
  self.action_space = action_space
315
- self.use_experimental_distribution = use_experimental_distribution
316
311
  self.output_size = get_output_size_from_space(self.action_space)
317
312
 
318
313
  self.build_network_head(head_config)
@@ -332,14 +327,6 @@ class StochasticActor(EvolvableNetwork):
332
327
  else:
333
328
  self.action_low, self.action_high = None, None
334
329
 
335
- # Wrap the network in an EvolvableDistribution
336
- if use_experimental_distribution:
337
- from agilerl.networks.distributions_experimental import (
338
- EvolvableDistribution,
339
- )
340
- else:
341
- from agilerl.networks.distributions import EvolvableDistribution
342
-
343
330
  self.head_net = EvolvableDistribution(
344
331
  action_space=action_space,
345
332
  network=self.head_net,
@@ -1,20 +1,20 @@
1
- import math
1
+ from typing import Union
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- import torch.nn.functional as F
6
5
  from gymnasium import spaces
7
6
 
8
7
  from agilerl.modules.base import EvolvableModule, EvolvableWrapper
9
8
  from agilerl.typing import ArrayOrTensor, DeviceType, NetConfigType
10
-
11
- # NOTE: we still import Normal / Bernoulli solely for continuous & binary helpers,
12
- # but no Categorical objects are ever instantiated any more.
9
+ from agilerl.utils.torch_utils import (
10
+ entropy_from_space,
11
+ log_prob_from_space,
12
+ sample_from_space,
13
+ )
13
14
 
14
15
 
15
16
  def apply_action_mask_discrete(
16
- logits: torch.Tensor,
17
- mask: torch.Tensor,
17
+ logits: torch.Tensor, mask: torch.Tensor
18
18
  ) -> torch.Tensor:
19
19
  """Apply a mask to the logits.
20
20
 
@@ -37,185 +37,74 @@ class TorchDistribution:
37
37
  :param action_space: Action space of the environment.
38
38
  :type action_space: spaces.Space
39
39
  :param logits: Logits.
40
- :type logits: torch.Tensor
40
+ :type logits: torch.Tensor | None
41
41
  :param mu: Mean.
42
- :type mu: torch.Tensor
42
+ :type mu: torch.Tensor | None
43
43
  :param log_std: Log standard deviation.
44
- :type log_std: torch.Tensor
44
+ :type log_std: torch.Tensor | None
45
45
  :param squash_output: Whether to squash the output to the action space.
46
46
  :type squash_output: bool
47
-
48
47
  """
49
48
 
50
49
  def __init__(
51
50
  self,
52
51
  *,
53
52
  action_space: spaces.Space,
54
- logits: (
55
- torch.Tensor | None
56
- ) = None, # for discrete / multidiscrete / multibinary
57
- mu: torch.Tensor | None = None, # for Box
53
+ logits: torch.Tensor | None = None,
54
+ mu: torch.Tensor | None = None,
58
55
  log_std: torch.Tensor | None = None,
59
56
  squash_output: bool = False,
60
- ) -> None:
57
+ ):
61
58
  self.action_space = action_space
62
- self.logits, self.mu, self.log_std = logits, mu, log_std
59
+ self.logits = logits
60
+ self.mu = mu
61
+ self.log_std = log_std
63
62
  self.squash_output = squash_output and isinstance(action_space, spaces.Box)
64
63
  self._sampled_action: torch.Tensor | None = None
65
64
 
66
- # ------------------------------------------------------------------ #
67
- # fast tensor-only primitives #
68
- # ------------------------------------------------------------------ #
69
65
  def sample(self) -> torch.Tensor:
70
- if isinstance(self.action_space, spaces.Discrete):
71
- probs = torch.softmax(self.logits, dim=-1)
72
- self._sampled_action = torch.multinomial(probs, 1).squeeze(-1)
73
- return self._sampled_action
74
-
75
- if isinstance(self.action_space, spaces.Box):
76
- eps = torch.randn_like(self.mu)
77
- out = self.mu + torch.exp(self.log_std) * eps
78
- if self.squash_output:
79
- out = torch.tanh(out)
80
- self._sampled_action = out
81
- return out
82
-
83
- # -------- MultiDiscrete --------
84
- if isinstance(self.action_space, spaces.MultiDiscrete):
85
- actions = []
86
- offset = 0
87
- for size in self.action_space.nvec:
88
- logits_i = self.logits[:, offset : offset + size]
89
- probs_i = torch.softmax(logits_i, dim=-1)
90
- act_i = torch.multinomial(probs_i, 1).squeeze(-1)
91
- actions.append(act_i)
92
- offset += size
93
- self._sampled_action = torch.stack(actions, dim=-1)
94
- return self._sampled_action
95
-
96
- # -------- MultiBinary --------
97
- if isinstance(self.action_space, spaces.MultiBinary):
98
- probs = torch.sigmoid(self.logits)
99
- self._sampled_action = torch.bernoulli(
100
- probs,
101
- ) # Ensures float tensor, removed .to(torch.int64)
102
- return self._sampled_action
66
+ """Sample from the distribution for the given action space.
103
67
 
104
- msg = "Unsupported action space in fast path."
105
- raise NotImplementedError(msg)
68
+ :return: Sampled action.
69
+ :rtype: torch.Tensor
70
+ """
71
+ self._sampled_action = sample_from_space(
72
+ self.action_space,
73
+ logits=self.logits,
74
+ mu=self.mu,
75
+ log_std=self.log_std,
76
+ squash_output=self.squash_output,
77
+ )
78
+ return self._sampled_action
106
79
 
107
80
  def log_prob(self, action: torch.Tensor) -> torch.Tensor:
108
- if isinstance(self.action_space, spaces.Discrete):
109
- log_p_all = torch.log_softmax(self.logits, dim=-1) # Shape (B, N_actions)
110
- action_long = action.long()
111
-
112
- action_indices_for_gather: torch.Tensor
113
-
114
- if action_long.ndim == log_p_all.ndim - 1: # action_long is (B,)
115
- action_indices_for_gather = action_long.unsqueeze(
116
- -1,
117
- ) # Converts to (B,1)
118
- elif action_long.ndim == log_p_all.ndim: # action_long is (B, K)
119
- if action_long.shape[-1] == 1: # action_long is (B,1)
120
- action_indices_for_gather = action_long
121
- elif (
122
- action_long.shape == log_p_all.shape
123
- and hasattr(self.action_space, "n")
124
- and action_long.shape[-1] == self.action_space.n
125
- ):
126
- # Special handling for test case: action is (B, N_actions) for Discrete(N_actions)
127
- # Use argmax to get the action index.
128
- action_indices_for_gather = torch.argmax(
129
- action_long,
130
- dim=-1,
131
- keepdim=True,
132
- ) # Converts (B, N_actions) to (B,1)
133
- else:
134
- msg = (
135
- f"Action shape {action.shape} is not compatible with Discrete space. "
136
- f"Expected (batch_size,), (batch_size, 1), or (batch_size, num_actions) for argmax case. "
137
- f"Logits shape: {log_p_all.shape}. Action space: {self.action_space}"
138
- )
139
- raise ValueError(
140
- msg,
141
- )
142
- else:
143
- msg = (
144
- f"Action tensor ndim {action.ndim} is not compatible with Discrete space logits ndim {log_p_all.ndim}. "
145
- f"Expected action ndim to be {log_p_all.ndim - 1} or {log_p_all.ndim}."
146
- )
147
- raise ValueError(
148
- msg,
149
- )
150
-
151
- return log_p_all.gather(-1, action_indices_for_gather).squeeze(-1)
152
-
153
- if isinstance(self.action_space, spaces.Box):
154
- var = torch.exp(2 * self.log_std)
155
- return (
156
- -0.5
157
- * (
158
- ((action - self.mu) ** 2) / var
159
- + 2 * self.log_std
160
- + math.log(2 * math.pi)
161
- )
162
- ).sum(-1)
163
-
164
- # -------- MultiDiscrete --------
165
- if isinstance(self.action_space, spaces.MultiDiscrete):
166
- logps = []
167
- offset = 0
168
- for idx, size in enumerate(self.action_space.nvec):
169
- logits_i = self.logits[:, offset : offset + size]
170
- logp_all = torch.log_softmax(logits_i, dim=-1)
171
- act_i = action[:, idx].long()
172
- logp_i = logp_all.gather(-1, act_i.unsqueeze(-1)).squeeze(-1)
173
- logps.append(logp_i)
174
- offset += size
175
- return torch.stack(logps, dim=-1).sum(-1)
176
-
177
- # -------- MultiBinary --------
178
- if isinstance(self.action_space, spaces.MultiBinary):
179
- # log sigma(x) and log (1-sigma(x))
180
- log_p1 = -F.softplus(-self.logits)
181
- log_p0 = -self.logits + log_p1
182
- a = (
183
- action.float()
184
- ) # Action for MultiBinary is expected to be float (0.0 or 1.0)
185
- return (a * log_p1 + (1.0 - a) * log_p0).sum(-1)
81
+ """Log probability of the action.
186
82
 
187
- raise NotImplementedError
83
+ :param action: Action.
84
+ :type action: torch.Tensor
85
+ :return: Log probability of the action.
86
+ :rtype: torch.Tensor
87
+ """
88
+ return log_prob_from_space(
89
+ self.action_space,
90
+ action,
91
+ logits=self.logits,
92
+ mu=self.mu,
93
+ log_std=self.log_std,
94
+ )
188
95
 
189
96
  def entropy(self) -> torch.Tensor:
190
- if isinstance(self.action_space, spaces.Discrete):
191
- p = torch.softmax(self.logits, dim=-1)
192
- return -(p * torch.log(p + 1e-8)).sum(-1)
97
+ """Entropy of the distribution.
193
98
 
194
- if isinstance(self.action_space, spaces.Box):
195
- return 0.5 * (1 + math.log(2 * math.pi)) * self.mu.size(
196
- -1,
197
- ) + self.log_std.sum(-1)
198
-
199
- # -------- MultiDiscrete --------
200
- if isinstance(self.action_space, spaces.MultiDiscrete):
201
- entropies = []
202
- offset = 0
203
- for size in self.action_space.nvec:
204
- logits_i = self.logits[:, offset : offset + size]
205
- p_i = torch.softmax(logits_i, dim=-1)
206
- ent_i = -(p_i * torch.log(p_i + 1e-8)).sum(-1)
207
- entropies.append(ent_i)
208
- offset += size
209
- return torch.stack(entropies, dim=-1).sum(-1)
210
-
211
- # -------- MultiBinary --------
212
- if isinstance(self.action_space, spaces.MultiBinary):
213
- p = torch.sigmoid(self.logits)
214
- return -(p * torch.log(p + 1e-8) + (1 - p) * torch.log(1 - p + 1e-8)).sum(
215
- -1,
216
- )
217
-
218
- raise NotImplementedError
99
+ :return: Entropy of the distribution.
100
+ :rtype: torch.Tensor
101
+ """
102
+ return entropy_from_space(
103
+ self.action_space,
104
+ logits=self.logits,
105
+ mu=self.mu,
106
+ log_std=self.log_std,
107
+ )
219
108
 
220
109
 
221
110
  class EvolvableDistribution(EvolvableWrapper):
@@ -247,7 +136,7 @@ class EvolvableDistribution(EvolvableWrapper):
247
136
  action_std_init: float = 0.0,
248
137
  squash_output: bool = False,
249
138
  device: DeviceType = "cpu",
250
- ) -> None:
139
+ ):
251
140
  super().__init__(network)
252
141
 
253
142
  self.action_space = action_space
@@ -263,7 +152,7 @@ class EvolvableDistribution(EvolvableWrapper):
263
152
  if isinstance(action_space, spaces.Box):
264
153
  self.log_std = torch.nn.Parameter(
265
154
  torch.ones(1, np.prod(action_space.shape), device=device)
266
- * action_std_init,
155
+ * action_std_init
267
156
  )
268
157
 
269
158
  @property
@@ -286,7 +175,6 @@ class EvolvableDistribution(EvolvableWrapper):
286
175
  # Normal distribution for Continuous action spaces
287
176
  if isinstance(self.action_space, spaces.Box):
288
177
  log_std = self.log_std.expand_as(logits)
289
- # Pass mu and log_std directly to TorchDistribution
290
178
  return TorchDistribution(
291
179
  action_space=self.action_space,
292
180
  mu=logits,
@@ -295,20 +183,30 @@ class EvolvableDistribution(EvolvableWrapper):
295
183
  )
296
184
 
297
185
  # Categorical distribution for Discrete action spaces
298
- if isinstance(
299
- self.action_space,
300
- (spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
301
- ):
302
- # Pass logits directly to TorchDistribution
186
+ if isinstance(self.action_space, spaces.Discrete):
303
187
  return TorchDistribution(
304
188
  action_space=self.action_space,
305
189
  logits=logits,
306
- squash_output=self.squash_output, # squash_output is ignored for discrete
190
+ squash_output=self.squash_output,
191
+ )
192
+
193
+ # List of categorical distributions for MultiDiscrete action spaces
194
+ if isinstance(self.action_space, spaces.MultiDiscrete):
195
+ return TorchDistribution(
196
+ action_space=self.action_space,
197
+ logits=logits,
198
+ squash_output=self.squash_output,
199
+ )
200
+
201
+ # Bernoulli distribution for MultiBinary action spaces
202
+ if isinstance(self.action_space, spaces.MultiBinary):
203
+ return TorchDistribution(
204
+ action_space=self.action_space,
205
+ logits=logits,
206
+ squash_output=self.squash_output,
307
207
  )
308
208
  msg = f"Action space {self.action_space} not supported."
309
- raise NotImplementedError(
310
- msg,
311
- )
209
+ raise NotImplementedError(msg)
312
210
 
313
211
  def log_prob(self, action: torch.Tensor) -> torch.Tensor:
314
212
  """Get the log probability of the action.
@@ -322,7 +220,7 @@ class EvolvableDistribution(EvolvableWrapper):
322
220
  msg = "Distribution not initialized. Call forward first."
323
221
  raise ValueError(msg)
324
222
 
325
- # The new TorchDistribution handles squashing correction internally for Box space
223
+ # Handles squashing correction internally for Box space
326
224
  return self.dist.log_prob(action)
327
225
 
328
226
  def entropy(self) -> torch.Tensor:
@@ -335,7 +233,7 @@ class EvolvableDistribution(EvolvableWrapper):
335
233
  msg = "Distribution not initialized. Call forward first."
336
234
  raise ValueError(msg)
337
235
 
338
- # The new TorchDistribution returns analytical entropy for supported spaces
236
+ # Returns analytical entropy for supported spaces
339
237
  return self.dist.entropy()
340
238
 
341
239
  def apply_mask(self, logits: torch.Tensor, mask: ArrayOrTensor) -> torch.Tensor:
@@ -350,7 +248,7 @@ class EvolvableDistribution(EvolvableWrapper):
350
248
  """
351
249
  # Convert mask to tensor and reshape to match logits shape
352
250
  mask = torch.as_tensor(mask, dtype=torch.bool, device=self.device).view(
353
- logits.shape,
251
+ logits.shape
354
252
  )
355
253
 
356
254
  if isinstance(self.action_space, spaces.Discrete):
@@ -360,7 +258,7 @@ class EvolvableDistribution(EvolvableWrapper):
360
258
  list(self.action_space.nvec)
361
259
  if isinstance(self.action_space, spaces.MultiDiscrete)
362
260
  else [
363
- self.action_space.n,
261
+ self.action_space.n
364
262
  ] # For MultiBinary, nvec is not present, use n
365
263
  )
366
264
  # Split mask and logits into separate distributions
@@ -370,12 +268,10 @@ class EvolvableDistribution(EvolvableWrapper):
370
268
  # Apply mask to each split
371
269
  masked_logits = []
372
270
  for split_logits_i, split_mask_i in zip(
373
- split_logits,
374
- split_masks,
375
- strict=False,
376
- ): # Renamed for clarity
271
+ split_logits, split_masks, strict=False
272
+ ):
377
273
  masked_logits.append(
378
- apply_action_mask_discrete(split_logits_i, split_mask_i),
274
+ apply_action_mask_discrete(split_logits_i, split_mask_i)
379
275
  )
380
276
 
381
277
  masked_logits = torch.cat(masked_logits, dim=1)
@@ -383,9 +279,7 @@ class EvolvableDistribution(EvolvableWrapper):
383
279
  # This should ideally not be reached if get_distribution handles the space,
384
280
  # but keeping for safety.
385
281
  msg = f"Action space {self.action_space} not supported for masking."
386
- raise NotImplementedError(
387
- msg,
388
- )
282
+ raise NotImplementedError(msg)
389
283
 
390
284
  return masked_logits
391
285
 
@@ -394,20 +288,19 @@ class EvolvableDistribution(EvolvableWrapper):
394
288
  latent: torch.Tensor,
395
289
  action_mask: ArrayOrTensor | None = None,
396
290
  sample: bool = True,
397
- ) -> (
398
- tuple[torch.Tensor, torch.Tensor, torch.Tensor]
399
- | tuple[None, None, torch.Tensor]
400
- ):
291
+ ) -> Union[
292
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[None, None, torch.Tensor]
293
+ ]:
401
294
  """Forward pass of the network.
402
295
 
403
296
  :param latent: Latent space representation.
404
297
  :type latent: torch.Tensor
405
298
  :param action_mask: Mask to apply to the logits. Defaults to None.
406
- :type action_mask: ArrayOrTensor | None
299
+ :type action_mask: Optional[ArrayOrTensor]
407
300
  :param sample: Whether to sample an action or return the mode/mean. Defaults to True.
408
301
  :type sample: bool
409
302
  :return: Action and log probability of the action.
410
- :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | tuple[None, torch.Tensor, torch.Tensor]
303
+ :rtype: Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[None, torch.Tensor, torch.Tensor]]
411
304
  """
412
305
  logits = self.wrapped(latent)
413
306
 
@@ -415,7 +308,8 @@ class EvolvableDistribution(EvolvableWrapper):
415
308
  if isinstance(action_mask, (np.ndarray, list)):
416
309
  # Attempt to stack if it's a list of arrays or object array, typical for vectorized envs
417
310
  if isinstance(action_mask, list) or (
418
- isinstance(action_mask, np.ndarray) and action_mask.dtype == object
311
+ isinstance(action_mask, np.ndarray)
312
+ and action_mask.dtype == np.object_
419
313
  ):
420
314
  try:
421
315
  action_mask = np.stack(action_mask)
@@ -428,15 +322,12 @@ class EvolvableDistribution(EvolvableWrapper):
428
322
  # Ensure action_mask is a tensor before applying.
429
323
  # The view in apply_mask expects a compatible shape or will error.
430
324
  action_mask = torch.as_tensor(
431
- action_mask,
432
- device=self.device,
433
- dtype=torch.bool,
325
+ action_mask, device=self.device, dtype=torch.bool
434
326
  )
435
327
 
436
328
  logits = self.apply_mask(logits, action_mask)
437
329
 
438
330
  # Distribution from logits
439
- # get_distribution now creates the new TorchDistribution object
440
331
  self.dist = self.get_distribution(logits)
441
332
 
442
333
  # Sample action, compute log probability and entropy
@@ -444,16 +335,14 @@ class EvolvableDistribution(EvolvableWrapper):
444
335
  action = self.dist.sample()
445
336
  log_prob = self.dist.log_prob(action)
446
337
  else:
447
- action = None # Mode/mean might be more appropriate if not sampling
448
- log_prob = (
449
- None # Log prob of mode/mean typically not used in PPO sample step
450
- )
338
+ action = None
339
+ log_prob = None
451
340
 
452
341
  entropy = self.dist.entropy()
453
342
  return action, log_prob, entropy
454
343
 
455
344
  def clone(self) -> "EvolvableDistribution":
456
- """Clone the distribution.
345
+ """Clones the distribution.
457
346
 
458
347
  :return: Cloned distribution.
459
348
  :rtype: EvolvableDistribution