agilerl 2.4.2.dev0__tar.gz → 2.4.3.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.pre-commit-config.yaml +3 -3
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/PKG-INFO +2 -2
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/base.py +1 -2
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ippo.py +1 -1
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/agent.py +2 -2
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/pyproject.toml +2 -2
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_make_evolvable.py +1 -4
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_train/test_train.py +18 -6
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/uv.lock +21 -21
- agilerl-2.4.2.dev0/DQN_LEARNING_ALGORITHM_ANALYSIS.md +0 -309
- agilerl-2.4.2.dev0/DQN_LEARNING_ANALYSIS.md +0 -168
- agilerl-2.4.2.dev0/GPU_CLEANUP_ANALYSIS.md +0 -541
- agilerl-2.4.2.dev0/find_dqn_commit.sh +0 -82
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.coveragerc +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/badges/arena-github-badge.svg +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/workflows/codeql.yml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.github/workflows/python-app.yml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.gitignore +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/.readthedocs.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CITATION.cff +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CODE_OF_CONDUCT.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/CONTRIBUTING.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/LICENSE +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/README.md +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/core/registry.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/grpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ilql.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/maddpg.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/matd3.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/data.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/sampler.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/base.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/bert.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/configs.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/modules/simba.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/actors.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/base.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/distributions.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/distributions_experimental.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/protocols.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_llm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/typing.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/algo_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/cache.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/llm_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/probe_envs.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/probe_envs_ma.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/torch_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/utils/utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/agilerl/wrappers/utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_bandits.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_dpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_grpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_multi_agent_off_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_multi_agent_on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_off_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_off_policy_distributed.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_offline.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_offline_distributed.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_rainbow.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_recurrent.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_resnet.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/benchmarking_simba.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/configs/ds_config.json +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/make_evolvable_benchmarking.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/benchmarking/networks.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/accelerate/accelerate.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/accelerate/grpo_accelerate_config.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/bandit/neural_ts.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/bandit/neural_ucb.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/cqn.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg_lstm.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ddpg/ddpg_simba.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dpo.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn_lstm.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/dqn/dqn_rainbow.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/grpo.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/ippo.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/ippo_pong.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/maddpg.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_agent/matd3.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/multi_input.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo_image.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/ppo/ppo_recurrent.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/configs/training/td3.yaml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/cartpole/cartpole_random_v1.1.0.h5 +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/cartpole/cartpole_v1.1.0.h5 +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/pendulum/pendulum_random_v1.1.0.h5 +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/data/pendulum/pendulum_v1.1.0.h5 +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_bandit.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_custom_network.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_multi_agent.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_off_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_off_policy_distributed.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_offline.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_offline_distributed.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_cartpole.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_memory.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/demo_on_policy_rnn_minigrid.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_cartpole.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_lunar_lander.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_lunar_lander_rnn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/demos/performance_flamegraph_rnn_memory.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/dependabot.yml +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/Makefile +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/arena-github-badge.svg +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/css/custom.css +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/favicon.ico +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/js/expand_sidebar.js +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/logo_teal.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/logo_white.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/module.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/network.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/iris-thumbnail.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/pendigits-thumbnail.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/rainbow_performance.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/_static/thumbnails/simba_thumbnail.png +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/base.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/cql.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ddpg.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dpo.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dqn.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/dqn_rainbow.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/grpo.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ilql.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ippo.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/maddpg.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/matd3.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/neural_ts.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/neural_ucb.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/ppo.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/registry.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/td3.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/algorithms/wrappers.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/data.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/multi_agent_replay_buffer.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/replay_buffer.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/rollout_buffer.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/sampler.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/components/segment_tree.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/mutation.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/hpo/tournament.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/base.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/bert.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/cnn.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/custom_activation.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/dummy.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/gpt.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/lstm.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/mlp.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/multi_input.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/resnet.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/modules/simba.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/actors.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/base.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/q_networks.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/networks/value_networks.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/rollouts/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/rollouts/on_policy.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/train.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/algo_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/cache.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/evolvable_networks.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/ilql_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/llm_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/log_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/minari_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/probe_envs.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/torch_utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/utils/utils.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/petting_zoo_async_vector_env.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/vector/petting_zoo_vector_env.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/agent.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/learning.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/make_evolvable.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/api/wrappers/pettingzoo.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/bandits/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/conf.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/custom_algorithms/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/debugging_rl/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/distributed_training/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/evo_hyperparam_opt/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/evolvable_networks/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/get_started/agilerl2changes.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/get_started/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/llm_finetuning/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/make.bat +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/multi_agent_training/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/off_policy/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/offline_training/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/on_policy/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/pomdp/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/releases/index.rst +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/docs/requirements.txt +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/pytest.ini +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/conftest.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/helper_functions.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/pz_vector_test_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/test_neural_ts.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bandits/test_neural_ucb.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_base.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_bc_lm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/conftest.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/test_dpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_llms/test_grpo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_ippo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_maddpg.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_multi_agent/test_matd3.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_optimizer_wrapper.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_registry.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_cqn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ddpg.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_dqn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_dqn_rainbow.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ilql.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_ppo.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_algorithms/test_single_agent/test_td3.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_multi_agent_replay_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_replay_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_replay_data.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_rollout_buffer.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_sampler.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_components/test_segment_tree.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_data.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/test_mutation.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_hpo/test_tournament.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_base.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_bert.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_cnn.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_custom_activation.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_dummy.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_gpt.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_lstm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_mlp.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_multi_input.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_resnet.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_modules/test_simba.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_actors.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_base.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_q_networks.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_networks/test_value_functions.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_protocols.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_train/test_train_llm.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_algo_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_cache.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_ilql_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_llm_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_log_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_minari_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_probe_envs.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_probe_envs_ma.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_sampling_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_torch_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_utils.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_utils/test_utils_evolvable.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_vector/test_vector.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/__init__.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_agent.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_autoreset.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_bandit_env.py +0 -0
- {agilerl-2.4.2.dev0 → agilerl-2.4.3.dev0}/tests/test_wrappers/test_skills.py +0 -0
|
@@ -24,7 +24,7 @@ repos:
|
|
|
24
24
|
- id: mixed-line-ending
|
|
25
25
|
args: [--fix=lf]
|
|
26
26
|
- repo: https://github.com/psf/black-pre-commit-mirror
|
|
27
|
-
rev:
|
|
27
|
+
rev: 26.1.0
|
|
28
28
|
hooks:
|
|
29
29
|
- id: black
|
|
30
30
|
- repo: https://github.com/codespell-project/codespell
|
|
@@ -35,7 +35,7 @@ repos:
|
|
|
35
35
|
- --skip=*.css,*.js,*.map,*.scss,*.svg
|
|
36
36
|
- --ignore-words-list=magent,pres,roate
|
|
37
37
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
38
|
-
rev: v0.14.
|
|
38
|
+
rev: v0.14.14
|
|
39
39
|
hooks:
|
|
40
40
|
- id: ruff-check
|
|
41
41
|
args:
|
|
@@ -53,6 +53,6 @@ repos:
|
|
|
53
53
|
- id: yamlfmt
|
|
54
54
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
|
55
55
|
# uv version.
|
|
56
|
-
rev: 0.9.
|
|
56
|
+
rev: 0.9.28
|
|
57
57
|
hooks:
|
|
58
58
|
- id: uv-lock
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.4.
|
|
3
|
+
Version: 2.4.3.dev0
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
Author-email: Nick Ustaran-Anderegg <dev@agilerl.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -22,7 +22,7 @@ Requires-Dist: omegaconf~=2.3.0
|
|
|
22
22
|
Requires-Dist: packaging>=20.0
|
|
23
23
|
Requires-Dist: pandas~=2.2.3
|
|
24
24
|
Requires-Dist: pettingzoo~=1.23.1
|
|
25
|
-
Requires-Dist: pre-commit~=3.
|
|
25
|
+
Requires-Dist: pre-commit~=3.8.0
|
|
26
26
|
Requires-Dist: pygame~=2.6.0
|
|
27
27
|
Requires-Dist: pymunk~=6.2.0
|
|
28
28
|
Requires-Dist: redis~=4.4.4
|
|
@@ -2066,8 +2066,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2066
2066
|
accelerator: Optional[Accelerator] = None,
|
|
2067
2067
|
) -> None:
|
|
2068
2068
|
raise NotImplementedError(
|
|
2069
|
-
"The load class method is not supported for this algorithm class."
|
|
2070
|
-
"""
|
|
2069
|
+
"The load class method is not supported for this algorithm class." """
|
|
2071
2070
|
To load a saved LLM, please load the model as follows, and then re-instantiate the GRPO
|
|
2072
2071
|
class, using the pre-trained model.
|
|
2073
2072
|
|
|
@@ -671,7 +671,7 @@ class IPPO(MultiAgentRLAlgorithm):
|
|
|
671
671
|
:param action_space: Action space for the agent
|
|
672
672
|
:type action_space: gymnasium.spaces
|
|
673
673
|
"""
|
|
674
|
-
|
|
674
|
+
states, actions, log_probs, rewards, dones, values, next_state, next_done = (
|
|
675
675
|
experiences
|
|
676
676
|
)
|
|
677
677
|
|
|
@@ -597,8 +597,8 @@ class AsyncAgentsWrapper(AgentWrapper[MultiAgentRLAlgorithm]):
|
|
|
597
597
|
:return: Learning information
|
|
598
598
|
:rtype: Any
|
|
599
599
|
"""
|
|
600
|
-
|
|
601
|
-
|
|
600
|
+
states, actions, log_probs, rewards, dones, values, next_state, next_done = map(
|
|
601
|
+
self.stack_experiences, experiences
|
|
602
602
|
)
|
|
603
603
|
|
|
604
604
|
# Handle case where we haven't collected a next state for each sub-agent
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "agilerl"
|
|
3
|
-
version = "2.4.
|
|
3
|
+
version = "2.4.3.dev0"
|
|
4
4
|
description = "AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps."
|
|
5
5
|
authors = [{ name = "Nick Ustaran-Anderegg", email = "dev@agilerl.com" }]
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -24,7 +24,7 @@ dependencies = [
|
|
|
24
24
|
"pettingzoo~=1.23.1",
|
|
25
25
|
"jax[cpu]~=0.4.31",
|
|
26
26
|
"packaging>=20.0",
|
|
27
|
-
"pre-commit~=3.
|
|
27
|
+
"pre-commit~=3.8.0",
|
|
28
28
|
"pygame~=2.6.0",
|
|
29
29
|
"pymunk~=6.2.0",
|
|
30
30
|
"redis~=4.4.4",
|
|
@@ -194,9 +194,7 @@ def test_instantiation_with_rainbow():
|
|
|
194
194
|
network, input_tensor, support=support, rainbow=True
|
|
195
195
|
)
|
|
196
196
|
assert isinstance(evolvable_network, MakeEvolvable)
|
|
197
|
-
assert (
|
|
198
|
-
str(evolvable_network)
|
|
199
|
-
== """MakeEvolvable(
|
|
197
|
+
assert str(evolvable_network) == """MakeEvolvable(
|
|
200
198
|
(feature_net): Sequential(
|
|
201
199
|
(feature_linear_layer_0): Linear(in_features=3, out_features=128, bias=True)
|
|
202
200
|
(feature_activation_0): ReLU()
|
|
@@ -212,7 +210,6 @@ def test_instantiation_with_rainbow():
|
|
|
212
210
|
(advantage_linear_layer_output): NoisyLinear(in_features=8, out_features=102)
|
|
213
211
|
)
|
|
214
212
|
)"""
|
|
215
|
-
)
|
|
216
213
|
del network, evolvable_network
|
|
217
214
|
|
|
218
215
|
|
|
@@ -1498,8 +1498,10 @@ def test_train_off_policy_agent_calls_made_rainbow(
|
|
|
1498
1498
|
def test_train_off_policy_save_elite_warning(
|
|
1499
1499
|
env, population_off_policy, tournament, mutations, memory
|
|
1500
1500
|
):
|
|
1501
|
-
warning_string =
|
|
1501
|
+
warning_string = (
|
|
1502
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
1502
1503
|
be saved unless 'save_elite' is set to True."
|
|
1504
|
+
)
|
|
1503
1505
|
with pytest.warns(match=warning_string):
|
|
1504
1506
|
pop, pop_fitnesses = train_off_policy(
|
|
1505
1507
|
env,
|
|
@@ -2137,8 +2139,10 @@ def test_train_on_policy_save_elite_warning(
|
|
|
2137
2139
|
tournament,
|
|
2138
2140
|
mutations,
|
|
2139
2141
|
):
|
|
2140
|
-
warning_string =
|
|
2142
|
+
warning_string = (
|
|
2143
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
2141
2144
|
be saved unless 'save_elite' is set to True."
|
|
2145
|
+
)
|
|
2142
2146
|
with pytest.warns(match=warning_string):
|
|
2143
2147
|
pop, pop_fitnesses = train_on_policy(
|
|
2144
2148
|
env,
|
|
@@ -2703,8 +2707,10 @@ def test_train_multi_agent_on_policy_rgb_vectorized(
|
|
|
2703
2707
|
def test_train_multi_save_elite_warning(
|
|
2704
2708
|
multi_env, population_multi_agent, on_policy, multi_memory, tournament, mutations
|
|
2705
2709
|
):
|
|
2706
|
-
warning_string =
|
|
2710
|
+
warning_string = (
|
|
2711
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
2707
2712
|
be saved unless 'save_elite' is set to True."
|
|
2713
|
+
)
|
|
2708
2714
|
with pytest.warns(match=warning_string):
|
|
2709
2715
|
pop, pop_fitnesses = train_multi_agent_off_policy(
|
|
2710
2716
|
multi_env,
|
|
@@ -2730,8 +2736,10 @@ def test_train_multi_save_elite_warning(
|
|
|
2730
2736
|
def test_train_multi_save_elite_warning_on_policy(
|
|
2731
2737
|
multi_env, population_multi_agent, on_policy, multi_memory, tournament, mutations
|
|
2732
2738
|
):
|
|
2733
|
-
warning_string =
|
|
2739
|
+
warning_string = (
|
|
2740
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
2734
2741
|
be saved unless 'save_elite' is set to True."
|
|
2742
|
+
)
|
|
2735
2743
|
with pytest.warns(match=warning_string):
|
|
2736
2744
|
pop, pop_fitnesses = train_multi_agent_on_policy(
|
|
2737
2745
|
multi_env,
|
|
@@ -3567,8 +3575,10 @@ def test_train_offline_save_elite_warning(
|
|
|
3567
3575
|
offline_init_hp,
|
|
3568
3576
|
dummy_h5py_data,
|
|
3569
3577
|
):
|
|
3570
|
-
warning_string =
|
|
3578
|
+
warning_string = (
|
|
3579
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
3571
3580
|
be saved unless 'save_elite' is set to True."
|
|
3581
|
+
)
|
|
3572
3582
|
with pytest.warns(match=warning_string):
|
|
3573
3583
|
pop, pop_fitness = train_offline(
|
|
3574
3584
|
env,
|
|
@@ -4057,8 +4067,10 @@ def test_train_bandit_agent_calls_made(
|
|
|
4057
4067
|
def test_train_bandit_save_elite_warning(
|
|
4058
4068
|
bandit_env, population_bandit, tournament, mutations, bandit_memory
|
|
4059
4069
|
):
|
|
4060
|
-
warning_string =
|
|
4070
|
+
warning_string = (
|
|
4071
|
+
"'save_elite' set to False but 'elite_path' has been defined, elite will not\
|
|
4061
4072
|
be saved unless 'save_elite' is set to True."
|
|
4073
|
+
)
|
|
4062
4074
|
with pytest.warns(match=warning_string):
|
|
4063
4075
|
pop, pop_fitnesses = train_bandits(
|
|
4064
4076
|
bandit_env,
|
|
@@ -53,7 +53,7 @@ wheels = [
|
|
|
53
53
|
|
|
54
54
|
[[package]]
|
|
55
55
|
name = "agilerl"
|
|
56
|
-
version = "2.4.
|
|
56
|
+
version = "2.4.3.dev0"
|
|
57
57
|
source = { editable = "." }
|
|
58
58
|
dependencies = [
|
|
59
59
|
{ name = "accelerate" },
|
|
@@ -139,7 +139,7 @@ requires-dist = [
|
|
|
139
139
|
{ name = "peft", marker = "extra == 'all'", specifier = "~=0.18.0" },
|
|
140
140
|
{ name = "peft", marker = "extra == 'llm'", specifier = "~=0.18.0" },
|
|
141
141
|
{ name = "pettingzoo", specifier = "~=1.23.1" },
|
|
142
|
-
{ name = "pre-commit", specifier = "~=3.
|
|
142
|
+
{ name = "pre-commit", specifier = "~=3.8.0" },
|
|
143
143
|
{ name = "pygame", specifier = "~=2.6.0" },
|
|
144
144
|
{ name = "pymunk", specifier = "~=6.2.0" },
|
|
145
145
|
{ name = "redis", specifier = "~=4.4.4" },
|
|
@@ -2261,13 +2261,13 @@ name = "mlx-lm"
|
|
|
2261
2261
|
version = "0.29.1"
|
|
2262
2262
|
source = { registry = "https://pypi.org/simple" }
|
|
2263
2263
|
dependencies = [
|
|
2264
|
-
{ name = "jinja2", marker = "
|
|
2264
|
+
{ name = "jinja2", marker = "sys_platform == 'darwin'" },
|
|
2265
2265
|
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
|
2266
|
-
{ name = "numpy", marker = "
|
|
2267
|
-
{ name = "protobuf", marker = "
|
|
2268
|
-
{ name = "pyyaml", marker = "
|
|
2269
|
-
{ name = "sentencepiece", marker = "
|
|
2270
|
-
{ name = "transformers", marker = "
|
|
2266
|
+
{ name = "numpy", marker = "sys_platform == 'darwin'" },
|
|
2267
|
+
{ name = "protobuf", marker = "sys_platform == 'darwin'" },
|
|
2268
|
+
{ name = "pyyaml", marker = "sys_platform == 'darwin'" },
|
|
2269
|
+
{ name = "sentencepiece", marker = "sys_platform == 'darwin'" },
|
|
2270
|
+
{ name = "transformers", marker = "sys_platform == 'darwin'" },
|
|
2271
2271
|
]
|
|
2272
2272
|
sdist = { url = "https://files.pythonhosted.org/packages/e3/62/f46e1355256a114808517947f8e83ad6be310c7288c551db0fa678f47923/mlx_lm-0.29.1.tar.gz", hash = "sha256:b99180d8f33d33a077b814e550bfb2d8a59ae003d668fd1f4b3fff62a381d34b", size = 232302, upload-time = "2025-12-16T16:58:27.959Z" }
|
|
2273
2273
|
wheels = [
|
|
@@ -2634,7 +2634,7 @@ name = "nvidia-cudnn-cu12"
|
|
|
2634
2634
|
version = "9.5.1.17"
|
|
2635
2635
|
source = { registry = "https://pypi.org/simple" }
|
|
2636
2636
|
dependencies = [
|
|
2637
|
-
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
|
2637
|
+
{ name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2638
2638
|
]
|
|
2639
2639
|
wheels = [
|
|
2640
2640
|
{ url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386, upload-time = "2024-10-25T19:54:26.39Z" },
|
|
@@ -2645,7 +2645,7 @@ name = "nvidia-cufft-cu12"
|
|
|
2645
2645
|
version = "11.3.0.4"
|
|
2646
2646
|
source = { registry = "https://pypi.org/simple" }
|
|
2647
2647
|
dependencies = [
|
|
2648
|
-
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
|
2648
|
+
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2649
2649
|
]
|
|
2650
2650
|
wheels = [
|
|
2651
2651
|
{ url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" },
|
|
@@ -2674,9 +2674,9 @@ name = "nvidia-cusolver-cu12"
|
|
|
2674
2674
|
version = "11.7.1.2"
|
|
2675
2675
|
source = { registry = "https://pypi.org/simple" }
|
|
2676
2676
|
dependencies = [
|
|
2677
|
-
{ name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
|
|
2678
|
-
{ name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
|
|
2679
|
-
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
|
2677
|
+
{ name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2678
|
+
{ name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2679
|
+
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2680
2680
|
]
|
|
2681
2681
|
wheels = [
|
|
2682
2682
|
{ url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" },
|
|
@@ -2688,7 +2688,7 @@ name = "nvidia-cusparse-cu12"
|
|
|
2688
2688
|
version = "12.5.4.2"
|
|
2689
2689
|
source = { registry = "https://pypi.org/simple" }
|
|
2690
2690
|
dependencies = [
|
|
2691
|
-
{ name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
|
|
2691
|
+
{ name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
2692
2692
|
]
|
|
2693
2693
|
wheels = [
|
|
2694
2694
|
{ url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" },
|
|
@@ -3024,7 +3024,7 @@ wheels = [
|
|
|
3024
3024
|
|
|
3025
3025
|
[[package]]
|
|
3026
3026
|
name = "pre-commit"
|
|
3027
|
-
version = "3.
|
|
3027
|
+
version = "3.8.0"
|
|
3028
3028
|
source = { registry = "https://pypi.org/simple" }
|
|
3029
3029
|
dependencies = [
|
|
3030
3030
|
{ name = "cfgv" },
|
|
@@ -3033,9 +3033,9 @@ dependencies = [
|
|
|
3033
3033
|
{ name = "pyyaml" },
|
|
3034
3034
|
{ name = "virtualenv" },
|
|
3035
3035
|
]
|
|
3036
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
3036
|
+
sdist = { url = "https://files.pythonhosted.org/packages/64/10/97ee2fa54dff1e9da9badbc5e35d0bbaef0776271ea5907eccf64140f72f/pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af", size = 177815, upload-time = "2024-07-28T19:59:01.538Z" }
|
|
3037
3037
|
wheels = [
|
|
3038
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
3038
|
+
{ url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643, upload-time = "2024-07-28T19:58:59.335Z" },
|
|
3039
3039
|
]
|
|
3040
3040
|
|
|
3041
3041
|
[[package]]
|
|
@@ -5030,8 +5030,8 @@ name = "triton"
|
|
|
5030
5030
|
version = "3.3.1"
|
|
5031
5031
|
source = { registry = "https://pypi.org/simple" }
|
|
5032
5032
|
dependencies = [
|
|
5033
|
-
{ name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12' and sys_platform == 'linux'" },
|
|
5034
|
-
{ name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12' and sys_platform == 'linux'" },
|
|
5033
|
+
{ name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
5034
|
+
{ name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
5035
5035
|
]
|
|
5036
5036
|
wheels = [
|
|
5037
5037
|
{ url = "https://files.pythonhosted.org/packages/8d/a9/549e51e9b1b2c9b854fd761a1d23df0ba2fbc60bd0c13b489ffa518cfcb7/triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e", size = 155600257, upload-time = "2025-05-29T23:39:36.085Z" },
|
|
@@ -5378,8 +5378,8 @@ name = "xformers"
|
|
|
5378
5378
|
version = "0.0.31"
|
|
5379
5379
|
source = { registry = "https://pypi.org/simple" }
|
|
5380
5380
|
dependencies = [
|
|
5381
|
-
{ name = "numpy", marker = "sys_platform == 'linux'" },
|
|
5382
|
-
{ name = "torch", marker = "sys_platform == 'linux'" },
|
|
5381
|
+
{ name = "numpy", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
5382
|
+
{ name = "torch", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" },
|
|
5383
5383
|
]
|
|
5384
5384
|
sdist = { url = "https://files.pythonhosted.org/packages/33/35/91c172a57681e1c03de5ad1ca654dc87c282279b941052ed04e616ae5bcd/xformers-0.0.31.tar.gz", hash = "sha256:3fccb159c6327c13fc1b08f8b963c2779ca526e2e50755dee9bcc1bac67d20c6", size = 12102740, upload-time = "2025-06-25T15:12:10.241Z" }
|
|
5385
5385
|
wheels = [
|
|
@@ -1,309 +0,0 @@
|
|
|
1
|
-
# DQN Learning Algorithm Analysis
|
|
2
|
-
|
|
3
|
-
## Overview
|
|
4
|
-
Detailed analysis of the DQN learning algorithm implementation, focusing on the `learn()`, `update()`, and `soft_update()` methods.
|
|
5
|
-
|
|
6
|
-
## Algorithm Flow
|
|
7
|
-
|
|
8
|
-
### 1. `learn()` Method (lines 338-359)
|
|
9
|
-
|
|
10
|
-
```python
|
|
11
|
-
def learn(self, experiences: ExperiencesType) -> float:
|
|
12
|
-
obs = experiences["obs"]
|
|
13
|
-
actions = experiences["action"]
|
|
14
|
-
rewards = experiences["reward"]
|
|
15
|
-
next_obs = experiences["next_obs"]
|
|
16
|
-
dones = experiences["done"]
|
|
17
|
-
|
|
18
|
-
obs = self.preprocess_observation(obs)
|
|
19
|
-
next_obs = self.preprocess_observation(next_obs)
|
|
20
|
-
|
|
21
|
-
loss = self.update(obs, actions, rewards, next_obs, dones)
|
|
22
|
-
|
|
23
|
-
# soft update target network
|
|
24
|
-
self.soft_update()
|
|
25
|
-
return loss.item()
|
|
26
|
-
```
|
|
27
|
-
|
|
28
|
-
**Analysis**: ✅ Looks correct
|
|
29
|
-
- Extracts experiences correctly
|
|
30
|
-
- Preprocesses observations
|
|
31
|
-
- Calls `update()` to compute loss and backpropagate
|
|
32
|
-
- Calls `soft_update()` after each learning step
|
|
33
|
-
- Returns scalar loss value
|
|
34
|
-
|
|
35
|
-
### 2. `update()` Method (lines 286-336)
|
|
36
|
-
|
|
37
|
-
```python
|
|
38
|
-
def update(self, obs, actions, rewards, next_obs, dones) -> torch.Tensor:
|
|
39
|
-
with torch.no_grad():
|
|
40
|
-
if self.double: # Double Q-learning
|
|
41
|
-
q_idx = self.actor(next_obs).argmax(dim=1).unsqueeze(1)
|
|
42
|
-
q_target = (
|
|
43
|
-
self.actor_target(next_obs).gather(dim=1, index=q_idx).detach()
|
|
44
|
-
)
|
|
45
|
-
else:
|
|
46
|
-
q_target = self.actor_target(next_obs).max(axis=1)[0].unsqueeze(1)
|
|
47
|
-
|
|
48
|
-
# target, if terminal then y_j = rewards
|
|
49
|
-
y_j = rewards + self.gamma * q_target * (1 - dones)
|
|
50
|
-
|
|
51
|
-
if actions.ndim == 1:
|
|
52
|
-
actions = actions.unsqueeze(-1)
|
|
53
|
-
|
|
54
|
-
# Compute Q-values for actions taken and loss
|
|
55
|
-
q_eval = self.actor(obs).gather(1, actions.long())
|
|
56
|
-
loss: torch.Tensor = self.criterion(q_eval, y_j)
|
|
57
|
-
|
|
58
|
-
# zero gradients, perform a backward pass, and update the weights
|
|
59
|
-
self.optimizer.zero_grad()
|
|
60
|
-
if self.accelerator is not None:
|
|
61
|
-
self.accelerator.backward(loss)
|
|
62
|
-
else:
|
|
63
|
-
loss.backward()
|
|
64
|
-
|
|
65
|
-
self.optimizer.step()
|
|
66
|
-
return loss.detach()
|
|
67
|
-
```
|
|
68
|
-
|
|
69
|
-
## Issues Found
|
|
70
|
-
|
|
71
|
-
### ⚠️ Issue 1: Inconsistent `max()` Usage (Line 316)
|
|
72
|
-
|
|
73
|
-
**Problem**: Uses `axis=1` instead of `dim=1`
|
|
74
|
-
|
|
75
|
-
```python
|
|
76
|
-
q_target = self.actor_target(next_obs).max(axis=1)[0].unsqueeze(1)
|
|
77
|
-
```
|
|
78
|
-
|
|
79
|
-
**Impact**:
|
|
80
|
-
- PyTorch's `max()` accepts `axis` but it's deprecated
|
|
81
|
-
- Should use `dim=1` for consistency
|
|
82
|
-
- **However**: This shouldn't prevent learning, just causes deprecation warning
|
|
83
|
-
|
|
84
|
-
**Comparison**:
|
|
85
|
-
- Line 311: Uses `.argmax(dim=1)` ✅ (correct)
|
|
86
|
-
- Line 316: Uses `.max(axis=1)` ❌ (should be `dim=1`)
|
|
87
|
-
|
|
88
|
-
**Fix**:
|
|
89
|
-
```python
|
|
90
|
-
q_target = self.actor_target(next_obs).max(dim=1)[0].unsqueeze(1)
|
|
91
|
-
```
|
|
92
|
-
|
|
93
|
-
### ⚠️ Issue 2: Target Network Initialization Method
|
|
94
|
-
|
|
95
|
-
**Problem**: DQN uses a complex TensorDict-based initialization via `init_hook()`, while other algorithms use simple `load_state_dict()`
|
|
96
|
-
|
|
97
|
-
**DQN Approach** (lines 185-203):
|
|
98
|
-
```python
|
|
99
|
-
def init_hook(self) -> None:
|
|
100
|
-
param_vals: TensorDict = from_module(self.actor).detach()
|
|
101
|
-
target_params: TensorDict = param_vals.clone().lock_()
|
|
102
|
-
try:
|
|
103
|
-
target_params.to_module(self.actor_target)
|
|
104
|
-
except KeyError:
|
|
105
|
-
pass
|
|
106
|
-
finally:
|
|
107
|
-
self.param_vals = param_vals
|
|
108
|
-
self.target_params = target_params
|
|
109
|
-
```
|
|
110
|
-
|
|
111
|
-
**RainbowDQN/CQN Approach**:
|
|
112
|
-
```python
|
|
113
|
-
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
114
|
-
```
|
|
115
|
-
|
|
116
|
-
**Potential Issues**:
|
|
117
|
-
1. The `lock_()` creates a locked TensorDict that's detached from computation graph
|
|
118
|
-
2. If `to_module()` fails silently (caught by `except KeyError: pass`), target network might not be initialized
|
|
119
|
-
3. The locked TensorDict might interfere with `soft_update()` parameter updates
|
|
120
|
-
|
|
121
|
-
**Impact**:
|
|
122
|
-
- If `to_module()` fails, target network starts with random weights instead of copying from actor
|
|
123
|
-
- This would cause incorrect Q-targets and prevent learning
|
|
124
|
-
- The silent exception handling makes this hard to detect
|
|
125
|
-
|
|
126
|
-
**Recommendation**: Add logging or assertion to verify target network is initialized:
|
|
127
|
-
```python
|
|
128
|
-
def init_hook(self) -> None:
|
|
129
|
-
param_vals: TensorDict = from_module(self.actor).detach()
|
|
130
|
-
target_params: TensorDict = param_vals.clone().lock_()
|
|
131
|
-
try:
|
|
132
|
-
target_params.to_module(self.actor_target)
|
|
133
|
-
except KeyError as e:
|
|
134
|
-
# Log the error instead of silently passing
|
|
135
|
-
warnings.warn(f"Failed to initialize target network: {e}. Using load_state_dict fallback.")
|
|
136
|
-
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
137
|
-
finally:
|
|
138
|
-
self.param_vals = param_vals
|
|
139
|
-
self.target_params = target_params
|
|
140
|
-
```
|
|
141
|
-
|
|
142
|
-
### ⚠️ Issue 3: Missing Gradient Clipping
|
|
143
|
-
|
|
144
|
-
**Problem**: DQN doesn't clip gradients, while RainbowDQN does
|
|
145
|
-
|
|
146
|
-
**RainbowDQN** (line 442):
|
|
147
|
-
```python
|
|
148
|
-
clip_grad_norm_(self.actor.parameters(), 10.0)
|
|
149
|
-
```
|
|
150
|
-
|
|
151
|
-
**DQN**: No gradient clipping
|
|
152
|
-
|
|
153
|
-
**Impact**:
|
|
154
|
-
- Could lead to gradient explosion in some cases
|
|
155
|
-
- Not necessarily a bug, but could cause instability
|
|
156
|
-
|
|
157
|
-
**Recommendation**: Consider adding gradient clipping:
|
|
158
|
-
```python
|
|
159
|
-
from torch.nn.utils import clip_grad_norm_
|
|
160
|
-
|
|
161
|
-
# After loss.backward(), before optimizer.step()
|
|
162
|
-
clip_grad_norm_(self.actor.parameters(), max_norm=10.0)
|
|
163
|
-
self.optimizer.step()
|
|
164
|
-
```
|
|
165
|
-
|
|
166
|
-
### ✅ Correct Implementations
|
|
167
|
-
|
|
168
|
-
1. **Q-Learning Update Formula** (line 319): ✅ Correct
|
|
169
|
-
```python
|
|
170
|
-
y_j = rewards + self.gamma * q_target * (1 - dones)
|
|
171
|
-
```
|
|
172
|
-
|
|
173
|
-
2. **Double Q-Learning** (lines 310-314): ✅ Correct
|
|
174
|
-
- Uses actor to select action, target to evaluate
|
|
175
|
-
|
|
176
|
-
3. **Loss Computation** (line 326): ✅ Correct
|
|
177
|
-
```python
|
|
178
|
-
q_eval = self.actor(obs).gather(1, actions.long())
|
|
179
|
-
loss = self.criterion(q_eval, y_j)
|
|
180
|
-
```
|
|
181
|
-
|
|
182
|
-
4. **Gradient Flow** (lines 329-335): ✅ Correct
|
|
183
|
-
- Zero gradients
|
|
184
|
-
- Backward pass
|
|
185
|
-
- Optimizer step
|
|
186
|
-
|
|
187
|
-
5. **Soft Update** (lines 361-368): ✅ Correct formula
|
|
188
|
-
```python
|
|
189
|
-
target_param.data.copy_(
|
|
190
|
-
self.tau * eval_param.data + (1.0 - self.tau) * target_param.data
|
|
191
|
-
)
|
|
192
|
-
```
|
|
193
|
-
|
|
194
|
-
## Potential Learning Issues
|
|
195
|
-
|
|
196
|
-
### 1. Target Network Not Initialized Properly
|
|
197
|
-
|
|
198
|
-
**Most Likely Issue**: If `init_hook()` fails silently, target network has random weights, causing:
|
|
199
|
-
- Incorrect Q-targets
|
|
200
|
-
- No learning signal
|
|
201
|
-
- Random behavior
|
|
202
|
-
|
|
203
|
-
**How to Verify**:
|
|
204
|
-
```python
|
|
205
|
-
# After initialization, check if target network matches actor
|
|
206
|
-
actor_params = list(agent.actor.parameters())
|
|
207
|
-
target_params = list(agent.actor_target.parameters())
|
|
208
|
-
for a, t in zip(actor_params, target_params):
|
|
209
|
-
if not torch.allclose(a.data, t.data, atol=1e-6):
|
|
210
|
-
print("WARNING: Target network not initialized correctly!")
|
|
211
|
-
```
|
|
212
|
-
|
|
213
|
-
### 2. Tau Too Small
|
|
214
|
-
|
|
215
|
-
**Config**: `TAU: 0.001` (line 18)
|
|
216
|
-
|
|
217
|
-
**Impact**:
|
|
218
|
-
- Very slow target network updates
|
|
219
|
-
- Target network stays close to initial values for a long time
|
|
220
|
-
- Slower learning convergence
|
|
221
|
-
|
|
222
|
-
**Typical Values**:
|
|
223
|
-
- DQN papers often use `tau=0.01` or `tau=0.005`
|
|
224
|
-
- `tau=0.001` means only 0.1% update per step
|
|
225
|
-
|
|
226
|
-
**Recommendation**: Try `tau=0.01` or `tau=0.005`
|
|
227
|
-
|
|
228
|
-
### 3. Learning Rate
|
|
229
|
-
|
|
230
|
-
**Config**: `LR: 0.001` (line 12)
|
|
231
|
-
|
|
232
|
-
**Impact**:
|
|
233
|
-
- Might be too high for some environments
|
|
234
|
-
- Could cause instability
|
|
235
|
-
|
|
236
|
-
**Typical Values**:
|
|
237
|
-
- DQN often uses `lr=1e-4` to `lr=5e-4`
|
|
238
|
-
- `lr=0.001` is on the higher side
|
|
239
|
-
|
|
240
|
-
**Recommendation**: Try `lr=5e-4` or `lr=1e-4`
|
|
241
|
-
|
|
242
|
-
### 4. Learn Step Frequency
|
|
243
|
-
|
|
244
|
-
**Config**: `LEARN_STEP: 1` (line 17)
|
|
245
|
-
|
|
246
|
-
**Impact**:
|
|
247
|
-
- Learning every step (with 16 parallel envs, that's 16 steps per environment step)
|
|
248
|
-
- Very frequent learning might cause instability
|
|
249
|
-
- Typical DQN learns every 4-5 steps
|
|
250
|
-
|
|
251
|
-
**Recommendation**: Try `LEARN_STEP: 4` or `LEARN_STEP: 5`
|
|
252
|
-
|
|
253
|
-
## Summary of Critical Issues
|
|
254
|
-
|
|
255
|
-
1. **🔴 HIGH PRIORITY**: Target network initialization might fail silently
|
|
256
|
-
- Check if `init_hook()` actually initializes target network
|
|
257
|
-
- Add fallback to `load_state_dict()` if TensorDict method fails
|
|
258
|
-
|
|
259
|
-
2. **🟡 MEDIUM PRIORITY**: Inconsistent `max()` usage
|
|
260
|
-
- Change `axis=1` to `dim=1` for consistency
|
|
261
|
-
|
|
262
|
-
3. **🟡 MEDIUM PRIORITY**: Consider adding gradient clipping
|
|
263
|
-
- Prevents gradient explosion
|
|
264
|
-
|
|
265
|
-
4. **🟡 MEDIUM PRIORITY**: Hyperparameter tuning
|
|
266
|
-
- `tau=0.001` might be too small
|
|
267
|
-
- `lr=0.001` might be too high
|
|
268
|
-
- `learn_step=1` might be too frequent
|
|
269
|
-
|
|
270
|
-
## Recommended Fixes
|
|
271
|
-
|
|
272
|
-
### Fix 1: Improve Target Network Initialization
|
|
273
|
-
|
|
274
|
-
```python
|
|
275
|
-
def init_hook(self) -> None:
|
|
276
|
-
"""Resets module parameters for the detached and target networks."""
|
|
277
|
-
param_vals: TensorDict = from_module(self.actor).detach()
|
|
278
|
-
target_params: TensorDict = param_vals.clone().lock_()
|
|
279
|
-
|
|
280
|
-
try:
|
|
281
|
-
target_params.to_module(self.actor_target)
|
|
282
|
-
# Verify initialization succeeded
|
|
283
|
-
actor_first_param = next(self.actor.parameters()).data
|
|
284
|
-
target_first_param = next(self.actor_target.parameters()).data
|
|
285
|
-
if not torch.allclose(actor_first_param, target_first_param, atol=1e-5):
|
|
286
|
-
raise RuntimeError("Target network initialization verification failed")
|
|
287
|
-
except (KeyError, RuntimeError) as e:
|
|
288
|
-
warnings.warn(f"TensorDict initialization failed ({e}), using load_state_dict fallback")
|
|
289
|
-
self.actor_target.load_state_dict(self.actor.state_dict())
|
|
290
|
-
finally:
|
|
291
|
-
self.param_vals = param_vals
|
|
292
|
-
self.target_params = target_params
|
|
293
|
-
```
|
|
294
|
-
|
|
295
|
-
### Fix 2: Fix max() Usage
|
|
296
|
-
|
|
297
|
-
```python
|
|
298
|
-
# Line 316
|
|
299
|
-
q_target = self.actor_target(next_obs).max(dim=1)[0].unsqueeze(1)
|
|
300
|
-
```
|
|
301
|
-
|
|
302
|
-
### Fix 3: Add Gradient Clipping (Optional)
|
|
303
|
-
|
|
304
|
-
```python
|
|
305
|
-
# After line 333 (loss.backward())
|
|
306
|
-
from torch.nn.utils import clip_grad_norm_
|
|
307
|
-
clip_grad_norm_(self.actor.parameters(), max_norm=10.0)
|
|
308
|
-
self.optimizer.step()
|
|
309
|
-
```
|