agilerl 2.4.2.dev0__tar.gz → 2.4.3.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (351) hide show
  1. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.pre-commit-config.yaml +3 -3
  2. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/PKG-INFO +2 -2
  3. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/base.py +1 -2
  4. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ippo.py +1 -1
  5. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/agent.py +2 -2
  6. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/pyproject.toml +2 -2
  7. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_make_evolvable.py +1 -4
  8. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_train/test_train.py +18 -6
  9. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/uv.lock +21 -21
  10. agilerl-2.4.2.dev0/DQN_LEARNING_ALGORITHM_ANALYSIS.md +0 -309
  11. agilerl-2.4.2.dev0/DQN_LEARNING_ANALYSIS.md +0 -168
  12. agilerl-2.4.2.dev0/GPU_CLEANUP_ANALYSIS.md +0 -541
  13. agilerl-2.4.2.dev0/find_dqn_commit.sh +0 -82
  14. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.coveragerc +0 -0
  15. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  16. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  17. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +0 -0
  18. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/badges/arena-github-badge.svg +0 -0
  19. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/workflows/codeql.yml +0 -0
  20. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/workflows/python-app.yml +0 -0
  21. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.gitignore +0 -0
  22. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.readthedocs.yaml +0 -0
  23. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CITATION.cff +0 -0
  24. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CODE_OF_CONDUCT.md +0 -0
  25. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CONTRIBUTING.md +0 -0
  26. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/LICENSE +0 -0
  27. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/README.md +0 -0
  28. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/__init__.py +0 -0
  29. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/__init__.py +0 -0
  30. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/bc_lm.py +0 -0
  31. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/__init__.py +0 -0
  32. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
  33. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/registry.py +0 -0
  34. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/cqn.py +0 -0
  35. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ddpg.py +0 -0
  36. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dpo.py +0 -0
  37. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dqn.py +0 -0
  38. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dqn_rainbow.py +0 -0
  39. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/grpo.py +0 -0
  40. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ilql.py +0 -0
  41. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/maddpg.py +0 -0
  42. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/matd3.py +0 -0
  43. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
  44. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
  45. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ppo.py +0 -0
  46. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/td3.py +0 -0
  47. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/__init__.py +0 -0
  48. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/data.py +0 -0
  49. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
  50. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/replay_buffer.py +0 -0
  51. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/rollout_buffer.py +0 -0
  52. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/sampler.py +0 -0
  53. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/segment_tree.py +0 -0
  54. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/__init__.py +0 -0
  55. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/language_environment.py +0 -0
  56. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/rl_data.py +0 -0
  57. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/tokenizer.py +0 -0
  58. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/torch_datasets.py +0 -0
  59. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/__init__.py +0 -0
  60. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/mutation.py +0 -0
  61. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/tournament.py +0 -0
  62. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/__init__.py +0 -0
  63. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/base.py +0 -0
  64. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/bert.py +0 -0
  65. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/cnn.py +0 -0
  66. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/configs.py +0 -0
  67. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/custom_components.py +0 -0
  68. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/dummy.py +0 -0
  69. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/gpt.py +0 -0
  70. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/lstm.py +0 -0
  71. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/mlp.py +0 -0
  72. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/multi_input.py +0 -0
  73. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/resnet.py +0 -0
  74. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/simba.py +0 -0
  75. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/__init__.py +0 -0
  76. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/actors.py +0 -0
  77. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/base.py +0 -0
  78. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/custom_modules.py +0 -0
  79. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/distributions.py +0 -0
  80. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/distributions_experimental.py +0 -0
  81. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/q_networks.py +0 -0
  82. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/value_networks.py +0 -0
  83. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/protocols.py +0 -0
  84. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/rollouts/__init__.py +0 -0
  85. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/rollouts/on_policy.py +0 -0
  86. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/__init__.py +0 -0
  87. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_bandits.py +0 -0
  88. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_llm.py +0 -0
  89. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_multi_agent_off_policy.py +0 -0
  90. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_multi_agent_on_policy.py +0 -0
  91. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_off_policy.py +0 -0
  92. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_offline.py +0 -0
  93. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_on_policy.py +0 -0
  94. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/typing.py +0 -0
  95. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/__init__.py +0 -0
  96. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/algo_utils.py +0 -0
  97. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/cache.py +0 -0
  98. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/evolvable_networks.py +0 -0
  99. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/ilql_utils.py +0 -0
  100. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/llm_utils.py +0 -0
  101. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/log_utils.py +0 -0
  102. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/minari_utils.py +0 -0
  103. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/probe_envs.py +0 -0
  104. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
  105. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/sampling_utils.py +0 -0
  106. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/torch_utils.py +0 -0
  107. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/utils.py +0 -0
  108. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/__init__.py +0 -0
  109. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
  110. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/pz_vec_env.py +0 -0
  111. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/__init__.py +0 -0
  112. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/learning.py +0 -0
  113. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
  114. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
  115. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/utils.py +0 -0
  116. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_bandits.py +0 -0
  117. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_dpo.py +0 -0
  118. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_grpo.py +0 -0
  119. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_multi_agent_off_policy.py +0 -0
  120. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_multi_agent_on_policy.py +0 -0
  121. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_off_policy.py +0 -0
  122. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_off_policy_distributed.py +0 -0
  123. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_offline.py +0 -0
  124. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_offline_distributed.py +0 -0
  125. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_on_policy.py +0 -0
  126. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_rainbow.py +0 -0
  127. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_recurrent.py +0 -0
  128. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_resnet.py +0 -0
  129. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_simba.py +0 -0
  130. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/configs/ds_config.json +0 -0
  131. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/make_evolvable_benchmarking.py +0 -0
  132. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/networks.py +0 -0
  133. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/accelerate/accelerate.yaml +0 -0
  134. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/accelerate/grpo_accelerate_config.yaml +0 -0
  135. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/bandit/neural_ts.yaml +0 -0
  136. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/bandit/neural_ucb.yaml +0 -0
  137. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/cqn.yaml +0 -0
  138. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg.yaml +0 -0
  139. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg_lstm.yaml +0 -0
  140. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg_simba.yaml +0 -0
  141. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dpo.yaml +0 -0
  142. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn.yaml +0 -0
  143. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn_lstm.yaml +0 -0
  144. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn_rainbow.yaml +0 -0
  145. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/grpo.yaml +0 -0
  146. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/ippo.yaml +0 -0
  147. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/ippo_pong.yaml +0 -0
  148. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/maddpg.yaml +0 -0
  149. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/matd3.yaml +0 -0
  150. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_input.yaml +0 -0
  151. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo.yaml +0 -0
  152. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo_image.yaml +0 -0
  153. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo_recurrent.yaml +0 -0
  154. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/td3.yaml +0 -0
  155. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/cartpole/cartpole_random_v1.1.0.h5 +0 -0
  156. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/cartpole/cartpole_v1.1.0.h5 +0 -0
  157. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/pendulum/pendulum_random_v1.1.0.h5 +0 -0
  158. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/pendulum/pendulum_v1.1.0.h5 +0 -0
  159. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_bandit.py +0 -0
  160. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_custom_network.py +0 -0
  161. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_multi_agent.py +0 -0
  162. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_off_policy.py +0 -0
  163. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_off_policy_distributed.py +0 -0
  164. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_offline.py +0 -0
  165. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_offline_distributed.py +0 -0
  166. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy.py +0 -0
  167. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_cartpole.py +0 -0
  168. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_memory.py +0 -0
  169. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_minigrid.py +0 -0
  170. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_cartpole.py +0 -0
  171. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_lunar_lander.py +0 -0
  172. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_lunar_lander_rnn.py +0 -0
  173. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_rnn_memory.py +0 -0
  174. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/dependabot.yml +0 -0
  175. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/Makefile +0 -0
  176. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/arena-github-badge.svg +0 -0
  177. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/css/custom.css +0 -0
  178. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/favicon.ico +0 -0
  179. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/js/expand_sidebar.js +0 -0
  180. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/logo_teal.png +0 -0
  181. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/logo_white.png +0 -0
  182. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/module.png +0 -0
  183. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/network.png +0 -0
  184. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/iris-thumbnail.png +0 -0
  185. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/pendigits-thumbnail.png +0 -0
  186. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/rainbow_performance.png +0 -0
  187. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/simba_thumbnail.png +0 -0
  188. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/base.rst +0 -0
  189. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/cql.rst +0 -0
  190. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ddpg.rst +0 -0
  191. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dpo.rst +0 -0
  192. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dqn.rst +0 -0
  193. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dqn_rainbow.rst +0 -0
  194. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/grpo.rst +0 -0
  195. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ilql.rst +0 -0
  196. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/index.rst +0 -0
  197. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ippo.rst +0 -0
  198. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/maddpg.rst +0 -0
  199. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/matd3.rst +0 -0
  200. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/neural_ts.rst +0 -0
  201. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/neural_ucb.rst +0 -0
  202. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ppo.rst +0 -0
  203. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/registry.rst +0 -0
  204. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/td3.rst +0 -0
  205. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/wrappers.rst +0 -0
  206. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/data.rst +0 -0
  207. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/index.rst +0 -0
  208. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/multi_agent_replay_buffer.rst +0 -0
  209. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/replay_buffer.rst +0 -0
  210. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/rollout_buffer.rst +0 -0
  211. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/sampler.rst +0 -0
  212. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/segment_tree.rst +0 -0
  213. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/index.rst +0 -0
  214. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/mutation.rst +0 -0
  215. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/tournament.rst +0 -0
  216. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/base.rst +0 -0
  217. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/bert.rst +0 -0
  218. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/cnn.rst +0 -0
  219. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/custom_activation.rst +0 -0
  220. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/dummy.rst +0 -0
  221. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/gpt.rst +0 -0
  222. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/index.rst +0 -0
  223. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/lstm.rst +0 -0
  224. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/mlp.rst +0 -0
  225. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/multi_input.rst +0 -0
  226. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/resnet.rst +0 -0
  227. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/simba.rst +0 -0
  228. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/actors.rst +0 -0
  229. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/base.rst +0 -0
  230. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/index.rst +0 -0
  231. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/q_networks.rst +0 -0
  232. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/value_networks.rst +0 -0
  233. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/rollouts/index.rst +0 -0
  234. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/rollouts/on_policy.rst +0 -0
  235. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/train.rst +0 -0
  236. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/algo_utils.rst +0 -0
  237. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/cache.rst +0 -0
  238. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/evolvable_networks.rst +0 -0
  239. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/ilql_utils.rst +0 -0
  240. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/index.rst +0 -0
  241. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/llm_utils.rst +0 -0
  242. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/log_utils.rst +0 -0
  243. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/minari_utils.rst +0 -0
  244. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/probe_envs.rst +0 -0
  245. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/torch_utils.rst +0 -0
  246. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/utils.rst +0 -0
  247. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/index.rst +0 -0
  248. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/petting_zoo_async_vector_env.rst +0 -0
  249. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/petting_zoo_vector_env.rst +0 -0
  250. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/agent.rst +0 -0
  251. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/index.rst +0 -0
  252. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/learning.rst +0 -0
  253. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/make_evolvable.rst +0 -0
  254. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/pettingzoo.rst +0 -0
  255. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/bandits/index.rst +0 -0
  256. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/conf.py +0 -0
  257. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/custom_algorithms/index.rst +0 -0
  258. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/debugging_rl/index.rst +0 -0
  259. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/distributed_training/index.rst +0 -0
  260. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/evo_hyperparam_opt/index.rst +0 -0
  261. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/evolvable_networks/index.rst +0 -0
  262. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/get_started/agilerl2changes.rst +0 -0
  263. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/get_started/index.rst +0 -0
  264. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/index.rst +0 -0
  265. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/llm_finetuning/index.rst +0 -0
  266. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/make.bat +0 -0
  267. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/multi_agent_training/index.rst +0 -0
  268. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/off_policy/index.rst +0 -0
  269. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/offline_training/index.rst +0 -0
  270. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/on_policy/index.rst +0 -0
  271. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/pomdp/index.rst +0 -0
  272. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/releases/index.rst +0 -0
  273. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/requirements.txt +0 -0
  274. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/pytest.ini +0 -0
  275. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/__init__.py +0 -0
  276. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/conftest.py +0 -0
  277. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/helper_functions.py +0 -0
  278. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/pz_vector_test_utils.py +0 -0
  279. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/__init__.py +0 -0
  280. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/__init__.py +0 -0
  281. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/test_neural_ts.py +0 -0
  282. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/test_neural_ucb.py +0 -0
  283. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_base.py +0 -0
  284. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bc_lm.py +0 -0
  285. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/__init__.py +0 -0
  286. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/conftest.py +0 -0
  287. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/test_dpo.py +0 -0
  288. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/test_grpo.py +0 -0
  289. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/__init__.py +0 -0
  290. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_ippo.py +0 -0
  291. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_maddpg.py +0 -0
  292. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_matd3.py +0 -0
  293. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_optimizer_wrapper.py +0 -0
  294. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_registry.py +0 -0
  295. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/__init__.py +0 -0
  296. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_cqn.py +0 -0
  297. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ddpg.py +0 -0
  298. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_dqn.py +0 -0
  299. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_dqn_rainbow.py +0 -0
  300. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ilql.py +0 -0
  301. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ppo.py +0 -0
  302. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_td3.py +0 -0
  303. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/__init__.py +0 -0
  304. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_multi_agent_replay_buffer.py +0 -0
  305. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_replay_buffer.py +0 -0
  306. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_replay_data.py +0 -0
  307. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_rollout_buffer.py +0 -0
  308. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_sampler.py +0 -0
  309. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_segment_tree.py +0 -0
  310. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_data.py +0 -0
  311. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/__init__.py +0 -0
  312. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/test_mutation.py +0 -0
  313. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/test_tournament.py +0 -0
  314. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/__init__.py +0 -0
  315. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_base.py +0 -0
  316. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_bert.py +0 -0
  317. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_cnn.py +0 -0
  318. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_custom_activation.py +0 -0
  319. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_dummy.py +0 -0
  320. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_gpt.py +0 -0
  321. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_lstm.py +0 -0
  322. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_mlp.py +0 -0
  323. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_multi_input.py +0 -0
  324. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_resnet.py +0 -0
  325. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_simba.py +0 -0
  326. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/__init__.py +0 -0
  327. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_actors.py +0 -0
  328. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_base.py +0 -0
  329. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_q_networks.py +0 -0
  330. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_value_functions.py +0 -0
  331. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_protocols.py +0 -0
  332. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_train/test_train_llm.py +0 -0
  333. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/__init__.py +0 -0
  334. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_algo_utils.py +0 -0
  335. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_cache.py +0 -0
  336. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_ilql_utils.py +0 -0
  337. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_llm_utils.py +0 -0
  338. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_log_utils.py +0 -0
  339. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_minari_utils.py +0 -0
  340. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_probe_envs.py +0 -0
  341. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_probe_envs_ma.py +0 -0
  342. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_sampling_utils.py +0 -0
  343. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_torch_utils.py +0 -0
  344. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_utils.py +0 -0
  345. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_utils_evolvable.py +0 -0
  346. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_vector/test_vector.py +0 -0
  347. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/__init__.py +0 -0
  348. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_agent.py +0 -0
  349. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_autoreset.py +0 -0
  350. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_bandit_env.py +0 -0
  351. {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_skills.py +0 -0
@@ -24,7 +24,7 @@ repos:
24
24
  - id: mixed-line-ending
25
25
  args: [--fix=lf]
26
26
  - repo: https://github.com/psf/black-pre-commit-mirror
27
- rev: 25.12.0
27
+ rev: 26.1.0
28
28
  hooks:
29
29
  - id: black
30
30
  - repo: https://github.com/codespell-project/codespell
@@ -35,7 +35,7 @@ repos:
35
35
  - --skip=*.css,*.js,*.map,*.scss,*.svg
36
36
  - --ignore-words-list=magent,pres,roate
37
37
  - repo: https://github.com/astral-sh/ruff-pre-commit
38
- rev: v0.14.9
38
+ rev: v0.14.14
39
39
  hooks:
40
40
  - id: ruff-check
41
41
  args:
@@ -53,6 +53,6 @@ repos:
53
53
  - id: yamlfmt
54
54
  - repo: https://github.com/astral-sh/uv-pre-commit
55
55
  # uv version.
56
- rev: 0.9.27
56
+ rev: 0.9.28
57
57
  hooks:
58
58
  - id: uv-lock
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agilerl
3
- Version: 2.4.2.dev0
3
+ Version: 2.4.3.dev0
4
4
  Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
5
5
  Author-email: Nick Ustaran-Anderegg <dev@agilerl.com>
6
6
  License-Expression: Apache-2.0
@@ -22,7 +22,7 @@ Requires-Dist: omegaconf~=2.3.0
22
22
  Requires-Dist: packaging>=20.0
23
23
  Requires-Dist: pandas~=2.2.3
24
24
  Requires-Dist: pettingzoo~=1.23.1
25
- Requires-Dist: pre-commit~=3.4.0
25
+ Requires-Dist: pre-commit~=3.8.0
26
26
  Requires-Dist: pygame~=2.6.0
27
27
  Requires-Dist: pymunk~=6.2.0
28
28
  Requires-Dist: redis~=4.4.4
@@ -2066,8 +2066,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
2066
2066
  accelerator: Optional[Accelerator] = None,
2067
2067
  ) -> None:
