torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,2009 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import threading
|
|
5
|
+
import warnings
|
|
6
|
+
import weakref
|
|
7
|
+
from collections import OrderedDict
|
|
8
|
+
from collections.abc import Callable, Iterator, Sequence
|
|
9
|
+
from textwrap import indent
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
|
|
15
|
+
from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase
|
|
16
|
+
from torch import nn
|
|
17
|
+
from torchrl import compile_with_warmup
|
|
18
|
+
from torchrl._utils import (
|
|
19
|
+
_ends_with,
|
|
20
|
+
_make_ordinal_device,
|
|
21
|
+
_replace_last,
|
|
22
|
+
accept_remote_rref_udf_invocation,
|
|
23
|
+
prod,
|
|
24
|
+
RL_WARNINGS,
|
|
25
|
+
)
|
|
26
|
+
from torchrl.collectors._base import _LegacyCollectorMeta, BaseCollector, ProfileConfig
|
|
27
|
+
from torchrl.collectors._constants import (
|
|
28
|
+
cudagraph_mark_step_begin,
|
|
29
|
+
DEFAULT_EXPLORATION_TYPE,
|
|
30
|
+
ExplorationType,
|
|
31
|
+
)
|
|
32
|
+
from torchrl.collectors.utils import _TrajectoryPool, split_trajectories
|
|
33
|
+
from torchrl.collectors.weight_update import WeightUpdaterBase
|
|
34
|
+
from torchrl.data import ReplayBuffer
|
|
35
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
36
|
+
from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv
|
|
37
|
+
from torchrl.envs.common import _do_nothing
|
|
38
|
+
from torchrl.envs.llm.transforms import PolicyVersion
|
|
39
|
+
from torchrl.envs.utils import (
|
|
40
|
+
_aggregate_end_of_traj,
|
|
41
|
+
_make_compatible_policy,
|
|
42
|
+
set_exploration_type,
|
|
43
|
+
)
|
|
44
|
+
from torchrl.modules import RandomPolicy, set_exploration_modules_spec_from_env
|
|
45
|
+
from torchrl.weight_update.utils import _resolve_model
|
|
46
|
+
from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _CollectorProfiler:
|
|
50
|
+
"""Helper class for profiling collector rollouts in single-process mode.
|
|
51
|
+
|
|
52
|
+
Manages the PyTorch profiler lifecycle for the Collector class.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, profile_config: ProfileConfig):
|
|
56
|
+
self.config = profile_config
|
|
57
|
+
self.rollout_count = 0
|
|
58
|
+
self._profiler = None
|
|
59
|
+
self._stopped = False
|
|
60
|
+
self._active = False
|
|
61
|
+
|
|
62
|
+
# Set up profiler schedule
|
|
63
|
+
active_rollouts = self.config.num_rollouts - self.config.warmup_rollouts
|
|
64
|
+
profiler_schedule = torch.profiler.schedule(
|
|
65
|
+
skip_first=self.config.warmup_rollouts,
|
|
66
|
+
wait=0,
|
|
67
|
+
warmup=0,
|
|
68
|
+
active=active_rollouts,
|
|
69
|
+
repeat=1,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Get activities
|
|
73
|
+
activities = self.config.get_activities()
|
|
74
|
+
if not activities:
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
# Determine trace handler
|
|
78
|
+
if self.config.on_trace_ready is not None:
|
|
79
|
+
on_trace_ready = self.config.on_trace_ready
|
|
80
|
+
else:
|
|
81
|
+
save_path = self.config.get_save_path(
|
|
82
|
+
0
|
|
83
|
+
) # Use worker_idx 0 for single-process
|
|
84
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
|
|
86
|
+
from torchrl import logger as torchrl_logger
|
|
87
|
+
|
|
88
|
+
def on_trace_ready(prof, save_path=save_path):
|
|
89
|
+
prof.export_chrome_trace(str(save_path))
|
|
90
|
+
torchrl_logger.info(f"Collector: Profiling trace saved to {save_path}")
|
|
91
|
+
|
|
92
|
+
self._profiler = torch.profiler.profile(
|
|
93
|
+
activities=activities,
|
|
94
|
+
schedule=profiler_schedule,
|
|
95
|
+
on_trace_ready=on_trace_ready,
|
|
96
|
+
record_shapes=self.config.record_shapes,
|
|
97
|
+
profile_memory=self.config.profile_memory,
|
|
98
|
+
with_stack=self.config.with_stack,
|
|
99
|
+
with_flops=self.config.with_flops,
|
|
100
|
+
)
|
|
101
|
+
self._active = True
|
|
102
|
+
|
|
103
|
+
def start(self) -> None:
|
|
104
|
+
"""Start the profiler."""
|
|
105
|
+
from torchrl import logger as torchrl_logger
|
|
106
|
+
|
|
107
|
+
if self._profiler is not None and not self._stopped:
|
|
108
|
+
self._profiler.start()
|
|
109
|
+
torchrl_logger.info(
|
|
110
|
+
f"Collector: Profiling started. "
|
|
111
|
+
f"Will profile rollouts {self.config.warmup_rollouts} to {self.config.num_rollouts - 1}."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def step(self) -> bool:
|
|
115
|
+
"""Step the profiler after a rollout.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
True if profiling is complete.
|
|
119
|
+
"""
|
|
120
|
+
if self._profiler is None or self._stopped:
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
self.rollout_count += 1
|
|
124
|
+
self._profiler.step()
|
|
125
|
+
|
|
126
|
+
# Check if profiling is complete
|
|
127
|
+
if self.rollout_count >= self.config.num_rollouts:
|
|
128
|
+
self.stop()
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
def stop(self) -> None:
|
|
134
|
+
"""Stop the profiler and export trace."""
|
|
135
|
+
from torchrl import logger as torchrl_logger
|
|
136
|
+
|
|
137
|
+
if self._profiler is not None and not self._stopped:
|
|
138
|
+
self._profiler.stop()
|
|
139
|
+
self._stopped = True
|
|
140
|
+
torchrl_logger.info(
|
|
141
|
+
f"Collector: Profiling complete after {self.rollout_count} rollouts."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def is_active(self) -> bool:
|
|
146
|
+
"""Check if profiling is active."""
|
|
147
|
+
return self._active and not self._stopped
|
|
148
|
+
|
|
149
|
+
@contextlib.contextmanager
|
|
150
|
+
def profile_rollout(self):
|
|
151
|
+
"""Context manager for profiling a single rollout."""
|
|
152
|
+
if self._profiler is not None and not self._stopped:
|
|
153
|
+
with torch.profiler.record_function("collector_rollout"):
|
|
154
|
+
yield
|
|
155
|
+
else:
|
|
156
|
+
yield
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _cuda_sync_if_initialized():
|
|
160
|
+
"""Synchronize CUDA only if it has been initialized.
|
|
161
|
+
|
|
162
|
+
This is a safe alternative to calling `torch.cuda.synchronize()` directly.
|
|
163
|
+
In forked subprocesses on machines with CUDA, calling `synchronize()` will
|
|
164
|
+
fail with "Cannot re-initialize CUDA in forked subprocess" if CUDA was
|
|
165
|
+
initialized in the parent process before fork. By checking
|
|
166
|
+
`is_initialized()` first, we skip the sync in such cases since no CUDA
|
|
167
|
+
operations have occurred in this process.
|
|
168
|
+
"""
|
|
169
|
+
if torch.cuda.is_initialized():
|
|
170
|
+
torch.cuda.synchronize()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@accept_remote_rref_udf_invocation
|
|
174
|
+
class Collector(BaseCollector):
|
|
175
|
+
"""Generic data collector for RL problems. Requires an environment constructor and a policy.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
create_env_fn (Callable or EnvBase): a callable that returns an instance of
|
|
179
|
+
:class:`~torchrl.envs.EnvBase` class, or the env itself.
|
|
180
|
+
policy (Callable): Policy to be executed in the environment.
|
|
181
|
+
Must accept :class:`tensordict.tensordict.TensorDictBase` object as input.
|
|
182
|
+
If ``None`` is provided, the policy used will be a
|
|
183
|
+
:class:`~torchrl.collectors.RandomPolicy` instance with the environment
|
|
184
|
+
``action_spec``.
|
|
185
|
+
Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`.
|
|
186
|
+
This is the recommended usage of the collector.
|
|
187
|
+
Other callables are accepted too:
|
|
188
|
+
If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module`
|
|
189
|
+
instances) it will be wrapped in a `nn.Module` first.
|
|
190
|
+
Then, the collector will try to assess if these
|
|
191
|
+
modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not.
|
|
192
|
+
|
|
193
|
+
- If the policy forward signature matches any of ``forward(self, tensordict)``,
|
|
194
|
+
``forward(self, td)`` or ``forward(self, <anything>: TensorDictBase)`` (or
|
|
195
|
+
any typing with a single argument typed as a subclass of ``TensorDictBase``)
|
|
196
|
+
then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`.
|
|
197
|
+
|
|
198
|
+
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
199
|
+
|
|
200
|
+
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
201
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
202
|
+
|
|
203
|
+
Keyword Args:
|
|
204
|
+
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
205
|
+
a policy instance. This is exclusive with the `policy` argument.
|
|
206
|
+
|
|
207
|
+
.. note:: `policy_factory` comes in handy whenever the policy cannot be serialized.
|
|
208
|
+
|
|
209
|
+
frames_per_batch (int): A keyword-only argument representing the total
|
|
210
|
+
number of elements in a batch.
|
|
211
|
+
total_frames (int): A keyword-only argument representing the total
|
|
212
|
+
number of frames returned by the collector
|
|
213
|
+
during its lifespan. If the ``total_frames`` is not divisible by
|
|
214
|
+
``frames_per_batch``, an exception is raised.
|
|
215
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
216
|
+
Defaults to ``-1`` (endless collector).
|
|
217
|
+
device (int, str or torch.device, optional): The generic device of the
|
|
218
|
+
collector. The ``device`` args fills any non-specified device: if
|
|
219
|
+
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
|
220
|
+
``env_device`` is not specified, its value will be set to ``device``.
|
|
221
|
+
Defaults to ``None`` (No default device).
|
|
222
|
+
storing_device (int, str or torch.device, optional): The device on which
|
|
223
|
+
the output :class:`~tensordict.TensorDict` will be stored.
|
|
224
|
+
If ``device`` is passed and ``storing_device`` is ``None``, it will
|
|
225
|
+
default to the value indicated by ``device``.
|
|
226
|
+
For long trajectories, it may be necessary to store the data on a different
|
|
227
|
+
device than the one where the policy and env are executed.
|
|
228
|
+
Defaults to ``None`` (the output tensordict isn't on a specific device,
|
|
229
|
+
leaf tensors sit on the device where they were created).
|
|
230
|
+
env_device (int, str or torch.device, optional): The device on which
|
|
231
|
+
the environment should be cast (or executed if that functionality is
|
|
232
|
+
supported). If not specified and the env has a non-``None`` device,
|
|
233
|
+
``env_device`` will default to that value. If ``device`` is passed
|
|
234
|
+
and ``env_device=None``, it will default to ``device``. If the value
|
|
235
|
+
as such specified of ``env_device`` differs from ``policy_device``
|
|
236
|
+
and one of them is not ``None``, the data will be cast to ``env_device``
|
|
237
|
+
before being passed to the env (i.e., passing different devices to
|
|
238
|
+
policy and env is supported). Defaults to ``None``.
|
|
239
|
+
policy_device (int, str or torch.device, optional): The device on which
|
|
240
|
+
the policy should be cast.
|
|
241
|
+
If ``device`` is passed and ``policy_device=None``, it will default
|
|
242
|
+
to ``device``. If the value as such specified of ``policy_device``
|
|
243
|
+
differs from ``env_device`` and one of them is not ``None``,
|
|
244
|
+
the data will be cast to ``policy_device`` before being passed to
|
|
245
|
+
the policy (i.e., passing different devices to policy and env is
|
|
246
|
+
supported). Defaults to ``None``.
|
|
247
|
+
create_env_kwargs (dict, optional): Dictionary of kwargs for
|
|
248
|
+
``create_env_fn``.
|
|
249
|
+
max_frames_per_traj (int, optional): Maximum steps per trajectory.
|
|
250
|
+
Note that a trajectory can span across multiple batches (unless
|
|
251
|
+
``reset_at_each_iter`` is set to ``True``, see below).
|
|
252
|
+
Once a trajectory reaches ``n_steps``, the environment is reset.
|
|
253
|
+
If the environment wraps multiple environments together, the number
|
|
254
|
+
of steps is tracked for each environment independently. Negative
|
|
255
|
+
values are allowed, in which case this argument is ignored.
|
|
256
|
+
Defaults to ``None`` (i.e., no maximum number of steps).
|
|
257
|
+
init_random_frames (int, optional): Number of frames for which the
|
|
258
|
+
policy is ignored before it is called. This feature is mainly
|
|
259
|
+
intended to be used in offline/model-based settings, where a
|
|
260
|
+
batch of random trajectories can be used to initialize training.
|
|
261
|
+
If provided, it will be rounded up to the closest multiple of frames_per_batch.
|
|
262
|
+
Defaults to ``None`` (i.e. no random frames).
|
|
263
|
+
reset_at_each_iter (bool, optional): Whether environments should be reset
|
|
264
|
+
at the beginning of a batch collection.
|
|
265
|
+
Defaults to ``False``.
|
|
266
|
+
postproc (Callable, optional): A post-processing transform, such as
|
|
267
|
+
a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep`
|
|
268
|
+
instance.
|
|
269
|
+
|
|
270
|
+
.. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer
|
|
271
|
+
as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`.
|
|
272
|
+
|
|
273
|
+
Defaults to ``None``.
|
|
274
|
+
split_trajs (bool, optional): Boolean indicating whether the resulting
|
|
275
|
+
TensorDict should be split according to the trajectories.
|
|
276
|
+
See :func:`~torchrl.collectors.utils.split_trajectories` for more
|
|
277
|
+
information.
|
|
278
|
+
Defaults to ``False``.
|
|
279
|
+
exploration_type (ExplorationType, optional): interaction mode to be used when
|
|
280
|
+
collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``,
|
|
281
|
+
``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE``
|
|
282
|
+
or ``torchrl.envs.utils.ExplorationType.MEAN``.
|
|
283
|
+
return_same_td (bool, optional): if ``True``, the same TensorDict
|
|
284
|
+
will be returned at each iteration, with its values
|
|
285
|
+
updated. This feature should be used cautiously: if the same
|
|
286
|
+
tensordict is added to a replay buffer for instance,
|
|
287
|
+
the whole content of the buffer will be identical.
|
|
288
|
+
Default is ``False``.
|
|
289
|
+
interruptor (_Interruptor, optional):
|
|
290
|
+
An _Interruptor object that can be used from outside the class to control rollout collection.
|
|
291
|
+
The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement
|
|
292
|
+
strategies such as preeptively stopping rollout collection.
|
|
293
|
+
Default is ``False``.
|
|
294
|
+
set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding
|
|
295
|
+
``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of
|
|
296
|
+
a rollout is reached. If no ``"truncated"`` key is found, an exception is raised.
|
|
297
|
+
Truncated keys can be set through ``env.add_truncated_keys``.
|
|
298
|
+
Defaults to ``False``.
|
|
299
|
+
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
|
|
300
|
+
This isn't compatible with environments with dynamic specs. Defaults to ``True``
|
|
301
|
+
for envs without dynamic specs, ``False`` for others.
|
|
302
|
+
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
|
|
303
|
+
but populate the buffer instead.
|
|
304
|
+
Defaults to ``None``.
|
|
305
|
+
|
|
306
|
+
.. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts.
|
|
307
|
+
If the buffer needs to be populated with individual frames as they are collected,
|
|
308
|
+
set ``extend_buffer=False`` (deprecated).
|
|
309
|
+
|
|
310
|
+
.. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires
|
|
311
|
+
`extend_buffer=True`, as the whole batch needs to be observed to apply these transforms.
|
|
312
|
+
|
|
313
|
+
extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
|
|
314
|
+
with single steps. Defaults to `True`.
|
|
315
|
+
|
|
316
|
+
.. note:: Setting this to `False` is deprecated and will be removed in a future version.
|
|
317
|
+
Extending the buffer with entire rollouts is the recommended approach for better
|
|
318
|
+
compatibility with postprocessing and trajectory splitting.
|
|
319
|
+
trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be
|
|
320
|
+
assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules
|
|
321
|
+
and ``False`` otherwise.
|
|
322
|
+
compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled
|
|
323
|
+
using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it
|
|
324
|
+
will be used to compile the policy.
|
|
325
|
+
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
|
|
326
|
+
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
|
|
327
|
+
If a dictionary of kwargs is passed, it will be used to wrap the policy.
|
|
328
|
+
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
|
|
329
|
+
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
|
|
330
|
+
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
|
|
331
|
+
crashes.
|
|
332
|
+
Defaults to ``False``.
|
|
333
|
+
weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase`
|
|
334
|
+
or its subclass, responsible for updating the policy weights on remote inference workers.
|
|
335
|
+
This is typically not used in :class:`~torchrl.collectors.Collector` as it operates in a single-process environment.
|
|
336
|
+
Consider using a constructor if the updater needs to be serialized.
|
|
337
|
+
weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for Collector**.
|
|
338
|
+
Collector is a leaf collector and cannot send weights to sub-collectors.
|
|
339
|
+
Providing this parameter will raise a ValueError.
|
|
340
|
+
Use ``weight_recv_schemes`` if you need to receive weights from a parent collector.
|
|
341
|
+
weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for
|
|
342
|
+
RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy")
|
|
343
|
+
and values are WeightSyncScheme instances configured to receive weights.
|
|
344
|
+
This enables cascading weight updates in hierarchies like:
|
|
345
|
+
RPCDataCollector -> MultiSyncCollector -> Collector.
|
|
346
|
+
Defaults to ``None``.
|
|
347
|
+
track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy.
|
|
348
|
+
This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment.
|
|
349
|
+
Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track
|
|
350
|
+
the policy version.
|
|
351
|
+
Defaults to `False`.
|
|
352
|
+
|
|
353
|
+
Examples:
|
|
354
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
355
|
+
>>> from tensordict.nn import TensorDictModule
|
|
356
|
+
>>> from torch import nn
|
|
357
|
+
>>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu")
|
|
358
|
+
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
|
|
359
|
+
>>> collector = Collector(
|
|
360
|
+
... create_env_fn=env_maker,
|
|
361
|
+
... policy=policy,
|
|
362
|
+
... total_frames=2000,
|
|
363
|
+
... max_frames_per_traj=50,
|
|
364
|
+
... frames_per_batch=200,
|
|
365
|
+
... init_random_frames=-1,
|
|
366
|
+
... reset_at_each_iter=False,
|
|
367
|
+
... device="cpu",
|
|
368
|
+
... storing_device="cpu",
|
|
369
|
+
... )
|
|
370
|
+
>>> for i, data in enumerate(collector):
|
|
371
|
+
... if i == 2:
|
|
372
|
+
... print(data)
|
|
373
|
+
... break
|
|
374
|
+
TensorDict(
|
|
375
|
+
fields={
|
|
376
|
+
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
377
|
+
collector: TensorDict(
|
|
378
|
+
fields={
|
|
379
|
+
traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
380
|
+
batch_size=torch.Size([200]),
|
|
381
|
+
device=cpu,
|
|
382
|
+
is_shared=False),
|
|
383
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
384
|
+
next: TensorDict(
|
|
385
|
+
fields={
|
|
386
|
+
done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
387
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
388
|
+
reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
389
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
390
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
391
|
+
batch_size=torch.Size([200]),
|
|
392
|
+
device=cpu,
|
|
393
|
+
is_shared=False),
|
|
394
|
+
observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
395
|
+
step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
396
|
+
truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
397
|
+
batch_size=torch.Size([200]),
|
|
398
|
+
device=cpu,
|
|
399
|
+
is_shared=False)
|
|
400
|
+
>>> del collector
|
|
401
|
+
|
|
402
|
+
The collector delivers batches of data that are marked with a ``"time"``
|
|
403
|
+
dimension.
|
|
404
|
+
|
|
405
|
+
Examples:
|
|
406
|
+
>>> assert data.names[-1] == "time"
|
|
407
|
+
|
|
408
|
+
"""
|
|
409
|
+
|
|
410
|
+
_ignore_rb: bool = False
|
|
411
|
+
|
|
412
|
+
def __init__(
|
|
413
|
+
self,
|
|
414
|
+
create_env_fn: (
|
|
415
|
+
EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821
|
|
416
|
+
), # noqa: F821
|
|
417
|
+
policy: None
|
|
418
|
+
| (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None,
|
|
419
|
+
*,
|
|
420
|
+
policy_factory: Callable[[], Callable] | None = None,
|
|
421
|
+
frames_per_batch: int,
|
|
422
|
+
total_frames: int = -1,
|
|
423
|
+
device: DEVICE_TYPING | None = None,
|
|
424
|
+
storing_device: DEVICE_TYPING | None = None,
|
|
425
|
+
policy_device: DEVICE_TYPING | None = None,
|
|
426
|
+
env_device: DEVICE_TYPING | None = None,
|
|
427
|
+
create_env_kwargs: dict[str, Any] | None = None,
|
|
428
|
+
max_frames_per_traj: int | None = None,
|
|
429
|
+
init_random_frames: int | None = None,
|
|
430
|
+
reset_at_each_iter: bool = False,
|
|
431
|
+
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
|
|
432
|
+
split_trajs: bool | None = None,
|
|
433
|
+
exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE,
|
|
434
|
+
return_same_td: bool = False,
|
|
435
|
+
reset_when_done: bool = True,
|
|
436
|
+
interruptor=None,
|
|
437
|
+
set_truncated: bool = False,
|
|
438
|
+
use_buffers: bool | None = None,
|
|
439
|
+
replay_buffer: ReplayBuffer | None = None,
|
|
440
|
+
extend_buffer: bool = True,
|
|
441
|
+
local_init_rb: bool | None = None,
|
|
442
|
+
trust_policy: bool | None = None,
|
|
443
|
+
compile_policy: bool | dict[str, Any] | None = None,
|
|
444
|
+
cudagraph_policy: bool | dict[str, Any] | None = None,
|
|
445
|
+
no_cuda_sync: bool = False,
|
|
446
|
+
weight_updater: WeightUpdaterBase
|
|
447
|
+
| Callable[[], WeightUpdaterBase]
|
|
448
|
+
| None = None,
|
|
449
|
+
weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
|
|
450
|
+
weight_recv_schemes: dict[str, WeightSyncScheme] | None = None,
|
|
451
|
+
track_policy_version: bool = False,
|
|
452
|
+
worker_idx: int | None = None,
|
|
453
|
+
**kwargs,
|
|
454
|
+
):
|
|
455
|
+
self.closed = True
|
|
456
|
+
self.worker_idx = worker_idx
|
|
457
|
+
|
|
458
|
+
# Note: weight_sync_schemes can be used to send weights to components
|
|
459
|
+
# within the environment (e.g., RayModuleTransform), not just sub-collectors
|
|
460
|
+
|
|
461
|
+
# Initialize environment
|
|
462
|
+
env = self._init_env(create_env_fn, create_env_kwargs)
|
|
463
|
+
|
|
464
|
+
# Initialize policy
|
|
465
|
+
policy = self._init_policy(policy, policy_factory, env, trust_policy)
|
|
466
|
+
self._read_compile_kwargs(compile_policy, cudagraph_policy)
|
|
467
|
+
|
|
468
|
+
# Handle trajectory pool and validate kwargs
|
|
469
|
+
self._traj_pool_val = kwargs.pop("traj_pool", None)
|
|
470
|
+
if kwargs:
|
|
471
|
+
raise TypeError(
|
|
472
|
+
f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}."
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Set up devices and synchronization
|
|
476
|
+
self._setup_devices(
|
|
477
|
+
device=device,
|
|
478
|
+
storing_device=storing_device,
|
|
479
|
+
policy_device=policy_device,
|
|
480
|
+
env_device=env_device,
|
|
481
|
+
no_cuda_sync=no_cuda_sync,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
self.env: EnvBase = env
|
|
485
|
+
del env
|
|
486
|
+
|
|
487
|
+
# Set up policy version tracking
|
|
488
|
+
self._setup_policy_version_tracking(track_policy_version)
|
|
489
|
+
|
|
490
|
+
# Set up replay buffer
|
|
491
|
+
self._setup_replay_buffer(
|
|
492
|
+
replay_buffer=replay_buffer,
|
|
493
|
+
extend_buffer=extend_buffer,
|
|
494
|
+
local_init_rb=local_init_rb,
|
|
495
|
+
postproc=postproc,
|
|
496
|
+
split_trajs=split_trajs,
|
|
497
|
+
return_same_td=return_same_td,
|
|
498
|
+
use_buffers=use_buffers,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
self.closed = False
|
|
502
|
+
|
|
503
|
+
# Validate reset_when_done
|
|
504
|
+
if not reset_when_done:
|
|
505
|
+
raise ValueError("reset_when_done is deprecated.")
|
|
506
|
+
self.reset_when_done = reset_when_done
|
|
507
|
+
self.n_env = self.env.batch_size.numel()
|
|
508
|
+
|
|
509
|
+
# Register collector with policy and env
|
|
510
|
+
if hasattr(policy, "register_collector"):
|
|
511
|
+
policy.register_collector(self)
|
|
512
|
+
if hasattr(self.env, "register_collector"):
|
|
513
|
+
self.env.register_collector(self)
|
|
514
|
+
|
|
515
|
+
# Set up policy and weights
|
|
516
|
+
self._setup_policy_and_weights(policy)
|
|
517
|
+
|
|
518
|
+
# Apply environment device
|
|
519
|
+
self._apply_env_device()
|
|
520
|
+
|
|
521
|
+
# Set up max frames per trajectory
|
|
522
|
+
self._setup_max_frames_per_traj(max_frames_per_traj)
|
|
523
|
+
|
|
524
|
+
# Validate and set total frames
|
|
525
|
+
self.reset_at_each_iter = reset_at_each_iter
|
|
526
|
+
self._setup_total_frames(total_frames, frames_per_batch)
|
|
527
|
+
|
|
528
|
+
# Set up init random frames
|
|
529
|
+
self._setup_init_random_frames(init_random_frames, frames_per_batch)
|
|
530
|
+
|
|
531
|
+
# Set up postproc
|
|
532
|
+
self._setup_postproc(postproc)
|
|
533
|
+
|
|
534
|
+
# Calculate frames per batch
|
|
535
|
+
self._setup_frames_per_batch(frames_per_batch)
|
|
536
|
+
|
|
537
|
+
# Set exploration and other options
|
|
538
|
+
self.exploration_type = (
|
|
539
|
+
exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE
|
|
540
|
+
)
|
|
541
|
+
self.return_same_td = return_same_td
|
|
542
|
+
self.set_truncated = set_truncated
|
|
543
|
+
|
|
544
|
+
# Create shuttle and rollout buffers
|
|
545
|
+
self._make_shuttle()
|
|
546
|
+
self._maybe_make_final_rollout(make_rollout=self._use_buffers)
|
|
547
|
+
self._set_truncated_keys()
|
|
548
|
+
|
|
549
|
+
# Set split trajectories option
|
|
550
|
+
if split_trajs is None:
|
|
551
|
+
split_trajs = False
|
|
552
|
+
self.split_trajs = split_trajs
|
|
553
|
+
self._exclude_private_keys = True
|
|
554
|
+
|
|
555
|
+
# Set up interruptor and frame tracking
|
|
556
|
+
self.interruptor = interruptor
|
|
557
|
+
self._frames = 0
|
|
558
|
+
self._iter = -1
|
|
559
|
+
|
|
560
|
+
# Set up weight synchronization
|
|
561
|
+
self._setup_weight_sync(weight_updater, weight_sync_schemes)
|
|
562
|
+
|
|
563
|
+
# Set up weight receivers if provided
|
|
564
|
+
if weight_recv_schemes is not None:
|
|
565
|
+
self.register_scheme_receiver(weight_recv_schemes)
|
|
566
|
+
|
|
567
|
+
def _init_env(
|
|
568
|
+
self,
|
|
569
|
+
create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase],
|
|
570
|
+
create_env_kwargs: dict[str, Any] | None,
|
|
571
|
+
) -> EnvBase:
|
|
572
|
+
"""Initialize and configure the environment."""
|
|
573
|
+
from torchrl.envs.batched_envs import BatchedEnvBase
|
|
574
|
+
|
|
575
|
+
if create_env_kwargs is None:
|
|
576
|
+
create_env_kwargs = {}
|
|
577
|
+
|
|
578
|
+
if not isinstance(create_env_fn, EnvBase):
|
|
579
|
+
env = create_env_fn(**create_env_kwargs)
|
|
580
|
+
else:
|
|
581
|
+
env = create_env_fn
|
|
582
|
+
if create_env_kwargs:
|
|
583
|
+
if not isinstance(env, BatchedEnvBase):
|
|
584
|
+
raise RuntimeError(
|
|
585
|
+
"kwargs were passed to Collector but they can't be set "
|
|
586
|
+
f"on environment of type {type(create_env_fn)}."
|
|
587
|
+
)
|
|
588
|
+
env.update_kwargs(create_env_kwargs)
|
|
589
|
+
return env
|
|
590
|
+
|
|
591
|
+
def _init_policy(
|
|
592
|
+
self,
|
|
593
|
+
policy: TensorDictModule | Callable | None,
|
|
594
|
+
policy_factory: Callable[[], Callable] | None,
|
|
595
|
+
env: EnvBase,
|
|
596
|
+
trust_policy: bool | None,
|
|
597
|
+
) -> TensorDictModule | Callable:
|
|
598
|
+
"""Initialize and configure the policy before device placement / wrapping."""
|
|
599
|
+
if policy is None:
|
|
600
|
+
if policy_factory is not None:
|
|
601
|
+
policy = policy_factory()
|
|
602
|
+
else:
|
|
603
|
+
policy = RandomPolicy(env.full_action_spec)
|
|
604
|
+
elif policy_factory is not None:
|
|
605
|
+
raise TypeError("policy_factory cannot be used with policy argument.")
|
|
606
|
+
|
|
607
|
+
if trust_policy is None:
|
|
608
|
+
trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule))
|
|
609
|
+
self.trust_policy = trust_policy
|
|
610
|
+
|
|
611
|
+
return policy
|
|
612
|
+
|
|
613
|
+
def _setup_devices(
|
|
614
|
+
self,
|
|
615
|
+
device: DEVICE_TYPING | None,
|
|
616
|
+
storing_device: DEVICE_TYPING | None,
|
|
617
|
+
policy_device: DEVICE_TYPING | None,
|
|
618
|
+
env_device: DEVICE_TYPING | None,
|
|
619
|
+
no_cuda_sync: bool,
|
|
620
|
+
) -> None:
|
|
621
|
+
"""Set up devices and synchronization functions."""
|
|
622
|
+
storing_device, policy_device, env_device = self._get_devices(
|
|
623
|
+
storing_device=storing_device,
|
|
624
|
+
policy_device=policy_device,
|
|
625
|
+
env_device=env_device,
|
|
626
|
+
device=device,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
self.storing_device = storing_device
|
|
630
|
+
self._sync_storage = self._get_sync_fn(storing_device)
|
|
631
|
+
|
|
632
|
+
self.env_device = env_device
|
|
633
|
+
self._sync_env = self._get_sync_fn(env_device)
|
|
634
|
+
|
|
635
|
+
self.policy_device = policy_device
|
|
636
|
+
self._sync_policy = self._get_sync_fn(policy_device)
|
|
637
|
+
|
|
638
|
+
self.device = device
|
|
639
|
+
self.no_cuda_sync = no_cuda_sync
|
|
640
|
+
self._cast_to_policy_device = self.policy_device != self.env_device
|
|
641
|
+
|
|
642
|
+
def _get_sync_fn(self, device: torch.device | None) -> Callable:
|
|
643
|
+
"""Get the appropriate synchronization function for a device."""
|
|
644
|
+
if device is not None and device.type != "cuda":
|
|
645
|
+
# When destination is not CUDA, we may need to sync to wait for
|
|
646
|
+
# async GPU→CPU transfers to complete before proceeding.
|
|
647
|
+
if torch.cuda.is_available():
|
|
648
|
+
# Return a safe wrapper that only syncs if CUDA was actually
|
|
649
|
+
# initialized. This avoids "Cannot re-initialize CUDA in forked
|
|
650
|
+
# subprocess" errors when using fork start method on GPU machines
|
|
651
|
+
# with CPU-only collectors.
|
|
652
|
+
return _cuda_sync_if_initialized
|
|
653
|
+
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
|
|
654
|
+
return torch.mps.synchronize
|
|
655
|
+
elif hasattr(torch, "npu") and torch.npu.is_available():
|
|
656
|
+
return torch.npu.synchronize
|
|
657
|
+
elif device.type == "cpu":
|
|
658
|
+
return _do_nothing
|
|
659
|
+
else:
|
|
660
|
+
raise RuntimeError("Non supported device")
|
|
661
|
+
else:
|
|
662
|
+
return _do_nothing
|
|
663
|
+
|
|
664
|
+
def _setup_policy_version_tracking(
|
|
665
|
+
self, track_policy_version: bool | PolicyVersion
|
|
666
|
+
) -> None:
|
|
667
|
+
"""Set up policy version tracking if requested."""
|
|
668
|
+
self.policy_version_tracker = track_policy_version
|
|
669
|
+
if isinstance(track_policy_version, bool) and track_policy_version:
|
|
670
|
+
from torchrl.envs.batched_envs import BatchedEnvBase
|
|
671
|
+
|
|
672
|
+
if isinstance(self.env, BatchedEnvBase):
|
|
673
|
+
raise RuntimeError(
|
|
674
|
+
"BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, "
|
|
675
|
+
"and pass that transform to the collector."
|
|
676
|
+
)
|
|
677
|
+
self.policy_version_tracker = PolicyVersion()
|
|
678
|
+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
|
|
679
|
+
elif hasattr(track_policy_version, "increment_version"):
|
|
680
|
+
self.policy_version_tracker = track_policy_version
|
|
681
|
+
self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore
|
|
682
|
+
else:
|
|
683
|
+
self.policy_version_tracker = None
|
|
684
|
+
|
|
685
|
+
def _setup_replay_buffer(
|
|
686
|
+
self,
|
|
687
|
+
replay_buffer: ReplayBuffer | None,
|
|
688
|
+
extend_buffer: bool,
|
|
689
|
+
local_init_rb: bool | None,
|
|
690
|
+
postproc: Callable | None,
|
|
691
|
+
split_trajs: bool | None,
|
|
692
|
+
return_same_td: bool,
|
|
693
|
+
use_buffers: bool | None,
|
|
694
|
+
) -> None:
|
|
695
|
+
"""Set up replay buffer configuration and validate compatibility."""
|
|
696
|
+
self.replay_buffer = replay_buffer
|
|
697
|
+
self.extend_buffer = extend_buffer
|
|
698
|
+
|
|
699
|
+
# Handle local_init_rb deprecation
|
|
700
|
+
if local_init_rb is None:
|
|
701
|
+
local_init_rb = False
|
|
702
|
+
if replay_buffer is not None and not local_init_rb:
|
|
703
|
+
warnings.warn(
|
|
704
|
+
"local_init_rb=False is deprecated and will be removed in v0.12. "
|
|
705
|
+
"The new storage-level initialization provides better performance.",
|
|
706
|
+
FutureWarning,
|
|
707
|
+
)
|
|
708
|
+
self.local_init_rb = local_init_rb
|
|
709
|
+
|
|
710
|
+
# Validate replay buffer compatibility
|
|
711
|
+
if self.replay_buffer is not None and not self._ignore_rb:
|
|
712
|
+
if postproc is not None and not self.extend_buffer:
|
|
713
|
+
raise TypeError(
|
|
714
|
+
"postproc must be None when a replay buffer is passed, or extend_buffer must be set to True."
|
|
715
|
+
)
|
|
716
|
+
if split_trajs not in (None, False) and not self.extend_buffer:
|
|
717
|
+
raise TypeError(
|
|
718
|
+
"split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True."
|
|
719
|
+
)
|
|
720
|
+
if return_same_td:
|
|
721
|
+
raise TypeError(
|
|
722
|
+
"return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True."
|
|
723
|
+
)
|
|
724
|
+
if use_buffers:
|
|
725
|
+
raise TypeError("replay_buffer is exclusive with use_buffers.")
|
|
726
|
+
|
|
727
|
+
if use_buffers is None:
|
|
728
|
+
use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None
|
|
729
|
+
self._use_buffers = use_buffers
|
|
730
|
+
|
|
731
|
+
def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None:
|
|
732
|
+
"""Set up policy, wrapped policy, and extract weights."""
|
|
733
|
+
# Store weak reference to original policy before any transformations
|
|
734
|
+
# This allows update_policy_weights_ to sync from the original when no scheme is configured
|
|
735
|
+
if isinstance(policy, nn.Module):
|
|
736
|
+
self._orig_policy_ref = weakref.ref(policy)
|
|
737
|
+
else:
|
|
738
|
+
self._orig_policy_ref = None
|
|
739
|
+
|
|
740
|
+
# Check if policy has meta-device parameters (sent from weight sync schemes)
|
|
741
|
+
# In that case, skip device placement - weights will come from the receiver
|
|
742
|
+
has_meta_params = False
|
|
743
|
+
if isinstance(policy, nn.Module):
|
|
744
|
+
for p in policy.parameters():
|
|
745
|
+
if p.device.type == "meta":
|
|
746
|
+
has_meta_params = True
|
|
747
|
+
break
|
|
748
|
+
|
|
749
|
+
if has_meta_params:
|
|
750
|
+
# Policy has meta params - sent from weight sync schemes
|
|
751
|
+
# Skip device placement, weights will come from receiver
|
|
752
|
+
# Keep policy on meta device until weights are loaded
|
|
753
|
+
if not self.trust_policy:
|
|
754
|
+
self.policy = policy
|
|
755
|
+
env = getattr(self, "env", None)
|
|
756
|
+
try:
|
|
757
|
+
wrapped_policy = _make_compatible_policy(
|
|
758
|
+
policy=policy,
|
|
759
|
+
observation_spec=getattr(env, "observation_spec", None),
|
|
760
|
+
env=self.env,
|
|
761
|
+
)
|
|
762
|
+
except (TypeError, AttributeError, ValueError) as err:
|
|
763
|
+
raise TypeError(
|
|
764
|
+
"Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details."
|
|
765
|
+
) from err
|
|
766
|
+
self._wrapped_policy = wrapped_policy
|
|
767
|
+
else:
|
|
768
|
+
self.policy = self._wrapped_policy = policy
|
|
769
|
+
|
|
770
|
+
# Auto-configure exploration modules if needed (e.g. spec=None)
|
|
771
|
+
if isinstance(self.policy, nn.Module):
|
|
772
|
+
set_exploration_modules_spec_from_env(self.policy, self.env)
|
|
773
|
+
|
|
774
|
+
# For meta-parameter policies, keep the internal (worker-side) policy
|
|
775
|
+
# as the reference for collector state_dict / load_state_dict.
|
|
776
|
+
if isinstance(self.policy, nn.Module):
|
|
777
|
+
self._policy_w_state_dict = self.policy
|
|
778
|
+
|
|
779
|
+
# Don't extract weights yet - they're on meta device (empty)
|
|
780
|
+
self.policy_weights = TensorDict()
|
|
781
|
+
self.get_weights_fn = None
|
|
782
|
+
else:
|
|
783
|
+
# Normal path: move policy to correct device
|
|
784
|
+
policy, self.get_weights_fn = self._get_policy_and_device(policy=policy)
|
|
785
|
+
|
|
786
|
+
if not self.trust_policy:
|
|
787
|
+
self.policy = policy
|
|
788
|
+
env = getattr(self, "env", None)
|
|
789
|
+
try:
|
|
790
|
+
wrapped_policy = _make_compatible_policy(
|
|
791
|
+
policy=policy,
|
|
792
|
+
observation_spec=getattr(env, "observation_spec", None),
|
|
793
|
+
env=self.env,
|
|
794
|
+
)
|
|
795
|
+
except (TypeError, AttributeError, ValueError) as err:
|
|
796
|
+
raise TypeError(
|
|
797
|
+
"Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details."
|
|
798
|
+
) from err
|
|
799
|
+
self._wrapped_policy = wrapped_policy
|
|
800
|
+
else:
|
|
801
|
+
self.policy = self._wrapped_policy = policy
|
|
802
|
+
|
|
803
|
+
# Auto-configure exploration modules if needed (e.g. spec=None)
|
|
804
|
+
if isinstance(self.policy, nn.Module):
|
|
805
|
+
set_exploration_modules_spec_from_env(self.policy, self.env)
|
|
806
|
+
|
|
807
|
+
# Use the internal, unwrapped policy (cast to the correct device) as the
|
|
808
|
+
# reference for state_dict / load_state_dict and legacy weight extractors.
|
|
809
|
+
if isinstance(self.policy, nn.Module):
|
|
810
|
+
self._policy_w_state_dict = self.policy
|
|
811
|
+
|
|
812
|
+
# Extract policy weights from the uncompiled wrapped policy
|
|
813
|
+
# Access _wrapped_policy_uncompiled directly to avoid triggering compilation.
|
|
814
|
+
if isinstance(self._wrapped_policy_uncompiled, nn.Module):
|
|
815
|
+
self.policy_weights = TensorDict.from_module(
|
|
816
|
+
self._wrapped_policy_uncompiled, as_module=True
|
|
817
|
+
).data
|
|
818
|
+
else:
|
|
819
|
+
self.policy_weights = TensorDict()
|
|
820
|
+
|
|
821
|
+
# If policy doesn't have meta params, compile immediately
|
|
822
|
+
# Otherwise, defer until first use (after weights are loaded)
|
|
823
|
+
if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy):
|
|
824
|
+
self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy(
|
|
825
|
+
self._wrapped_policy_uncompiled
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
def _compile_wrapped_policy(self, policy):
|
|
829
|
+
"""Apply compilation and/or cudagraph to a policy."""
|
|
830
|
+
if self.compiled_policy:
|
|
831
|
+
policy = compile_with_warmup(policy, **self.compiled_policy_kwargs)
|
|
832
|
+
if self.cudagraphed_policy:
|
|
833
|
+
policy = CudaGraphModule(
|
|
834
|
+
policy,
|
|
835
|
+
in_keys=[],
|
|
836
|
+
out_keys=[],
|
|
837
|
+
device=self.policy_device,
|
|
838
|
+
**self.cudagraphed_policy_kwargs,
|
|
839
|
+
)
|
|
840
|
+
return policy
|
|
841
|
+
|
|
842
|
+
@property
|
|
843
|
+
def _wrapped_policy(self):
|
|
844
|
+
"""Returns the compiled policy, compiling it lazily if needed."""
|
|
845
|
+
if (policy := self._wrapped_policy_maybe_compiled) is None:
|
|
846
|
+
if self.compiled_policy or self.cudagraphed_policy:
|
|
847
|
+
policy = (
|
|
848
|
+
self._wrapped_policy_maybe_compiled
|
|
849
|
+
) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled)
|
|
850
|
+
else:
|
|
851
|
+
policy = (
|
|
852
|
+
self._wrapped_policy_maybe_compiled
|
|
853
|
+
) = self._wrapped_policy_uncompiled
|
|
854
|
+
return policy
|
|
855
|
+
|
|
856
|
+
@property
|
|
857
|
+
def _orig_policy(self):
|
|
858
|
+
"""Returns the original policy passed to the collector, if still alive."""
|
|
859
|
+
if self._orig_policy_ref is not None:
|
|
860
|
+
return self._orig_policy_ref()
|
|
861
|
+
return None
|
|
862
|
+
|
|
863
|
+
@_wrapped_policy.setter
|
|
864
|
+
def _wrapped_policy(self, value):
|
|
865
|
+
"""Allow setting the wrapped policy during initialization."""
|
|
866
|
+
self._wrapped_policy_uncompiled = value
|
|
867
|
+
self._wrapped_policy_maybe_compiled = None
|
|
868
|
+
|
|
869
|
+
def _apply_env_device(self) -> None:
|
|
870
|
+
"""Apply device to environment if specified."""
|
|
871
|
+
if self.env_device:
|
|
872
|
+
self.env: EnvBase = self.env.to(self.env_device)
|
|
873
|
+
elif self.env.device is not None:
|
|
874
|
+
# Use the device of the env if none was provided
|
|
875
|
+
self.env_device = self.env.device
|
|
876
|
+
|
|
877
|
+
# Check if we need to cast to env device
|
|
878
|
+
self._cast_to_env_device = self._cast_to_policy_device or (
|
|
879
|
+
self.env.device != self.storing_device
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None:
|
|
883
|
+
"""Set up maximum frames per trajectory and add StepCounter if needed."""
|
|
884
|
+
self.max_frames_per_traj = (
|
|
885
|
+
int(max_frames_per_traj) if max_frames_per_traj is not None else 0
|
|
886
|
+
)
|
|
887
|
+
if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0:
|
|
888
|
+
# Check that there is no StepCounter yet
|
|
889
|
+
for key in self.env.output_spec.keys(True, True):
|
|
890
|
+
if isinstance(key, str):
|
|
891
|
+
key = (key,)
|
|
892
|
+
if "step_count" in key:
|
|
893
|
+
raise ValueError(
|
|
894
|
+
"A 'step_count' key is already present in the environment "
|
|
895
|
+
"and the 'max_frames_per_traj' argument may conflict with "
|
|
896
|
+
"a 'StepCounter' that has already been set. "
|
|
897
|
+
"Possible solutions: Set max_frames_per_traj to 0 or "
|
|
898
|
+
"remove the StepCounter limit from the environment transforms."
|
|
899
|
+
)
|
|
900
|
+
self.env = TransformedEnv(
|
|
901
|
+
self.env, StepCounter(max_steps=self.max_frames_per_traj)
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None:
|
|
905
|
+
"""Validate and set total frames."""
|
|
906
|
+
if total_frames is None or total_frames < 0:
|
|
907
|
+
total_frames = float("inf")
|
|
908
|
+
else:
|
|
909
|
+
remainder = total_frames % frames_per_batch
|
|
910
|
+
if remainder != 0 and RL_WARNINGS:
|
|
911
|
+
warnings.warn(
|
|
912
|
+
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
|
|
913
|
+
f"This means {frames_per_batch - remainder} additional frames will be collected."
|
|
914
|
+
"To silence this message, set the environment variable RL_WARNINGS to False."
|
|
915
|
+
)
|
|
916
|
+
self.total_frames = (
|
|
917
|
+
int(total_frames) if total_frames != float("inf") else total_frames
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
def _setup_init_random_frames(
|
|
921
|
+
self, init_random_frames: int | None, frames_per_batch: int
|
|
922
|
+
) -> None:
|
|
923
|
+
"""Set up initial random frames."""
|
|
924
|
+
self.init_random_frames = (
|
|
925
|
+
int(init_random_frames) if init_random_frames not in (None, -1) else 0
|
|
926
|
+
)
|
|
927
|
+
if (
|
|
928
|
+
init_random_frames not in (-1, None, 0)
|
|
929
|
+
and init_random_frames % frames_per_batch != 0
|
|
930
|
+
and RL_WARNINGS
|
|
931
|
+
):
|
|
932
|
+
warnings.warn(
|
|
933
|
+
f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), "
|
|
934
|
+
f" this results in more init_random_frames than requested"
|
|
935
|
+
f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})."
|
|
936
|
+
"To silence this message, set the environment variable RL_WARNINGS to False."
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
def _setup_postproc(self, postproc: Callable | None) -> None:
|
|
940
|
+
"""Set up post-processing transform."""
|
|
941
|
+
self.postproc = postproc
|
|
942
|
+
if (
|
|
943
|
+
self.postproc is not None
|
|
944
|
+
and hasattr(self.postproc, "to")
|
|
945
|
+
and self.storing_device
|
|
946
|
+
):
|
|
947
|
+
postproc = self.postproc.to(self.storing_device)
|
|
948
|
+
if postproc is not self.postproc and postproc is not None:
|
|
949
|
+
self.postproc = postproc
|
|
950
|
+
|
|
951
|
+
def _setup_frames_per_batch(self, frames_per_batch: int) -> None:
|
|
952
|
+
"""Calculate and validate frames per batch."""
|
|
953
|
+
if frames_per_batch % self.n_env != 0 and RL_WARNINGS:
|
|
954
|
+
warnings.warn(
|
|
955
|
+
f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), "
|
|
956
|
+
f" this results in more frames_per_batch per iteration that requested"
|
|
957
|
+
f" ({-(-frames_per_batch // self.n_env) * self.n_env}). "
|
|
958
|
+
"To silence this message, set the environment variable RL_WARNINGS to False."
|
|
959
|
+
)
|
|
960
|
+
self.frames_per_batch = -(-frames_per_batch // self.n_env)
|
|
961
|
+
self.requested_frames_per_batch = self.frames_per_batch * self.n_env
|
|
962
|
+
|
|
963
|
+
def _setup_weight_sync(
|
|
964
|
+
self,
|
|
965
|
+
weight_updater: WeightUpdaterBase | Callable | None,
|
|
966
|
+
weight_sync_schemes: dict[str, WeightSyncScheme] | None,
|
|
967
|
+
) -> None:
|
|
968
|
+
"""Set up weight synchronization system."""
|
|
969
|
+
if weight_sync_schemes is not None:
|
|
970
|
+
# Use new simplified weight synchronization system
|
|
971
|
+
self._weight_sync_schemes = weight_sync_schemes
|
|
972
|
+
# Initialize and synchronize schemes that need sender-side setup
|
|
973
|
+
# (e.g., RayModuleTransformScheme for updating transforms in the env)
|
|
974
|
+
for model_id, scheme in weight_sync_schemes.items():
|
|
975
|
+
if not scheme.initialized_on_sender:
|
|
976
|
+
scheme.init_on_sender(model_id=model_id, context=self)
|
|
977
|
+
if not scheme.synchronized_on_sender:
|
|
978
|
+
scheme.connect()
|
|
979
|
+
self.weight_updater = None # Don't use legacy system
|
|
980
|
+
elif weight_updater is not None:
|
|
981
|
+
# Use legacy weight updater system if explicitly provided
|
|
982
|
+
if not isinstance(weight_updater, WeightUpdaterBase):
|
|
983
|
+
if callable(weight_updater):
|
|
984
|
+
weight_updater = weight_updater()
|
|
985
|
+
else:
|
|
986
|
+
raise TypeError(
|
|
987
|
+
f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead."
|
|
988
|
+
)
|
|
989
|
+
warnings.warn(
|
|
990
|
+
"Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. "
|
|
991
|
+
"This will be removed in a future version.",
|
|
992
|
+
DeprecationWarning,
|
|
993
|
+
stacklevel=2,
|
|
994
|
+
)
|
|
995
|
+
self.weight_updater = weight_updater
|
|
996
|
+
self._weight_sync_schemes = None
|
|
997
|
+
else:
|
|
998
|
+
# No weight sync needed for single-process collectors
|
|
999
|
+
self.weight_updater = None
|
|
1000
|
+
self._weight_sync_schemes = None
|
|
1001
|
+
|
|
1002
|
+
@property
|
|
1003
|
+
def _traj_pool(self):
|
|
1004
|
+
pool = getattr(self, "_traj_pool_val", None)
|
|
1005
|
+
if pool is None:
|
|
1006
|
+
pool = self._traj_pool_val = _TrajectoryPool()
|
|
1007
|
+
return pool
|
|
1008
|
+
|
|
1009
|
+
def _make_shuttle(self):
|
|
1010
|
+
# Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env
|
|
1011
|
+
with torch.no_grad():
|
|
1012
|
+
self._carrier = self.env.reset()
|
|
1013
|
+
if self.policy_device != self.env_device or self.env_device is None:
|
|
1014
|
+
self._shuttle_has_no_device = True
|
|
1015
|
+
self._carrier.clear_device_()
|
|
1016
|
+
else:
|
|
1017
|
+
self._shuttle_has_no_device = False
|
|
1018
|
+
|
|
1019
|
+
traj_ids = self._traj_pool.get_traj_and_increment(
|
|
1020
|
+
self.n_env, device=self.storing_device
|
|
1021
|
+
).view(self.env.batch_size)
|
|
1022
|
+
self._carrier.set(
|
|
1023
|
+
("collector", "traj_ids"),
|
|
1024
|
+
traj_ids,
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
def _maybe_make_final_rollout(self, make_rollout: bool):
|
|
1028
|
+
if make_rollout:
|
|
1029
|
+
with torch.no_grad():
|
|
1030
|
+
self._final_rollout = self.env.fake_tensordict()
|
|
1031
|
+
|
|
1032
|
+
# If storing device is not None, we use this to cast the storage.
|
|
1033
|
+
# If it is None and the env and policy are on the same device,
|
|
1034
|
+
# the storing device is already the same as those, so we don't need
|
|
1035
|
+
# to consider this use case.
|
|
1036
|
+
# In all other cases, we can't really put a device on the storage,
|
|
1037
|
+
# since at least one data source has a device that is not clear.
|
|
1038
|
+
if self.storing_device:
|
|
1039
|
+
self._final_rollout = self._final_rollout.to(
|
|
1040
|
+
self.storing_device, non_blocking=True
|
|
1041
|
+
)
|
|
1042
|
+
else:
|
|
1043
|
+
# erase all devices
|
|
1044
|
+
self._final_rollout.clear_device_()
|
|
1045
|
+
|
|
1046
|
+
# Check if policy has meta-device parameters (not yet initialized)
|
|
1047
|
+
has_meta_params = False
|
|
1048
|
+
if hasattr(self, "_wrapped_policy_uncompiled") and isinstance(
|
|
1049
|
+
self._wrapped_policy_uncompiled, nn.Module
|
|
1050
|
+
):
|
|
1051
|
+
for p in self._wrapped_policy_uncompiled.parameters():
|
|
1052
|
+
if p.device.type == "meta":
|
|
1053
|
+
has_meta_params = True
|
|
1054
|
+
break
|
|
1055
|
+
|
|
1056
|
+
# If the policy has a valid spec, we use it
|
|
1057
|
+
self._policy_output_keys = set()
|
|
1058
|
+
_policy_to_check = (
|
|
1059
|
+
self._wrapped_policy_uncompiled if has_meta_params else self._wrapped_policy
|
|
1060
|
+
)
|
|
1061
|
+
_has_spec = hasattr(_policy_to_check, "spec")
|
|
1062
|
+
_spec_not_none = False
|
|
1063
|
+
_all_values_not_none = False
|
|
1064
|
+
if _has_spec:
|
|
1065
|
+
_spec = _policy_to_check.spec
|
|
1066
|
+
_spec_not_none = _spec is not None
|
|
1067
|
+
if _spec_not_none:
|
|
1068
|
+
_all_values_not_none = all(
|
|
1069
|
+
v is not None for v in _spec.values(True, True)
|
|
1070
|
+
)
|
|
1071
|
+
_condition = (
|
|
1072
|
+
make_rollout and _has_spec and _spec_not_none and _all_values_not_none
|
|
1073
|
+
)
|
|
1074
|
+
if _condition:
|
|
1075
|
+
if any(
|
|
1076
|
+
key not in self._final_rollout.keys(isinstance(key, tuple))
|
|
1077
|
+
for key in (
|
|
1078
|
+
self._wrapped_policy_uncompiled
|
|
1079
|
+
if has_meta_params
|
|
1080
|
+
else self._wrapped_policy
|
|
1081
|
+
).spec.keys(True, True)
|
|
1082
|
+
):
|
|
1083
|
+
# if policy spec is non-empty, all the values are not None and the keys
|
|
1084
|
+
# match the out_keys we assume the user has given all relevant information
|
|
1085
|
+
# the policy could have more keys than the env:
|
|
1086
|
+
policy_spec = (
|
|
1087
|
+
self._wrapped_policy_uncompiled
|
|
1088
|
+
if has_meta_params
|
|
1089
|
+
else self._wrapped_policy
|
|
1090
|
+
).spec
|
|
1091
|
+
if policy_spec.ndim < self._final_rollout.ndim:
|
|
1092
|
+
policy_spec = policy_spec.expand(self._final_rollout.shape)
|
|
1093
|
+
for key, spec in policy_spec.items(True, True):
|
|
1094
|
+
self._policy_output_keys.add(key)
|
|
1095
|
+
if key in self._final_rollout.keys(True):
|
|
1096
|
+
continue
|
|
1097
|
+
self._final_rollout.set(key, spec.zero())
|
|
1098
|
+
elif (
|
|
1099
|
+
not make_rollout
|
|
1100
|
+
and hasattr(
|
|
1101
|
+
self._wrapped_policy_uncompiled
|
|
1102
|
+
if has_meta_params
|
|
1103
|
+
else self._wrapped_policy,
|
|
1104
|
+
"out_keys",
|
|
1105
|
+
)
|
|
1106
|
+
and (
|
|
1107
|
+
self._wrapped_policy_uncompiled
|
|
1108
|
+
if has_meta_params
|
|
1109
|
+
else self._wrapped_policy
|
|
1110
|
+
).out_keys
|
|
1111
|
+
):
|
|
1112
|
+
self._policy_output_keys = list(
|
|
1113
|
+
(
|
|
1114
|
+
self._wrapped_policy_uncompiled
|
|
1115
|
+
if has_meta_params
|
|
1116
|
+
else self._wrapped_policy
|
|
1117
|
+
).out_keys
|
|
1118
|
+
)
|
|
1119
|
+
elif has_meta_params:
|
|
1120
|
+
# Policy has meta params and no spec/out_keys - defer initialization
|
|
1121
|
+
# Mark that we need to initialize later when weights are loaded
|
|
1122
|
+
self._policy_output_keys = set()
|
|
1123
|
+
if make_rollout:
|
|
1124
|
+
# We'll populate keys on first actual rollout after weights are loaded
|
|
1125
|
+
self._final_rollout_needs_init = True
|
|
1126
|
+
else:
|
|
1127
|
+
if make_rollout:
|
|
1128
|
+
# otherwise, we perform a small number of steps with the policy to
|
|
1129
|
+
# determine the relevant keys with which to pre-populate _final_rollout.
|
|
1130
|
+
# This is the safest thing to do if the spec has None fields or if there is
|
|
1131
|
+
# no spec at all.
|
|
1132
|
+
# See #505 for additional context.
|
|
1133
|
+
self._final_rollout.update(self._carrier.copy())
|
|
1134
|
+
with torch.no_grad():
|
|
1135
|
+
policy_input = self._carrier.copy()
|
|
1136
|
+
if self.policy_device:
|
|
1137
|
+
policy_input = policy_input.to(self.policy_device)
|
|
1138
|
+
# we cast to policy device, we'll deal with the device later
|
|
1139
|
+
policy_input_copy = policy_input.copy()
|
|
1140
|
+
policy_input_clone = (
|
|
1141
|
+
policy_input.clone()
|
|
1142
|
+
) # to test if values have changed in-place
|
|
1143
|
+
if self.compiled_policy:
|
|
1144
|
+
cudagraph_mark_step_begin()
|
|
1145
|
+
policy_output = self._wrapped_policy(policy_input)
|
|
1146
|
+
|
|
1147
|
+
# check that we don't have exclusive keys, because they don't appear in keys
|
|
1148
|
+
def check_exclusive(val):
|
|
1149
|
+
if (
|
|
1150
|
+
isinstance(val, LazyStackedTensorDict)
|
|
1151
|
+
and val._has_exclusive_keys
|
|
1152
|
+
):
|
|
1153
|
+
raise RuntimeError(
|
|
1154
|
+
"LazyStackedTensorDict with exclusive keys are not permitted in collectors. "
|
|
1155
|
+
"Consider using a placeholder for missing keys."
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
policy_output._fast_apply(
|
|
1159
|
+
check_exclusive, call_on_nested=True, filter_empty=True
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
# Use apply, because it works well with lazy stacks
|
|
1163
|
+
# Edge-case of this approach: the policy may change the values in-place and only by a tiny bit
|
|
1164
|
+
# or occasionally. In these cases, the keys will be missed (we can't detect if the policy has
|
|
1165
|
+
# changed them here).
|
|
1166
|
+
# This will cause a failure to update entries when policy and env device mismatch and
|
|
1167
|
+
# casting is necessary.
|
|
1168
|
+
def filter_policy(name, value_output, value_input, value_input_clone):
|
|
1169
|
+
if (value_input is None) or (
|
|
1170
|
+
(value_output is not value_input)
|
|
1171
|
+
and (
|
|
1172
|
+
value_output.device != value_input_clone.device
|
|
1173
|
+
or ~torch.isclose(value_output, value_input_clone).any()
|
|
1174
|
+
)
|
|
1175
|
+
):
|
|
1176
|
+
return value_output
|
|
1177
|
+
|
|
1178
|
+
filtered_policy_output = policy_output.apply(
|
|
1179
|
+
filter_policy,
|
|
1180
|
+
policy_input_copy,
|
|
1181
|
+
policy_input_clone,
|
|
1182
|
+
default=None,
|
|
1183
|
+
filter_empty=True,
|
|
1184
|
+
named=True,
|
|
1185
|
+
)
|
|
1186
|
+
self._policy_output_keys = list(
|
|
1187
|
+
self._policy_output_keys.union(
|
|
1188
|
+
set(filtered_policy_output.keys(True, True))
|
|
1189
|
+
)
|
|
1190
|
+
)
|
|
1191
|
+
if make_rollout:
|
|
1192
|
+
self._final_rollout.update(
|
|
1193
|
+
policy_output.select(*self._policy_output_keys)
|
|
1194
|
+
)
|
|
1195
|
+
del filtered_policy_output, policy_output, policy_input
|
|
1196
|
+
|
|
1197
|
+
_env_output_keys = []
|
|
1198
|
+
for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]:
|
|
1199
|
+
_env_output_keys += list(self.env.output_spec[spec].keys(True, True))
|
|
1200
|
+
self._env_output_keys = _env_output_keys
|
|
1201
|
+
if make_rollout:
|
|
1202
|
+
self._final_rollout = (
|
|
1203
|
+
self._final_rollout.unsqueeze(-1)
|
|
1204
|
+
.expand(*self.env.batch_size, self.frames_per_batch)
|
|
1205
|
+
.clone()
|
|
1206
|
+
.zero_()
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
# in addition to outputs of the policy, we add traj_ids to
|
|
1210
|
+
# _final_rollout which will be collected during rollout
|
|
1211
|
+
self._final_rollout.set(
|
|
1212
|
+
("collector", "traj_ids"),
|
|
1213
|
+
torch.zeros(
|
|
1214
|
+
*self._final_rollout.batch_size,
|
|
1215
|
+
dtype=torch.int64,
|
|
1216
|
+
device=self.storing_device,
|
|
1217
|
+
),
|
|
1218
|
+
)
|
|
1219
|
+
self._final_rollout.refine_names(..., "time")
|
|
1220
|
+
|
|
1221
|
+
def _set_truncated_keys(self):
|
|
1222
|
+
self._truncated_keys = []
|
|
1223
|
+
if self.set_truncated:
|
|
1224
|
+
if not any(_ends_with(key, "truncated") for key in self.env.done_keys):
|
|
1225
|
+
raise RuntimeError(
|
|
1226
|
+
"set_truncated was set to True but no truncated key could be found "
|
|
1227
|
+
"in the environment. Make sure the truncated keys are properly set using "
|
|
1228
|
+
"`env.add_truncated_keys()` before passing the env to the collector."
|
|
1229
|
+
)
|
|
1230
|
+
self._truncated_keys = [
|
|
1231
|
+
key for key in self.env.done_keys if _ends_with(key, "truncated")
|
|
1232
|
+
]
|
|
1233
|
+
|
|
1234
|
+
@classmethod
|
|
1235
|
+
def _get_devices(
|
|
1236
|
+
cls,
|
|
1237
|
+
*,
|
|
1238
|
+
storing_device: torch.device,
|
|
1239
|
+
policy_device: torch.device,
|
|
1240
|
+
env_device: torch.device,
|
|
1241
|
+
device: torch.device,
|
|
1242
|
+
):
|
|
1243
|
+
device = _make_ordinal_device(torch.device(device) if device else device)
|
|
1244
|
+
storing_device = _make_ordinal_device(
|
|
1245
|
+
torch.device(storing_device) if storing_device else device
|
|
1246
|
+
)
|
|
1247
|
+
policy_device = _make_ordinal_device(
|
|
1248
|
+
torch.device(policy_device) if policy_device else device
|
|
1249
|
+
)
|
|
1250
|
+
env_device = _make_ordinal_device(
|
|
1251
|
+
torch.device(env_device) if env_device else device
|
|
1252
|
+
)
|
|
1253
|
+
if storing_device is None and (env_device == policy_device):
|
|
1254
|
+
storing_device = env_device
|
|
1255
|
+
return storing_device, policy_device, env_device
|
|
1256
|
+
|
|
1257
|
+
# for RPC
|
|
1258
|
+
def next(self):
|
|
1259
|
+
return super().next()
|
|
1260
|
+
|
|
1261
|
+
# for RPC
|
|
1262
|
+
def update_policy_weights_(
|
|
1263
|
+
self,
|
|
1264
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
1265
|
+
*,
|
|
1266
|
+
worker_ids: int | list[int] | torch.device | list[torch.device] | None = None,
|
|
1267
|
+
**kwargs,
|
|
1268
|
+
) -> None:
|
|
1269
|
+
if "policy_weights" in kwargs:
|
|
1270
|
+
warnings.warn(
|
|
1271
|
+
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
|
|
1272
|
+
DeprecationWarning,
|
|
1273
|
+
)
|
|
1274
|
+
policy_or_weights = kwargs.pop("policy_weights")
|
|
1275
|
+
|
|
1276
|
+
super().update_policy_weights_(
|
|
1277
|
+
policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
def _maybe_fallback_update(
|
|
1281
|
+
self,
|
|
1282
|
+
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
|
|
1283
|
+
*,
|
|
1284
|
+
model_id: str | None = None,
|
|
1285
|
+
) -> None:
|
|
1286
|
+
"""Copy weights from original policy to internal policy when no scheme configured."""
|
|
1287
|
+
if model_id is not None and model_id != "policy":
|
|
1288
|
+
return
|
|
1289
|
+
|
|
1290
|
+
# Get source weights - either from argument or from original policy
|
|
1291
|
+
if policy_or_weights is not None:
|
|
1292
|
+
weights = self._extract_weights_if_needed(policy_or_weights, "policy")
|
|
1293
|
+
elif self._orig_policy is not None:
|
|
1294
|
+
weights = TensorDict.from_module(self._orig_policy)
|
|
1295
|
+
else:
|
|
1296
|
+
return
|
|
1297
|
+
|
|
1298
|
+
# Apply to internal policy
|
|
1299
|
+
if (
|
|
1300
|
+
hasattr(self, "_policy_w_state_dict")
|
|
1301
|
+
and self._policy_w_state_dict is not None
|
|
1302
|
+
):
|
|
1303
|
+
TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data)
|
|
1304
|
+
|
|
1305
|
+
def set_seed(self, seed: int, static_seed: bool = False) -> int:
|
|
1306
|
+
"""Sets the seeds of the environments stored in the DataCollector.
|
|
1307
|
+
|
|
1308
|
+
Args:
|
|
1309
|
+
seed (int): integer representing the seed to be used for the environment.
|
|
1310
|
+
static_seed(bool, optional): if ``True``, the seed is not incremented.
|
|
1311
|
+
Defaults to False
|
|
1312
|
+
|
|
1313
|
+
Returns:
|
|
1314
|
+
Output seed. This is useful when more than one environment is contained in the DataCollector, as the
|
|
1315
|
+
seed will be incremented for each of these. The resulting seed is the seed of the last environment.
|
|
1316
|
+
|
|
1317
|
+
Examples:
|
|
1318
|
+
>>> from torchrl.envs import ParallelEnv
|
|
1319
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
1320
|
+
>>> from tensordict.nn import TensorDictModule
|
|
1321
|
+
>>> from torch import nn
|
|
1322
|
+
>>> env_fn = lambda: GymEnv("Pendulum-v1")
|
|
1323
|
+
>>> env_fn_parallel = ParallelEnv(6, env_fn)
|
|
1324
|
+
>>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"])
|
|
1325
|
+
>>> collector = Collector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100)
|
|
1326
|
+
>>> out_seed = collector.set_seed(1) # out_seed = 6
|
|
1327
|
+
|
|
1328
|
+
"""
|
|
1329
|
+
out = self.env.set_seed(seed, static_seed=static_seed)
|
|
1330
|
+
return out
|
|
1331
|
+
|
|
1332
|
+
def _increment_frames(self, numel):
|
|
1333
|
+
self._frames += numel
|
|
1334
|
+
completed = self._frames >= self.total_frames
|
|
1335
|
+
if completed:
|
|
1336
|
+
self.env.close()
|
|
1337
|
+
return completed
|
|
1338
|
+
|
|
1339
|
+
def iterator(self) -> Iterator[TensorDictBase]:
|
|
1340
|
+
"""Iterates through the DataCollector.
|
|
1341
|
+
|
|
1342
|
+
Yields: TensorDictBase objects containing (chunks of) trajectories
|
|
1343
|
+
|
|
1344
|
+
"""
|
|
1345
|
+
if (
|
|
1346
|
+
not self.no_cuda_sync
|
|
1347
|
+
and self.storing_device
|
|
1348
|
+
and self.storing_device.type == "cuda"
|
|
1349
|
+
):
|
|
1350
|
+
stream = torch.cuda.Stream(self.storing_device, priority=-1)
|
|
1351
|
+
event = stream.record_event()
|
|
1352
|
+
streams = [stream]
|
|
1353
|
+
events = [event]
|
|
1354
|
+
elif not self.no_cuda_sync and self.storing_device is None:
|
|
1355
|
+
streams = []
|
|
1356
|
+
events = []
|
|
1357
|
+
# this way of checking cuda is robust to lazy stacks with mismatching shapes
|
|
1358
|
+
cuda_devices = set()
|
|
1359
|
+
|
|
1360
|
+
def cuda_check(tensor: torch.Tensor):
|
|
1361
|
+
if tensor.is_cuda:
|
|
1362
|
+
cuda_devices.add(tensor.device)
|
|
1363
|
+
|
|
1364
|
+
if not self._use_buffers:
|
|
1365
|
+
# This may be a bit dangerous as `torch.device("cuda")` may not have a precise
|
|
1366
|
+
# device associated, whereas `tensor.device` always has
|
|
1367
|
+
for spec in self.env.specs.values(True, True):
|
|
1368
|
+
if spec.device is not None and spec.device.type == "cuda":
|
|
1369
|
+
if ":" not in str(spec.device):
|
|
1370
|
+
raise RuntimeError(
|
|
1371
|
+
"A cuda spec did not have a device associated. Make sure to "
|
|
1372
|
+
"pass `'cuda:device_num'` to each spec device."
|
|
1373
|
+
)
|
|
1374
|
+
cuda_devices.add(spec.device)
|
|
1375
|
+
else:
|
|
1376
|
+
self._final_rollout.apply(cuda_check, filter_empty=True)
|
|
1377
|
+
for device in cuda_devices:
|
|
1378
|
+
streams.append(torch.cuda.Stream(device, priority=-1))
|
|
1379
|
+
events.append(streams[-1].record_event())
|
|
1380
|
+
else:
|
|
1381
|
+
streams = []
|
|
1382
|
+
events = []
|
|
1383
|
+
|
|
1384
|
+
# Set up profiler if configured
|
|
1385
|
+
profiler = None
|
|
1386
|
+
if self._profile_config is not None:
|
|
1387
|
+
profiler = _CollectorProfiler(self._profile_config)
|
|
1388
|
+
if profiler.is_active:
|
|
1389
|
+
profiler.start()
|
|
1390
|
+
|
|
1391
|
+
with contextlib.ExitStack() as stack:
|
|
1392
|
+
for stream in streams:
|
|
1393
|
+
stack.enter_context(torch.cuda.stream(stream))
|
|
1394
|
+
|
|
1395
|
+
while self._frames < self.total_frames:
|
|
1396
|
+
self._iter += 1
|
|
1397
|
+
|
|
1398
|
+
# Use profiler context if profiling is active
|
|
1399
|
+
profile_ctx = (
|
|
1400
|
+
profiler.profile_rollout()
|
|
1401
|
+
if profiler is not None and profiler.is_active
|
|
1402
|
+
else contextlib.nullcontext()
|
|
1403
|
+
)
|
|
1404
|
+
|
|
1405
|
+
with profile_ctx:
|
|
1406
|
+
tensordict_out = self.rollout()
|
|
1407
|
+
|
|
1408
|
+
# Step the profiler after each rollout
|
|
1409
|
+
if profiler is not None and profiler.is_active:
|
|
1410
|
+
profiler.step()
|
|
1411
|
+
|
|
1412
|
+
if tensordict_out is None:
|
|
1413
|
+
# if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out
|
|
1414
|
+
# frames are updated within the rollout function
|
|
1415
|
+
yield
|
|
1416
|
+
continue
|
|
1417
|
+
self._increment_frames(tensordict_out.numel())
|
|
1418
|
+
tensordict_out = self._postproc(tensordict_out)
|
|
1419
|
+
if self.return_same_td:
|
|
1420
|
+
# This is used with multiprocessed collectors to use the buffers
|
|
1421
|
+
# stored in the tensordict.
|
|
1422
|
+
if events:
|
|
1423
|
+
for event in events:
|
|
1424
|
+
event.record()
|
|
1425
|
+
event.synchronize()
|
|
1426
|
+
yield tensordict_out
|
|
1427
|
+
elif self.replay_buffer is not None and not self._ignore_rb:
|
|
1428
|
+
self.replay_buffer.extend(tensordict_out)
|
|
1429
|
+
yield
|
|
1430
|
+
else:
|
|
1431
|
+
# we must clone the values, as the tensordict is updated in-place.
|
|
1432
|
+
# otherwise the following code may break:
|
|
1433
|
+
# >>> for i, data in enumerate(collector):
|
|
1434
|
+
# >>> if i == 0:
|
|
1435
|
+
# >>> data0 = data
|
|
1436
|
+
# >>> elif i == 1:
|
|
1437
|
+
# >>> data1 = data
|
|
1438
|
+
# >>> else:
|
|
1439
|
+
# >>> break
|
|
1440
|
+
# >>> assert data0["done"] is not data1["done"]
|
|
1441
|
+
yield tensordict_out.clone()
|
|
1442
|
+
|
|
1443
|
+
# Stop profiler if it hasn't been stopped yet
|
|
1444
|
+
if profiler is not None and profiler.is_active:
|
|
1445
|
+
profiler.stop()
|
|
1446
|
+
|
|
1447
|
+
def start(self):
|
|
1448
|
+
"""Starts the collector in a separate thread for asynchronous data collection.
|
|
1449
|
+
|
|
1450
|
+
The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data
|
|
1451
|
+
collection from training, allowing your training loop to run independently of the data collection process.
|
|
1452
|
+
|
|
1453
|
+
Raises:
|
|
1454
|
+
RuntimeError: If no replay buffer is defined during the collector's initialization.
|
|
1455
|
+
|
|
1456
|
+
Example:
|
|
1457
|
+
>>> from torchrl.modules import RandomPolicy >>> >>> import time
|
|
1458
|
+
>>> from functools import partial
|
|
1459
|
+
>>>
|
|
1460
|
+
>>> import tqdm
|
|
1461
|
+
>>>
|
|
1462
|
+
>>> from torchrl.collectors import Collector
|
|
1463
|
+
>>> from torchrl.data import LazyTensorStorage, ReplayBuffer
|
|
1464
|
+
>>> from torchrl.envs import GymEnv, set_gym_backend
|
|
1465
|
+
>>> import ale_py
|
|
1466
|
+
>>>
|
|
1467
|
+
>>> # Set the gym backend to gymnasium
|
|
1468
|
+
>>> set_gym_backend("gymnasium").set()
|
|
1469
|
+
>>>
|
|
1470
|
+
>>> if __name__ == "__main__":
|
|
1471
|
+
... # Create a random policy for the Pong environment
|
|
1472
|
+
... env = GymEnv("ALE/Pong-v5")
|
|
1473
|
+
... policy = RandomPolicy(env.action_spec)
|
|
1474
|
+
...
|
|
1475
|
+
... # Initialize a shared replay buffer
|
|
1476
|
+
... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True)
|
|
1477
|
+
...
|
|
1478
|
+
... # Create a synchronous data collector
|
|
1479
|
+
... collector = Collector(
|
|
1480
|
+
... env,
|
|
1481
|
+
... policy=policy,
|
|
1482
|
+
... replay_buffer=rb,
|
|
1483
|
+
... frames_per_batch=256,
|
|
1484
|
+
... total_frames=-1,
|
|
1485
|
+
... )
|
|
1486
|
+
...
|
|
1487
|
+
... # Progress bar to track the number of collected frames
|
|
1488
|
+
... pbar = tqdm.tqdm(total=100_000)
|
|
1489
|
+
...
|
|
1490
|
+
... # Start the collector asynchronously
|
|
1491
|
+
... collector.start()
|
|
1492
|
+
...
|
|
1493
|
+
... # Track the write count of the replay buffer
|
|
1494
|
+
... prec_wc = 0
|
|
1495
|
+
... while True:
|
|
1496
|
+
... wc = rb.write_count
|
|
1497
|
+
... c = wc - prec_wc
|
|
1498
|
+
... prec_wc = wc
|
|
1499
|
+
...
|
|
1500
|
+
... # Update the progress bar
|
|
1501
|
+
... pbar.update(c)
|
|
1502
|
+
... pbar.set_description(f"Write Count: {rb.write_count}")
|
|
1503
|
+
...
|
|
1504
|
+
... # Check the write count every 0.5 seconds
|
|
1505
|
+
... time.sleep(0.5)
|
|
1506
|
+
...
|
|
1507
|
+
... # Stop when the desired number of frames is reached
|
|
1508
|
+
... if rb.write_count . 100_000:
|
|
1509
|
+
... break
|
|
1510
|
+
...
|
|
1511
|
+
... # Shut down the collector
|
|
1512
|
+
... collector.async_shutdown()
|
|
1513
|
+
"""
|
|
1514
|
+
if self.replay_buffer is None:
|
|
1515
|
+
raise RuntimeError("Replay buffer must be defined for execution.")
|
|
1516
|
+
if not self.is_running():
|
|
1517
|
+
self._stop = False
|
|
1518
|
+
self._thread = threading.Thread(target=self._run_iterator)
|
|
1519
|
+
self._thread.daemon = (
|
|
1520
|
+
True # So that the thread dies when the main program exits
|
|
1521
|
+
)
|
|
1522
|
+
self._thread.start()
|
|
1523
|
+
|
|
1524
|
+
def _run_iterator(self):
|
|
1525
|
+
for _ in self:
|
|
1526
|
+
if self._stop:
|
|
1527
|
+
return
|
|
1528
|
+
|
|
1529
|
+
def is_running(self):
|
|
1530
|
+
return hasattr(self, "_thread") and self._thread.is_alive()
|
|
1531
|
+
|
|
1532
|
+
def _should_use_random_frames(self) -> bool:
|
|
1533
|
+
"""Determine if random frames should be used instead of the policy.
|
|
1534
|
+
|
|
1535
|
+
When a replay buffer is provided, uses `replay_buffer.write_count` as the
|
|
1536
|
+
global step counter to support `.start()` mode where `_frames` isn't updated
|
|
1537
|
+
until after collection. Otherwise, uses the internal `_frames` counter.
|
|
1538
|
+
|
|
1539
|
+
Returns:
|
|
1540
|
+
bool: True if random frames should be used, False otherwise.
|
|
1541
|
+
"""
|
|
1542
|
+
if self.init_random_frames is None or self.init_random_frames <= 0:
|
|
1543
|
+
return False
|
|
1544
|
+
# Use replay_buffer.write_count when available for accurate counting in .start() mode
|
|
1545
|
+
if self.replay_buffer is not None:
|
|
1546
|
+
return self.replay_buffer.write_count < self.init_random_frames
|
|
1547
|
+
return self._frames < self.init_random_frames
|
|
1548
|
+
|
|
1549
|
+
def async_shutdown(
|
|
1550
|
+
self, timeout: float | None = None, close_env: bool = True
|
|
1551
|
+
) -> None:
|
|
1552
|
+
"""Finishes processes started by ray.init() during async execution."""
|
|
1553
|
+
self._stop = True
|
|
1554
|
+
if hasattr(self, "_thread") and self._thread.is_alive():
|
|
1555
|
+
self._thread.join(timeout=timeout)
|
|
1556
|
+
self.shutdown(close_env=close_env)
|
|
1557
|
+
|
|
1558
|
+
def _postproc(self, tensordict_out):
|
|
1559
|
+
if self.split_trajs:
|
|
1560
|
+
tensordict_out = split_trajectories(tensordict_out, prefix="collector")
|
|
1561
|
+
if self.postproc is not None:
|
|
1562
|
+
tensordict_out = self.postproc(tensordict_out)
|
|
1563
|
+
if self._exclude_private_keys:
|
|
1564
|
+
|
|
1565
|
+
def is_private(key):
|
|
1566
|
+
if isinstance(key, str) and key.startswith("_"):
|
|
1567
|
+
return True
|
|
1568
|
+
if isinstance(key, tuple) and any(_key.startswith("_") for _key in key):
|
|
1569
|
+
return True
|
|
1570
|
+
return False
|
|
1571
|
+
|
|
1572
|
+
excluded_keys = [
|
|
1573
|
+
key for key in tensordict_out.keys(True) if is_private(key)
|
|
1574
|
+
]
|
|
1575
|
+
tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True)
|
|
1576
|
+
return tensordict_out
|
|
1577
|
+
|
|
1578
|
+
def _update_traj_ids(self, env_output) -> None:
|
|
1579
|
+
# we can't use the reset keys because they're gone
|
|
1580
|
+
traj_sop = _aggregate_end_of_traj(
|
|
1581
|
+
env_output.get("next"), done_keys=self.env.done_keys
|
|
1582
|
+
)
|
|
1583
|
+
if traj_sop.any():
|
|
1584
|
+
device = self.storing_device
|
|
1585
|
+
|
|
1586
|
+
traj_ids = self._carrier.get(("collector", "traj_ids"))
|
|
1587
|
+
if device is not None:
|
|
1588
|
+
traj_ids = traj_ids.to(device)
|
|
1589
|
+
traj_sop = traj_sop.to(device)
|
|
1590
|
+
elif traj_sop.device != traj_ids.device:
|
|
1591
|
+
traj_sop = traj_sop.to(traj_ids.device)
|
|
1592
|
+
|
|
1593
|
+
pool = self._traj_pool
|
|
1594
|
+
new_traj = pool.get_traj_and_increment(
|
|
1595
|
+
traj_sop.sum(), device=traj_sop.device
|
|
1596
|
+
)
|
|
1597
|
+
traj_ids = traj_ids.masked_scatter(traj_sop, new_traj)
|
|
1598
|
+
self._carrier.set(("collector", "traj_ids"), traj_ids)
|
|
1599
|
+
|
|
1600
|
+
@torch.no_grad()
|
|
1601
|
+
def rollout(self) -> TensorDictBase:
|
|
1602
|
+
"""Computes a rollout in the environment using the provided policy.
|
|
1603
|
+
|
|
1604
|
+
Returns:
|
|
1605
|
+
TensorDictBase containing the computed rollout.
|
|
1606
|
+
|
|
1607
|
+
"""
|
|
1608
|
+
if self.reset_at_each_iter:
|
|
1609
|
+
self._carrier.update(self.env.reset())
|
|
1610
|
+
|
|
1611
|
+
# self._shuttle.fill_(("collector", "step_count"), 0)
|
|
1612
|
+
if self._use_buffers:
|
|
1613
|
+
self._final_rollout.fill_(("collector", "traj_ids"), -1)
|
|
1614
|
+
else:
|
|
1615
|
+
pass
|
|
1616
|
+
tensordicts = []
|
|
1617
|
+
with set_exploration_type(self.exploration_type):
|
|
1618
|
+
for t in range(self.frames_per_batch):
|
|
1619
|
+
if self._should_use_random_frames():
|
|
1620
|
+
self.env.rand_action(self._carrier)
|
|
1621
|
+
if (
|
|
1622
|
+
self.policy_device is not None
|
|
1623
|
+
and self.policy_device != self.env_device
|
|
1624
|
+
):
|
|
1625
|
+
# TODO: This may break with exclusive / ragged lazy stacks
|
|
1626
|
+
self._carrier.apply(
|
|
1627
|
+
lambda name, val: val.to(
|
|
1628
|
+
device=self.policy_device, non_blocking=True
|
|
1629
|
+
)
|
|
1630
|
+
if name in self._policy_output_keys
|
|
1631
|
+
else val,
|
|
1632
|
+
out=self._carrier,
|
|
1633
|
+
named=True,
|
|
1634
|
+
nested_keys=True,
|
|
1635
|
+
)
|
|
1636
|
+
else:
|
|
1637
|
+
if self._cast_to_policy_device:
|
|
1638
|
+
if self.policy_device is not None:
|
|
1639
|
+
# This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
|
|
1640
|
+
non_blocking = (
|
|
1641
|
+
not self.no_cuda_sync
|
|
1642
|
+
or self.policy_device.type == "cuda"
|
|
1643
|
+
)
|
|
1644
|
+
policy_input = self._carrier.to(
|
|
1645
|
+
self.policy_device,
|
|
1646
|
+
non_blocking=non_blocking,
|
|
1647
|
+
)
|
|
1648
|
+
if not self.no_cuda_sync:
|
|
1649
|
+
self._sync_policy()
|
|
1650
|
+
elif self.policy_device is None:
|
|
1651
|
+
# we know the tensordict has a device otherwise we would not be here
|
|
1652
|
+
# we can pass this, clear_device_ must have been called earlier
|
|
1653
|
+
# policy_input = self._shuttle.clear_device_()
|
|
1654
|
+
policy_input = self._carrier
|
|
1655
|
+
else:
|
|
1656
|
+
policy_input = self._carrier
|
|
1657
|
+
# we still do the assignment for security
|
|
1658
|
+
if self.compiled_policy:
|
|
1659
|
+
cudagraph_mark_step_begin()
|
|
1660
|
+
policy_output = self._wrapped_policy(policy_input)
|
|
1661
|
+
if self.compiled_policy:
|
|
1662
|
+
policy_output = policy_output.clone()
|
|
1663
|
+
if self._carrier is not policy_output:
|
|
1664
|
+
# ad-hoc update shuttle
|
|
1665
|
+
self._carrier.update(
|
|
1666
|
+
policy_output, keys_to_update=self._policy_output_keys
|
|
1667
|
+
)
|
|
1668
|
+
|
|
1669
|
+
if self._cast_to_env_device:
|
|
1670
|
+
if self.env_device is not None:
|
|
1671
|
+
non_blocking = (
|
|
1672
|
+
not self.no_cuda_sync or self.env_device.type == "cuda"
|
|
1673
|
+
)
|
|
1674
|
+
env_input = self._carrier.to(
|
|
1675
|
+
self.env_device, non_blocking=non_blocking
|
|
1676
|
+
)
|
|
1677
|
+
if not self.no_cuda_sync:
|
|
1678
|
+
self._sync_env()
|
|
1679
|
+
elif self.env_device is None:
|
|
1680
|
+
# we know the tensordict has a device otherwise we would not be here
|
|
1681
|
+
# we can pass this, clear_device_ must have been called earlier
|
|
1682
|
+
# env_input = self._shuttle.clear_device_()
|
|
1683
|
+
env_input = self._carrier
|
|
1684
|
+
else:
|
|
1685
|
+
env_input = self._carrier
|
|
1686
|
+
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
|
|
1687
|
+
|
|
1688
|
+
if self._carrier is not env_output:
|
|
1689
|
+
# ad-hoc update shuttle
|
|
1690
|
+
next_data = env_output.get("next")
|
|
1691
|
+
if self._shuttle_has_no_device:
|
|
1692
|
+
# Make sure
|
|
1693
|
+
next_data.clear_device_()
|
|
1694
|
+
self._carrier.set("next", next_data)
|
|
1695
|
+
|
|
1696
|
+
if (
|
|
1697
|
+
self.replay_buffer is not None
|
|
1698
|
+
and not self._ignore_rb
|
|
1699
|
+
and not self.extend_buffer
|
|
1700
|
+
):
|
|
1701
|
+
self.replay_buffer.add(self._carrier)
|
|
1702
|
+
if self._increment_frames(self._carrier.numel()):
|
|
1703
|
+
return
|
|
1704
|
+
else:
|
|
1705
|
+
if self.storing_device is not None:
|
|
1706
|
+
non_blocking = (
|
|
1707
|
+
not self.no_cuda_sync or self.storing_device.type == "cuda"
|
|
1708
|
+
)
|
|
1709
|
+
tensordicts.append(
|
|
1710
|
+
self._carrier.to(
|
|
1711
|
+
self.storing_device, non_blocking=non_blocking
|
|
1712
|
+
)
|
|
1713
|
+
)
|
|
1714
|
+
if not self.no_cuda_sync:
|
|
1715
|
+
self._sync_storage()
|
|
1716
|
+
else:
|
|
1717
|
+
tensordicts.append(self._carrier)
|
|
1718
|
+
|
|
1719
|
+
# carry over collector data without messing up devices
|
|
1720
|
+
collector_data = self._carrier.get("collector").copy()
|
|
1721
|
+
self._carrier = env_next_output
|
|
1722
|
+
if self._shuttle_has_no_device:
|
|
1723
|
+
self._carrier.clear_device_()
|
|
1724
|
+
self._carrier.set("collector", collector_data)
|
|
1725
|
+
self._update_traj_ids(env_output)
|
|
1726
|
+
|
|
1727
|
+
if (
|
|
1728
|
+
self.interruptor is not None
|
|
1729
|
+
and self.interruptor.collection_stopped()
|
|
1730
|
+
):
|
|
1731
|
+
if (
|
|
1732
|
+
self.replay_buffer is not None
|
|
1733
|
+
and not self._ignore_rb
|
|
1734
|
+
and not self.extend_buffer
|
|
1735
|
+
):
|
|
1736
|
+
return
|
|
1737
|
+
result = self._final_rollout
|
|
1738
|
+
if self._use_buffers:
|
|
1739
|
+
try:
|
|
1740
|
+
torch.stack(
|
|
1741
|
+
tensordicts,
|
|
1742
|
+
self._final_rollout.ndim - 1,
|
|
1743
|
+
out=self._final_rollout[..., : t + 1],
|
|
1744
|
+
)
|
|
1745
|
+
except RuntimeError:
|
|
1746
|
+
with self._final_rollout.unlock_():
|
|
1747
|
+
torch.stack(
|
|
1748
|
+
tensordicts,
|
|
1749
|
+
self._final_rollout.ndim - 1,
|
|
1750
|
+
out=self._final_rollout[..., : t + 1],
|
|
1751
|
+
)
|
|
1752
|
+
else:
|
|
1753
|
+
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
|
|
1754
|
+
break
|
|
1755
|
+
else:
|
|
1756
|
+
if self._use_buffers:
|
|
1757
|
+
result = self._final_rollout
|
|
1758
|
+
try:
|
|
1759
|
+
result = torch.stack(
|
|
1760
|
+
tensordicts,
|
|
1761
|
+
self._final_rollout.ndim - 1,
|
|
1762
|
+
out=self._final_rollout,
|
|
1763
|
+
)
|
|
1764
|
+
|
|
1765
|
+
except RuntimeError:
|
|
1766
|
+
with self._final_rollout.unlock_():
|
|
1767
|
+
result = torch.stack(
|
|
1768
|
+
tensordicts,
|
|
1769
|
+
self._final_rollout.ndim - 1,
|
|
1770
|
+
out=self._final_rollout,
|
|
1771
|
+
)
|
|
1772
|
+
elif (
|
|
1773
|
+
self.replay_buffer is not None
|
|
1774
|
+
and not self._ignore_rb
|
|
1775
|
+
and not self.extend_buffer
|
|
1776
|
+
):
|
|
1777
|
+
return
|
|
1778
|
+
else:
|
|
1779
|
+
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
|
|
1780
|
+
result.refine_names(..., "time")
|
|
1781
|
+
|
|
1782
|
+
return self._maybe_set_truncated(result)
|
|
1783
|
+
|
|
1784
|
+
def _maybe_set_truncated(self, final_rollout):
|
|
1785
|
+
last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,)
|
|
1786
|
+
for truncated_key in self._truncated_keys:
|
|
1787
|
+
truncated = final_rollout["next", truncated_key]
|
|
1788
|
+
truncated[last_step] = True
|
|
1789
|
+
final_rollout["next", truncated_key] = truncated
|
|
1790
|
+
done = final_rollout["next", _replace_last(truncated_key, "done")]
|
|
1791
|
+
final_rollout["next", _replace_last(truncated_key, "done")] = (
|
|
1792
|
+
done | truncated
|
|
1793
|
+
)
|
|
1794
|
+
return final_rollout
|
|
1795
|
+
|
|
1796
|
+
@torch.no_grad()
|
|
1797
|
+
def reset(self, index=None, **kwargs) -> None:
|
|
1798
|
+
"""Resets the environments to a new initial state."""
|
|
1799
|
+
# metadata
|
|
1800
|
+
collector_metadata = self._carrier.get("collector").clone()
|
|
1801
|
+
if index is not None:
|
|
1802
|
+
# check that the env supports partial reset
|
|
1803
|
+
if prod(self.env.batch_size) == 0:
|
|
1804
|
+
raise RuntimeError("resetting unique env with index is not permitted.")
|
|
1805
|
+
for reset_key, done_keys in zip(
|
|
1806
|
+
self.env.reset_keys, self.env.done_keys_groups
|
|
1807
|
+
):
|
|
1808
|
+
_reset = torch.zeros(
|
|
1809
|
+
self.env.full_done_spec[done_keys[0]].shape,
|
|
1810
|
+
dtype=torch.bool,
|
|
1811
|
+
device=self.env.device,
|
|
1812
|
+
)
|
|
1813
|
+
_reset[index] = 1
|
|
1814
|
+
self._carrier.set(reset_key, _reset)
|
|
1815
|
+
else:
|
|
1816
|
+
_reset = None
|
|
1817
|
+
self._carrier.zero_()
|
|
1818
|
+
|
|
1819
|
+
self._carrier.update(self.env.reset(**kwargs), inplace=True)
|
|
1820
|
+
collector_metadata["traj_ids"] = (
|
|
1821
|
+
collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min()
|
|
1822
|
+
)
|
|
1823
|
+
self._carrier["collector"] = collector_metadata
|
|
1824
|
+
|
|
1825
|
+
def shutdown(
|
|
1826
|
+
self,
|
|
1827
|
+
timeout: float | None = None,
|
|
1828
|
+
close_env: bool = True,
|
|
1829
|
+
raise_on_error: bool = True,
|
|
1830
|
+
) -> None:
|
|
1831
|
+
"""Shuts down all workers and/or closes the local environment.
|
|
1832
|
+
|
|
1833
|
+
Args:
|
|
1834
|
+
timeout (float, optional): The timeout for closing pipes between workers.
|
|
1835
|
+
No effect for this class.
|
|
1836
|
+
close_env (bool, optional): Whether to close the environment. Defaults to `True`.
|
|
1837
|
+
raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`.
|
|
1838
|
+
"""
|
|
1839
|
+
try:
|
|
1840
|
+
if not self.closed:
|
|
1841
|
+
self.closed = True
|
|
1842
|
+
del self._carrier
|
|
1843
|
+
if self._use_buffers:
|
|
1844
|
+
del self._final_rollout
|
|
1845
|
+
if close_env and not self.env.is_closed:
|
|
1846
|
+
self.env.close(raise_if_closed=raise_on_error)
|
|
1847
|
+
del self.env
|
|
1848
|
+
return
|
|
1849
|
+
except Exception as e:
|
|
1850
|
+
if raise_on_error:
|
|
1851
|
+
raise e
|
|
1852
|
+
else:
|
|
1853
|
+
pass
|
|
1854
|
+
|
|
1855
|
+
def __del__(self):
|
|
1856
|
+
try:
|
|
1857
|
+
self.shutdown()
|
|
1858
|
+
except Exception:
|
|
1859
|
+
# an AttributeError will typically be raised if the collector is deleted when the program ends.
|
|
1860
|
+
# In the future, insignificant changes to the close method may change the error type.
|
|
1861
|
+
# We excplicitely assume that any error raised during closure in
|
|
1862
|
+
# __del__ will not affect the program.
|
|
1863
|
+
pass
|
|
1864
|
+
|
|
1865
|
+
def state_dict(self) -> OrderedDict:
|
|
1866
|
+
"""Returns the local state_dict of the data collector (environment and policy).
|
|
1867
|
+
|
|
1868
|
+
Returns:
|
|
1869
|
+
an ordered dictionary with fields :obj:`"policy_state_dict"` and
|
|
1870
|
+
`"env_state_dict"`.
|
|
1871
|
+
|
|
1872
|
+
"""
|
|
1873
|
+
from torchrl.envs.batched_envs import BatchedEnvBase
|
|
1874
|
+
|
|
1875
|
+
if isinstance(self.env, TransformedEnv):
|
|
1876
|
+
env_state_dict = self.env.transform.state_dict()
|
|
1877
|
+
elif isinstance(self.env, BatchedEnvBase):
|
|
1878
|
+
env_state_dict = self.env.state_dict()
|
|
1879
|
+
else:
|
|
1880
|
+
env_state_dict = OrderedDict()
|
|
1881
|
+
|
|
1882
|
+
if hasattr(self, "_policy_w_state_dict"):
|
|
1883
|
+
policy_state_dict = self._policy_w_state_dict.state_dict()
|
|
1884
|
+
state_dict = OrderedDict(
|
|
1885
|
+
policy_state_dict=policy_state_dict,
|
|
1886
|
+
env_state_dict=env_state_dict,
|
|
1887
|
+
)
|
|
1888
|
+
else:
|
|
1889
|
+
state_dict = OrderedDict(env_state_dict=env_state_dict)
|
|
1890
|
+
|
|
1891
|
+
state_dict.update({"frames": self._frames, "iter": self._iter})
|
|
1892
|
+
|
|
1893
|
+
return state_dict
|
|
1894
|
+
|
|
1895
|
+
def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
|
|
1896
|
+
"""Loads a state_dict on the environment and policy.
|
|
1897
|
+
|
|
1898
|
+
Args:
|
|
1899
|
+
state_dict (OrderedDict): ordered dictionary containing the fields
|
|
1900
|
+
`"policy_state_dict"` and :obj:`"env_state_dict"`.
|
|
1901
|
+
|
|
1902
|
+
"""
|
|
1903
|
+
strict = kwargs.get("strict", True)
|
|
1904
|
+
if strict or "env_state_dict" in state_dict:
|
|
1905
|
+
self.env.load_state_dict(state_dict["env_state_dict"], **kwargs)
|
|
1906
|
+
if strict or "policy_state_dict" in state_dict:
|
|
1907
|
+
if not hasattr(self, "_policy_w_state_dict"):
|
|
1908
|
+
raise ValueError(
|
|
1909
|
+
"Underlying policy does not have state_dict to load policy_state_dict into."
|
|
1910
|
+
)
|
|
1911
|
+
self._policy_w_state_dict.load_state_dict(
|
|
1912
|
+
state_dict["policy_state_dict"], **kwargs
|
|
1913
|
+
)
|
|
1914
|
+
self._frames = state_dict["frames"]
|
|
1915
|
+
self._iter = state_dict["iter"]
|
|
1916
|
+
|
|
1917
|
+
def __repr__(self) -> str:
|
|
1918
|
+
try:
|
|
1919
|
+
env_str = indent(f"env={self.env}", 4 * " ")
|
|
1920
|
+
policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ")
|
|
1921
|
+
td_out_str = repr(getattr(self, "_final_rollout", None))
|
|
1922
|
+
if len(td_out_str) > 50:
|
|
1923
|
+
td_out_str = td_out_str[:50] + "..."
|
|
1924
|
+
td_out_str = indent(f"td_out={td_out_str}", 4 * " ")
|
|
1925
|
+
string = (
|
|
1926
|
+
f"{self.__class__.__name__}("
|
|
1927
|
+
f"\n{env_str},"
|
|
1928
|
+
f"\n{policy_str},"
|
|
1929
|
+
f"\n{td_out_str},"
|
|
1930
|
+
f"\nexploration={self.exploration_type})"
|
|
1931
|
+
)
|
|
1932
|
+
return string
|
|
1933
|
+
except Exception:
|
|
1934
|
+
return f"{type(self).__name__}(not_init)"
|
|
1935
|
+
|
|
1936
|
+
def increment_version(self):
|
|
1937
|
+
"""Increment the policy version."""
|
|
1938
|
+
if self.policy_version_tracker is not None:
|
|
1939
|
+
if not hasattr(self.policy_version_tracker, "increment_version"):
|
|
1940
|
+
raise RuntimeError(
|
|
1941
|
+
"Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector."
|
|
1942
|
+
)
|
|
1943
|
+
self.policy_version_tracker.increment_version()
|
|
1944
|
+
|
|
1945
|
+
@property
|
|
1946
|
+
def policy_version(self) -> str | int | None:
|
|
1947
|
+
"""The current policy version."""
|
|
1948
|
+
if not hasattr(self.policy_version_tracker, "version"):
|
|
1949
|
+
return None
|
|
1950
|
+
return self.policy_version_tracker.version
|
|
1951
|
+
|
|
1952
|
+
def get_policy_version(self) -> str | int | None:
|
|
1953
|
+
"""Get the current policy version.
|
|
1954
|
+
|
|
1955
|
+
This method exists to support remote calls in Ray actors, since properties
|
|
1956
|
+
cannot be accessed directly through Ray's RPC mechanism.
|
|
1957
|
+
|
|
1958
|
+
Returns:
|
|
1959
|
+
The current version number (int) or UUID (str), or None if version tracking is disabled.
|
|
1960
|
+
"""
|
|
1961
|
+
return self.policy_version
|
|
1962
|
+
|
|
1963
|
+
def getattr_policy(self, attr):
|
|
1964
|
+
"""Get an attribute from the policy."""
|
|
1965
|
+
# send command to policy to return the attr
|
|
1966
|
+
return getattr(self._wrapped_policy, attr)
|
|
1967
|
+
|
|
1968
|
+
def getattr_env(self, attr):
|
|
1969
|
+
"""Get an attribute from the environment."""
|
|
1970
|
+
# send command to env to return the attr
|
|
1971
|
+
return getattr(self.env, attr)
|
|
1972
|
+
|
|
1973
|
+
def getattr_rb(self, attr):
|
|
1974
|
+
"""Get an attribute from the replay buffer."""
|
|
1975
|
+
# send command to rb to return the attr
|
|
1976
|
+
return getattr(self.replay_buffer, attr)
|
|
1977
|
+
|
|
1978
|
+
def get_model(self, model_id: str):
|
|
1979
|
+
"""Get model instance by ID (for weight sync schemes).
|
|
1980
|
+
|
|
1981
|
+
Args:
|
|
1982
|
+
model_id: Model identifier (e.g., "policy", "value_net")
|
|
1983
|
+
|
|
1984
|
+
Returns:
|
|
1985
|
+
The model instance
|
|
1986
|
+
|
|
1987
|
+
Raises:
|
|
1988
|
+
ValueError: If model_id is not recognized
|
|
1989
|
+
"""
|
|
1990
|
+
if model_id == "policy":
|
|
1991
|
+
# Return the unwrapped policy instance for weight synchronization
|
|
1992
|
+
# The unwrapped policy has the same parameter structure as what's
|
|
1993
|
+
# extracted in the main process, avoiding key mismatches when
|
|
1994
|
+
# the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule)
|
|
1995
|
+
if hasattr(self, "policy") and self.policy is not None:
|
|
1996
|
+
return self.policy
|
|
1997
|
+
else:
|
|
1998
|
+
raise ValueError(f"No policy found for model_id '{model_id}'")
|
|
1999
|
+
else:
|
|
2000
|
+
return _resolve_model(self, model_id)
|
|
2001
|
+
|
|
2002
|
+
def _receive_weights_scheme(self):
|
|
2003
|
+
return super()._receive_weights_scheme()
|
|
2004
|
+
|
|
2005
|
+
|
|
2006
|
+
class SyncDataCollector(Collector, metaclass=_LegacyCollectorMeta):
|
|
2007
|
+
"""Deprecated version of :class:`~torchrl.collectors.Collector`."""
|
|
2008
|
+
|
|
2009
|
+
...
|