torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""Constants and helper classes for collectors."""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from multiprocessing.managers import SyncManager
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import multiprocessing as mp
|
|
14
|
+
|
|
15
|
+
from torchrl.envs.utils import ExplorationType
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from torch.compiler import cudagraph_mark_step_begin
|
|
19
|
+
except ImportError:
|
|
20
|
+
|
|
21
|
+
def cudagraph_mark_step_begin():
|
|
22
|
+
"""Placeholder for missing cudagraph_mark_step_begin method."""
|
|
23
|
+
raise NotImplementedError("cudagraph_mark_step_begin not implemented.")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"_TIMEOUT",
|
|
28
|
+
"INSTANTIATE_TIMEOUT",
|
|
29
|
+
"_MIN_TIMEOUT",
|
|
30
|
+
"_MAX_IDLE_COUNT",
|
|
31
|
+
"WEIGHT_SYNC_TIMEOUT",
|
|
32
|
+
"DEFAULT_EXPLORATION_TYPE",
|
|
33
|
+
"_is_osx",
|
|
34
|
+
"_Interruptor",
|
|
35
|
+
"_InterruptorManager",
|
|
36
|
+
"cudagraph_mark_step_begin",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
_TIMEOUT = 1.0
|
|
40
|
+
INSTANTIATE_TIMEOUT = 20
|
|
41
|
+
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
|
|
42
|
+
# Timeout for weight synchronization during collector init.
|
|
43
|
+
# Increase this when using many collectors across different CUDA devices.
|
|
44
|
+
WEIGHT_SYNC_TIMEOUT = float(os.environ.get("TORCHRL_WEIGHT_SYNC_TIMEOUT", 120.0))
|
|
45
|
+
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
|
|
46
|
+
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))
|
|
47
|
+
|
|
48
|
+
DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM
|
|
49
|
+
|
|
50
|
+
_is_osx = sys.platform.startswith("darwin")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class _Interruptor:
|
|
54
|
+
"""A class for managing the collection state of a process.
|
|
55
|
+
|
|
56
|
+
This class provides methods to start and stop collection, and to check
|
|
57
|
+
whether collection has been stopped. The collection state is protected
|
|
58
|
+
by a lock to ensure thread-safety.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
# interrupter vs interruptor: google trends seems to indicate that "or" is more
|
|
62
|
+
# widely used than "er" even if my IDE complains about that...
|
|
63
|
+
def __init__(self):
|
|
64
|
+
self._collect = True
|
|
65
|
+
self._lock = mp.Lock()
|
|
66
|
+
|
|
67
|
+
def start_collection(self):
|
|
68
|
+
with self._lock:
|
|
69
|
+
self._collect = True
|
|
70
|
+
|
|
71
|
+
def stop_collection(self):
|
|
72
|
+
with self._lock:
|
|
73
|
+
self._collect = False
|
|
74
|
+
|
|
75
|
+
def collection_stopped(self):
|
|
76
|
+
with self._lock:
|
|
77
|
+
return self._collect is False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class _InterruptorManager(SyncManager):
|
|
81
|
+
"""A custom SyncManager for managing the collection state of a process.
|
|
82
|
+
|
|
83
|
+
This class extends the SyncManager class and allows to share an Interruptor object
|
|
84
|
+
between processes.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
_InterruptorManager.register("_Interruptor", _Interruptor)
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import warnings
|
|
5
|
+
from collections import defaultdict, OrderedDict
|
|
6
|
+
from collections.abc import Iterator, Sequence
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from queue import Empty
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from tensordict import TensorDictBase
|
|
13
|
+
from tensordict.nn import TensorDictModuleBase
|
|
14
|
+
from torchrl._utils import (
|
|
15
|
+
_check_for_faulty_process,
|
|
16
|
+
accept_remote_rref_udf_invocation,
|
|
17
|
+
logger as torchrl_logger,
|
|
18
|
+
)
|
|
19
|
+
from torchrl.collectors._base import _make_legacy_metaclass
|
|
20
|
+
from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT
|
|
21
|
+
from torchrl.collectors._multi_base import _MultiCollectorMeta, MultiCollector
|
|
22
|
+
from torchrl.collectors.utils import split_trajectories
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@accept_remote_rref_udf_invocation
|
|
26
|
+
class MultiAsyncCollector(MultiCollector):
|
|
27
|
+
"""Runs a given number of DataCollectors on separate processes asynchronously.
|
|
28
|
+
|
|
29
|
+
.. aafig::
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
+----------------------------------------------------------------------+
|
|
33
|
+
| "MultiConcurrentCollector" | |
|
|
34
|
+
|~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| |
|
|
35
|
+
| "Collector 1" | "Collector 2" | "Collector 3" | "Main" |
|
|
36
|
+
|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~|
|
|
37
|
+
| "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | |
|
|
38
|
+
|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~|
|
|
39
|
+
|"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | |
|
|
40
|
+
| | | | | | | |
|
|
41
|
+
| "actor" | | | "actor" | |
|
|
42
|
+
| | | | | |
|
|
43
|
+
| "step" | "step" | "actor" | | |
|
|
44
|
+
| | | | | |
|
|
45
|
+
| | | | "step" | "step" | |
|
|
46
|
+
| | | | | | |
|
|
47
|
+
| "actor | "step" | "step" | "actor" | |
|
|
48
|
+
| | | | | |
|
|
49
|
+
| "yield batch 1" | "actor" | |"collect, train"|
|
|
50
|
+
| | | | |
|
|
51
|
+
| "step" | "step" | | "yield batch 2" |"collect, train"|
|
|
52
|
+
| | | | | |
|
|
53
|
+
| | | "yield batch 3" | |"collect, train"|
|
|
54
|
+
| | | | | |
|
|
55
|
+
+----------------------------------------------------------------------+
|
|
56
|
+
|
|
57
|
+
Environment types can be identical or different.
|
|
58
|
+
|
|
59
|
+
The collection keeps on occurring on all processes even between the time
|
|
60
|
+
the batch of rollouts is collected and the next call to the iterator.
|
|
61
|
+
This class can be safely used with offline RL sota-implementations.
|
|
62
|
+
|
|
63
|
+
.. note:: Python requires multiprocessed code to be instantiated within a main guard:
|
|
64
|
+
|
|
65
|
+
>>> from torchrl.collectors import MultiAsyncCollector
|
|
66
|
+
>>> if __name__ == "__main__":
|
|
67
|
+
... # Create your collector here
|
|
68
|
+
|
|
69
|
+
See https://docs.python.org/3/library/multiprocessing.html for more info.
|
|
70
|
+
|
|
71
|
+
Examples:
|
|
72
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
73
|
+
>>> from tensordict.nn import TensorDictModule
|
|
74
|
+
>>> from torch import nn
|
|
75
|
+
>>> from torchrl.collectors import MultiAsyncCollector
|
|
76
|
+
>>> if __name__ == "__main__":
|
|
77
|
+
... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
|
|
78
|
+
... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
|
|
79
|
+
... collector = MultiAsyncCollector(
|
|
80
|
+
... create_env_fn=[env_maker, env_maker],
|
|
81
|
+
... policy=policy,
|
|
82
|
+
... total_frames=2000,
|
|
83
|
+
... max_frames_per_traj=50,
|
|
84
|
+
... frames_per_batch=200,
|
|
85
|
+
... init_random_frames=-1,
|
|
86
|
+
... reset_at_each_iter=False,
|
|
87
|
+
... device="cpu",
|
|
88
|
+
... storing_device="cpu",
|
|
89
|
+
... cat_results="stack",
|
|
90
|
+
... )
|
|
91
|
+
... for i, data in enumerate(collector):
|
|
92
|
+
... if i == 2:
|
|
93
|
+
... print(data)
|
|
94
|
+
... break
|
|
95
|
+
... collector.shutdown()
|
|
96
|
+
... del collector
|
|
97
|
+
TensorDict(
|
|
98
|
+
fields={
|
|
99
|
+
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
100
|
+
collector: TensorDict(
|
|
101
|
+
fields={
|
|
102
|
+
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
103
|
+
batch_size=torch.Size([200]),
|
|
104
|
+
device=cpu,
|
|
105
|
+
is_shared=False),
|
|
106
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
107
|
+
next: TensorDict(
|
|
108
|
+
fields={
|
|
109
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
110
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
111
|
+
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
112
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
113
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
114
|
+
batch_size=torch.Size([200]),
|
|
115
|
+
device=cpu,
|
|
116
|
+
is_shared=False),
|
|
117
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
118
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
119
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
120
|
+
batch_size=torch.Size([200]),
|
|
121
|
+
device=cpu,
|
|
122
|
+
is_shared=False)
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
__doc__ += MultiCollector.__doc__
|
|
127
|
+
|
|
128
|
+
def __init__(self, *args, **kwargs):
|
|
129
|
+
super().__init__(*args, **kwargs)
|
|
130
|
+
self.out_tensordicts = defaultdict(lambda: None)
|
|
131
|
+
self.running = False
|
|
132
|
+
|
|
133
|
+
if self.postprocs is not None and self.replay_buffer is None:
|
|
134
|
+
postproc = self.postprocs
|
|
135
|
+
self.postprocs = {}
|
|
136
|
+
for _device in self.storing_device:
|
|
137
|
+
if _device not in self.postprocs:
|
|
138
|
+
if hasattr(postproc, "to"):
|
|
139
|
+
postproc = deepcopy(postproc).to(_device)
|
|
140
|
+
self.postprocs[_device] = postproc
|
|
141
|
+
|
|
142
|
+
# for RPC
|
|
143
|
+
def next(self):
|
|
144
|
+
return super().next()
|
|
145
|
+
|
|
146
|
+
# for RPC
|
|
147
|
+
def shutdown(
|
|
148
|
+
self,
|
|
149
|
+
timeout: float | None = None,
|
|
150
|
+
close_env: bool = True,
|
|
151
|
+
raise_on_error: bool = True,
|
|
152
|
+
) -> None:
|
|
153
|
+
if hasattr(self, "out_tensordicts"):
|
|
154
|
+
del self.out_tensordicts
|
|
155
|
+
if not close_env:
|
|
156
|
+
raise RuntimeError(
|
|
157
|
+
f"Cannot shutdown {type(self).__name__} collector without environment being closed."
|
|
158
|
+
)
|
|
159
|
+
return super().shutdown(timeout=timeout, raise_on_error=raise_on_error)
|
|
160
|
+
|
|
161
|
+
# for RPC
|
|
162
|
+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
|
|
163
|
+
return super().set_seed(seed, static_seed)
|
|
164
|
+
|
|
165
|
+
# for RPC
|
|
166
|
+
def state_dict(self) -> OrderedDict:
|
|
167
|
+
return super().state_dict()
|
|
168
|
+
|
|
169
|
+
# for RPC
|
|
170
|
+
def load_state_dict(self, state_dict: OrderedDict) -> None:
|
|
171
|
+
return super().load_state_dict(state_dict)
|
|
172
|
+
|
|
173
|
+
# for RPC
|
|
174
|
+
def update_policy_weights_(
|
|
175
|
+
self,
|
|
176
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
177
|
+
*,
|
|
178
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
179
|
+
**kwargs,
|
|
180
|
+
) -> None:
|
|
181
|
+
if "policy_weights" in kwargs:
|
|
182
|
+
warnings.warn(
|
|
183
|
+
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
|
|
184
|
+
DeprecationWarning,
|
|
185
|
+
)
|
|
186
|
+
policy_or_weights = kwargs.pop("policy_weights")
|
|
187
|
+
|
|
188
|
+
super().update_policy_weights_(
|
|
189
|
+
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
|
|
193
|
+
return self.requested_frames_per_batch
|
|
194
|
+
|
|
195
|
+
def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
|
|
196
|
+
new_data, j = self.queue_out.get(timeout=timeout)
|
|
197
|
+
use_buffers = self._use_buffers
|
|
198
|
+
if self.replay_buffer is not None:
|
|
199
|
+
idx = new_data
|
|
200
|
+
elif j == 0 or not use_buffers:
|
|
201
|
+
try:
|
|
202
|
+
data, idx = new_data
|
|
203
|
+
self.out_tensordicts[idx] = data
|
|
204
|
+
if use_buffers is None and j > 0:
|
|
205
|
+
use_buffers = self._use_buffers = False
|
|
206
|
+
except TypeError:
|
|
207
|
+
if use_buffers is None:
|
|
208
|
+
use_buffers = self._use_buffers = True
|
|
209
|
+
idx = new_data
|
|
210
|
+
else:
|
|
211
|
+
raise
|
|
212
|
+
else:
|
|
213
|
+
idx = new_data
|
|
214
|
+
out = self.out_tensordicts[idx]
|
|
215
|
+
if not self.replay_buffer and (j == 0 or use_buffers):
|
|
216
|
+
# we clone the data to make sure that we'll be working with a fixed copy
|
|
217
|
+
out = out.clone()
|
|
218
|
+
return idx, j, out
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def _queue_len(self) -> int:
|
|
222
|
+
return 1
|
|
223
|
+
|
|
224
|
+
def iterator(self) -> Iterator[TensorDictBase]:
|
|
225
|
+
if self.update_at_each_batch:
|
|
226
|
+
self.update_policy_weights_()
|
|
227
|
+
|
|
228
|
+
for i in range(self.num_workers):
|
|
229
|
+
if self._should_use_random_frames():
|
|
230
|
+
self.pipes[i].send((None, "continue_random"))
|
|
231
|
+
else:
|
|
232
|
+
self.pipes[i].send((None, "continue"))
|
|
233
|
+
self.running = True
|
|
234
|
+
|
|
235
|
+
workers_frames = [0 for _ in range(self.num_workers)]
|
|
236
|
+
_iter_start_time = time.time()
|
|
237
|
+
while self._frames < self.total_frames:
|
|
238
|
+
self._iter += 1
|
|
239
|
+
counter = 0
|
|
240
|
+
while True:
|
|
241
|
+
try:
|
|
242
|
+
idx, j, out = self._get_from_queue(timeout=_TIMEOUT)
|
|
243
|
+
break
|
|
244
|
+
except (TimeoutError, Empty):
|
|
245
|
+
counter += _TIMEOUT
|
|
246
|
+
_check_for_faulty_process(self.procs)
|
|
247
|
+
# Debug logging for queue timeout
|
|
248
|
+
if counter % (10 * _TIMEOUT) == 0: # Log every 10 timeouts
|
|
249
|
+
_elapsed = time.time() - _iter_start_time
|
|
250
|
+
torchrl_logger.debug(
|
|
251
|
+
f"MultiAsyncCollector.iterator: Queue timeout, counter={counter:.1f}s, "
|
|
252
|
+
f"iter={self._iter}, frames={self._frames}, elapsed={_elapsed:.1f}s"
|
|
253
|
+
)
|
|
254
|
+
if counter > (_TIMEOUT * _MAX_IDLE_COUNT):
|
|
255
|
+
_elapsed = time.time() - _iter_start_time
|
|
256
|
+
torchrl_logger.debug(
|
|
257
|
+
f"MultiAsyncCollector.iterator: CRITICAL - Max idle exceeded, "
|
|
258
|
+
f"counter={counter:.1f}s, iter={self._iter}, frames={self._frames}, elapsed={_elapsed:.1f}s"
|
|
259
|
+
)
|
|
260
|
+
raise RuntimeError(
|
|
261
|
+
f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. "
|
|
262
|
+
f"Increase the MAX_IDLE_COUNT environment variable to bypass this error."
|
|
263
|
+
)
|
|
264
|
+
if self.replay_buffer is None:
|
|
265
|
+
worker_frames = out.numel()
|
|
266
|
+
if self.split_trajs:
|
|
267
|
+
out = split_trajectories(out, prefix="collector")
|
|
268
|
+
else:
|
|
269
|
+
worker_frames = self.frames_per_batch_worker()
|
|
270
|
+
self._frames += worker_frames
|
|
271
|
+
workers_frames[idx] = workers_frames[idx] + worker_frames
|
|
272
|
+
if out is not None and self.postprocs:
|
|
273
|
+
out = self.postprocs[out.device](out)
|
|
274
|
+
|
|
275
|
+
# the function blocks here until the next item is asked, hence we send the message to the
|
|
276
|
+
# worker to keep on working in the meantime before the yield statement
|
|
277
|
+
if self._should_use_random_frames():
|
|
278
|
+
msg = "continue_random"
|
|
279
|
+
else:
|
|
280
|
+
msg = "continue"
|
|
281
|
+
self.pipes[idx].send((idx, msg))
|
|
282
|
+
if out is not None and self._exclude_private_keys:
|
|
283
|
+
excluded_keys = [key for key in out.keys() if key.startswith("_")]
|
|
284
|
+
out = out.exclude(*excluded_keys)
|
|
285
|
+
yield out
|
|
286
|
+
|
|
287
|
+
# We don't want to shutdown yet, the user may want to call state_dict before
|
|
288
|
+
# self._shutdown_main()
|
|
289
|
+
self.running = False
|
|
290
|
+
|
|
291
|
+
def _shutdown_main(self, *args, **kwargs) -> None:
|
|
292
|
+
if hasattr(self, "out_tensordicts"):
|
|
293
|
+
del self.out_tensordicts
|
|
294
|
+
return super()._shutdown_main(*args, **kwargs)
|
|
295
|
+
|
|
296
|
+
def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
|
|
297
|
+
super().reset(reset_idx)
|
|
298
|
+
if self.queue_out.full():
|
|
299
|
+
time.sleep(_TIMEOUT) # wait until queue is empty
|
|
300
|
+
if self.queue_out.full():
|
|
301
|
+
raise Exception("self.queue_out is full")
|
|
302
|
+
if self.running:
|
|
303
|
+
for idx in range(self.num_workers):
|
|
304
|
+
if self._should_use_random_frames():
|
|
305
|
+
self.pipes[idx].send((idx, "continue_random"))
|
|
306
|
+
else:
|
|
307
|
+
self.pipes[idx].send((idx, "continue"))
|
|
308
|
+
|
|
309
|
+
# for RPC
|
|
310
|
+
def _receive_weights_scheme(self):
|
|
311
|
+
return super()._receive_weights_scheme()
|
|
312
|
+
|
|
313
|
+
# for RPC
|
|
314
|
+
def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
|
|
315
|
+
return super().receive_weights(policy_or_weights)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
_LegacyMultiAsyncMeta = _make_legacy_metaclass(_MultiCollectorMeta)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class MultiaSyncDataCollector(MultiAsyncCollector, metaclass=_LegacyMultiAsyncMeta):
|
|
322
|
+
"""Deprecated version of :class:`~torchrl.collectors.MultiAsyncCollector`."""
|
|
323
|
+
|
|
324
|
+
...
|