torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
|
|
9
|
+
import torch.nn
|
|
10
|
+
import torch.optim
|
|
11
|
+
from tensordict.nn import InteractionType, TensorDictModule
|
|
12
|
+
from tensordict.nn.distributions import NormalParamExtractor
|
|
13
|
+
from torch.distributions import Categorical
|
|
14
|
+
|
|
15
|
+
from torchrl.collectors import SyncDataCollector
|
|
16
|
+
from torchrl.data import (
|
|
17
|
+
Composite,
|
|
18
|
+
LazyMemmapStorage,
|
|
19
|
+
TensorDictPrioritizedReplayBuffer,
|
|
20
|
+
TensorDictReplayBuffer,
|
|
21
|
+
)
|
|
22
|
+
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
|
23
|
+
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
|
24
|
+
from torchrl.envs import (
|
|
25
|
+
CatTensors,
|
|
26
|
+
Compose,
|
|
27
|
+
DMControlEnv,
|
|
28
|
+
DoubleToFloat,
|
|
29
|
+
EnvCreator,
|
|
30
|
+
InitTracker,
|
|
31
|
+
ParallelEnv,
|
|
32
|
+
RewardSum,
|
|
33
|
+
TransformedEnv,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
|
|
37
|
+
from torchrl.envs.utils import ExplorationType, set_exploration_type
|
|
38
|
+
from torchrl.modules import (
|
|
39
|
+
MLP,
|
|
40
|
+
ProbabilisticActor,
|
|
41
|
+
SafeModule,
|
|
42
|
+
TanhNormal,
|
|
43
|
+
ValueOperator,
|
|
44
|
+
)
|
|
45
|
+
from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate
|
|
46
|
+
from torchrl.record import VideoRecorder
|
|
47
|
+
from torchrl.trainers.helpers.models import ACTIVATIONS
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ====================================================================
|
|
51
|
+
# Environment utils
|
|
52
|
+
# -----------------
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def env_maker(cfg, device="cpu", from_pixels=False):
|
|
56
|
+
lib = cfg.env.backend
|
|
57
|
+
if lib in ("gym", "gymnasium"):
|
|
58
|
+
with set_gym_backend(lib):
|
|
59
|
+
return GymEnv(
|
|
60
|
+
cfg.env.name,
|
|
61
|
+
device=device,
|
|
62
|
+
from_pixels=from_pixels,
|
|
63
|
+
pixels_only=False,
|
|
64
|
+
categorical_action_encoding=True,
|
|
65
|
+
)
|
|
66
|
+
elif lib == "dm_control":
|
|
67
|
+
env = DMControlEnv(
|
|
68
|
+
cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False
|
|
69
|
+
)
|
|
70
|
+
return TransformedEnv(
|
|
71
|
+
env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation")
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
raise NotImplementedError(f"Unknown lib {lib}.")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def apply_env_transforms(
|
|
78
|
+
env,
|
|
79
|
+
):
|
|
80
|
+
transformed_env = TransformedEnv(
|
|
81
|
+
env,
|
|
82
|
+
Compose(
|
|
83
|
+
InitTracker(),
|
|
84
|
+
DoubleToFloat(),
|
|
85
|
+
RewardSum(),
|
|
86
|
+
),
|
|
87
|
+
)
|
|
88
|
+
return transformed_env
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None):
|
|
92
|
+
"""Make environments for training and evaluation."""
|
|
93
|
+
maker = functools.partial(env_maker, cfg)
|
|
94
|
+
parallel_env = ParallelEnv(
|
|
95
|
+
train_num_envs,
|
|
96
|
+
EnvCreator(maker),
|
|
97
|
+
serial_for_single=True,
|
|
98
|
+
)
|
|
99
|
+
parallel_env.set_seed(cfg.env.seed)
|
|
100
|
+
|
|
101
|
+
train_env = apply_env_transforms(parallel_env)
|
|
102
|
+
|
|
103
|
+
maker = functools.partial(env_maker, cfg, from_pixels=cfg.logger.video)
|
|
104
|
+
eval_env = TransformedEnv(
|
|
105
|
+
ParallelEnv(
|
|
106
|
+
eval_num_envs,
|
|
107
|
+
EnvCreator(maker),
|
|
108
|
+
serial_for_single=True,
|
|
109
|
+
),
|
|
110
|
+
train_env.transform.clone(),
|
|
111
|
+
)
|
|
112
|
+
if cfg.logger.video:
|
|
113
|
+
eval_env.insert_transform(
|
|
114
|
+
0, VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
|
|
115
|
+
)
|
|
116
|
+
return train_env, eval_env
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ====================================================================
|
|
120
|
+
# Collector and replay buffer
|
|
121
|
+
# ---------------------------
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def make_collector(cfg, train_env, actor_model_explore, compile_mode):
|
|
125
|
+
"""Make collector."""
|
|
126
|
+
device = cfg.collector.device
|
|
127
|
+
if device in ("", None):
|
|
128
|
+
if torch.cuda.is_available():
|
|
129
|
+
device = torch.device("cuda:0")
|
|
130
|
+
else:
|
|
131
|
+
device = torch.device("cpu")
|
|
132
|
+
collector = SyncDataCollector(
|
|
133
|
+
train_env,
|
|
134
|
+
actor_model_explore,
|
|
135
|
+
frames_per_batch=cfg.collector.frames_per_batch,
|
|
136
|
+
init_random_frames=cfg.collector.init_random_frames,
|
|
137
|
+
max_frames_per_traj=cfg.collector.max_frames_per_traj,
|
|
138
|
+
total_frames=cfg.collector.total_frames,
|
|
139
|
+
device=device,
|
|
140
|
+
compile_policy={"mode": compile_mode} if compile_mode else False,
|
|
141
|
+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
|
|
142
|
+
)
|
|
143
|
+
collector.set_seed(cfg.env.seed)
|
|
144
|
+
return collector
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def make_replay_buffer(
|
|
148
|
+
batch_size,
|
|
149
|
+
prb=False,
|
|
150
|
+
buffer_size=1000000,
|
|
151
|
+
scratch_dir=None,
|
|
152
|
+
device="cpu",
|
|
153
|
+
prefetch=3,
|
|
154
|
+
):
|
|
155
|
+
if prb:
|
|
156
|
+
replay_buffer = TensorDictPrioritizedReplayBuffer(
|
|
157
|
+
alpha=0.7,
|
|
158
|
+
beta=0.5,
|
|
159
|
+
pin_memory=False,
|
|
160
|
+
prefetch=prefetch,
|
|
161
|
+
storage=LazyMemmapStorage(
|
|
162
|
+
buffer_size,
|
|
163
|
+
scratch_dir=scratch_dir,
|
|
164
|
+
device=device,
|
|
165
|
+
),
|
|
166
|
+
batch_size=batch_size,
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
replay_buffer = TensorDictReplayBuffer(
|
|
170
|
+
pin_memory=False,
|
|
171
|
+
prefetch=prefetch,
|
|
172
|
+
storage=LazyMemmapStorage(
|
|
173
|
+
buffer_size,
|
|
174
|
+
scratch_dir=scratch_dir,
|
|
175
|
+
device=device,
|
|
176
|
+
),
|
|
177
|
+
batch_size=batch_size,
|
|
178
|
+
)
|
|
179
|
+
return replay_buffer
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def make_offline_replay_buffer(rb_cfg):
|
|
183
|
+
data = D4RLExperienceReplay(
|
|
184
|
+
dataset_id=rb_cfg.dataset,
|
|
185
|
+
split_trajs=False,
|
|
186
|
+
batch_size=rb_cfg.batch_size,
|
|
187
|
+
# We use drop_last to avoid recompiles (and dynamic shapes)
|
|
188
|
+
sampler=SamplerWithoutReplacement(drop_last=True),
|
|
189
|
+
prefetch=4,
|
|
190
|
+
direct_download=True,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
data.append_transform(DoubleToFloat())
|
|
194
|
+
|
|
195
|
+
return data
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# ====================================================================
|
|
199
|
+
# Model
|
|
200
|
+
# -----
|
|
201
|
+
#
|
|
202
|
+
# We give one version of the model for learning from pixels, and one for state.
|
|
203
|
+
# TorchRL comes in handy at this point, as the high-level interactions with
|
|
204
|
+
# these models is unchanged, regardless of the modality.
|
|
205
|
+
#
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def make_iql_model(cfg, train_env, eval_env, device="cpu"):
|
|
209
|
+
model_cfg = cfg.model
|
|
210
|
+
|
|
211
|
+
in_keys = ["observation"]
|
|
212
|
+
action_spec = train_env.action_spec_unbatched
|
|
213
|
+
actor_net, q_net, value_net = make_iql_modules_state(model_cfg, eval_env)
|
|
214
|
+
|
|
215
|
+
out_keys = ["loc", "scale"]
|
|
216
|
+
|
|
217
|
+
actor_module = TensorDictModule(actor_net, in_keys=in_keys, out_keys=out_keys)
|
|
218
|
+
|
|
219
|
+
# We use a ProbabilisticActor to make sure that we map the
|
|
220
|
+
# network output to the right space using a TanhDelta
|
|
221
|
+
# distribution.
|
|
222
|
+
actor = ProbabilisticActor(
|
|
223
|
+
module=actor_module,
|
|
224
|
+
in_keys=["loc", "scale"],
|
|
225
|
+
spec=action_spec,
|
|
226
|
+
distribution_class=TanhNormal,
|
|
227
|
+
distribution_kwargs={
|
|
228
|
+
"low": action_spec.space.low.to(device),
|
|
229
|
+
"high": action_spec.space.high.to(device),
|
|
230
|
+
"tanh_loc": False,
|
|
231
|
+
},
|
|
232
|
+
default_interaction_type=ExplorationType.RANDOM,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
in_keys = ["observation", "action"]
|
|
236
|
+
|
|
237
|
+
out_keys = ["state_action_value"]
|
|
238
|
+
qvalue = ValueOperator(
|
|
239
|
+
in_keys=in_keys,
|
|
240
|
+
out_keys=out_keys,
|
|
241
|
+
module=q_net,
|
|
242
|
+
)
|
|
243
|
+
in_keys = ["observation"]
|
|
244
|
+
out_keys = ["state_value"]
|
|
245
|
+
value_net = ValueOperator(
|
|
246
|
+
in_keys=in_keys,
|
|
247
|
+
out_keys=out_keys,
|
|
248
|
+
module=value_net,
|
|
249
|
+
)
|
|
250
|
+
model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device)
|
|
251
|
+
# init nets
|
|
252
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
253
|
+
td = eval_env.fake_tensordict()
|
|
254
|
+
td = td.to(device)
|
|
255
|
+
for net in model:
|
|
256
|
+
net(td)
|
|
257
|
+
|
|
258
|
+
return model
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def make_iql_modules_state(model_cfg, proof_environment):
|
|
262
|
+
action_spec = proof_environment.action_spec_unbatched
|
|
263
|
+
|
|
264
|
+
actor_net_kwargs = {
|
|
265
|
+
"num_cells": model_cfg.hidden_sizes,
|
|
266
|
+
"out_features": 2 * action_spec.shape[-1],
|
|
267
|
+
"activation_class": ACTIVATIONS[model_cfg.activation],
|
|
268
|
+
}
|
|
269
|
+
actor_net = MLP(**actor_net_kwargs)
|
|
270
|
+
actor_extractor = NormalParamExtractor(
|
|
271
|
+
scale_mapping=f"biased_softplus_{model_cfg.default_policy_scale}",
|
|
272
|
+
scale_lb=model_cfg.scale_lb,
|
|
273
|
+
)
|
|
274
|
+
actor_net = torch.nn.Sequential(actor_net, actor_extractor)
|
|
275
|
+
|
|
276
|
+
qvalue_net_kwargs = {
|
|
277
|
+
"num_cells": model_cfg.hidden_sizes,
|
|
278
|
+
"out_features": 1,
|
|
279
|
+
"activation_class": ACTIVATIONS[model_cfg.activation],
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
q_net = MLP(**qvalue_net_kwargs)
|
|
283
|
+
|
|
284
|
+
# Define Value Network
|
|
285
|
+
value_net_kwargs = {
|
|
286
|
+
"num_cells": model_cfg.hidden_sizes,
|
|
287
|
+
"out_features": 1,
|
|
288
|
+
"activation_class": ACTIVATIONS[model_cfg.activation],
|
|
289
|
+
}
|
|
290
|
+
value_net = MLP(**value_net_kwargs)
|
|
291
|
+
|
|
292
|
+
return actor_net, q_net, value_net
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def make_discrete_iql_model(cfg, train_env, eval_env, device):
|
|
296
|
+
"""Make discrete IQL agent."""
|
|
297
|
+
# Define Actor Network
|
|
298
|
+
in_keys = ["observation"]
|
|
299
|
+
action_spec = train_env.action_spec_unbatched
|
|
300
|
+
# Define Actor Network
|
|
301
|
+
in_keys = ["observation"]
|
|
302
|
+
|
|
303
|
+
actor_net = MLP(
|
|
304
|
+
num_cells=cfg.model.hidden_sizes,
|
|
305
|
+
out_features=action_spec.space.n,
|
|
306
|
+
activation_class=ACTIVATIONS[cfg.model.activation],
|
|
307
|
+
device=device,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
actor_module = SafeModule(
|
|
311
|
+
module=actor_net,
|
|
312
|
+
in_keys=in_keys,
|
|
313
|
+
out_keys=["logits"],
|
|
314
|
+
)
|
|
315
|
+
actor = ProbabilisticActor(
|
|
316
|
+
spec=Composite(action=eval_env.action_spec_unbatched).to(device),
|
|
317
|
+
module=actor_module,
|
|
318
|
+
in_keys=["logits"],
|
|
319
|
+
out_keys=["action"],
|
|
320
|
+
distribution_class=Categorical,
|
|
321
|
+
distribution_kwargs={},
|
|
322
|
+
default_interaction_type=InteractionType.RANDOM,
|
|
323
|
+
return_log_prob=False,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Define Critic Network
|
|
327
|
+
qvalue_net = MLP(
|
|
328
|
+
num_cells=cfg.model.hidden_sizes,
|
|
329
|
+
out_features=action_spec.space.n,
|
|
330
|
+
activation_class=ACTIVATIONS[cfg.model.activation],
|
|
331
|
+
device=device,
|
|
332
|
+
)
|
|
333
|
+
qvalue = TensorDictModule(
|
|
334
|
+
in_keys=["observation"],
|
|
335
|
+
out_keys=["state_action_value"],
|
|
336
|
+
module=qvalue_net,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Define Value Network
|
|
340
|
+
value_net = MLP(
|
|
341
|
+
num_cells=cfg.model.hidden_sizes,
|
|
342
|
+
out_features=1,
|
|
343
|
+
activation_class=ACTIVATIONS[cfg.model.activation],
|
|
344
|
+
device=device,
|
|
345
|
+
)
|
|
346
|
+
value_net = TensorDictModule(
|
|
347
|
+
in_keys=["observation"],
|
|
348
|
+
out_keys=["state_value"],
|
|
349
|
+
module=value_net,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
model = torch.nn.ModuleList([actor, qvalue, value_net])
|
|
353
|
+
# init nets
|
|
354
|
+
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
|
|
355
|
+
td = eval_env.fake_tensordict()
|
|
356
|
+
td = td.to(device)
|
|
357
|
+
for net in model:
|
|
358
|
+
net(td)
|
|
359
|
+
|
|
360
|
+
return model
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# ====================================================================
|
|
364
|
+
# IQL Loss
|
|
365
|
+
# ---------
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def make_loss(loss_cfg, model, device):
|
|
369
|
+
loss_module = IQLLoss(
|
|
370
|
+
model[0],
|
|
371
|
+
model[1],
|
|
372
|
+
value_network=model[2],
|
|
373
|
+
loss_function=loss_cfg.loss_function,
|
|
374
|
+
temperature=loss_cfg.temperature,
|
|
375
|
+
expectile=loss_cfg.expectile,
|
|
376
|
+
)
|
|
377
|
+
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
|
|
378
|
+
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
|
|
379
|
+
|
|
380
|
+
return loss_module, target_net_updater
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def make_discrete_loss(loss_cfg, model, device):
|
|
384
|
+
loss_module = DiscreteIQLLoss(
|
|
385
|
+
model[0],
|
|
386
|
+
model[1],
|
|
387
|
+
value_network=model[2],
|
|
388
|
+
loss_function=loss_cfg.loss_function,
|
|
389
|
+
temperature=loss_cfg.temperature,
|
|
390
|
+
expectile=loss_cfg.expectile,
|
|
391
|
+
action_space="categorical",
|
|
392
|
+
)
|
|
393
|
+
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
|
|
394
|
+
target_net_updater = HardUpdate(
|
|
395
|
+
loss_module, value_network_update_interval=loss_cfg.hard_update_interval
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
return loss_module, target_net_updater
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def make_iql_optimizer(optim_cfg, loss_module):
|
|
402
|
+
critic_params = list(loss_module.qvalue_network_params.flatten_keys().values())
|
|
403
|
+
actor_params = list(loss_module.actor_network_params.flatten_keys().values())
|
|
404
|
+
value_params = list(loss_module.value_network_params.flatten_keys().values())
|
|
405
|
+
|
|
406
|
+
optimizer_actor = torch.optim.Adam(
|
|
407
|
+
actor_params,
|
|
408
|
+
lr=optim_cfg.lr,
|
|
409
|
+
weight_decay=optim_cfg.weight_decay,
|
|
410
|
+
)
|
|
411
|
+
optimizer_critic = torch.optim.Adam(
|
|
412
|
+
critic_params,
|
|
413
|
+
lr=optim_cfg.lr,
|
|
414
|
+
weight_decay=optim_cfg.weight_decay,
|
|
415
|
+
)
|
|
416
|
+
optimizer_value = torch.optim.Adam(
|
|
417
|
+
value_params,
|
|
418
|
+
lr=optim_cfg.lr,
|
|
419
|
+
weight_decay=optim_cfg.weight_decay,
|
|
420
|
+
)
|
|
421
|
+
return optimizer_actor, optimizer_critic, optimizer_value
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
# ====================================================================
|
|
425
|
+
# General utils
|
|
426
|
+
# ---------
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def log_metrics(logger, metrics, step):
|
|
430
|
+
if logger is not None:
|
|
431
|
+
for metric_name, metric_value in metrics.items():
|
|
432
|
+
logger.log_scalar(metric_name, metric_value, step)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def dump_video(module):
|
|
436
|
+
if isinstance(module, VideoRecorder):
|
|
437
|
+
module.dump()
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Multi-agent examples
|
|
2
|
+
|
|
3
|
+
In this folder we provide a set of multi-agent example scripts using the [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) simulator.
|
|
4
|
+
|
|
5
|
+
<p align="center">
|
|
6
|
+
<img src="https://pytorch.s3.amazonaws.com/torchrl/github-artifacts/img/marl_vmas.png" width="600px">
|
|
7
|
+
</p>
|
|
8
|
+
|
|
9
|
+
<center><i>The MARL algorithms contained in the scripts of this folder run on three multi-robot tasks in VMAS.</i></center>
|
|
10
|
+
|
|
11
|
+
For more details on the experiment setup and the environments please refer to the corresponding section of the appendix in the [TorchRL paper](https://arxiv.org/abs/2306.00577).
|
|
12
|
+
|
|
13
|
+
> [!NOTE]
|
|
14
|
+
> If you are interested in Multi-Agent Reinforcement Learning (MARL) in TorchRL, check out [BenchMARL](https://github.com/facebookresearch/BenchMARL):
|
|
15
|
+
> a benchmarking library where you
|
|
16
|
+
> can train and compare MARL algorithms, tasks, and models using TorchRL!
|
|
17
|
+
|
|
18
|
+
## Using the scripts
|
|
19
|
+
|
|
20
|
+
### Install
|
|
21
|
+
|
|
22
|
+
First you need to install vmas and the dependencies of the scripts.
|
|
23
|
+
|
|
24
|
+
Install torchrl and tensordict following repo instructions.
|
|
25
|
+
|
|
26
|
+
Install vmas and dependencies:
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install vmas
|
|
30
|
+
pip install wandb "moviepy<2.0.0"
|
|
31
|
+
pip install hydra-core
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
### Run
|
|
35
|
+
|
|
36
|
+
To run the scripts just execute the corresponding python file after having modified the corresponding config
|
|
37
|
+
according to your needs.
|
|
38
|
+
The config can be found in the .yaml file with the same name.
|
|
39
|
+
|
|
40
|
+
For example:
|
|
41
|
+
```bash
|
|
42
|
+
python mappo_ippo.py
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
You can even change the config from the command line like:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
python mappo_ippo.py --m env.scenario_name=navigation
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Computational demand
|
|
52
|
+
The scripts are set up for collecting many frames, if your compute is limited, you can change the "frames_per_batch"
|
|
53
|
+
and "num_epochs" parameters to reduce compute requirements.
|
|
54
|
+
|
|
55
|
+
### Script structure
|
|
56
|
+
|
|
57
|
+
The scripts are self-contained.
|
|
58
|
+
This means that all the code you will need to look at is contained in the script file.
|
|
59
|
+
No helper functions are used.
|
|
60
|
+
|
|
61
|
+
The structure of scripts follows this order:
|
|
62
|
+
- Configuration dictionary for the script
|
|
63
|
+
- Environment creation
|
|
64
|
+
- Modules creation
|
|
65
|
+
- Collector instantiation
|
|
66
|
+
- Replay buffer instantiation
|
|
67
|
+
- Loss module creation
|
|
68
|
+
- Training loop (with inner minibatch loops)
|
|
69
|
+
- Evaluation run (at the desired frequency)
|
|
70
|
+
|
|
71
|
+
Logging is done by default to wandb.
|
|
72
|
+
The logging backend can be changed in the config files to one of "wandb", "tensorboard", "csv", "mlflow".
|
|
73
|
+
|
|
74
|
+
All the scripts follow the same on-policy training structure so that results can be compared across different algorithms.
|