torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import queue
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from functools import partial
|
|
7
|
+
from multiprocessing import connection, queues
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import TensorDict, TensorDictBase
|
|
13
|
+
|
|
14
|
+
from torchrl import logger as torchrl_logger
|
|
15
|
+
from torchrl._utils import timeit, VERBOSE
|
|
16
|
+
from torchrl.collectors._base import BaseCollector, ProfileConfig
|
|
17
|
+
from torchrl.collectors._constants import (
|
|
18
|
+
_MAX_IDLE_COUNT,
|
|
19
|
+
_MIN_TIMEOUT,
|
|
20
|
+
_TIMEOUT,
|
|
21
|
+
DEFAULT_EXPLORATION_TYPE,
|
|
22
|
+
)
|
|
23
|
+
from torchrl.collectors._single import Collector
|
|
24
|
+
|
|
25
|
+
from torchrl.collectors.utils import (
|
|
26
|
+
_cast,
|
|
27
|
+
_make_policy_factory,
|
|
28
|
+
_map_to_cpu_if_needed,
|
|
29
|
+
_TrajectoryPool,
|
|
30
|
+
)
|
|
31
|
+
from torchrl.data import ReplayBuffer
|
|
32
|
+
from torchrl.envs import EnvBase, EnvCreator
|
|
33
|
+
from torchrl.envs.utils import ExplorationType
|
|
34
|
+
from torchrl.weight_update import WeightSyncScheme
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class _WorkerProfiler:
|
|
38
|
+
"""Helper class for profiling worker rollouts.
|
|
39
|
+
|
|
40
|
+
Manages the PyTorch profiler lifecycle for a worker process,
|
|
41
|
+
handling warmup, active profiling, and trace export.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
profile_config: ProfileConfig,
|
|
47
|
+
worker_idx: int,
|
|
48
|
+
):
|
|
49
|
+
self.config = profile_config
|
|
50
|
+
self.worker_idx = worker_idx
|
|
51
|
+
self.rollout_count = 0
|
|
52
|
+
self._profiler = None
|
|
53
|
+
self._stopped = False
|
|
54
|
+
self._active = False
|
|
55
|
+
|
|
56
|
+
# Check if this worker should be profiled
|
|
57
|
+
if not self.config.should_profile_worker(worker_idx):
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# Set up profiler schedule
|
|
61
|
+
# - skip_first: warmup rollouts (profiler runs but data discarded)
|
|
62
|
+
# - wait: 0 (no wait between cycles)
|
|
63
|
+
# - warmup: 0 (we handle warmup via skip_first)
|
|
64
|
+
# - active: num_rollouts - warmup_rollouts
|
|
65
|
+
# - repeat: 1 (single profiling cycle)
|
|
66
|
+
active_rollouts = self.config.num_rollouts - self.config.warmup_rollouts
|
|
67
|
+
profiler_schedule = torch.profiler.schedule(
|
|
68
|
+
skip_first=self.config.warmup_rollouts,
|
|
69
|
+
wait=0,
|
|
70
|
+
warmup=0,
|
|
71
|
+
active=active_rollouts,
|
|
72
|
+
repeat=1,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Get activities
|
|
76
|
+
activities = self.config.get_activities()
|
|
77
|
+
if not activities:
|
|
78
|
+
torchrl_logger.warning(
|
|
79
|
+
f"Worker {worker_idx}: No profiler activities available. Profiling disabled."
|
|
80
|
+
)
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
# Determine trace handler
|
|
84
|
+
if self.config.on_trace_ready is not None:
|
|
85
|
+
on_trace_ready = self.config.on_trace_ready
|
|
86
|
+
else:
|
|
87
|
+
save_path = self.config.get_save_path(worker_idx)
|
|
88
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
89
|
+
|
|
90
|
+
def on_trace_ready(prof, save_path=save_path):
|
|
91
|
+
prof.export_chrome_trace(str(save_path))
|
|
92
|
+
torchrl_logger.info(
|
|
93
|
+
f"Worker {worker_idx}: Profiling trace saved to {save_path}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self._profiler = torch.profiler.profile(
|
|
97
|
+
activities=activities,
|
|
98
|
+
schedule=profiler_schedule,
|
|
99
|
+
on_trace_ready=on_trace_ready,
|
|
100
|
+
record_shapes=self.config.record_shapes,
|
|
101
|
+
profile_memory=self.config.profile_memory,
|
|
102
|
+
with_stack=self.config.with_stack,
|
|
103
|
+
with_flops=self.config.with_flops,
|
|
104
|
+
)
|
|
105
|
+
self._active = True
|
|
106
|
+
|
|
107
|
+
def start(self) -> None:
|
|
108
|
+
"""Start the profiler."""
|
|
109
|
+
if self._profiler is not None and not self._stopped:
|
|
110
|
+
self._profiler.start()
|
|
111
|
+
torchrl_logger.info(
|
|
112
|
+
f"Worker {self.worker_idx}: Profiling started. "
|
|
113
|
+
f"Will profile rollouts {self.config.warmup_rollouts} to {self.config.num_rollouts - 1}."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def step(self) -> bool:
|
|
117
|
+
"""Step the profiler after a rollout.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
True if profiling is complete.
|
|
121
|
+
"""
|
|
122
|
+
if self._profiler is None or self._stopped:
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
self.rollout_count += 1
|
|
126
|
+
self._profiler.step()
|
|
127
|
+
|
|
128
|
+
# Check if profiling is complete
|
|
129
|
+
if self.rollout_count >= self.config.num_rollouts:
|
|
130
|
+
self.stop()
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
return False
|
|
134
|
+
|
|
135
|
+
def stop(self) -> None:
|
|
136
|
+
"""Stop the profiler and export trace."""
|
|
137
|
+
if self._profiler is not None and not self._stopped:
|
|
138
|
+
self._profiler.stop()
|
|
139
|
+
self._stopped = True
|
|
140
|
+
torchrl_logger.info(
|
|
141
|
+
f"Worker {self.worker_idx}: Profiling complete after {self.rollout_count} rollouts."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def is_active(self) -> bool:
|
|
146
|
+
"""Check if profiling is active."""
|
|
147
|
+
return self._active and not self._stopped
|
|
148
|
+
|
|
149
|
+
@contextlib.contextmanager
|
|
150
|
+
def profile_rollout(self):
|
|
151
|
+
"""Context manager for profiling a single rollout."""
|
|
152
|
+
if self._profiler is not None and not self._stopped:
|
|
153
|
+
with torch.profiler.record_function(f"worker_{self.worker_idx}_rollout"):
|
|
154
|
+
yield
|
|
155
|
+
else:
|
|
156
|
+
yield
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _main_async_collector(
|
|
160
|
+
pipe_child: connection.Connection,
|
|
161
|
+
queue_out: queues.Queue,
|
|
162
|
+
create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821
|
|
163
|
+
create_env_kwargs: dict[str, Any],
|
|
164
|
+
policy: Callable[[TensorDictBase], TensorDictBase],
|
|
165
|
+
max_frames_per_traj: int,
|
|
166
|
+
frames_per_batch: int,
|
|
167
|
+
reset_at_each_iter: bool,
|
|
168
|
+
storing_device: torch.device | str | int | None,
|
|
169
|
+
env_device: torch.device | str | int | None,
|
|
170
|
+
policy_device: torch.device | str | int | None,
|
|
171
|
+
idx: int = 0,
|
|
172
|
+
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
|
|
173
|
+
reset_when_done: bool = True,
|
|
174
|
+
verbose: bool = VERBOSE,
|
|
175
|
+
interruptor=None,
|
|
176
|
+
set_truncated: bool = False,
|
|
177
|
+
use_buffers: bool | None = None,
|
|
178
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
179
|
+
extend_buffer: bool = True,
|
|
180
|
+
traj_pool: _TrajectoryPool = None,
|
|
181
|
+
trust_policy: bool = False,
|
|
182
|
+
compile_policy: bool = False,
|
|
183
|
+
cudagraph_policy: bool = False,
|
|
184
|
+
no_cuda_sync: bool = False,
|
|
185
|
+
policy_factory: Callable | None = None,
|
|
186
|
+
collector_class: type | Callable[[], BaseCollector] | None = None,
|
|
187
|
+
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
188
|
+
weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
|
|
189
|
+
worker_idx: int | None = None,
|
|
190
|
+
init_random_frames: int | None = None,
|
|
191
|
+
profile_config: ProfileConfig | None = None,
|
|
192
|
+
) -> None:
|
|
193
|
+
if collector_class is None:
|
|
194
|
+
collector_class = Collector
|
|
195
|
+
# init variables that will be cleared when closing
|
|
196
|
+
collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None
|
|
197
|
+
|
|
198
|
+
# Make a policy-factory out of the policy
|
|
199
|
+
policy_factory = partial(
|
|
200
|
+
_make_policy_factory,
|
|
201
|
+
policy=policy,
|
|
202
|
+
policy_factory=policy_factory,
|
|
203
|
+
weight_sync_scheme=weight_sync_schemes.get("policy")
|
|
204
|
+
if weight_sync_schemes
|
|
205
|
+
else None,
|
|
206
|
+
worker_idx=worker_idx,
|
|
207
|
+
pipe=pipe_child,
|
|
208
|
+
)
|
|
209
|
+
policy = None
|
|
210
|
+
# Store the original init_random_frames for run_free mode logic
|
|
211
|
+
original_init_random_frames = (
|
|
212
|
+
init_random_frames if init_random_frames is not None else 0
|
|
213
|
+
)
|
|
214
|
+
try:
|
|
215
|
+
collector_class._ignore_rb = extend_buffer
|
|
216
|
+
inner_collector = collector_class(
|
|
217
|
+
create_env_fn,
|
|
218
|
+
create_env_kwargs=create_env_kwargs,
|
|
219
|
+
policy=policy,
|
|
220
|
+
policy_factory=policy_factory,
|
|
221
|
+
total_frames=-1,
|
|
222
|
+
max_frames_per_traj=max_frames_per_traj,
|
|
223
|
+
frames_per_batch=frames_per_batch,
|
|
224
|
+
reset_at_each_iter=reset_at_each_iter,
|
|
225
|
+
postproc=postproc,
|
|
226
|
+
split_trajs=False,
|
|
227
|
+
storing_device=storing_device,
|
|
228
|
+
policy_device=policy_device,
|
|
229
|
+
env_device=env_device,
|
|
230
|
+
exploration_type=exploration_type,
|
|
231
|
+
reset_when_done=reset_when_done,
|
|
232
|
+
return_same_td=replay_buffer is None,
|
|
233
|
+
interruptor=interruptor,
|
|
234
|
+
set_truncated=set_truncated,
|
|
235
|
+
use_buffers=use_buffers,
|
|
236
|
+
replay_buffer=replay_buffer,
|
|
237
|
+
extend_buffer=extend_buffer,
|
|
238
|
+
traj_pool=traj_pool,
|
|
239
|
+
trust_policy=trust_policy,
|
|
240
|
+
compile_policy=compile_policy,
|
|
241
|
+
cudagraph_policy=cudagraph_policy,
|
|
242
|
+
no_cuda_sync=no_cuda_sync,
|
|
243
|
+
# We don't pass the weight sync scheme as only the sender has the weight sync scheme within.
|
|
244
|
+
# weight_sync_schemes=weight_sync_schemes,
|
|
245
|
+
worker_idx=worker_idx,
|
|
246
|
+
# init_random_frames is passed; inner collector will use _should_use_random_frames()
|
|
247
|
+
# which checks replay_buffer.write_count when replay_buffer is provided
|
|
248
|
+
init_random_frames=init_random_frames,
|
|
249
|
+
)
|
|
250
|
+
# Set up weight receivers for worker process using the standard register_scheme_receiver API.
|
|
251
|
+
# This properly initializes the schemes on the receiver side and stores them in _receiver_schemes.
|
|
252
|
+
if weight_sync_schemes:
|
|
253
|
+
inner_collector.register_scheme_receiver(weight_sync_schemes)
|
|
254
|
+
|
|
255
|
+
use_buffers = inner_collector._use_buffers
|
|
256
|
+
if verbose:
|
|
257
|
+
torchrl_logger.debug("Sync data collector created")
|
|
258
|
+
|
|
259
|
+
# Set up profiler for this worker if configured
|
|
260
|
+
worker_profiler = None
|
|
261
|
+
if profile_config is not None:
|
|
262
|
+
worker_profiler = _WorkerProfiler(profile_config, worker_idx)
|
|
263
|
+
if worker_profiler.is_active:
|
|
264
|
+
worker_profiler.start()
|
|
265
|
+
|
|
266
|
+
dc_iter = iter(inner_collector)
|
|
267
|
+
j = 0
|
|
268
|
+
pipe_child.send("instantiated")
|
|
269
|
+
except Exception as e:
|
|
270
|
+
# Send error information to main process
|
|
271
|
+
# We send a dict with the exception info so we can recreate it in the main process
|
|
272
|
+
import traceback
|
|
273
|
+
|
|
274
|
+
error_info = {
|
|
275
|
+
"error": True,
|
|
276
|
+
"exception_type": type(e).__name__,
|
|
277
|
+
"exception_module": type(e).__module__,
|
|
278
|
+
"exception_msg": str(e),
|
|
279
|
+
"traceback": traceback.format_exc(),
|
|
280
|
+
}
|
|
281
|
+
try:
|
|
282
|
+
pipe_child.send(error_info)
|
|
283
|
+
except Exception:
|
|
284
|
+
# If pipe is broken, nothing we can do
|
|
285
|
+
pass
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
has_timed_out = False
|
|
289
|
+
counter = 0
|
|
290
|
+
run_free = False
|
|
291
|
+
while True:
|
|
292
|
+
_timeout = _TIMEOUT if not has_timed_out else 1e-3
|
|
293
|
+
if not run_free and pipe_child.poll(_timeout):
|
|
294
|
+
counter = 0
|
|
295
|
+
try:
|
|
296
|
+
data_in, msg = pipe_child.recv()
|
|
297
|
+
if verbose:
|
|
298
|
+
torchrl_logger.debug(f"mp worker {idx} received {msg}")
|
|
299
|
+
except EOFError:
|
|
300
|
+
raise
|
|
301
|
+
elif not run_free:
|
|
302
|
+
if verbose:
|
|
303
|
+
torchrl_logger.debug(f"poll failed, j={j}, worker={idx}")
|
|
304
|
+
# default is "continue" (after first iteration)
|
|
305
|
+
# this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe
|
|
306
|
+
# in that case, the main process probably expects the worker to continue collect data
|
|
307
|
+
if has_timed_out:
|
|
308
|
+
counter = 0
|
|
309
|
+
# has_timed_out is True if the process failed to send data, which will
|
|
310
|
+
# typically occur if main has taken another batch (i.e. the queue is Full).
|
|
311
|
+
# In this case, msg is the previous msg sent by main, which will typically be "continue"
|
|
312
|
+
# If it's not the case, it is not expected that has_timed_out is True.
|
|
313
|
+
if msg not in ("continue", "continue_random"):
|
|
314
|
+
raise RuntimeError(f"Unexpected message after time out: msg={msg}")
|
|
315
|
+
else:
|
|
316
|
+
# if has_timed_out is False, then the time out does not come from the fact that the queue is Full.
|
|
317
|
+
# this means that our process has been waiting for a command from main in vain, while main was not
|
|
318
|
+
# receiving data.
|
|
319
|
+
# This will occur if main is busy doing something else (e.g. computing loss etc).
|
|
320
|
+
|
|
321
|
+
counter += _timeout
|
|
322
|
+
if verbose:
|
|
323
|
+
torchrl_logger.debug(f"mp worker {idx} has counter {counter}")
|
|
324
|
+
if counter >= (_MAX_IDLE_COUNT * _TIMEOUT):
|
|
325
|
+
raise RuntimeError(
|
|
326
|
+
f"This process waited for {counter} seconds "
|
|
327
|
+
f"without receiving a command from main. Consider increasing the maximum idle count "
|
|
328
|
+
f"if this is expected via the environment variable MAX_IDLE_COUNT "
|
|
329
|
+
f"(current value is {_MAX_IDLE_COUNT})."
|
|
330
|
+
f"\nIf this occurs at the end of a function or program, it means that your collector has not been "
|
|
331
|
+
f"collected, consider calling `collector.shutdown()` before ending the program."
|
|
332
|
+
)
|
|
333
|
+
continue
|
|
334
|
+
else:
|
|
335
|
+
# placeholder, will be checked after
|
|
336
|
+
msg = "continue"
|
|
337
|
+
if msg == "run_free":
|
|
338
|
+
run_free = True
|
|
339
|
+
msg = "continue"
|
|
340
|
+
if run_free:
|
|
341
|
+
# Capture shutdown / update / seed signal, but continue should not be expected
|
|
342
|
+
if pipe_child.poll(1e-4):
|
|
343
|
+
data_in, msg = pipe_child.recv()
|
|
344
|
+
if msg == "continue":
|
|
345
|
+
# Switch back to run_free = False
|
|
346
|
+
run_free = False
|
|
347
|
+
if msg == "pause":
|
|
348
|
+
queue_out.put((idx, "paused"), timeout=_TIMEOUT)
|
|
349
|
+
while not pipe_child.poll(1e-2):
|
|
350
|
+
continue
|
|
351
|
+
data_in, msg = pipe_child.recv()
|
|
352
|
+
if msg != "restart":
|
|
353
|
+
raise RuntimeError(f"Expected msg='restart', got {msg=}")
|
|
354
|
+
msg = "continue"
|
|
355
|
+
else:
|
|
356
|
+
data_in = None
|
|
357
|
+
# In run_free mode, determine msg based on replay_buffer.write_count for random frames
|
|
358
|
+
if (
|
|
359
|
+
replay_buffer is not None
|
|
360
|
+
and original_init_random_frames > 0
|
|
361
|
+
and replay_buffer.write_count < original_init_random_frames
|
|
362
|
+
):
|
|
363
|
+
msg = "continue_random"
|
|
364
|
+
else:
|
|
365
|
+
msg = "continue"
|
|
366
|
+
# Note: Weight updates are handled by background threads in weight sync schemes.
|
|
367
|
+
# The scheme's background receiver thread listens for "receive" instructions.
|
|
368
|
+
|
|
369
|
+
if msg == "enable_profile":
|
|
370
|
+
# Handle profile configuration sent after worker startup
|
|
371
|
+
if worker_profiler is None or not worker_profiler.is_active:
|
|
372
|
+
worker_profiler = _WorkerProfiler(data_in, worker_idx)
|
|
373
|
+
if worker_profiler.is_active:
|
|
374
|
+
worker_profiler.start()
|
|
375
|
+
pipe_child.send((j, "profile_enabled"))
|
|
376
|
+
has_timed_out = False
|
|
377
|
+
continue
|
|
378
|
+
|
|
379
|
+
if msg == "update":
|
|
380
|
+
# Legacy - weight updater
|
|
381
|
+
with timeit(f"worker/{idx}/update") as update_timer:
|
|
382
|
+
torchrl_logger.debug(
|
|
383
|
+
f"mp worker {idx}: Received weight update request..."
|
|
384
|
+
)
|
|
385
|
+
inner_collector.update_policy_weights_(policy_weights=data_in)
|
|
386
|
+
torchrl_logger.debug(
|
|
387
|
+
f"mp worker {idx}: Weight update completed in {update_timer.elapsed():.3f}s"
|
|
388
|
+
)
|
|
389
|
+
pipe_child.send((j, "updated"))
|
|
390
|
+
has_timed_out = False
|
|
391
|
+
continue
|
|
392
|
+
|
|
393
|
+
# Note: Weight updates are now handled by background threads in the weight sync schemes.
|
|
394
|
+
# The scheme's background receiver thread listens for "receive" instructions and
|
|
395
|
+
# applies weights automatically. No explicit message handling needed here.
|
|
396
|
+
|
|
397
|
+
if msg in ("continue", "continue_random"):
|
|
398
|
+
# When in run_free mode with a replay_buffer, the inner collector uses
|
|
399
|
+
# _should_use_random_frames() which checks replay_buffer.write_count.
|
|
400
|
+
# So we don't override init_random_frames. Otherwise, we use the message
|
|
401
|
+
# to control whether random frames are used.
|
|
402
|
+
if not run_free or replay_buffer is None:
|
|
403
|
+
if msg == "continue_random":
|
|
404
|
+
inner_collector.init_random_frames = float("inf")
|
|
405
|
+
else:
|
|
406
|
+
inner_collector.init_random_frames = -1
|
|
407
|
+
|
|
408
|
+
# Debug logging for rollout timing
|
|
409
|
+
# Use profiler context if profiling is active
|
|
410
|
+
profile_ctx = (
|
|
411
|
+
worker_profiler.profile_rollout()
|
|
412
|
+
if worker_profiler is not None and worker_profiler.is_active
|
|
413
|
+
else contextlib.nullcontext()
|
|
414
|
+
)
|
|
415
|
+
with profile_ctx:
|
|
416
|
+
with timeit(f"worker/{idx}/rollout") as rollout_timer:
|
|
417
|
+
torchrl_logger.debug(
|
|
418
|
+
f"mp worker {idx}: Starting rollout (j={j})..."
|
|
419
|
+
)
|
|
420
|
+
next_data = next(dc_iter)
|
|
421
|
+
torchrl_logger.debug(
|
|
422
|
+
f"mp worker {idx}: Rollout completed in {rollout_timer.elapsed():.3f}s, "
|
|
423
|
+
f"frames={next_data.numel() if hasattr(next_data, 'numel') else 'N/A'}"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Step the profiler after each rollout
|
|
427
|
+
if worker_profiler is not None and worker_profiler.is_active:
|
|
428
|
+
worker_profiler.step()
|
|
429
|
+
if pipe_child.poll(_MIN_TIMEOUT):
|
|
430
|
+
# in this case, main send a message to the worker while it was busy collecting trajectories.
|
|
431
|
+
# In that case, we skip the collected trajectory and get the message from main. This is faster than
|
|
432
|
+
# sending the trajectory in the queue until timeout when it's never going to be received.
|
|
433
|
+
continue
|
|
434
|
+
|
|
435
|
+
if replay_buffer is not None:
|
|
436
|
+
if extend_buffer:
|
|
437
|
+
next_data.names = None
|
|
438
|
+
replay_buffer.extend(next_data)
|
|
439
|
+
|
|
440
|
+
if run_free:
|
|
441
|
+
continue
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
queue_out.put((idx, j), timeout=_TIMEOUT)
|
|
445
|
+
if verbose:
|
|
446
|
+
torchrl_logger.debug(f"mp worker {idx} successfully sent data")
|
|
447
|
+
j += 1
|
|
448
|
+
has_timed_out = False
|
|
449
|
+
continue
|
|
450
|
+
except queue.Full:
|
|
451
|
+
has_timed_out = True
|
|
452
|
+
continue
|
|
453
|
+
|
|
454
|
+
if j == 0 or not use_buffers:
|
|
455
|
+
collected_tensordict = next_data
|
|
456
|
+
if (
|
|
457
|
+
storing_device is not None
|
|
458
|
+
and collected_tensordict.device != storing_device
|
|
459
|
+
):
|
|
460
|
+
raise RuntimeError(
|
|
461
|
+
f"expected device to be {storing_device} but got {collected_tensordict.device}"
|
|
462
|
+
)
|
|
463
|
+
if use_buffers:
|
|
464
|
+
# If policy and env are on cpu, we put in shared mem,
|
|
465
|
+
# if policy is on cuda and env on cuda, we are fine with this
|
|
466
|
+
# If policy is on cuda and env on cpu (or opposite) we put tensors that
|
|
467
|
+
# are on cpu in shared mem.
|
|
468
|
+
MPS_ERROR = (
|
|
469
|
+
"tensors on mps device cannot be put in shared memory. Make sure "
|
|
470
|
+
"the shared device (aka storing_device) is set to CPU."
|
|
471
|
+
)
|
|
472
|
+
if collected_tensordict.device is not None:
|
|
473
|
+
# placeholder in case we need different behaviors
|
|
474
|
+
if collected_tensordict.device.type in ("cpu",):
|
|
475
|
+
collected_tensordict.share_memory_()
|
|
476
|
+
elif collected_tensordict.device.type in ("mps",):
|
|
477
|
+
raise RuntimeError(MPS_ERROR)
|
|
478
|
+
elif collected_tensordict.device.type == "cuda":
|
|
479
|
+
collected_tensordict.share_memory_()
|
|
480
|
+
else:
|
|
481
|
+
raise NotImplementedError(
|
|
482
|
+
f"Device {collected_tensordict.device} is not supported in multi-collectors yet."
|
|
483
|
+
)
|
|
484
|
+
else:
|
|
485
|
+
# make sure each cpu tensor is shared - assuming non-cpu devices are shared
|
|
486
|
+
def cast_tensor(x, MPS_ERROR=MPS_ERROR):
|
|
487
|
+
if x.device.type in ("cpu",):
|
|
488
|
+
x.share_memory_()
|
|
489
|
+
if x.device.type in ("mps",):
|
|
490
|
+
RuntimeError(MPS_ERROR)
|
|
491
|
+
|
|
492
|
+
collected_tensordict.apply(cast_tensor, filter_empty=True)
|
|
493
|
+
data = (collected_tensordict, idx)
|
|
494
|
+
else:
|
|
495
|
+
if next_data is not collected_tensordict:
|
|
496
|
+
raise RuntimeError(
|
|
497
|
+
"Collector should return the same tensordict modified in-place."
|
|
498
|
+
)
|
|
499
|
+
data = idx # flag the worker that has sent its data
|
|
500
|
+
try:
|
|
501
|
+
queue_out.put((data, j), timeout=_TIMEOUT)
|
|
502
|
+
if verbose:
|
|
503
|
+
torchrl_logger.debug(f"mp worker {idx} successfully sent data")
|
|
504
|
+
j += 1
|
|
505
|
+
has_timed_out = False
|
|
506
|
+
continue
|
|
507
|
+
except queue.Full:
|
|
508
|
+
if verbose:
|
|
509
|
+
torchrl_logger.debug(f"mp worker {idx} has timed out")
|
|
510
|
+
has_timed_out = True
|
|
511
|
+
continue
|
|
512
|
+
|
|
513
|
+
if msg == "seed":
|
|
514
|
+
data_in, static_seed = data_in
|
|
515
|
+
new_seed = inner_collector.set_seed(data_in, static_seed=static_seed)
|
|
516
|
+
torch.manual_seed(data_in)
|
|
517
|
+
np.random.seed(data_in)
|
|
518
|
+
pipe_child.send((new_seed, "seeded"))
|
|
519
|
+
has_timed_out = False
|
|
520
|
+
continue
|
|
521
|
+
|
|
522
|
+
elif msg == "reset":
|
|
523
|
+
inner_collector.reset()
|
|
524
|
+
pipe_child.send((j, "reset"))
|
|
525
|
+
continue
|
|
526
|
+
|
|
527
|
+
elif msg == "state_dict":
|
|
528
|
+
from torch.utils._pytree import tree_map
|
|
529
|
+
|
|
530
|
+
state_dict = inner_collector.state_dict()
|
|
531
|
+
# Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility
|
|
532
|
+
# CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth)
|
|
533
|
+
state_dict = tree_map(_map_to_cpu_if_needed, state_dict)
|
|
534
|
+
state_dict = TensorDict(state_dict)
|
|
535
|
+
state_dict = state_dict.clone().apply(_cast, state_dict).to_dict()
|
|
536
|
+
pipe_child.send((state_dict, "state_dict"))
|
|
537
|
+
has_timed_out = False
|
|
538
|
+
continue
|
|
539
|
+
|
|
540
|
+
elif msg == "load_state_dict":
|
|
541
|
+
state_dict = data_in
|
|
542
|
+
inner_collector.load_state_dict(state_dict)
|
|
543
|
+
del state_dict
|
|
544
|
+
pipe_child.send((j, "loaded"))
|
|
545
|
+
has_timed_out = False
|
|
546
|
+
continue
|
|
547
|
+
|
|
548
|
+
elif msg == "getattr_policy":
|
|
549
|
+
attr_name = data_in
|
|
550
|
+
try:
|
|
551
|
+
result = getattr(inner_collector.policy, attr_name)
|
|
552
|
+
pipe_child.send((result, "getattr_policy"))
|
|
553
|
+
except AttributeError as e:
|
|
554
|
+
pipe_child.send((e, "getattr_policy"))
|
|
555
|
+
has_timed_out = False
|
|
556
|
+
continue
|
|
557
|
+
|
|
558
|
+
elif msg == "getattr_env":
|
|
559
|
+
attr_name = data_in
|
|
560
|
+
try:
|
|
561
|
+
result = getattr(inner_collector.env, attr_name)
|
|
562
|
+
pipe_child.send((result, "getattr_env"))
|
|
563
|
+
except AttributeError as e:
|
|
564
|
+
pipe_child.send((e, "getattr_env"))
|
|
565
|
+
has_timed_out = False
|
|
566
|
+
continue
|
|
567
|
+
|
|
568
|
+
elif msg == "close":
|
|
569
|
+
# Stop profiler if active
|
|
570
|
+
if worker_profiler is not None and worker_profiler.is_active:
|
|
571
|
+
worker_profiler.stop()
|
|
572
|
+
del collected_tensordict, data, next_data, data_in
|
|
573
|
+
inner_collector.shutdown()
|
|
574
|
+
del inner_collector, dc_iter
|
|
575
|
+
pipe_child.send("closed")
|
|
576
|
+
if verbose:
|
|
577
|
+
torchrl_logger.debug(f"collector {idx} closed")
|
|
578
|
+
break
|
|
579
|
+
|
|
580
|
+
else:
|
|
581
|
+
raise Exception(f"Unrecognized message {msg}")
|