torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import pickle
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import ray
|
|
16
|
+
|
|
17
|
+
import vmas
|
|
18
|
+
from matplotlib import pyplot as plt
|
|
19
|
+
from ray import tune
|
|
20
|
+
|
|
21
|
+
from ray.rllib.agents.ppo import PPOTrainer
|
|
22
|
+
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
|
23
|
+
from ray.tune import register_env
|
|
24
|
+
from torchrl._utils import logger as torchrl_logger
|
|
25
|
+
from torchrl.collectors import SyncDataCollector
|
|
26
|
+
from torchrl.envs.libs.vmas import VmasEnv
|
|
27
|
+
from vmas import Wrapper
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def store_pickled_evaluation(name: str, evaluation: dict):
|
|
31
|
+
save_folder = f"{os.path.dirname(os.path.realpath(__file__))}"
|
|
32
|
+
file = f"{save_folder}/{name}.pkl"
|
|
33
|
+
|
|
34
|
+
pickle.dump(evaluation, open(file, "wb"))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_pickled_evaluation(
|
|
38
|
+
name: str,
|
|
39
|
+
):
|
|
40
|
+
save_folder = f"{os.path.dirname(os.path.realpath(__file__))}"
|
|
41
|
+
file = Path(f"{save_folder}/{name}.pkl")
|
|
42
|
+
|
|
43
|
+
if file.is_file():
|
|
44
|
+
return pickle.load(open(file, "rb"))
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def run_vmas_torchrl(
|
|
49
|
+
scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0
|
|
50
|
+
):
|
|
51
|
+
env = VmasEnv(
|
|
52
|
+
scenario_name,
|
|
53
|
+
device=device,
|
|
54
|
+
num_envs=n_envs,
|
|
55
|
+
continuous_actions=False,
|
|
56
|
+
seed=seed,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
collector = SyncDataCollector(
|
|
60
|
+
env,
|
|
61
|
+
policy=None,
|
|
62
|
+
device=device,
|
|
63
|
+
frames_per_batch=n_envs * n_steps,
|
|
64
|
+
total_frames=n_envs * n_steps,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
init_time = time.time()
|
|
68
|
+
|
|
69
|
+
for _data in collector:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
total_time = time.time() - init_time
|
|
73
|
+
collector.shutdown()
|
|
74
|
+
return total_time
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def run_vmas_rllib(
|
|
78
|
+
scenario_name: str, n_envs: int, n_steps: int, device: str, seed: int = 0
|
|
79
|
+
):
|
|
80
|
+
class TimerCallback(DefaultCallbacks):
|
|
81
|
+
result_time = None
|
|
82
|
+
|
|
83
|
+
def on_train_result(
|
|
84
|
+
self,
|
|
85
|
+
*,
|
|
86
|
+
algorithm,
|
|
87
|
+
result: dict,
|
|
88
|
+
**kwargs,
|
|
89
|
+
) -> None:
|
|
90
|
+
TimerCallback.result_time = (
|
|
91
|
+
result["timers"]["training_iteration_time_ms"]
|
|
92
|
+
- result["timers"]["learn_time_ms"]
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def env_creator(config: dict):
|
|
96
|
+
env = vmas.make_env(
|
|
97
|
+
scenario=config["scenario_name"],
|
|
98
|
+
num_envs=config["num_envs"],
|
|
99
|
+
device=config["device"],
|
|
100
|
+
continuous_actions=False,
|
|
101
|
+
wrapper=Wrapper.RLLIB,
|
|
102
|
+
)
|
|
103
|
+
return env
|
|
104
|
+
|
|
105
|
+
if not ray.is_initialized():
|
|
106
|
+
ray.init()
|
|
107
|
+
register_env(scenario_name, lambda config: env_creator(config))
|
|
108
|
+
|
|
109
|
+
num_gpus = 0.5 if device == "cuda" else 0
|
|
110
|
+
num_gpus_per_worker = 0.5 if device == "cuda" else 0
|
|
111
|
+
tune.run(
|
|
112
|
+
PPOTrainer,
|
|
113
|
+
stop={"training_iteration": 1},
|
|
114
|
+
config={
|
|
115
|
+
"seed": seed,
|
|
116
|
+
"framework": "torch",
|
|
117
|
+
"env": scenario_name,
|
|
118
|
+
"train_batch_size": n_envs * n_steps,
|
|
119
|
+
"rollout_fragment_length": n_steps,
|
|
120
|
+
"sgd_minibatch_size": n_envs * n_steps,
|
|
121
|
+
"num_gpus": num_gpus,
|
|
122
|
+
"num_workers": 0,
|
|
123
|
+
"num_gpus_per_worker": num_gpus_per_worker,
|
|
124
|
+
"num_envs_per_worker": n_envs,
|
|
125
|
+
"batch_mode": "truncate_episodes",
|
|
126
|
+
"env_config": {
|
|
127
|
+
"device": device,
|
|
128
|
+
"num_envs": n_envs,
|
|
129
|
+
"scenario_name": scenario_name,
|
|
130
|
+
"max_steps": n_steps,
|
|
131
|
+
},
|
|
132
|
+
"callbacks": TimerCallback,
|
|
133
|
+
},
|
|
134
|
+
)
|
|
135
|
+
assert TimerCallback.result_time is not None
|
|
136
|
+
TimerCallback.result_time /= 1_000 # convert to seconds
|
|
137
|
+
return TimerCallback.result_time
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def run_comparison_torchrl_rllib(
|
|
141
|
+
scenario_name: str,
|
|
142
|
+
device: str,
|
|
143
|
+
n_steps: int = 100,
|
|
144
|
+
max_n_envs: int = 3000,
|
|
145
|
+
step_n_envs: int = 3,
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
scenario_name (str): name of scenario to benchmark
|
|
151
|
+
device (str): device to ron comparison on ("cpu" or "cuda")
|
|
152
|
+
n_steps (int): number of environment steps
|
|
153
|
+
max_n_envs (int): the maximum number of parallel environments to test
|
|
154
|
+
step_n_envs (int): the step size in number of environments from 1 to max_n_envs
|
|
155
|
+
|
|
156
|
+
"""
|
|
157
|
+
list_n_envs = np.linspace(1, max_n_envs, step_n_envs)
|
|
158
|
+
|
|
159
|
+
figure_name = f"VMAS_{scenario_name}_{n_steps}_{device}_steps_rllib_vs_torchrl"
|
|
160
|
+
figure_name_pkl = figure_name + f"_range_{1}_{max_n_envs}_num_{step_n_envs}"
|
|
161
|
+
|
|
162
|
+
evaluation = load_pickled_evaluation(figure_name_pkl)
|
|
163
|
+
if not evaluation:
|
|
164
|
+
evaluation = {}
|
|
165
|
+
for framework in ["TorchRL", "RLlib"]:
|
|
166
|
+
if framework not in evaluation.keys():
|
|
167
|
+
torchrl_logger.info(f"\nFramework {framework}")
|
|
168
|
+
vmas_times = []
|
|
169
|
+
for n_envs in list_n_envs:
|
|
170
|
+
n_envs = int(n_envs)
|
|
171
|
+
torchrl_logger.info(f"Running {n_envs} environments")
|
|
172
|
+
if framework == "TorchRL":
|
|
173
|
+
vmas_times.append(
|
|
174
|
+
(n_envs * n_steps)
|
|
175
|
+
/ run_vmas_torchrl(
|
|
176
|
+
scenario_name=scenario_name,
|
|
177
|
+
n_envs=n_envs,
|
|
178
|
+
n_steps=n_steps,
|
|
179
|
+
device=device,
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
vmas_times.append(
|
|
184
|
+
(n_envs * n_steps)
|
|
185
|
+
/ run_vmas_rllib(
|
|
186
|
+
scenario_name=scenario_name,
|
|
187
|
+
n_envs=n_envs,
|
|
188
|
+
n_steps=n_steps,
|
|
189
|
+
device=device,
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
torchrl_logger.info(f"fps {vmas_times[-1]}s")
|
|
193
|
+
evaluation[framework] = vmas_times
|
|
194
|
+
|
|
195
|
+
store_pickled_evaluation(name=figure_name_pkl, evaluation=evaluation)
|
|
196
|
+
|
|
197
|
+
fig, ax = plt.subplots()
|
|
198
|
+
for key, item in evaluation.items():
|
|
199
|
+
ax.plot(
|
|
200
|
+
list_n_envs,
|
|
201
|
+
item,
|
|
202
|
+
label=key,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
plt.xlabel("Number of batched environments", fontsize=14)
|
|
206
|
+
plt.ylabel("Frames per second", fontsize=14)
|
|
207
|
+
ax.legend(loc="upper left")
|
|
208
|
+
|
|
209
|
+
ax.set_title(
|
|
210
|
+
f"Execution time of '{scenario_name}' for {n_steps} steps on {device}.",
|
|
211
|
+
fontsize=8,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
save_folder = os.path.dirname(os.path.realpath(__file__))
|
|
215
|
+
plt.savefig(f"{save_folder}/{figure_name}.pdf")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
if __name__ == "__main__":
|
|
219
|
+
# pip install matplotlib
|
|
220
|
+
# pip install "ray[rllib]"==2.1.0
|
|
221
|
+
# pip install torchrl
|
|
222
|
+
# pip install vmas
|
|
223
|
+
# pip install numpy==1.23.5
|
|
224
|
+
|
|
225
|
+
run_comparison_torchrl_rllib(
|
|
226
|
+
scenario_name="simple_spread",
|
|
227
|
+
device="cuda",
|
|
228
|
+
n_steps=100,
|
|
229
|
+
max_n_envs=30000,
|
|
230
|
+
step_n_envs=10,
|
|
231
|
+
)
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
Sample latency benchmarking (using RPC)
|
|
8
|
+
======================================
|
|
9
|
+
A rough benchmark of sample latency using different storage types over the network using `torch.rpc`.
|
|
10
|
+
Run this script with --rank=0 and --rank=1 flags set in separate processes - these ranks correspond to the trainer worker and buffer worker respectively, and both need to be initialised.
|
|
11
|
+
e.g. to benchmark LazyMemmapStorage, run the following commands using either two separate shells or multiprocessing.
|
|
12
|
+
- python3 benchmark_sample_latency_over_rpc.py --rank=0 --storage=LazyMemmapStorage
|
|
13
|
+
- python3 benchmark_sample_latency_over_rpc.py --rank=1 --storage=LazyMemmapStorage
|
|
14
|
+
This code is based on examples/distributed/distributed_replay_buffer.py.
|
|
15
|
+
"""
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
import pickle
|
|
19
|
+
import sys
|
|
20
|
+
import time
|
|
21
|
+
import timeit
|
|
22
|
+
from datetime import datetime
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.distributed.rpc as rpc
|
|
26
|
+
from tensordict import TensorDict
|
|
27
|
+
from torchrl._utils import logger as torchrl_logger
|
|
28
|
+
from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer
|
|
29
|
+
from torchrl.data.replay_buffers.samplers import RandomSampler
|
|
30
|
+
from torchrl.data.replay_buffers.storages import (
|
|
31
|
+
LazyMemmapStorage,
|
|
32
|
+
LazyTensorStorage,
|
|
33
|
+
ListStorage,
|
|
34
|
+
)
|
|
35
|
+
from torchrl.data.replay_buffers.writers import RoundRobinWriter
|
|
36
|
+
|
|
37
|
+
RETRY_LIMIT = 2
|
|
38
|
+
RETRY_DELAY_SECS = 3
|
|
39
|
+
REPLAY_BUFFER_NODE = "ReplayBuffer"
|
|
40
|
+
TRAINER_NODE = "Trainer"
|
|
41
|
+
TENSOR_SIZE = 3 * 86 * 86
|
|
42
|
+
BUFFER_SIZE = 1001
|
|
43
|
+
BATCH_SIZE = 256
|
|
44
|
+
REPEATS = 1000
|
|
45
|
+
|
|
46
|
+
storage_options = {
|
|
47
|
+
"LazyMemmapStorage": LazyMemmapStorage,
|
|
48
|
+
"LazyTensorStorage": LazyTensorStorage,
|
|
49
|
+
"ListStorage": ListStorage,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
storage_arg_options = {
|
|
53
|
+
"LazyMemmapStorage": {"scratch_dir": "/tmp/", "device": torch.device("cpu")},
|
|
54
|
+
"LazyTensorStorage": {},
|
|
55
|
+
"ListStorage": {},
|
|
56
|
+
}
|
|
57
|
+
parser = argparse.ArgumentParser(
|
|
58
|
+
description="RPC Replay Buffer Example",
|
|
59
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--rank",
|
|
64
|
+
type=int,
|
|
65
|
+
default=-1,
|
|
66
|
+
help="Node Rank [0 = Replay Buffer, 1 = Dummy Trainer, 2+ = Dummy Data Collector]",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--storage",
|
|
71
|
+
type=str,
|
|
72
|
+
default="LazyMemmapStorage",
|
|
73
|
+
help="Storage type [LazyMemmapStorage, LazyTensorStorage, ListStorage]",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class DummyTrainerNode:
|
|
78
|
+
def __init__(self) -> None:
|
|
79
|
+
self.id = rpc.get_worker_info().id
|
|
80
|
+
self.replay_buffer = self._create_replay_buffer()
|
|
81
|
+
self._ret = None
|
|
82
|
+
|
|
83
|
+
def train(self, batch_size: int) -> None:
|
|
84
|
+
start_time = timeit.default_timer()
|
|
85
|
+
ret = rpc.rpc_sync(
|
|
86
|
+
self.replay_buffer.owner(),
|
|
87
|
+
ReplayBufferNode.sample,
|
|
88
|
+
args=(self.replay_buffer, batch_size),
|
|
89
|
+
)
|
|
90
|
+
if storage_type == "ListStorage":
|
|
91
|
+
self._ret = ret[0]
|
|
92
|
+
else:
|
|
93
|
+
if self._ret is None:
|
|
94
|
+
self._ret = ret
|
|
95
|
+
else:
|
|
96
|
+
self._ret.update_(ret)
|
|
97
|
+
# make sure the content is read
|
|
98
|
+
self._ret["observation"] + 1
|
|
99
|
+
self._ret["next_observation"] + 1
|
|
100
|
+
return timeit.default_timer() - start_time
|
|
101
|
+
|
|
102
|
+
def _create_replay_buffer(self) -> rpc.RRef:
|
|
103
|
+
while True:
|
|
104
|
+
try:
|
|
105
|
+
replay_buffer_info = rpc.get_worker_info(REPLAY_BUFFER_NODE)
|
|
106
|
+
buffer_rref = rpc.remote(
|
|
107
|
+
replay_buffer_info, ReplayBufferNode, args=(1000000,)
|
|
108
|
+
)
|
|
109
|
+
torchrl_logger.info(f"Connected to replay buffer {replay_buffer_info}")
|
|
110
|
+
return buffer_rref
|
|
111
|
+
except Exception:
|
|
112
|
+
torchrl_logger.info("Failed to connect to replay buffer")
|
|
113
|
+
time.sleep(RETRY_DELAY_SECS)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ReplayBufferNode(RemoteTensorDictReplayBuffer):
|
|
117
|
+
def __init__(self, capacity: int):
|
|
118
|
+
super().__init__(
|
|
119
|
+
storage=storage_options[storage_type](
|
|
120
|
+
max_size=capacity, **storage_arg_options[storage_type]
|
|
121
|
+
),
|
|
122
|
+
sampler=RandomSampler(),
|
|
123
|
+
writer=RoundRobinWriter(),
|
|
124
|
+
collate_fn=lambda x: x,
|
|
125
|
+
)
|
|
126
|
+
tds = TensorDict(
|
|
127
|
+
{
|
|
128
|
+
"observation": torch.randn(
|
|
129
|
+
BUFFER_SIZE,
|
|
130
|
+
TENSOR_SIZE,
|
|
131
|
+
),
|
|
132
|
+
"next_observation": torch.randn(
|
|
133
|
+
BUFFER_SIZE,
|
|
134
|
+
TENSOR_SIZE,
|
|
135
|
+
),
|
|
136
|
+
},
|
|
137
|
+
batch_size=[BUFFER_SIZE],
|
|
138
|
+
)
|
|
139
|
+
self.extend(tds)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if __name__ == "__main__":
|
|
143
|
+
args = parser.parse_args()
|
|
144
|
+
rank = args.rank
|
|
145
|
+
storage_type = args.storage
|
|
146
|
+
|
|
147
|
+
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")
|
|
148
|
+
|
|
149
|
+
os.environ["MASTER_ADDR"] = "localhost"
|
|
150
|
+
os.environ["MASTER_PORT"] = "29500"
|
|
151
|
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
152
|
+
options = rpc.TensorPipeRpcBackendOptions(
|
|
153
|
+
num_worker_threads=16, init_method="tcp://localhost:10002", rpc_timeout=120
|
|
154
|
+
)
|
|
155
|
+
if rank == 0:
|
|
156
|
+
# rank 0 is the trainer
|
|
157
|
+
rpc.init_rpc(
|
|
158
|
+
TRAINER_NODE,
|
|
159
|
+
rank=rank,
|
|
160
|
+
backend=rpc.BackendType.TENSORPIPE,
|
|
161
|
+
rpc_backend_options=options,
|
|
162
|
+
)
|
|
163
|
+
trainer = DummyTrainerNode()
|
|
164
|
+
results = []
|
|
165
|
+
for i in range(REPEATS):
|
|
166
|
+
result = trainer.train(batch_size=BATCH_SIZE)
|
|
167
|
+
if i == 0:
|
|
168
|
+
continue
|
|
169
|
+
results.append(result)
|
|
170
|
+
torchrl_logger.info(f"{i}, {results[-1]}")
|
|
171
|
+
|
|
172
|
+
with open(
|
|
173
|
+
f'./benchmark_{datetime.now().strftime("%d-%m-%Y%H:%M:%S")};batch_size={BATCH_SIZE};tensor_size={TENSOR_SIZE};repeat={REPEATS};storage={storage_type}.pkl',
|
|
174
|
+
"wb+",
|
|
175
|
+
) as f:
|
|
176
|
+
pickle.dump(results, f)
|
|
177
|
+
|
|
178
|
+
tensor_results = torch.tensor(results)
|
|
179
|
+
torchrl_logger.info(f"Mean: {torch.mean(tensor_results)}")
|
|
180
|
+
breakpoint()
|
|
181
|
+
elif rank == 1:
|
|
182
|
+
# rank 1 is the replay buffer
|
|
183
|
+
# replay buffer waits passively for construction instructions from trainer node
|
|
184
|
+
rpc.init_rpc(
|
|
185
|
+
REPLAY_BUFFER_NODE,
|
|
186
|
+
rank=rank,
|
|
187
|
+
backend=rpc.BackendType.TENSORPIPE,
|
|
188
|
+
rpc_backend_options=options,
|
|
189
|
+
)
|
|
190
|
+
breakpoint()
|
|
191
|
+
else:
|
|
192
|
+
sys.exit(1)
|
|
193
|
+
rpc.shutdown()
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
import argparse
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch.cuda
|
|
10
|
+
import tqdm
|
|
11
|
+
|
|
12
|
+
from torchrl.collectors import (
|
|
13
|
+
MultiaSyncDataCollector,
|
|
14
|
+
MultiSyncDataCollector,
|
|
15
|
+
SyncDataCollector,
|
|
16
|
+
)
|
|
17
|
+
from torchrl.data import LazyTensorStorage, ReplayBuffer
|
|
18
|
+
from torchrl.data.utils import CloudpickleWrapper
|
|
19
|
+
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
|
|
20
|
+
from torchrl.envs.libs.dm_control import DMControlEnv
|
|
21
|
+
from torchrl.modules import RandomPolicy
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def single_collector_setup():
|
|
25
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
26
|
+
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
|
|
27
|
+
c = SyncDataCollector(
|
|
28
|
+
env,
|
|
29
|
+
RandomPolicy(env.action_spec),
|
|
30
|
+
total_frames=-1,
|
|
31
|
+
frames_per_batch=100,
|
|
32
|
+
device=device,
|
|
33
|
+
)
|
|
34
|
+
c = iter(c)
|
|
35
|
+
for i, _ in enumerate(c):
|
|
36
|
+
if i == 10:
|
|
37
|
+
break
|
|
38
|
+
return ((c,), {})
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def sync_collector_setup():
|
|
42
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
43
|
+
env = EnvCreator(
|
|
44
|
+
lambda: TransformedEnv(
|
|
45
|
+
DMControlEnv("cheetah", "run", device=device), StepCounter(50)
|
|
46
|
+
)
|
|
47
|
+
)
|
|
48
|
+
c = MultiSyncDataCollector(
|
|
49
|
+
[env, env],
|
|
50
|
+
RandomPolicy(env().action_spec),
|
|
51
|
+
total_frames=-1,
|
|
52
|
+
frames_per_batch=100,
|
|
53
|
+
device=device,
|
|
54
|
+
)
|
|
55
|
+
c = iter(c)
|
|
56
|
+
for i, _ in enumerate(c):
|
|
57
|
+
if i == 10:
|
|
58
|
+
break
|
|
59
|
+
return ((c,), {})
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def async_collector_setup():
|
|
63
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
64
|
+
env = EnvCreator(
|
|
65
|
+
lambda: TransformedEnv(
|
|
66
|
+
DMControlEnv("cheetah", "run", device=device), StepCounter(50)
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
c = MultiaSyncDataCollector(
|
|
70
|
+
[env, env],
|
|
71
|
+
RandomPolicy(env().action_spec),
|
|
72
|
+
total_frames=-1,
|
|
73
|
+
frames_per_batch=100,
|
|
74
|
+
device=device,
|
|
75
|
+
)
|
|
76
|
+
c = iter(c)
|
|
77
|
+
for i, _ in enumerate(c):
|
|
78
|
+
if i == 10:
|
|
79
|
+
break
|
|
80
|
+
return ((c,), {})
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def single_collector_setup_pixels():
|
|
84
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
85
|
+
# env = TransformedEnv(
|
|
86
|
+
# DMControlEnv("cheetah", "run", device=device, from_pixels=True), StepCounter(50)
|
|
87
|
+
# )
|
|
88
|
+
env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50))
|
|
89
|
+
c = SyncDataCollector(
|
|
90
|
+
env,
|
|
91
|
+
RandomPolicy(env.action_spec),
|
|
92
|
+
total_frames=-1,
|
|
93
|
+
frames_per_batch=100,
|
|
94
|
+
device=device,
|
|
95
|
+
)
|
|
96
|
+
c = iter(c)
|
|
97
|
+
for i, _ in enumerate(c):
|
|
98
|
+
if i == 10:
|
|
99
|
+
break
|
|
100
|
+
return ((c,), {})
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def sync_collector_setup_pixels():
|
|
104
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
105
|
+
env = EnvCreator(
|
|
106
|
+
lambda: TransformedEnv(
|
|
107
|
+
# DMControlEnv("cheetah", "run", device=device, from_pixels=True),
|
|
108
|
+
GymEnv("ALE/Pong-v5"),
|
|
109
|
+
StepCounter(50),
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
c = MultiSyncDataCollector(
|
|
113
|
+
[env, env],
|
|
114
|
+
RandomPolicy(env().action_spec),
|
|
115
|
+
total_frames=-1,
|
|
116
|
+
frames_per_batch=100,
|
|
117
|
+
device=device,
|
|
118
|
+
)
|
|
119
|
+
c = iter(c)
|
|
120
|
+
for i, _ in enumerate(c):
|
|
121
|
+
if i == 10:
|
|
122
|
+
break
|
|
123
|
+
return ((c,), {})
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def async_collector_setup_pixels():
|
|
127
|
+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
|
|
128
|
+
env = EnvCreator(
|
|
129
|
+
lambda: TransformedEnv(
|
|
130
|
+
# DMControlEnv("cheetah", "run", device=device, from_pixels=True),
|
|
131
|
+
GymEnv("ALE/Pong-v5"),
|
|
132
|
+
StepCounter(50),
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
c = MultiaSyncDataCollector(
|
|
136
|
+
[env, env],
|
|
137
|
+
RandomPolicy(env().action_spec),
|
|
138
|
+
total_frames=-1,
|
|
139
|
+
frames_per_batch=100,
|
|
140
|
+
device=device,
|
|
141
|
+
)
|
|
142
|
+
c = iter(c)
|
|
143
|
+
for i, _ in enumerate(c):
|
|
144
|
+
if i == 10:
|
|
145
|
+
break
|
|
146
|
+
return ((c,), {})
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def execute_collector(c):
|
|
150
|
+
# will run for 9 iterations (1 during setup)
|
|
151
|
+
next(c)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_single(benchmark):
|
|
155
|
+
(c,), _ = single_collector_setup()
|
|
156
|
+
benchmark(execute_collector, c)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def test_sync(benchmark):
|
|
160
|
+
(c,), _ = sync_collector_setup()
|
|
161
|
+
benchmark(execute_collector, c)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_async(benchmark):
|
|
165
|
+
(c,), _ = async_collector_setup()
|
|
166
|
+
benchmark(execute_collector, c)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
|
|
170
|
+
def test_single_pixels(benchmark):
|
|
171
|
+
(c,), _ = single_collector_setup_pixels()
|
|
172
|
+
benchmark(execute_collector, c)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
|
|
176
|
+
def test_sync_pixels(benchmark):
|
|
177
|
+
(c,), _ = sync_collector_setup_pixels()
|
|
178
|
+
benchmark(execute_collector, c)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
|
|
182
|
+
def test_async_pixels(benchmark):
|
|
183
|
+
(c,), _ = async_collector_setup_pixels()
|
|
184
|
+
benchmark(execute_collector, c)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class TestRBGCollector:
|
|
188
|
+
@pytest.mark.parametrize(
|
|
189
|
+
"n_col,n_wokrers_per_col",
|
|
190
|
+
[
|
|
191
|
+
[2, 2],
|
|
192
|
+
[4, 2],
|
|
193
|
+
[8, 2],
|
|
194
|
+
[16, 2],
|
|
195
|
+
[2, 1],
|
|
196
|
+
[4, 1],
|
|
197
|
+
[8, 1],
|
|
198
|
+
[16, 1],
|
|
199
|
+
],
|
|
200
|
+
)
|
|
201
|
+
def test_multiasync_rb(self, n_col, n_wokrers_per_col):
|
|
202
|
+
make_env = EnvCreator(lambda: GymEnv("ALE/Pong-v5"))
|
|
203
|
+
if n_wokrers_per_col > 1:
|
|
204
|
+
make_env = ParallelEnv(n_wokrers_per_col, make_env)
|
|
205
|
+
env = make_env
|
|
206
|
+
policy = RandomPolicy(env.action_spec)
|
|
207
|
+
else:
|
|
208
|
+
env = make_env()
|
|
209
|
+
policy = RandomPolicy(env.action_spec)
|
|
210
|
+
|
|
211
|
+
storage = LazyTensorStorage(10_000)
|
|
212
|
+
rb = ReplayBuffer(storage=storage)
|
|
213
|
+
rb.extend(env.rollout(2, policy).reshape(-1))
|
|
214
|
+
rb.append_transform(CloudpickleWrapper(lambda x: x.reshape(-1)), invert=True)
|
|
215
|
+
|
|
216
|
+
fpb = n_wokrers_per_col * 100
|
|
217
|
+
total_frames = n_wokrers_per_col * 100_000
|
|
218
|
+
c = MultiaSyncDataCollector(
|
|
219
|
+
[make_env] * n_col,
|
|
220
|
+
policy,
|
|
221
|
+
frames_per_batch=fpb,
|
|
222
|
+
total_frames=total_frames,
|
|
223
|
+
replay_buffer=rb,
|
|
224
|
+
)
|
|
225
|
+
frames = 0
|
|
226
|
+
pbar = tqdm.tqdm(total=total_frames - (n_col * fpb))
|
|
227
|
+
for i, _ in enumerate(c):
|
|
228
|
+
if i == n_col:
|
|
229
|
+
t0 = time.time()
|
|
230
|
+
if i >= n_col:
|
|
231
|
+
frames += fpb
|
|
232
|
+
if i > n_col:
|
|
233
|
+
fps = frames / (time.time() - t0)
|
|
234
|
+
pbar.update(fpb)
|
|
235
|
+
pbar.set_description(f"fps: {fps: 4.4f}")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
if __name__ == "__main__":
|
|
239
|
+
args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
240
|
+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|