torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
Benchmarking different types of batched environments
|
|
7
|
+
====================================================
|
|
8
|
+
Compares runtime for different environments which allow performing operations in a batch.
|
|
9
|
+
- SerialEnv executes the operations sequentially
|
|
10
|
+
- ParallelEnv uses multiprocess parallelism
|
|
11
|
+
- MultiThreadedEnv uses multithreaded parallelism and is based on envpool library.
|
|
12
|
+
|
|
13
|
+
Run as "python benchmarks/benchmark_batched_envs.py"
|
|
14
|
+
Requires pandas ("pip install pandas").
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import pandas as pd
|
|
19
|
+
from torchrl._utils import logger as torchrl_logger
|
|
20
|
+
|
|
21
|
+
pd.set_option("display.max_columns", 100)
|
|
22
|
+
pd.set_option("display.width", 1000)
|
|
23
|
+
import torch
|
|
24
|
+
from torch.utils.benchmark import Timer
|
|
25
|
+
from torchrl.envs import MultiThreadedEnv, ParallelEnv, SerialEnv
|
|
26
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
27
|
+
|
|
28
|
+
N_STEPS = 1000
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_multithreaded(num_workers, device):
|
|
32
|
+
env = MultiThreadedEnv(num_workers=num_workers, env_name="Pendulum-v1")
|
|
33
|
+
# GPU doesn't lead to any speedup for MultiThreadedEnv, as the underlying library (envpool) works only on CPU
|
|
34
|
+
env = env.to(device=torch.device(device))
|
|
35
|
+
env.rollout(policy=None, max_steps=5) # Warm-up
|
|
36
|
+
return env
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def factory():
|
|
40
|
+
return GymEnv("Pendulum-v1")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_serial(num_workers, device):
|
|
44
|
+
env = SerialEnv(num_workers=num_workers, create_env_fn=factory)
|
|
45
|
+
env = env.to(device=torch.device(device))
|
|
46
|
+
env.rollout(policy=None, max_steps=5) # Warm-up
|
|
47
|
+
return env
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def create_parallel(num_workers, device):
|
|
51
|
+
env = ParallelEnv(num_workers=num_workers, create_env_fn=factory)
|
|
52
|
+
env = env.to(device=torch.device(device))
|
|
53
|
+
env.rollout(policy=None, max_steps=5) # Warm-up
|
|
54
|
+
return env
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def run_env(env):
|
|
58
|
+
env.rollout(policy=None, max_steps=N_STEPS)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
res = {}
|
|
63
|
+
devices = ["cpu"]
|
|
64
|
+
if torch.cuda.is_available():
|
|
65
|
+
devices.append("cuda")
|
|
66
|
+
for device in devices:
|
|
67
|
+
for num_workers in [1, 4, 16]:
|
|
68
|
+
torchrl_logger.info(f"With num_workers={num_workers}, {device}")
|
|
69
|
+
torchrl_logger.info("Multithreaded...")
|
|
70
|
+
env_multithreaded = create_multithreaded(num_workers, device)
|
|
71
|
+
res_multithreaded = Timer(
|
|
72
|
+
stmt="run_env(env)",
|
|
73
|
+
setup="from __main__ import run_env",
|
|
74
|
+
globals={"env": env_multithreaded},
|
|
75
|
+
)
|
|
76
|
+
time_multithreaded = res_multithreaded.blocked_autorange().mean
|
|
77
|
+
|
|
78
|
+
torchrl_logger.info("Serial...")
|
|
79
|
+
env_serial = create_serial(num_workers, device)
|
|
80
|
+
res_serial = Timer(
|
|
81
|
+
stmt="run_env(env)",
|
|
82
|
+
setup="from __main__ import run_env",
|
|
83
|
+
globals={"env": env_serial},
|
|
84
|
+
)
|
|
85
|
+
time_serial = res_serial.blocked_autorange().mean
|
|
86
|
+
|
|
87
|
+
torchrl_logger.info("Parallel...")
|
|
88
|
+
env_parallel = create_parallel(num_workers, device)
|
|
89
|
+
res_parallel = Timer(
|
|
90
|
+
stmt="run_env(env)",
|
|
91
|
+
setup="from __main__ import run_env",
|
|
92
|
+
globals={"env": env_parallel},
|
|
93
|
+
)
|
|
94
|
+
time_parallel = res_parallel.blocked_autorange().mean
|
|
95
|
+
|
|
96
|
+
res[f"num_workers_{num_workers}_{device}"] = {
|
|
97
|
+
"Serial, s": time_serial,
|
|
98
|
+
"Parallel, s": time_parallel,
|
|
99
|
+
"Multithreaded, s": time_multithreaded,
|
|
100
|
+
}
|
|
101
|
+
df = pd.DataFrame(res).round(3)
|
|
102
|
+
gain = 1 - df.loc["Multithreaded, s"] / df.loc["Parallel, s"]
|
|
103
|
+
df.loc["Gain, %", :] = (gain * 100).round(1)
|
|
104
|
+
df.to_csv("multithreaded_benchmark.csv")
|
benchmarks/conftest.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import warnings
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
from torchrl._utils import logger as torchrl_logger
|
|
12
|
+
|
|
13
|
+
CALL_TIMES = defaultdict(float)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def pytest_sessionfinish(maxprint=50):
|
|
17
|
+
out_str = """
|
|
18
|
+
Call times:
|
|
19
|
+
===========
|
|
20
|
+
"""
|
|
21
|
+
keys = list(CALL_TIMES.keys())
|
|
22
|
+
if len(keys) > 1:
|
|
23
|
+
maxchar = max(*[len(key) for key in keys])
|
|
24
|
+
elif len(keys) == 1:
|
|
25
|
+
maxchar = len(keys[0])
|
|
26
|
+
else:
|
|
27
|
+
return
|
|
28
|
+
for i, (key, item) in enumerate(
|
|
29
|
+
sorted(CALL_TIMES.items(), key=lambda x: x[1], reverse=True)
|
|
30
|
+
):
|
|
31
|
+
spaces = " " + " " * (maxchar - len(key))
|
|
32
|
+
out_str += f"\t{key}{spaces}{item: 4.4f}s\n"
|
|
33
|
+
if i == maxprint - 1:
|
|
34
|
+
break
|
|
35
|
+
torchrl_logger.info(out_str)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@pytest.fixture(autouse=True)
|
|
39
|
+
def measure_duration(request: pytest.FixtureRequest):
|
|
40
|
+
start_time = time.time()
|
|
41
|
+
|
|
42
|
+
def fin():
|
|
43
|
+
duration = time.time() - start_time
|
|
44
|
+
name = request.node.name
|
|
45
|
+
class_name = request.cls.__name__ if request.cls else None
|
|
46
|
+
name = name.split("[")[0]
|
|
47
|
+
if class_name is not None:
|
|
48
|
+
name = "::".join([class_name, name])
|
|
49
|
+
file = os.path.basename(request.path)
|
|
50
|
+
name = f"{file}::{name}"
|
|
51
|
+
CALL_TIMES[name] = CALL_TIMES[name] + duration
|
|
52
|
+
|
|
53
|
+
request.addfinalizer(fin)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def pytest_addoption(parser):
|
|
57
|
+
parser.addoption("--rank", action="store")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.fixture(scope="session", autouse=True)
|
|
61
|
+
def set_warnings() -> None:
|
|
62
|
+
warnings.filterwarnings(
|
|
63
|
+
"ignore",
|
|
64
|
+
category=UserWarning,
|
|
65
|
+
message=r"Lazy modules are a new feature under heavy development",
|
|
66
|
+
)
|
|
67
|
+
warnings.filterwarnings(
|
|
68
|
+
"ignore",
|
|
69
|
+
category=UserWarning,
|
|
70
|
+
message=r"Couldn't cast the policy onto the desired device on remote process",
|
|
71
|
+
)
|
|
72
|
+
warnings.filterwarnings(
|
|
73
|
+
"ignore",
|
|
74
|
+
category=DeprecationWarning,
|
|
75
|
+
message=r"Deprecated call to `pkg_resources.declare_namespace",
|
|
76
|
+
)
|
|
77
|
+
warnings.filterwarnings(
|
|
78
|
+
"ignore",
|
|
79
|
+
category=DeprecationWarning,
|
|
80
|
+
message=r"Using or importing the ABCs",
|
|
81
|
+
)
|
|
82
|
+
warnings.filterwarnings(
|
|
83
|
+
"ignore",
|
|
84
|
+
category=DeprecationWarning,
|
|
85
|
+
message=r"Please use `coo_matrix` from the `scipy.sparse` namespace",
|
|
86
|
+
)
|
|
87
|
+
warnings.filterwarnings(
|
|
88
|
+
"ignore",
|
|
89
|
+
category=DeprecationWarning,
|
|
90
|
+
message=r"jax.tree_util.register_keypaths is deprecated|jax.ShapedArray is deprecated",
|
|
91
|
+
)
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components.
|
|
7
|
+
|
|
8
|
+
We test:
|
|
9
|
+
- gym async envs embedded in a TorchRL's GymEnv wrapper,
|
|
10
|
+
- ParallelEnv with regular GymEnv instances,
|
|
11
|
+
- Data collector
|
|
12
|
+
- Multiprocessed data collectors with parallel envs.
|
|
13
|
+
|
|
14
|
+
The tests are executed with various number of cpus, and on different devices.
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
# import myosuite # noqa: F401
|
|
20
|
+
import torch
|
|
21
|
+
import tqdm
|
|
22
|
+
from torchrl._utils import timeit
|
|
23
|
+
from torchrl.collectors import (
|
|
24
|
+
MultiaSyncDataCollector,
|
|
25
|
+
MultiSyncDataCollector,
|
|
26
|
+
SyncDataCollector,
|
|
27
|
+
)
|
|
28
|
+
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
|
|
29
|
+
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
|
|
30
|
+
from torchrl.modules import RandomPolicy
|
|
31
|
+
|
|
32
|
+
if __name__ == "__main__":
|
|
33
|
+
avail_devices = ("cpu",)
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
avail_devices = avail_devices + ("cuda:0",)
|
|
36
|
+
|
|
37
|
+
for envname in [
|
|
38
|
+
"CartPole-v1",
|
|
39
|
+
"HalfCheetah-v4",
|
|
40
|
+
"myoHandReachRandom-v0",
|
|
41
|
+
"ALE/Breakout-v5",
|
|
42
|
+
]:
|
|
43
|
+
# the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes
|
|
44
|
+
for num_workers, num_collectors in zip((32, 64, 8, 16), (8, 8, 2, 4)):
|
|
45
|
+
with open(f"{envname}_{num_workers}.txt".replace("/", "-"), "w+") as log:
|
|
46
|
+
if "myo" in envname:
|
|
47
|
+
gym_backend = "gym"
|
|
48
|
+
else:
|
|
49
|
+
gym_backend = "gymnasium"
|
|
50
|
+
|
|
51
|
+
total_frames = num_workers * 10_000
|
|
52
|
+
|
|
53
|
+
# pure gym
|
|
54
|
+
def make(envname=envname, gym_backend=gym_backend):
|
|
55
|
+
with set_gym_backend(gym_backend):
|
|
56
|
+
return gym_bc().make(envname)
|
|
57
|
+
|
|
58
|
+
with set_gym_backend(gym_backend):
|
|
59
|
+
env = gym_bc().vector.AsyncVectorEnv(
|
|
60
|
+
[make for _ in range(num_workers)]
|
|
61
|
+
)
|
|
62
|
+
env.reset()
|
|
63
|
+
global_step = 0
|
|
64
|
+
times = []
|
|
65
|
+
start = time.time()
|
|
66
|
+
for _ in tqdm.tqdm(range(total_frames // num_workers)):
|
|
67
|
+
env.step(env.action_space.sample())
|
|
68
|
+
global_step += num_workers
|
|
69
|
+
env.close()
|
|
70
|
+
log.write(
|
|
71
|
+
f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n"
|
|
72
|
+
)
|
|
73
|
+
log.flush()
|
|
74
|
+
|
|
75
|
+
# regular parallel env
|
|
76
|
+
for device in avail_devices:
|
|
77
|
+
|
|
78
|
+
def make(envname=envname, gym_backend=gym_backend):
|
|
79
|
+
with set_gym_backend(gym_backend):
|
|
80
|
+
return GymEnv(envname, device="cpu")
|
|
81
|
+
|
|
82
|
+
# env_make = EnvCreator(make)
|
|
83
|
+
penv = ParallelEnv(num_workers, EnvCreator(make), device=device)
|
|
84
|
+
with torch.inference_mode():
|
|
85
|
+
# warmup
|
|
86
|
+
penv.rollout(2)
|
|
87
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
88
|
+
t0 = time.time()
|
|
89
|
+
data = None
|
|
90
|
+
for _ in range(100):
|
|
91
|
+
data = penv.rollout(
|
|
92
|
+
100, break_when_any_done=False, out=data
|
|
93
|
+
)
|
|
94
|
+
pbar.update(100 * num_workers)
|
|
95
|
+
log.write(
|
|
96
|
+
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
|
|
97
|
+
)
|
|
98
|
+
log.flush()
|
|
99
|
+
penv.close()
|
|
100
|
+
timeit.print()
|
|
101
|
+
del penv
|
|
102
|
+
|
|
103
|
+
for device in avail_devices:
|
|
104
|
+
|
|
105
|
+
def make(envname=envname, gym_backend=gym_backend):
|
|
106
|
+
with set_gym_backend(gym_backend):
|
|
107
|
+
return GymEnv(envname, device="cpu")
|
|
108
|
+
|
|
109
|
+
env_make = EnvCreator(make)
|
|
110
|
+
# penv = SerialEnv(num_workers, env_make)
|
|
111
|
+
penv = ParallelEnv(num_workers, env_make, device=device)
|
|
112
|
+
collector = SyncDataCollector(
|
|
113
|
+
penv,
|
|
114
|
+
RandomPolicy(penv.action_spec),
|
|
115
|
+
frames_per_batch=1024,
|
|
116
|
+
total_frames=num_workers * 10_000,
|
|
117
|
+
device=device,
|
|
118
|
+
)
|
|
119
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
120
|
+
total_frames = 0
|
|
121
|
+
t0 = time.time()
|
|
122
|
+
for data in collector:
|
|
123
|
+
total_frames += data.numel()
|
|
124
|
+
pbar.update(data.numel())
|
|
125
|
+
pbar.set_description(
|
|
126
|
+
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
|
|
127
|
+
)
|
|
128
|
+
log.write(
|
|
129
|
+
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
|
|
130
|
+
)
|
|
131
|
+
log.flush()
|
|
132
|
+
collector.shutdown()
|
|
133
|
+
del collector
|
|
134
|
+
|
|
135
|
+
for device in avail_devices:
|
|
136
|
+
# gym parallel env
|
|
137
|
+
def make_env(
|
|
138
|
+
envname=envname,
|
|
139
|
+
num_workers=num_workers,
|
|
140
|
+
gym_backend=gym_backend,
|
|
141
|
+
device=device,
|
|
142
|
+
):
|
|
143
|
+
with set_gym_backend(gym_backend):
|
|
144
|
+
penv = GymEnv(envname, num_envs=num_workers, device=device)
|
|
145
|
+
return penv
|
|
146
|
+
|
|
147
|
+
penv = make_env()
|
|
148
|
+
# warmup
|
|
149
|
+
penv.rollout(2)
|
|
150
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
151
|
+
t0 = time.time()
|
|
152
|
+
for _ in range(100):
|
|
153
|
+
data = penv.rollout(100, break_when_any_done=False)
|
|
154
|
+
pbar.update(100 * num_workers)
|
|
155
|
+
log.write(
|
|
156
|
+
f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
|
|
157
|
+
)
|
|
158
|
+
log.flush()
|
|
159
|
+
penv.close()
|
|
160
|
+
del penv
|
|
161
|
+
|
|
162
|
+
for device in avail_devices:
|
|
163
|
+
# async collector
|
|
164
|
+
# + torchrl parallel env
|
|
165
|
+
def make_env(envname=envname, gym_backend=gym_backend):
|
|
166
|
+
with set_gym_backend(gym_backend):
|
|
167
|
+
return GymEnv(envname, device="cpu")
|
|
168
|
+
|
|
169
|
+
penv = ParallelEnv(
|
|
170
|
+
num_workers // num_collectors,
|
|
171
|
+
EnvCreator(make_env),
|
|
172
|
+
device=device,
|
|
173
|
+
)
|
|
174
|
+
collector = MultiaSyncDataCollector(
|
|
175
|
+
[penv] * num_collectors,
|
|
176
|
+
policy=RandomPolicy(penv.action_spec),
|
|
177
|
+
frames_per_batch=1024,
|
|
178
|
+
total_frames=num_workers * 10_000,
|
|
179
|
+
device=device,
|
|
180
|
+
)
|
|
181
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
182
|
+
total_frames = 0
|
|
183
|
+
for i, data in enumerate(collector):
|
|
184
|
+
if i == num_collectors:
|
|
185
|
+
t0 = time.time()
|
|
186
|
+
if i >= num_collectors:
|
|
187
|
+
total_frames += data.numel()
|
|
188
|
+
pbar.update(data.numel())
|
|
189
|
+
pbar.set_description(
|
|
190
|
+
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
|
|
191
|
+
)
|
|
192
|
+
log.write(
|
|
193
|
+
f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
|
|
194
|
+
)
|
|
195
|
+
log.flush()
|
|
196
|
+
collector.shutdown()
|
|
197
|
+
del collector
|
|
198
|
+
|
|
199
|
+
for device in avail_devices:
|
|
200
|
+
# async collector
|
|
201
|
+
# + gym async env
|
|
202
|
+
def make_env(
|
|
203
|
+
envname=envname,
|
|
204
|
+
num_workers=num_workers,
|
|
205
|
+
gym_backend=gym_backend,
|
|
206
|
+
):
|
|
207
|
+
with set_gym_backend(gym_backend):
|
|
208
|
+
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
|
|
209
|
+
return penv
|
|
210
|
+
|
|
211
|
+
penv = EnvCreator(
|
|
212
|
+
lambda num_workers=num_workers // num_collectors: make_env(
|
|
213
|
+
num_workers=num_workers
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
collector = MultiaSyncDataCollector(
|
|
217
|
+
[penv] * num_collectors,
|
|
218
|
+
policy=RandomPolicy(penv().action_spec),
|
|
219
|
+
frames_per_batch=1024,
|
|
220
|
+
total_frames=num_workers * 10_000,
|
|
221
|
+
num_sub_threads=num_workers // num_collectors,
|
|
222
|
+
device=device,
|
|
223
|
+
)
|
|
224
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
225
|
+
total_frames = 0
|
|
226
|
+
for i, data in enumerate(collector):
|
|
227
|
+
if i == num_collectors:
|
|
228
|
+
t0 = time.time()
|
|
229
|
+
if i >= num_collectors:
|
|
230
|
+
total_frames += data.numel()
|
|
231
|
+
pbar.update(data.numel())
|
|
232
|
+
pbar.set_description(
|
|
233
|
+
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
|
|
234
|
+
)
|
|
235
|
+
log.write(
|
|
236
|
+
f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
|
|
237
|
+
)
|
|
238
|
+
log.flush()
|
|
239
|
+
collector.shutdown()
|
|
240
|
+
del collector
|
|
241
|
+
|
|
242
|
+
for device in avail_devices:
|
|
243
|
+
# sync collector
|
|
244
|
+
# + torchrl parallel env
|
|
245
|
+
def make_env(envname=envname, gym_backend=gym_backend):
|
|
246
|
+
with set_gym_backend(gym_backend):
|
|
247
|
+
return GymEnv(envname, device="cpu")
|
|
248
|
+
|
|
249
|
+
penv = ParallelEnv(
|
|
250
|
+
num_workers // num_collectors,
|
|
251
|
+
EnvCreator(make_env),
|
|
252
|
+
device=device,
|
|
253
|
+
)
|
|
254
|
+
collector = MultiSyncDataCollector(
|
|
255
|
+
[penv] * num_collectors,
|
|
256
|
+
policy=RandomPolicy(penv.action_spec),
|
|
257
|
+
frames_per_batch=1024,
|
|
258
|
+
total_frames=num_workers * 10_000,
|
|
259
|
+
device=device,
|
|
260
|
+
)
|
|
261
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
262
|
+
total_frames = 0
|
|
263
|
+
for i, data in enumerate(collector):
|
|
264
|
+
if i == num_collectors:
|
|
265
|
+
t0 = time.time()
|
|
266
|
+
if i >= num_collectors:
|
|
267
|
+
total_frames += data.numel()
|
|
268
|
+
pbar.update(data.numel())
|
|
269
|
+
pbar.set_description(
|
|
270
|
+
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
|
|
271
|
+
)
|
|
272
|
+
log.write(
|
|
273
|
+
f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
|
|
274
|
+
)
|
|
275
|
+
log.flush()
|
|
276
|
+
collector.shutdown()
|
|
277
|
+
del collector
|
|
278
|
+
|
|
279
|
+
for device in avail_devices:
|
|
280
|
+
# sync collector
|
|
281
|
+
# + gym async env
|
|
282
|
+
def make_env(
|
|
283
|
+
envname=envname,
|
|
284
|
+
num_workers=num_workers,
|
|
285
|
+
gym_backend=gym_backend,
|
|
286
|
+
):
|
|
287
|
+
with set_gym_backend(gym_backend):
|
|
288
|
+
penv = GymEnv(envname, num_envs=num_workers, device="cpu")
|
|
289
|
+
return penv
|
|
290
|
+
|
|
291
|
+
penv = EnvCreator(
|
|
292
|
+
lambda num_workers=num_workers // num_collectors: make_env(
|
|
293
|
+
num_workers=num_workers
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
collector = MultiSyncDataCollector(
|
|
297
|
+
[penv] * num_collectors,
|
|
298
|
+
policy=RandomPolicy(penv().action_spec),
|
|
299
|
+
frames_per_batch=1024,
|
|
300
|
+
total_frames=num_workers * 10_000,
|
|
301
|
+
num_sub_threads=num_workers // num_collectors,
|
|
302
|
+
device=device,
|
|
303
|
+
)
|
|
304
|
+
pbar = tqdm.tqdm(total=num_workers * 10_000)
|
|
305
|
+
total_frames = 0
|
|
306
|
+
for i, data in enumerate(collector):
|
|
307
|
+
if i == num_collectors:
|
|
308
|
+
t0 = time.time()
|
|
309
|
+
if i >= num_collectors:
|
|
310
|
+
total_frames += data.numel()
|
|
311
|
+
pbar.update(data.numel())
|
|
312
|
+
pbar.set_description(
|
|
313
|
+
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
|
|
314
|
+
)
|
|
315
|
+
log.write(
|
|
316
|
+
f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
|
|
317
|
+
)
|
|
318
|
+
log.flush()
|
|
319
|
+
collector.shutdown()
|
|
320
|
+
del collector
|
|
321
|
+
exit()
|