agilerl 2.5.0.dev0__tar.gz → 2.5.0.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.pre-commit-config.yaml +2 -1
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/PKG-INFO +4 -4
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/base.py +9 -14
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ippo.py +10 -4
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/maddpg.py +10 -4
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/matd3.py +10 -4
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/actors.py +0 -13
- agilerl-2.5.0.dev0/agilerl/networks/distributions_experimental.py → agilerl-2.5.0.dev2/agilerl/networks/distributions.py +91 -202
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/algo_utils.py +65 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/probe_envs.py +7 -1
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/probe_envs_ma.py +11 -2
- agilerl-2.5.0.dev2/agilerl/utils/torch_utils.py +617 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo.yaml +2 -3
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/pyproject.toml +21 -4
- agilerl-2.5.0.dev2/sitecustomize.py +12 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/conftest.py +1 -1
- agilerl-2.5.0.dev2/tests/subprocess_runner.py +210 -0
- agilerl-2.5.0.dev2/tests/test_algorithms/test_llms/conftest.py +117 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/test_dpo.py +90 -86
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/test_grpo.py +104 -88
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_ippo.py +63 -1
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_maddpg.py +61 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/test_matd3.py +63 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ppo.py +26 -3
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_actors.py +206 -337
- agilerl-2.5.0.dev2/tests/test_networks/test_distributions.py +184 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_train/test_train.py +1 -1
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_train/test_train_llm.py +4 -4
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_algo_utils.py +77 -0
- agilerl-2.5.0.dev2/tests/test_utils/test_torch_utils.py +329 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_utils.py +1 -1
- agilerl-2.5.0.dev2/tests/utils.py +310 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/uv.lock +613 -258
- agilerl-2.5.0.dev0/.coveragerc +0 -10
- agilerl-2.5.0.dev0/agilerl/networks/distributions.py +0 -530
- agilerl-2.5.0.dev0/agilerl/utils/torch_utils.py +0 -114
- agilerl-2.5.0.dev0/tests/test_algorithms/test_llms/conftest.py +0 -95
- agilerl-2.5.0.dev0/tests/test_networks/test_distributions.py +0 -209
- agilerl-2.5.0.dev0/tests/test_utils/test_torch_utils.py +0 -106
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/badges/arena-github-badge.svg +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/workflows/codeql.yml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.github/workflows/python-app.yml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.gitignore +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/.readthedocs.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CITATION.cff +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CODE_OF_CONDUCT.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/CONTRIBUTING.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/LICENSE +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/README.md +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/bc_lm.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/optimizer_wrapper.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/core/registry.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/cqn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ddpg.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dpo.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dqn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/dqn_rainbow.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/grpo.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ilql.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/neural_ts_bandit.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/neural_ucb_bandit.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/ppo.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/algorithms/td3.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/data.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/multi_agent_replay_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/replay_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/rollout_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/sampler.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/components/segment_tree.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/language_environment.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/rl_data.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/tokenizer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/data/torch_datasets.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/mutation.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/hpo/tournament.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/base.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/bert.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/cnn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/configs.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/custom_components.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/dummy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/gpt.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/lstm.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/mlp.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/multi_input.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/resnet.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/modules/simba.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/base.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/custom_modules.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/q_networks.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/networks/value_networks.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/protocols.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/rollouts/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/rollouts/on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_bandits.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_llm.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_multi_agent_off_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_multi_agent_on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_off_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_offline.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/training/train_on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/typing.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/cache.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/evolvable_networks.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/ilql_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/llm_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/log_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/minari_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/sampling_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/utils/utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/pz_async_vec_env.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/vector/pz_vec_env.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/agent.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/learning.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/make_evolvable.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/pettingzoo_wrappers.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/agilerl/wrappers/utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_bandits.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_dpo.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_grpo.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_multi_agent_off_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_multi_agent_on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_off_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_off_policy_distributed.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_offline.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_offline_distributed.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_rainbow.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_recurrent.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_resnet.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/benchmarking_simba.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/configs/ds_config.json +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/make_evolvable_benchmarking.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/benchmarking/networks.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/accelerate/accelerate.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/accelerate/grpo_accelerate_config.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/bandit/neural_ts.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/bandit/neural_ucb.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/cqn.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg_lstm.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ddpg/ddpg_simba.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dpo.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn_lstm.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/dqn/dqn_rainbow.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/grpo.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/ippo.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/ippo_pong.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/maddpg.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_agent/matd3.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/multi_input.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo_image.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/ppo/ppo_recurrent.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/configs/training/td3.yaml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/cartpole/cartpole_random_v1.1.0.h5 +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/cartpole/cartpole_v1.1.0.h5 +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/pendulum/pendulum_random_v1.1.0.h5 +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/data/pendulum/pendulum_v1.1.0.h5 +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_bandit.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_custom_network.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_multi_agent.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_off_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_off_policy_distributed.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_offline.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_offline_distributed.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_cartpole.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_memory.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/demo_on_policy_rnn_minigrid.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_cartpole.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_lunar_lander.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_lunar_lander_rnn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/demos/performance_flamegraph_rnn_memory.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/dependabot.yml +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/Makefile +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/arena-github-badge.svg +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/css/custom.css +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/favicon.ico +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/js/expand_sidebar.js +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/logo_teal.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/logo_white.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/module.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/network.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/iris-thumbnail.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/pendigits-thumbnail.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/rainbow_performance.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/_static/thumbnails/simba_thumbnail.png +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/base.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/cql.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ddpg.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dpo.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dqn.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/dqn_rainbow.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/grpo.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ilql.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ippo.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/maddpg.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/matd3.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/neural_ts.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/neural_ucb.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/ppo.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/registry.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/td3.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/algorithms/wrappers.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/data.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/multi_agent_replay_buffer.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/replay_buffer.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/rollout_buffer.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/sampler.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/components/segment_tree.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/mutation.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/hpo/tournament.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/base.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/bert.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/cnn.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/custom_activation.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/dummy.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/gpt.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/lstm.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/mlp.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/multi_input.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/resnet.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/modules/simba.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/actors.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/base.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/q_networks.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/networks/value_networks.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/rollouts/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/rollouts/on_policy.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/train.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/algo_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/cache.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/evolvable_networks.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/ilql_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/llm_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/log_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/minari_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/probe_envs.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/torch_utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/utils/utils.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/petting_zoo_async_vector_env.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/vector/petting_zoo_vector_env.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/agent.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/learning.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/make_evolvable.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/api/wrappers/pettingzoo.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/bandits/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/conf.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/custom_algorithms/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/debugging_rl/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/distributed_training/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/evo_hyperparam_opt/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/evolvable_networks/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/get_started/agilerl2changes.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/get_started/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/llm_finetuning/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/make.bat +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/multi_agent_training/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/off_policy/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/offline_training/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/on_policy/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/pomdp/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/releases/index.rst +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/docs/requirements.txt +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/helper_functions.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/pz_vector_test_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/test_neural_ts.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bandits/test_neural_ucb.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_base.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_bc_lm.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_llms/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_multi_agent/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_optimizer_wrapper.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_registry.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_cqn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ddpg.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_dqn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_dqn_rainbow.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_ilql.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_algorithms/test_single_agent/test_td3.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_multi_agent_replay_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_replay_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_replay_data.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_rollout_buffer.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_sampler.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_components/test_segment_tree.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_data.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/test_mutation.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_hpo/test_tournament.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_base.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_bert.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_cnn.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_custom_activation.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_dummy.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_gpt.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_lstm.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_make_evolvable.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_mlp.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_multi_input.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_resnet.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_modules/test_simba.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_base.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_q_networks.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_networks/test_value_functions.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_protocols.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_cache.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_ilql_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_llm_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_log_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_minari_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_probe_envs.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_probe_envs_ma.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_sampling_utils.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_utils/test_utils_evolvable.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_vector/test_vector.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/__init__.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_agent.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_autoreset.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_bandit_env.py +0 -0
- {agilerl-2.5.0.dev0 → agilerl-2.5.0.dev2}/tests/test_wrappers/test_skills.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agilerl
|
|
3
|
-
Version: 2.5.0.
|
|
3
|
+
Version: 2.5.0.dev2
|
|
4
4
|
Summary: AgileRL is a deep reinforcement learning library focused on improving RL development through RLOps.
|
|
5
5
|
Author-email: Nick Ustaran-Anderegg <dev@agilerl.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -29,7 +29,7 @@ Requires-Dist: redis~=4.4.4
|
|
|
29
29
|
Requires-Dist: supersuit~=3.9.0
|
|
30
30
|
Requires-Dist: tensordict~=0.8
|
|
31
31
|
Requires-Dist: termcolor~=1.1.0
|
|
32
|
-
Requires-Dist: torch==2.
|
|
32
|
+
Requires-Dist: torch==2.9.0
|
|
33
33
|
Requires-Dist: tqdm>=4.66.4
|
|
34
34
|
Requires-Dist: wandb~=0.18.0
|
|
35
35
|
Provides-Extra: all
|
|
@@ -37,13 +37,13 @@ Requires-Dist: datasets==4.4.1; extra == 'all'
|
|
|
37
37
|
Requires-Dist: deepspeed~=0.17.1; extra == 'all'
|
|
38
38
|
Requires-Dist: peft~=0.18.0; extra == 'all'
|
|
39
39
|
Requires-Dist: transformers~=4.57.1; extra == 'all'
|
|
40
|
-
Requires-Dist: vllm
|
|
40
|
+
Requires-Dist: vllm==0.13.0; extra == 'all'
|
|
41
41
|
Provides-Extra: llm
|
|
42
42
|
Requires-Dist: datasets==4.4.1; extra == 'llm'
|
|
43
43
|
Requires-Dist: deepspeed~=0.17.1; extra == 'llm'
|
|
44
44
|
Requires-Dist: peft~=0.18.0; extra == 'llm'
|
|
45
45
|
Requires-Dist: transformers~=4.57.1; extra == 'llm'
|
|
46
|
-
Requires-Dist: vllm
|
|
46
|
+
Requires-Dist: vllm==0.13.0; extra == 'llm'
|
|
47
47
|
Description-Content-Type: text/markdown
|
|
48
48
|
|
|
49
49
|
<p align="center">
|
|
@@ -1561,27 +1561,16 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1561
1561
|
nan_arr = np.empty(self.action_dims[agent_id])
|
|
1562
1562
|
nan_arr[:] = np.nan
|
|
1563
1563
|
else:
|
|
1564
|
-
nan_arr = np.array([
|
|
1564
|
+
nan_arr = np.array([np.nan])
|
|
1565
1565
|
|
|
1566
1566
|
env_defined_actions[agent_id] = nan_arr
|
|
1567
1567
|
val = nan_arr
|
|
1568
1568
|
|
|
1569
1569
|
# Handle discrete actions + env not vectorized
|
|
1570
1570
|
if isinstance(val, (int, float)):
|
|
1571
|
-
val = np.array([
|
|
1571
|
+
val = np.array([val])
|
|
1572
1572
|
env_defined_actions[agent_id] = val
|
|
1573
1573
|
|
|
1574
|
-
# Ensure additional dimension is added in so shapes align for masking
|
|
1575
|
-
if isinstance(val, np.ndarray) and len(val.shape) == 1:
|
|
1576
|
-
val = (
|
|
1577
|
-
val[:, np.newaxis]
|
|
1578
|
-
if isinstance(
|
|
1579
|
-
self.possible_action_spaces[agent_id],
|
|
1580
|
-
spaces.Discrete,
|
|
1581
|
-
)
|
|
1582
|
-
else val[np.newaxis, :]
|
|
1583
|
-
)
|
|
1584
|
-
env_defined_actions[agent_id] = val
|
|
1585
1574
|
agent_masks[agent_id] = np.where(
|
|
1586
1575
|
np.isnan(env_defined_actions[agent_id]),
|
|
1587
1576
|
0,
|
|
@@ -1814,6 +1803,12 @@ class MultiAgentRLAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
1814
1803
|
for i, agent_id in enumerate(agent_ids):
|
|
1815
1804
|
output_dict[agent_id] = group_outputs[group_id][i]
|
|
1816
1805
|
|
|
1806
|
+
if (
|
|
1807
|
+
isinstance(self.possible_action_spaces[agent_id], spaces.Discrete)
|
|
1808
|
+
and output_dict[agent_id].shape[-1] == 1
|
|
1809
|
+
):
|
|
1810
|
+
output_dict[agent_id] = output_dict[agent_id].squeeze(-1)
|
|
1811
|
+
|
|
1817
1812
|
return output_dict
|
|
1818
1813
|
|
|
1819
1814
|
def sum_shared_rewards(self, rewards: ArrayDict) -> ArrayDict:
|
|
@@ -2302,7 +2297,7 @@ class LLMAlgorithm(EvolvableAlgorithm, ABC):
|
|
|
2302
2297
|
None,
|
|
2303
2298
|
None,
|
|
2304
2299
|
)
|
|
2305
|
-
if hasattr(self, "llm"):
|
|
2300
|
+
if hasattr(self, "llm") and self.llm is not None:
|
|
2306
2301
|
del self.llm.llm_engine.model_executor
|
|
2307
2302
|
del self.llm
|
|
2308
2303
|
gc.collect()
|
|
@@ -27,6 +27,7 @@ from agilerl.typing import (
|
|
|
27
27
|
TorchObsType,
|
|
28
28
|
)
|
|
29
29
|
from agilerl.utils.algo_utils import (
|
|
30
|
+
apply_env_defined_actions,
|
|
30
31
|
concatenate_experiences_into_batches,
|
|
31
32
|
concatenate_tensors,
|
|
32
33
|
get_experiences_samples,
|
|
@@ -601,10 +602,15 @@ class IPPO(MultiAgentRLAlgorithm):
|
|
|
601
602
|
|
|
602
603
|
# If using env_defined_actions replace actions
|
|
603
604
|
if env_defined_actions is not None:
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
605
|
+
action_dict = apply_env_defined_actions(
|
|
606
|
+
unique_agents_ids,
|
|
607
|
+
action_dict,
|
|
608
|
+
env_defined_actions,
|
|
609
|
+
agent_masks,
|
|
610
|
+
discrete_actions=isinstance(
|
|
611
|
+
next(iter(self.action_space.values())), spaces.Discrete
|
|
612
|
+
),
|
|
613
|
+
)
|
|
608
614
|
|
|
609
615
|
return (
|
|
610
616
|
action_dict,
|
|
@@ -25,6 +25,7 @@ from agilerl.typing import (
|
|
|
25
25
|
SupportedObsSpaces,
|
|
26
26
|
)
|
|
27
27
|
from agilerl.utils.algo_utils import (
|
|
28
|
+
apply_env_defined_actions,
|
|
28
29
|
concatenate_spaces,
|
|
29
30
|
format_shared_critic_encoder,
|
|
30
31
|
get_deepest_head_config,
|
|
@@ -515,10 +516,15 @@ class MADDPG(MultiAgentRLAlgorithm):
|
|
|
515
516
|
|
|
516
517
|
# If using env_defined_actions replace actions
|
|
517
518
|
if env_defined_actions is not None:
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
519
|
+
action_dict = apply_env_defined_actions(
|
|
520
|
+
self.agent_ids,
|
|
521
|
+
processed_action_dict,
|
|
522
|
+
env_defined_actions,
|
|
523
|
+
agent_masks,
|
|
524
|
+
discrete_actions=isinstance(
|
|
525
|
+
next(iter(self.action_space.values())), spaces.Discrete
|
|
526
|
+
),
|
|
527
|
+
)
|
|
522
528
|
|
|
523
529
|
return processed_action_dict, action_dict
|
|
524
530
|
|
|
@@ -24,6 +24,7 @@ from agilerl.typing import (
|
|
|
24
24
|
StandardTensorDict,
|
|
25
25
|
)
|
|
26
26
|
from agilerl.utils.algo_utils import (
|
|
27
|
+
apply_env_defined_actions,
|
|
27
28
|
concatenate_spaces,
|
|
28
29
|
format_shared_critic_encoder,
|
|
29
30
|
get_deepest_head_config,
|
|
@@ -575,10 +576,15 @@ class MATD3(MultiAgentRLAlgorithm):
|
|
|
575
576
|
|
|
576
577
|
# If using env_defined_actions replace actions
|
|
577
578
|
if env_defined_actions is not None:
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
579
|
+
action_dict = apply_env_defined_actions(
|
|
580
|
+
self.agent_ids,
|
|
581
|
+
processed_action_dict,
|
|
582
|
+
env_defined_actions,
|
|
583
|
+
agent_masks,
|
|
584
|
+
discrete_actions=isinstance(
|
|
585
|
+
next(iter(self.action_space.values())), spaces.Discrete
|
|
586
|
+
),
|
|
587
|
+
)
|
|
582
588
|
|
|
583
589
|
return processed_action_dict, action_dict
|
|
584
590
|
|
|
@@ -252,9 +252,6 @@ class StochasticActor(EvolvableNetwork):
|
|
|
252
252
|
:type recurrent: bool
|
|
253
253
|
:param device: Device to use for the network.
|
|
254
254
|
:type device: str
|
|
255
|
-
:param use_experimental_distribution: Whether to use the experimental distribution implementation, which
|
|
256
|
-
includes several optimizations related to using torch primitives for statistics calculations. Defaults to False.
|
|
257
|
-
:type use_experimental_distribution: bool
|
|
258
255
|
:param random_seed: Random seed to use for the network. Defaults to None.
|
|
259
256
|
:type random_seed: int | None
|
|
260
257
|
:param encoder_name: Name of the encoder network.
|
|
@@ -284,7 +281,6 @@ class StochasticActor(EvolvableNetwork):
|
|
|
284
281
|
simba: bool = False,
|
|
285
282
|
recurrent: bool = False,
|
|
286
283
|
device: str = "cpu",
|
|
287
|
-
use_experimental_distribution: bool = False,
|
|
288
284
|
random_seed: int | None = None,
|
|
289
285
|
encoder_name: str = "encoder",
|
|
290
286
|
) -> None:
|
|
@@ -312,7 +308,6 @@ class StochasticActor(EvolvableNetwork):
|
|
|
312
308
|
self.action_std_init = action_std_init
|
|
313
309
|
self.squash_output = squash_output
|
|
314
310
|
self.action_space = action_space
|
|
315
|
-
self.use_experimental_distribution = use_experimental_distribution
|
|
316
311
|
self.output_size = get_output_size_from_space(self.action_space)
|
|
317
312
|
|
|
318
313
|
self.build_network_head(head_config)
|
|
@@ -332,14 +327,6 @@ class StochasticActor(EvolvableNetwork):
|
|
|
332
327
|
else:
|
|
333
328
|
self.action_low, self.action_high = None, None
|
|
334
329
|
|
|
335
|
-
# Wrap the network in an EvolvableDistribution
|
|
336
|
-
if use_experimental_distribution:
|
|
337
|
-
from agilerl.networks.distributions_experimental import (
|
|
338
|
-
EvolvableDistribution,
|
|
339
|
-
)
|
|
340
|
-
else:
|
|
341
|
-
from agilerl.networks.distributions import EvolvableDistribution
|
|
342
|
-
|
|
343
330
|
self.head_net = EvolvableDistribution(
|
|
344
331
|
action_space=action_space,
|
|
345
332
|
network=self.head_net,
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
import
|
|
1
|
+
from typing import Union
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
|
-
import torch.nn.functional as F
|
|
6
5
|
from gymnasium import spaces
|
|
7
6
|
|
|
8
7
|
from agilerl.modules.base import EvolvableModule, EvolvableWrapper
|
|
9
8
|
from agilerl.typing import ArrayOrTensor, DeviceType, NetConfigType
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
9
|
+
from agilerl.utils.torch_utils import (
|
|
10
|
+
entropy_from_space,
|
|
11
|
+
log_prob_from_space,
|
|
12
|
+
sample_from_space,
|
|
13
|
+
)
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def apply_action_mask_discrete(
|
|
16
|
-
logits: torch.Tensor,
|
|
17
|
-
mask: torch.Tensor,
|
|
17
|
+
logits: torch.Tensor, mask: torch.Tensor
|
|
18
18
|
) -> torch.Tensor:
|
|
19
19
|
"""Apply a mask to the logits.
|
|
20
20
|
|
|
@@ -37,185 +37,74 @@ class TorchDistribution:
|
|
|
37
37
|
:param action_space: Action space of the environment.
|
|
38
38
|
:type action_space: spaces.Space
|
|
39
39
|
:param logits: Logits.
|
|
40
|
-
:type logits: torch.Tensor
|
|
40
|
+
:type logits: torch.Tensor | None
|
|
41
41
|
:param mu: Mean.
|
|
42
|
-
:type mu: torch.Tensor
|
|
42
|
+
:type mu: torch.Tensor | None
|
|
43
43
|
:param log_std: Log standard deviation.
|
|
44
|
-
:type log_std: torch.Tensor
|
|
44
|
+
:type log_std: torch.Tensor | None
|
|
45
45
|
:param squash_output: Whether to squash the output to the action space.
|
|
46
46
|
:type squash_output: bool
|
|
47
|
-
|
|
48
47
|
"""
|
|
49
48
|
|
|
50
49
|
def __init__(
|
|
51
50
|
self,
|
|
52
51
|
*,
|
|
53
52
|
action_space: spaces.Space,
|
|
54
|
-
logits:
|
|
55
|
-
|
|
56
|
-
) = None, # for discrete / multidiscrete / multibinary
|
|
57
|
-
mu: torch.Tensor | None = None, # for Box
|
|
53
|
+
logits: torch.Tensor | None = None,
|
|
54
|
+
mu: torch.Tensor | None = None,
|
|
58
55
|
log_std: torch.Tensor | None = None,
|
|
59
56
|
squash_output: bool = False,
|
|
60
|
-
)
|
|
57
|
+
):
|
|
61
58
|
self.action_space = action_space
|
|
62
|
-
self.logits
|
|
59
|
+
self.logits = logits
|
|
60
|
+
self.mu = mu
|
|
61
|
+
self.log_std = log_std
|
|
63
62
|
self.squash_output = squash_output and isinstance(action_space, spaces.Box)
|
|
64
63
|
self._sampled_action: torch.Tensor | None = None
|
|
65
64
|
|
|
66
|
-
# ------------------------------------------------------------------ #
|
|
67
|
-
# fast tensor-only primitives #
|
|
68
|
-
# ------------------------------------------------------------------ #
|
|
69
65
|
def sample(self) -> torch.Tensor:
|
|
70
|
-
|
|
71
|
-
probs = torch.softmax(self.logits, dim=-1)
|
|
72
|
-
self._sampled_action = torch.multinomial(probs, 1).squeeze(-1)
|
|
73
|
-
return self._sampled_action
|
|
74
|
-
|
|
75
|
-
if isinstance(self.action_space, spaces.Box):
|
|
76
|
-
eps = torch.randn_like(self.mu)
|
|
77
|
-
out = self.mu + torch.exp(self.log_std) * eps
|
|
78
|
-
if self.squash_output:
|
|
79
|
-
out = torch.tanh(out)
|
|
80
|
-
self._sampled_action = out
|
|
81
|
-
return out
|
|
82
|
-
|
|
83
|
-
# -------- MultiDiscrete --------
|
|
84
|
-
if isinstance(self.action_space, spaces.MultiDiscrete):
|
|
85
|
-
actions = []
|
|
86
|
-
offset = 0
|
|
87
|
-
for size in self.action_space.nvec:
|
|
88
|
-
logits_i = self.logits[:, offset : offset + size]
|
|
89
|
-
probs_i = torch.softmax(logits_i, dim=-1)
|
|
90
|
-
act_i = torch.multinomial(probs_i, 1).squeeze(-1)
|
|
91
|
-
actions.append(act_i)
|
|
92
|
-
offset += size
|
|
93
|
-
self._sampled_action = torch.stack(actions, dim=-1)
|
|
94
|
-
return self._sampled_action
|
|
95
|
-
|
|
96
|
-
# -------- MultiBinary --------
|
|
97
|
-
if isinstance(self.action_space, spaces.MultiBinary):
|
|
98
|
-
probs = torch.sigmoid(self.logits)
|
|
99
|
-
self._sampled_action = torch.bernoulli(
|
|
100
|
-
probs,
|
|
101
|
-
) # Ensures float tensor, removed .to(torch.int64)
|
|
102
|
-
return self._sampled_action
|
|
66
|
+
"""Sample from the distribution for the given action space.
|
|
103
67
|
|
|
104
|
-
|
|
105
|
-
|
|
68
|
+
:return: Sampled action.
|
|
69
|
+
:rtype: torch.Tensor
|
|
70
|
+
"""
|
|
71
|
+
self._sampled_action = sample_from_space(
|
|
72
|
+
self.action_space,
|
|
73
|
+
logits=self.logits,
|
|
74
|
+
mu=self.mu,
|
|
75
|
+
log_std=self.log_std,
|
|
76
|
+
squash_output=self.squash_output,
|
|
77
|
+
)
|
|
78
|
+
return self._sampled_action
|
|
106
79
|
|
|
107
80
|
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
|
108
|
-
|
|
109
|
-
log_p_all = torch.log_softmax(self.logits, dim=-1) # Shape (B, N_actions)
|
|
110
|
-
action_long = action.long()
|
|
111
|
-
|
|
112
|
-
action_indices_for_gather: torch.Tensor
|
|
113
|
-
|
|
114
|
-
if action_long.ndim == log_p_all.ndim - 1: # action_long is (B,)
|
|
115
|
-
action_indices_for_gather = action_long.unsqueeze(
|
|
116
|
-
-1,
|
|
117
|
-
) # Converts to (B,1)
|
|
118
|
-
elif action_long.ndim == log_p_all.ndim: # action_long is (B, K)
|
|
119
|
-
if action_long.shape[-1] == 1: # action_long is (B,1)
|
|
120
|
-
action_indices_for_gather = action_long
|
|
121
|
-
elif (
|
|
122
|
-
action_long.shape == log_p_all.shape
|
|
123
|
-
and hasattr(self.action_space, "n")
|
|
124
|
-
and action_long.shape[-1] == self.action_space.n
|
|
125
|
-
):
|
|
126
|
-
# Special handling for test case: action is (B, N_actions) for Discrete(N_actions)
|
|
127
|
-
# Use argmax to get the action index.
|
|
128
|
-
action_indices_for_gather = torch.argmax(
|
|
129
|
-
action_long,
|
|
130
|
-
dim=-1,
|
|
131
|
-
keepdim=True,
|
|
132
|
-
) # Converts (B, N_actions) to (B,1)
|
|
133
|
-
else:
|
|
134
|
-
msg = (
|
|
135
|
-
f"Action shape {action.shape} is not compatible with Discrete space. "
|
|
136
|
-
f"Expected (batch_size,), (batch_size, 1), or (batch_size, num_actions) for argmax case. "
|
|
137
|
-
f"Logits shape: {log_p_all.shape}. Action space: {self.action_space}"
|
|
138
|
-
)
|
|
139
|
-
raise ValueError(
|
|
140
|
-
msg,
|
|
141
|
-
)
|
|
142
|
-
else:
|
|
143
|
-
msg = (
|
|
144
|
-
f"Action tensor ndim {action.ndim} is not compatible with Discrete space logits ndim {log_p_all.ndim}. "
|
|
145
|
-
f"Expected action ndim to be {log_p_all.ndim - 1} or {log_p_all.ndim}."
|
|
146
|
-
)
|
|
147
|
-
raise ValueError(
|
|
148
|
-
msg,
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
return log_p_all.gather(-1, action_indices_for_gather).squeeze(-1)
|
|
152
|
-
|
|
153
|
-
if isinstance(self.action_space, spaces.Box):
|
|
154
|
-
var = torch.exp(2 * self.log_std)
|
|
155
|
-
return (
|
|
156
|
-
-0.5
|
|
157
|
-
* (
|
|
158
|
-
((action - self.mu) ** 2) / var
|
|
159
|
-
+ 2 * self.log_std
|
|
160
|
-
+ math.log(2 * math.pi)
|
|
161
|
-
)
|
|
162
|
-
).sum(-1)
|
|
163
|
-
|
|
164
|
-
# -------- MultiDiscrete --------
|
|
165
|
-
if isinstance(self.action_space, spaces.MultiDiscrete):
|
|
166
|
-
logps = []
|
|
167
|
-
offset = 0
|
|
168
|
-
for idx, size in enumerate(self.action_space.nvec):
|
|
169
|
-
logits_i = self.logits[:, offset : offset + size]
|
|
170
|
-
logp_all = torch.log_softmax(logits_i, dim=-1)
|
|
171
|
-
act_i = action[:, idx].long()
|
|
172
|
-
logp_i = logp_all.gather(-1, act_i.unsqueeze(-1)).squeeze(-1)
|
|
173
|
-
logps.append(logp_i)
|
|
174
|
-
offset += size
|
|
175
|
-
return torch.stack(logps, dim=-1).sum(-1)
|
|
176
|
-
|
|
177
|
-
# -------- MultiBinary --------
|
|
178
|
-
if isinstance(self.action_space, spaces.MultiBinary):
|
|
179
|
-
# log sigma(x) and log (1-sigma(x))
|
|
180
|
-
log_p1 = -F.softplus(-self.logits)
|
|
181
|
-
log_p0 = -self.logits + log_p1
|
|
182
|
-
a = (
|
|
183
|
-
action.float()
|
|
184
|
-
) # Action for MultiBinary is expected to be float (0.0 or 1.0)
|
|
185
|
-
return (a * log_p1 + (1.0 - a) * log_p0).sum(-1)
|
|
81
|
+
"""Log probability of the action.
|
|
186
82
|
|
|
187
|
-
|
|
83
|
+
:param action: Action.
|
|
84
|
+
:type action: torch.Tensor
|
|
85
|
+
:return: Log probability of the action.
|
|
86
|
+
:rtype: torch.Tensor
|
|
87
|
+
"""
|
|
88
|
+
return log_prob_from_space(
|
|
89
|
+
self.action_space,
|
|
90
|
+
action,
|
|
91
|
+
logits=self.logits,
|
|
92
|
+
mu=self.mu,
|
|
93
|
+
log_std=self.log_std,
|
|
94
|
+
)
|
|
188
95
|
|
|
189
96
|
def entropy(self) -> torch.Tensor:
|
|
190
|
-
|
|
191
|
-
p = torch.softmax(self.logits, dim=-1)
|
|
192
|
-
return -(p * torch.log(p + 1e-8)).sum(-1)
|
|
97
|
+
"""Entropy of the distribution.
|
|
193
98
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
for size in self.action_space.nvec:
|
|
204
|
-
logits_i = self.logits[:, offset : offset + size]
|
|
205
|
-
p_i = torch.softmax(logits_i, dim=-1)
|
|
206
|
-
ent_i = -(p_i * torch.log(p_i + 1e-8)).sum(-1)
|
|
207
|
-
entropies.append(ent_i)
|
|
208
|
-
offset += size
|
|
209
|
-
return torch.stack(entropies, dim=-1).sum(-1)
|
|
210
|
-
|
|
211
|
-
# -------- MultiBinary --------
|
|
212
|
-
if isinstance(self.action_space, spaces.MultiBinary):
|
|
213
|
-
p = torch.sigmoid(self.logits)
|
|
214
|
-
return -(p * torch.log(p + 1e-8) + (1 - p) * torch.log(1 - p + 1e-8)).sum(
|
|
215
|
-
-1,
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
raise NotImplementedError
|
|
99
|
+
:return: Entropy of the distribution.
|
|
100
|
+
:rtype: torch.Tensor
|
|
101
|
+
"""
|
|
102
|
+
return entropy_from_space(
|
|
103
|
+
self.action_space,
|
|
104
|
+
logits=self.logits,
|
|
105
|
+
mu=self.mu,
|
|
106
|
+
log_std=self.log_std,
|
|
107
|
+
)
|
|
219
108
|
|
|
220
109
|
|
|
221
110
|
class EvolvableDistribution(EvolvableWrapper):
|
|
@@ -247,7 +136,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
247
136
|
action_std_init: float = 0.0,
|
|
248
137
|
squash_output: bool = False,
|
|
249
138
|
device: DeviceType = "cpu",
|
|
250
|
-
)
|
|
139
|
+
):
|
|
251
140
|
super().__init__(network)
|
|
252
141
|
|
|
253
142
|
self.action_space = action_space
|
|
@@ -263,7 +152,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
263
152
|
if isinstance(action_space, spaces.Box):
|
|
264
153
|
self.log_std = torch.nn.Parameter(
|
|
265
154
|
torch.ones(1, np.prod(action_space.shape), device=device)
|
|
266
|
-
* action_std_init
|
|
155
|
+
* action_std_init
|
|
267
156
|
)
|
|
268
157
|
|
|
269
158
|
@property
|
|
@@ -286,7 +175,6 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
286
175
|
# Normal distribution for Continuous action spaces
|
|
287
176
|
if isinstance(self.action_space, spaces.Box):
|
|
288
177
|
log_std = self.log_std.expand_as(logits)
|
|
289
|
-
# Pass mu and log_std directly to TorchDistribution
|
|
290
178
|
return TorchDistribution(
|
|
291
179
|
action_space=self.action_space,
|
|
292
180
|
mu=logits,
|
|
@@ -295,20 +183,30 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
295
183
|
)
|
|
296
184
|
|
|
297
185
|
# Categorical distribution for Discrete action spaces
|
|
298
|
-
if isinstance(
|
|
299
|
-
self.action_space,
|
|
300
|
-
(spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary),
|
|
301
|
-
):
|
|
302
|
-
# Pass logits directly to TorchDistribution
|
|
186
|
+
if isinstance(self.action_space, spaces.Discrete):
|
|
303
187
|
return TorchDistribution(
|
|
304
188
|
action_space=self.action_space,
|
|
305
189
|
logits=logits,
|
|
306
|
-
squash_output=self.squash_output,
|
|
190
|
+
squash_output=self.squash_output,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# List of categorical distributions for MultiDiscrete action spaces
|
|
194
|
+
if isinstance(self.action_space, spaces.MultiDiscrete):
|
|
195
|
+
return TorchDistribution(
|
|
196
|
+
action_space=self.action_space,
|
|
197
|
+
logits=logits,
|
|
198
|
+
squash_output=self.squash_output,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Bernoulli distribution for MultiBinary action spaces
|
|
202
|
+
if isinstance(self.action_space, spaces.MultiBinary):
|
|
203
|
+
return TorchDistribution(
|
|
204
|
+
action_space=self.action_space,
|
|
205
|
+
logits=logits,
|
|
206
|
+
squash_output=self.squash_output,
|
|
307
207
|
)
|
|
308
208
|
msg = f"Action space {self.action_space} not supported."
|
|
309
|
-
raise NotImplementedError(
|
|
310
|
-
msg,
|
|
311
|
-
)
|
|
209
|
+
raise NotImplementedError(msg)
|
|
312
210
|
|
|
313
211
|
def log_prob(self, action: torch.Tensor) -> torch.Tensor:
|
|
314
212
|
"""Get the log probability of the action.
|
|
@@ -322,7 +220,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
322
220
|
msg = "Distribution not initialized. Call forward first."
|
|
323
221
|
raise ValueError(msg)
|
|
324
222
|
|
|
325
|
-
#
|
|
223
|
+
# Handles squashing correction internally for Box space
|
|
326
224
|
return self.dist.log_prob(action)
|
|
327
225
|
|
|
328
226
|
def entropy(self) -> torch.Tensor:
|
|
@@ -335,7 +233,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
335
233
|
msg = "Distribution not initialized. Call forward first."
|
|
336
234
|
raise ValueError(msg)
|
|
337
235
|
|
|
338
|
-
#
|
|
236
|
+
# Returns analytical entropy for supported spaces
|
|
339
237
|
return self.dist.entropy()
|
|
340
238
|
|
|
341
239
|
def apply_mask(self, logits: torch.Tensor, mask: ArrayOrTensor) -> torch.Tensor:
|
|
@@ -350,7 +248,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
350
248
|
"""
|
|
351
249
|
# Convert mask to tensor and reshape to match logits shape
|
|
352
250
|
mask = torch.as_tensor(mask, dtype=torch.bool, device=self.device).view(
|
|
353
|
-
logits.shape
|
|
251
|
+
logits.shape
|
|
354
252
|
)
|
|
355
253
|
|
|
356
254
|
if isinstance(self.action_space, spaces.Discrete):
|
|
@@ -360,7 +258,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
360
258
|
list(self.action_space.nvec)
|
|
361
259
|
if isinstance(self.action_space, spaces.MultiDiscrete)
|
|
362
260
|
else [
|
|
363
|
-
self.action_space.n
|
|
261
|
+
self.action_space.n
|
|
364
262
|
] # For MultiBinary, nvec is not present, use n
|
|
365
263
|
)
|
|
366
264
|
# Split mask and logits into separate distributions
|
|
@@ -370,12 +268,10 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
370
268
|
# Apply mask to each split
|
|
371
269
|
masked_logits = []
|
|
372
270
|
for split_logits_i, split_mask_i in zip(
|
|
373
|
-
split_logits,
|
|
374
|
-
|
|
375
|
-
strict=False,
|
|
376
|
-
): # Renamed for clarity
|
|
271
|
+
split_logits, split_masks, strict=False
|
|
272
|
+
):
|
|
377
273
|
masked_logits.append(
|
|
378
|
-
apply_action_mask_discrete(split_logits_i, split_mask_i)
|
|
274
|
+
apply_action_mask_discrete(split_logits_i, split_mask_i)
|
|
379
275
|
)
|
|
380
276
|
|
|
381
277
|
masked_logits = torch.cat(masked_logits, dim=1)
|
|
@@ -383,9 +279,7 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
383
279
|
# This should ideally not be reached if get_distribution handles the space,
|
|
384
280
|
# but keeping for safety.
|
|
385
281
|
msg = f"Action space {self.action_space} not supported for masking."
|
|
386
|
-
raise NotImplementedError(
|
|
387
|
-
msg,
|
|
388
|
-
)
|
|
282
|
+
raise NotImplementedError(msg)
|
|
389
283
|
|
|
390
284
|
return masked_logits
|
|
391
285
|
|
|
@@ -394,20 +288,19 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
394
288
|
latent: torch.Tensor,
|
|
395
289
|
action_mask: ArrayOrTensor | None = None,
|
|
396
290
|
sample: bool = True,
|
|
397
|
-
) ->
|
|
398
|
-
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
399
|
-
|
|
400
|
-
):
|
|
291
|
+
) -> Union[
|
|
292
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[None, None, torch.Tensor]
|
|
293
|
+
]:
|
|
401
294
|
"""Forward pass of the network.
|
|
402
295
|
|
|
403
296
|
:param latent: Latent space representation.
|
|
404
297
|
:type latent: torch.Tensor
|
|
405
298
|
:param action_mask: Mask to apply to the logits. Defaults to None.
|
|
406
|
-
:type action_mask: ArrayOrTensor
|
|
299
|
+
:type action_mask: Optional[ArrayOrTensor]
|
|
407
300
|
:param sample: Whether to sample an action or return the mode/mean. Defaults to True.
|
|
408
301
|
:type sample: bool
|
|
409
302
|
:return: Action and log probability of the action.
|
|
410
|
-
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
|
303
|
+
:rtype: Union[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[None, torch.Tensor, torch.Tensor]]
|
|
411
304
|
"""
|
|
412
305
|
logits = self.wrapped(latent)
|
|
413
306
|
|
|
@@ -415,7 +308,8 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
415
308
|
if isinstance(action_mask, (np.ndarray, list)):
|
|
416
309
|
# Attempt to stack if it's a list of arrays or object array, typical for vectorized envs
|
|
417
310
|
if isinstance(action_mask, list) or (
|
|
418
|
-
isinstance(action_mask, np.ndarray)
|
|
311
|
+
isinstance(action_mask, np.ndarray)
|
|
312
|
+
and action_mask.dtype == np.object_
|
|
419
313
|
):
|
|
420
314
|
try:
|
|
421
315
|
action_mask = np.stack(action_mask)
|
|
@@ -428,15 +322,12 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
428
322
|
# Ensure action_mask is a tensor before applying.
|
|
429
323
|
# The view in apply_mask expects a compatible shape or will error.
|
|
430
324
|
action_mask = torch.as_tensor(
|
|
431
|
-
action_mask,
|
|
432
|
-
device=self.device,
|
|
433
|
-
dtype=torch.bool,
|
|
325
|
+
action_mask, device=self.device, dtype=torch.bool
|
|
434
326
|
)
|
|
435
327
|
|
|
436
328
|
logits = self.apply_mask(logits, action_mask)
|
|
437
329
|
|
|
438
330
|
# Distribution from logits
|
|
439
|
-
# get_distribution now creates the new TorchDistribution object
|
|
440
331
|
self.dist = self.get_distribution(logits)
|
|
441
332
|
|
|
442
333
|
# Sample action, compute log probability and entropy
|
|
@@ -444,16 +335,14 @@ class EvolvableDistribution(EvolvableWrapper):
|
|
|
444
335
|
action = self.dist.sample()
|
|
445
336
|
log_prob = self.dist.log_prob(action)
|
|
446
337
|
else:
|
|
447
|
-
action = None
|
|
448
|
-
log_prob =
|
|
449
|
-
None # Log prob of mode/mean typically not used in PPO sample step
|
|
450
|
-
)
|
|
338
|
+
action = None
|
|
339
|
+
log_prob = None
|
|
451
340
|
|
|
452
341
|
entropy = self.dist.entropy()
|
|
453
342
|
return action, log_prob, entropy
|
|
454
343
|
|
|
455
344
|
def clone(self) -> "EvolvableDistribution":
|
|
456
|
-
"""
|
|
345
|
+
"""Clones the distribution.
|
|
457
346
|
|
|
458
347
|
:return: Cloned distribution.
|
|
459
348
|
:rtype: EvolvableDistribution
|