2068
2068
  raise NotImplementedError(
2069
- "The load class method is not supported for this algorithm class."
2070
- """
2069
+ "The load class method is not supported for this algorithm class." """
2071
2070
  To load a saved LLM, please load the model as follows, and then re-instantiate the GRPO
2072
2071
  class, using the pre-trained model.
2073
2072
 
@@ -671,7 +671,7 @@ class IPPO(MultiAgentRLAlgorithm):
671
671
  :param action_space: Action space for the agent
672
672
  :type action_space: gymnasium.spaces
673
673
  """
674
- (states, actions, log_probs, rewards, dones, values, next_state, next_done) = (
674
+ states, actions, log_probs, rewards, dones, values, next_state, next_done = (
675
675
  experiences
676
676
  )
677
677
 
@@ -597,8 +597,8 @@ class AsyncAgentsWrapper(AgentWrapper[MultiAgentRLAlgorithm]):
597
597
  :return: Learning information
598
598
  :rtype: Any
599
599
  """
600
- (states, actions, log_probs, rewards, dones, values, next_state, next_done) = (
601
- map(self.stack_experiences, experiences)
600
+ states, actions, log_probs, rewards, dones, values, next_state, next_done = map(
601
+ self.stack_experiences, experiences
602
602
  )
603
603
 
604
604
  # Handle case where we haven't collected a next state for each sub-agent
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "agilerl"
3
- version = "2.4.2.dev0"
3
+ version = "2.4.3.dev0"
4
4
  description = "AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps."
5
5
  authors = [{ name = "Nick Ustaran-Anderegg", email = "dev@agilerl.com" }]
6
6
  license = "Apache-2.0"
@@ -24,7 +24,7 @@ dependencies = [
24
24
  "pettingzoo~=1.23.1",
25
25
  "jax[cpu]~=0.4.31",
26
26
  "packaging>=20.0",
27
- "pre-commit~=3.4.0",
27
+ "pre-commit~=3.8.0",
28
28
  "pygame~=2.6.0",
29
29
  "pymunk~=6.2.0",
30
30
  "redis~=4.4.4",
@@ -194,9 +194,7 @@ def test_instantiation_with_rainbow():
194
194
  network, input_tensor, support=support, rainbow=True
195
195
  )
196
196
  assert isinstance(evolvable_network, MakeEvolvable)
197
- assert (
198
- str(evolvable_network)
199
- == """MakeEvolvable(
197
+ assert str(evolvable_network) == """MakeEvolvable(
200
198
  (feature_net): Sequential(
201
199
  (feature_linear_layer_0): Linear(in_features=3, out_features=128, bias=True)
202
200
  (feature_activation_0): ReLU()
@@ -212,7 +210,6 @@ def test_instantiation_with_rainbow():
212
210
  (advantage_linear_layer_output): NoisyLinear(in_features=8, out_features=102)
213
211
  )
214
212
  )"""
215
- )
216
213
  del network, evolvable_network
217
214
 
218
215
 
@@ -1498,8 +1498,10 @@ def test_train_off_policy_agent_calls_made_rainbow(
1498
1498
  def test_train_off_policy_save_elite_warning(
1499
1499
  env, population_off_policy, tournament, mutations, memory
1500
1500
  ):
1501
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
1501
+ warning_string = (
1502
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
1502
1503
  be saved unless 'save_elite' is set to True."
1504
+ )
1503
1505
  with pytest.warns(match=warning_string):
1504
1506
  pop, pop_fitnesses = train_off_policy(
1505
1507
  env,
@@ -2137,8 +2139,10 @@ def test_train_on_policy_save_elite_warning(
2137
2139
  tournament,
2138
2140
  mutations,
2139
2141
  ):
2140
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2142
+ warning_string = (
2143
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2141
2144
  be saved unless 'save_elite' is set to True."
2145
+ )
2142
2146
  with pytest.warns(match=warning_string):
2143
2147
  pop, pop_fitnesses = train_on_policy(
2144
2148
  env,
@@ -2703,8 +2707,10 @@ def test_train_multi_agent_on_policy_rgb_vectorized(
2703
2707
  def test_train_multi_save_elite_warning(
2704
2708
  multi_env, population_multi_agent, on_policy, multi_memory, tournament, mutations
2705
2709
  ):
2706
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2710
+ warning_string = (
2711
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2707
2712
  be saved unless 'save_elite' is set to True."
2713
+ )
2708
2714
  with pytest.warns(match=warning_string):
2709
2715
  pop, pop_fitnesses = train_multi_agent_off_policy(
2710
2716
  multi_env,
@@ -2730,8 +2736,10 @@ def test_train_multi_save_elite_warning(
2730
2736
  def test_train_multi_save_elite_warning_on_policy(
2731
2737
  multi_env, population_multi_agent, on_policy, multi_memory, tournament, mutations
2732
2738
  ):
2733
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2739
+ warning_string = (
2740
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
2734
2741
  be saved unless 'save_elite' is set to True."
2742
+ )
2735
2743
  with pytest.warns(match=warning_string):
2736
2744
  pop, pop_fitnesses = train_multi_agent_on_policy(
2737
2745
  multi_env,
@@ -3567,8 +3575,10 @@ def test_train_offline_save_elite_warning(
3567
3575
  offline_init_hp,
3568
3576
  dummy_h5py_data,
3569
3577
  ):
3570
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
3578
+ warning_string = (
3579
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
3571
3580
  be saved unless 'save_elite' is set to True."
3581
+ )
3572
3582
  with pytest.warns(match=warning_string):
3573
3583
  pop, pop_fitness = train_offline(
3574
3584
  env,
@@ -4057,8 +4067,10 @@ def test_train_bandit_agent_calls_made(
4057
4067
  def test_train_bandit_save_elite_warning(
4058
4068
  bandit_env, population_bandit, tournament, mutations, bandit_memory
4059
4069
  ):
4060
- warning_string = "'save_elite' set to False but 'elite_path' has been defined, elite will not\
4070
+ warning_string = (
4071
+ "'save_elite' set to False but 'elite_path' has been defined, elite will not\
4061
4072
  be saved unless 'save_elite' is set to True."
4073
+ )
4062
4074
  with pytest.warns(match=warning_string):
4063
4075
  pop, pop_fitnesses = train_bandits(
4064
4076
  bandit_env,
@@ -53,7 +53,7 @@ wheels = [
53
53
 
54
54
  [[package]]
55
55
  name = "agilerl"
56
- version = "2.4.2.dev0"
56
+ version = "2.4.3.dev0"
57
57
  source = { editable = "." }
58
58
  dependencies = [
59
59
  { name = "accelerate" },
@@ -139,7 +139,7 @@ requires-dist = [
139
139
  { name = "peft", marker = "extra == 'all'", specifier = "~=0.18.0" },
140
140
  { name = "peft", marker = "extra == 'llm'", specifier = "~=0.18.0" },
141
141
  { name = "pettingzoo", specifier = "~=1.23.1" },
142
- { name = "pre-commit", specifier = "~=3.4.0" },
142
+ { name = "pre-commit", specifier = "~=3.8.0" },
143
143
  { name = "pygame", specifier = "~=2.6.0" },
144
144
  { name = "pymunk", specifier = "~=6.2.0" },
145
145
  { name = "redis", specifier = "~=4.4.4" },
@@ -2261,13 +2261,13 @@ name = "mlx-lm"
2261
2261
  version = "0.29.1"
2262
2262
  source = { registry = "https://pypi.org/simple" }
2263
2263
  dependencies = [
2264
- { name = "jinja2", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2264
+ { name = "jinja2", marker = "sys_platform == 'darwin'" },
2265
2265
  { name = "mlx", marker = "sys_platform == 'darwin'" },
2266
- { name = "numpy", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2267
- { name = "protobuf", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2268
- { name = "pyyaml", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2269
- { name = "sentencepiece", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2270
- { name = "transformers", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')" },
2266
+ { name = "numpy", marker = "sys_platform == 'darwin'" },
2267
+ { name = "protobuf", marker = "sys_platform == 'darwin'" },
2268
+ { name = "pyyaml", marker = "sys_platform == 'darwin'" },
2269
+ { name = "sentencepiece", marker = "sys_platform == 'darwin'" },
2270
+ { name = "transformers", marker = "sys_platform == 'darwin'" },
2271
2271
  ]
2272
2272
  sdist = { url = "https://files.pythonhosted.org/packages/e3/62/f46e1355256a114808517947f8e83ad6be310c7288c551db0fa678f47923/mlx_lm-0.29.1.tar.gz", hash = "sha256:b99180d8f33d33a077b814e550bfb2d8a59ae003d668fd1f4b3fff62a381d34b", size = 232302, upload-time = "2025-12-16T16:58:27.959Z" }
2273
2273
  wheels = [
@@ -2634,7 +2634,7 @@ name = "nvidia-cudnn-cu12"
2634
2634
  version = "9.5.1.17"
2635
2635
  source = { registry = "https://pypi.org/simple" }
2636
2636
  dependencies = [
2637
- { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
2637
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2638
2638
  ]
2639
2639
  wheels = [
2640
2640
  { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386, upload-time = "2024-10-25T19:54:26.39Z" },
@@ -2645,7 +2645,7 @@ name = "nvidia-cufft-cu12"
2645
2645
  version = "11.3.0.4"
2646
2646
  source = { registry = "https://pypi.org/simple" }
2647
2647
  dependencies = [
2648
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
2648
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2649
2649
  ]
2650
2650
  wheels = [
2651
2651
  { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" },
@@ -2674,9 +2674,9 @@ name = "nvidia-cusolver-cu12"
2674
2674
  version = "11.7.1.2"
2675
2675
  source = { registry = "https://pypi.org/simple" }
2676
2676
  dependencies = [
2677
- { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
2678
- { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
2679
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
2677
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2678
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2679
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2680
2680
  ]
2681
2681
  wheels = [
2682
2682
  { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" },
@@ -2688,7 +2688,7 @@ name = "nvidia-cusparse-cu12"
2688
2688
  version = "12.5.4.2"
2689
2689
  source = { registry = "https://pypi.org/simple" }
2690
2690
  dependencies = [
2691
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
2691
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
2692
2692
  ]
2693
2693
  wheels = [
2694
2694
  { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" },
@@ -3024,7 +3024,7 @@ wheels = [
3024
3024
 
3025
3025
  [[package]]
3026
3026
  name = "pre-commit"
3027
- version = "3.4.0"
3027
+ version = "3.8.0"
3028
3028
  source = { registry = "https://pypi.org/simple" }
3029
3029
  dependencies = [
3030
3030
  { name = "cfgv" },
@@ -3033,9 +3033,9 @@ dependencies = [
3033
3033
  { name = "pyyaml" },
3034
3034
  { name = "virtualenv" },
3035
3035
  ]
3036
- sdist = { url = "https://files.pythonhosted.org/packages/56/a5/cb576829ab7c94e768221cf0629e0da8519e744d993e0c99a6ae9803babd/pre_commit-3.4.0.tar.gz", hash = "sha256:6bbd5129a64cad4c0dfaeeb12cd8f7ea7e15b77028d985341478c8af3c759522", size = 177006, upload-time = "2023-09-02T17:09:41.966Z" }
3036
+ sdist = { url = "https://files.pythonhosted.org/packages/64/10/97ee2fa54dff1e9da9badbc5e35d0bbaef0776271ea5907eccf64140f72f/pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af", size = 177815, upload-time = "2024-07-28T19:59:01.538Z" }
3037
3037
  wheels = [
3038
- { url = "https://files.pythonhosted.org/packages/58/56/3b24f8641c39021218ca16115a9cd88512ae16eab790513e832a36269e90/pre_commit-3.4.0-py2.py3-none-any.whl", hash = "sha256:96d529a951f8b677f730a7212442027e8ba53f9b04d217c4c67dc56c393ad945", size = 203701, upload-time = "2023-09-02T17:09:40.215Z" },
3038
+ { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643, upload-time = "2024-07-28T19:58:59.335Z" },
3039
3039
  ]
3040
3040
 
3041
3041
  [[package]]
@@ -5030,8 +5030,8 @@ name = "triton"
5030
5030
  version = "3.3.1"
5031
5031
  source = { registry = "https://pypi.org/simple" }
5032
5032
  dependencies = [
5033
- { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12' and sys_platform == 'linux'" },
5034
- { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12' and sys_platform == 'linux'" },
5033
+ { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'" },
5034
+ { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'" },
5035
5035
  ]
5036
5036
  wheels = [
5037
5037
  { url = "https://files.pythonhosted.org/packages/8d/a9/549e51e9b1b2c9b854fd761a1d23df0ba2fbc60bd0c13b489ffa518cfcb7/triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e", size = 155600257, upload-time = "2025-05-29T23:39:36.085Z" },
@@ -5378,8 +5378,8 @@ name = "xformers"
5378
5378
  version = "0.0.31"
5379
5379
  source = { registry = "https://pypi.org/simple" }
5380
5380
  dependencies = [
5381
- { name = "numpy", marker = "sys_platform == 'linux'" },
5382
- { name = "torch", marker = "sys_platform == 'linux'" },
5381
+ { name = "numpy", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
5382
+ { name = "torch", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
5383
5383
  ]
5384
5384
  sdist = { url = "https://files.pythonhosted.org/packages/33/35/91c172a57681e1c03de5ad1ca654dc87c282279b941052ed04e616ae5bcd/xformers-0.0.31.tar.gz", hash = "sha256:3fccb159c6327c13fc1b08f8b963c2779ca526e2e50755dee9bcc1bac67d20c6", size = 12102740, upload-time = "2025-06-25T15:12:10.241Z" }
5385
5385
  wheels = [
@@ -1,309 +0,0 @@
1
- # DQN Learning Algorithm Analysis
2
-
3
- ## Overview
4
- Detailed analysis of the DQN learning algorithm implementation, focusing on the `learn()`, `update()`, and `soft_update()` methods.
5
-
6
- ## Algorithm Flow
7
-
8
- ### 1. `learn()` Method (lines 338-359)
9
-
10
- ```python
11
- def learn(self, experiences: ExperiencesType) -> float:
12
- obs = experiences["obs"]
13
- actions = experiences["action"]
14
- rewards = experiences["reward"]
15
- next_obs = experiences["next_obs"]
16
- dones = experiences["done"]
17
-
18
- obs = self.preprocess_observation(obs)
19
- next_obs = self.preprocess_observation(next_obs)
20
-
21
- loss = self.update(obs, actions, rewards, next_obs, dones)
22
-
23
- # soft update target network
24
- self.soft_update()
25
- return loss.item()
26
- ```
27
-
28
- **Analysis**: ✅ Looks correct
29
- - Extracts experiences correctly
30
- - Preprocesses observations
31
- - Calls `update()` to compute loss and backpropagate
32
- - Calls `soft_update()` after each learning step
33
- - Returns scalar loss value
34
-
35
- ### 2. `update()` Method (lines 286-336)
36
-
37
- ```python
38
- def update(self, obs, actions, rewards, next_obs, dones) -> torch.Tensor:
39
- with torch.no_grad():
40
- if self.double: # Double Q-learning
41
- q_idx = self.actor(next_obs).argmax(dim=1).unsqueeze(1)
42
- q_target = (
43
- self.actor_target(next_obs).gather(dim=1, index=q_idx).detach()
44
- )
45
- else:
46
- q_target = self.actor_target(next_obs).max(axis=1)[0].unsqueeze(1)
47
-
48
- # target, if terminal then y_j = rewards
49
- y_j = rewards + self.gamma * q_target * (1 - dones)
50
-
51
- if actions.ndim == 1:
52
- actions = actions.unsqueeze(-1)
53
-
54
- # Compute Q-values for actions taken and loss
55
- q_eval = self.actor(obs).gather(1, actions.long())
56
- loss: torch.Tensor = self.criterion(q_eval, y_j)
57
-
58
- # zero gradients, perform a backward pass, and update the weights
59
- self.optimizer.zero_grad()
60
- if self.accelerator is not None:
61
- self.accelerator.backward(loss)
62
- else:
63
- loss.backward()
64
-
65
- self.optimizer.step()
66
- return loss.detach()
67
- ```
68
-
69
- ## Issues Found
70
-
71
- ### ⚠️ Issue 1: Inconsistent `max()` Usage (Line 316)
72
-
73
- **Problem**: Uses `axis=1` instead of `dim=1`
74
-
75
- ```python
76
- q_target = self.actor_target(next_obs).max(axis=1)[0].unsqueeze(1)
77
- ```
78
-
79
- **Impact**:
80
- - PyTorch's `max()` accepts `axis` but it's deprecated
81
- - Should use `dim=1` for consistency
82
- - **However**: This shouldn't prevent learning, just causes deprecation warning
83
-
84
- **Comparison**:
85
- - Line 311: Uses `.argmax(dim=1)` ✅ (correct)
86
- - Line 316: Uses `.max(axis=1)` ❌ (should be `dim=1`)
87
-
88
- **Fix**:
89
- ```python
90
- q_target = self.actor_target(next_obs).max(dim=1)[0].unsqueeze(1)
91
- ```
92
-
93
- ### ⚠️ Issue 2: Target Network Initialization Method
94
-
95
- **Problem**: DQN uses a complex TensorDict-based initialization via `init_hook()`, while other algorithms use simple `load_state_dict()`
96
-
97
- **DQN Approach** (lines 185-203):
98
- ```python
99
- def init_hook(self) -> None:
100
- param_vals: TensorDict = from_module(self.actor).detach()
101
- target_params: TensorDict = param_vals.clone().lock_()
102
- try:
103
- target_params.to_module(self.actor_target)
104
- except KeyError:
105
- pass
106
- finally:
107
- self.param_vals = param_vals
108
- self.target_params = target_params
109
- ```
110
-
111
- **RainbowDQN/CQN Approach**:
112
- ```python
113
- self.actor_target.load_state_dict(self.actor.state_dict())
114
- ```
115
-
116
- **Potential Issues**:
117
- 1. The `lock_()` creates a locked TensorDict that's detached from computation graph
118
- 2. If `to_module()` fails silently (caught by `except KeyError: pass`), target network might not be initialized
119
- 3. The locked TensorDict might interfere with `soft_update()` parameter updates
120
-
121
- **Impact**:
122
- - If `to_module()` fails, target network starts with random weights instead of copying from actor
123
- - This would cause incorrect Q-targets and prevent learning
124
- - The silent exception handling makes this hard to detect
125
-
126
- **Recommendation**: Add logging or assertion to verify target network is initialized:
127
- ```python
128
- def init_hook(self) -> None:
129
- param_vals: TensorDict = from_module(self.actor).detach()
130
- target_params: TensorDict = param_vals.clone().lock_()
131
- try:
132
- target_params.to_module(self.actor_target)
133
- except KeyError as e:
134
- # Log the error instead of silently passing
135
- warnings.warn(f"Failed to initialize target network: {e}. Using load_state_dict fallback.")
136
- self.actor_target.load_state_dict(self.actor.state_dict())
137
- finally:
138
- self.param_vals = param_vals
139
- self.target_params = target_params
140
- ```
141
-
142
- ### ⚠️ Issue 3: Missing Gradient Clipping
143
-
144
- **Problem**: DQN doesn't clip gradients, while RainbowDQN does
145
-
146
- **RainbowDQN** (line 442):
147
- ```python
148
- clip_grad_norm_(self.actor.parameters(), 10.0)
149
- ```
150
-
151
- **DQN**: No gradient clipping
152
-
153
- **Impact**:
154
- - Could lead to gradient explosion in some cases
155
- - Not necessarily a bug, but could cause instability
156
-
157
- **Recommendation**: Consider adding gradient clipping:
158
- ```python
159
- from torch.nn.utils import clip_grad_norm_
160
-
161
- # After loss.backward(), before optimizer.step()
162
- clip_grad_norm_(self.actor.parameters(), max_norm=10.0)
163
- self.optimizer.step()
164
- ```
165
-
166
- ### ✅ Correct Implementations
167
-
168
- 1. **Q-Learning Update Formula** (line 319): ✅ Correct
169
- ```python
170
- y_j = rewards + self.gamma * q_target * (1 - dones)
171
- ```
172
-
173
- 2. **Double Q-Learning** (lines 310-314): ✅ Correct
174
- - Uses actor to select action, target to evaluate
175
-
176
- 3. **Loss Computation** (line 326): ✅ Correct
177
- ```python
178
- q_eval = self.actor(obs).gather(1, actions.long())
179
- loss = self.criterion(q_eval, y_j)
180
- ```
181
-
182
- 4. **Gradient Flow** (lines 329-335): ✅ Correct
183
- - Zero gradients
184
- - Backward pass
185
- - Optimizer step
186
-
187
- 5. **Soft Update** (lines 361-368): ✅ Correct formula
188
- ```python
189
- target_param.data.copy_(
190
- self.tau * eval_param.data + (1.0 - self.tau) * target_param.data
191
- )
192
- ```
193
-
194
- ## Potential Learning Issues
195
-
196
- ### 1. Target Network Not Initialized Properly
197
-
198
- **Most Likely Issue**: If `init_hook()` fails silently, target network has random weights, causing:
199
- - Incorrect Q-targets
200
- - No learning signal
201
- - Random behavior
202
-
203
- **How to Verify**:
204
- ```python
205
- # After initialization, check if target network matches actor
206
- actor_params = list(agent.actor.parameters())
207
- target_params = list(agent.actor_target.parameters())
208
- for a, t in zip(actor_params, target_params):
209
- if not torch.allclose(a.data, t.data, atol=1e-6):
210
- print("WARNING: Target network not initialized correctly!")
211
- ```
212
-
213
- ### 2. Tau Too Small
214
-
215
- **Config**: `TAU: 0.001` (line 18)
216
-
217
- **Impact**:
218
- - Very slow target network updates
219
- - Target network stays close to initial values for a long time
220
- - Slower learning convergence
221
-
222
- **Typical Values**:
223
- - DQN papers often use `tau=0.01` or `tau=0.005`
224
- - `tau=0.001` means only 0.1% update per step
225
-
226
- **Recommendation**: Try `tau=0.01` or `tau=0.005`
227
-
228
- ### 3. Learning Rate
229
-
230
- **Config**: `LR: 0.001` (line 12)
231
-
232
- **Impact**:
233
- - Might be too high for some environments
234
- - Could cause instability
235
-
236
- **Typical Values**:
237
- - DQN often uses `lr=1e-4` to `lr=5e-4`
238
- - `lr=0.001` is on the higher side
239
-
240
- **Recommendation**: Try `lr=5e-4` or `lr=1e-4`
241
-
242
- ### 4. Learn Step Frequency
243
-
244
- **Config**: `LEARN_STEP: 1` (line 17)
245
-
246
- **Impact**:
247
- - Learning every step (with 16 parallel envs, that's 16 steps per environment step)
248
- - Very frequent learning might cause instability
249
- - Typical DQN learns every 4-5 steps
250
-
251
- **Recommendation**: Try `LEARN_STEP: 4` or `LEARN_STEP: 5`
252
-
253
- ## Summary of Critical Issues
254
-
255
- 1. **🔴 HIGH PRIORITY**: Target network initialization might fail silently
256
- - Check if `init_hook()` actually initializes target network
257
- - Add fallback to `load_state_dict()` if TensorDict method fails
258
-
259
- 2. **🟡 MEDIUM PRIORITY**: Inconsistent `max()` usage
260
- - Change `axis=1` to `dim=1` for consistency
261
-
262
- 3. **🟡 MEDIUM PRIORITY**: Consider adding gradient clipping
263
- - Prevents gradient explosion
264
-
265
- 4. **🟡 MEDIUM PRIORITY**: Hyperparameter tuning
266
- - `tau=0.001` might be too small
267
- - `lr=0.001` might be too high
268
- - `learn_step=1` might be too frequent
269
-
270
- ## Recommended Fixes
271
-
272
- ### Fix 1: Improve Target Network Initialization
273
-
274
- ```python
275
- def init_hook(self) -> None:
276
- """Resets module parameters for the detached and target networks."""
277
- param_vals: TensorDict = from_module(self.actor).detach()
278
- target_params: TensorDict = param_vals.clone().lock_()
279
-
280
- try:
281
- target_params.to_module(self.actor_target)
282
- # Verify initialization succeeded
283
- actor_first_param = next(self.actor.parameters()).data
284
- target_first_param = next(self.actor_target.parameters()).data
285
- if not torch.allclose(actor_first_param, target_first_param, atol=1e-5):
286
- raise RuntimeError("Target network initialization verification failed")
287
- except (KeyError, RuntimeError) as e:
288
- warnings.warn(f"TensorDict initialization failed ({e}), using load_state_dict fallback")
289
- self.actor_target.load_state_dict(self.actor.state_dict())
290
- finally:
291
- self.param_vals = param_vals
292
- self.target_params = target_params
293
- ```
294
-
295
- ### Fix 2: Fix max() Usage
296
-
297
- ```python
298
- # Line 316
299
- q_target = self.actor_target(next_obs).max(dim=1)[0].unsqueeze(1)
300
- ```
301
-
302
- ### Fix 3: Add Gradient Clipping (Optional)
303
-
304
- ```python
305
- # After line 333 (loss.backward())
306
- from torch.nn.utils import clip_grad_norm_
307
- clip_grad_norm_(self.actor.parameters(), max_norm=10.0)
308
- self.optimizer.step()
309
- ```