torchrl 0.11.0__cp314-cp314t-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314t-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
from collections.abc import Callable, Sequence
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from pyvers import implement_for
|
|
12
|
+
|
|
13
|
+
from tensordict import NestedKey, pad, set_lazy_legacy, TensorDict, TensorDictBase
|
|
14
|
+
from tensordict.utils import Buffer
|
|
15
|
+
from torch import multiprocessing as mp, nn as nn
|
|
16
|
+
from torch.nn import Parameter
|
|
17
|
+
|
|
18
|
+
_NON_NN_POLICY_WEIGHTS = (
|
|
19
|
+
"The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and "
|
|
20
|
+
"update_policy_weights_ will be a no-op. Consider passing a local/weight_updater object "
|
|
21
|
+
"to your collector to handle the weight updates."
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _stack_output(fun) -> Callable:
|
|
26
|
+
def stacked_output_fun(*args, **kwargs):
|
|
27
|
+
out = fun(*args, **kwargs)
|
|
28
|
+
return tuple(torch.stack(_o, 0) for _o in out)
|
|
29
|
+
|
|
30
|
+
return stacked_output_fun
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _stack_output_zip(fun) -> Callable:
|
|
34
|
+
def stacked_output_fun(*args, **kwargs):
|
|
35
|
+
out = fun(*args, **kwargs)
|
|
36
|
+
return tuple(torch.stack(_o, 0) for _o in zip(*out))
|
|
37
|
+
|
|
38
|
+
return stacked_output_fun
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@set_lazy_legacy(False)
|
|
42
|
+
def split_trajectories(
|
|
43
|
+
rollout_tensordict: TensorDictBase,
|
|
44
|
+
*,
|
|
45
|
+
prefix=None,
|
|
46
|
+
trajectory_key: NestedKey | None = None,
|
|
47
|
+
done_key: NestedKey | None = None,
|
|
48
|
+
as_nested: bool = False,
|
|
49
|
+
) -> TensorDictBase:
|
|
50
|
+
"""A util function for trajectory separation.
|
|
51
|
+
|
|
52
|
+
Takes a tensordict with a key traj_ids that indicates the id of each trajectory.
|
|
53
|
+
|
|
54
|
+
From there, builds a B x T x ... zero-padded tensordict with B batches on max duration T
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
rollout_tensordict (TensorDictBase): a rollout with adjacent trajectories
|
|
58
|
+
along the last dimension.
|
|
59
|
+
|
|
60
|
+
Keyword Args:
|
|
61
|
+
prefix (NestedKey, optional): the prefix used to read and write meta-data,
|
|
62
|
+
such as ``"traj_ids"`` (the optional integer id of each trajectory)
|
|
63
|
+
and the ``"mask"`` entry indicating which data are valid and which
|
|
64
|
+
aren't. Defaults to ``"collector"`` if the input has a ``"collector"``
|
|
65
|
+
entry, ``()`` (no prefix) otherwise.
|
|
66
|
+
``prefix`` is kept as a legacy feature and will be deprecated eventually.
|
|
67
|
+
Prefer ``trajectory_key`` or ``done_key`` whenever possible.
|
|
68
|
+
trajectory_key (NestedKey, optional): the key pointing to the trajectory
|
|
69
|
+
ids. Supersedes ``done_key`` and ``prefix``. If not provided, defaults
|
|
70
|
+
to ``(prefix, "traj_ids")``.
|
|
71
|
+
done_key (NestedKey, optional): the key pointing to the ``"done""`` signal,
|
|
72
|
+
if the trajectory could not be directly recovered. Defaults to ``"done"``.
|
|
73
|
+
as_nested (bool or torch.layout, optional): whether to return the results as nested
|
|
74
|
+
tensors. Defaults to ``False``. If a ``torch.layout`` is provided, it will be used
|
|
75
|
+
to construct the nested tensor, otherwise the default layout will be used.
|
|
76
|
+
|
|
77
|
+
.. note:: Using ``split_trajectories(tensordict, as_nested=True).to_padded_tensor(mask=mask_key)``
|
|
78
|
+
should result in the exact same result as ``as_nested=False``. Since this is an experimental
|
|
79
|
+
feature and relies on nested_tensors, which API may change in the future, we made this
|
|
80
|
+
an optional feature. The runtime should be faster with ``as_nested=True``.
|
|
81
|
+
|
|
82
|
+
.. note:: Providing a layout lets the user control whether the nested tensor is to be used
|
|
83
|
+
with ``torch.strided`` or ``torch.jagged`` layout. While the former has slightly more
|
|
84
|
+
capabilities at the time of writing, the second will be the main focus of the PyTorch team
|
|
85
|
+
in the future due to its better compatibility with :func:`~torch.compile`.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
A new tensordict with a leading dimension corresponding to the trajectory.
|
|
89
|
+
A ``"mask"`` boolean entry sharing the ``trajectory_key`` prefix
|
|
90
|
+
and the tensordict shape is also added. It indicated the valid elements of the tensordict,
|
|
91
|
+
as well as a ``"traj_ids"`` entry if ``trajectory_key`` could not be found.
|
|
92
|
+
|
|
93
|
+
Examples:
|
|
94
|
+
>>> from tensordict import TensorDict
|
|
95
|
+
>>> import torch
|
|
96
|
+
>>> from torchrl.collectors.utils import split_trajectories
|
|
97
|
+
>>> obs = torch.cat([torch.arange(10), torch.arange(5)])
|
|
98
|
+
>>> obs_ = torch.cat([torch.arange(1, 11), torch.arange(1, 6)])
|
|
99
|
+
>>> done = torch.zeros(15, dtype=torch.bool)
|
|
100
|
+
>>> done[9] = True
|
|
101
|
+
>>> trajectory_id = torch.cat([torch.zeros(10, dtype=torch.int32),
|
|
102
|
+
... torch.ones(5, dtype=torch.int32)])
|
|
103
|
+
>>> data = TensorDict({"obs": obs, ("next", "obs"): obs_, ("next", "done"): done, "trajectory": trajectory_id}, batch_size=[15])
|
|
104
|
+
>>> data_split = split_trajectories(data, done_key="done")
|
|
105
|
+
>>> print(data_split)
|
|
106
|
+
TensorDict(
|
|
107
|
+
fields={
|
|
108
|
+
mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
109
|
+
next: TensorDict(
|
|
110
|
+
fields={
|
|
111
|
+
done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
112
|
+
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
113
|
+
batch_size=torch.Size([2, 10]),
|
|
114
|
+
device=None,
|
|
115
|
+
is_shared=False),
|
|
116
|
+
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
117
|
+
traj_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
118
|
+
trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
119
|
+
batch_size=torch.Size([2, 10]),
|
|
120
|
+
device=None,
|
|
121
|
+
is_shared=False)
|
|
122
|
+
>>> # check that split_trajectories got the trajectories right with the done signal
|
|
123
|
+
>>> assert (data_split["traj_ids"] == data_split["trajectory"]).all()
|
|
124
|
+
>>> print(data_split["mask"])
|
|
125
|
+
tensor([[ True, True, True, True, True, True, True, True, True, True],
|
|
126
|
+
[ True, True, True, True, True, False, False, False, False, False]])
|
|
127
|
+
>>> data_split = split_trajectories(data, trajectory_key="trajectory")
|
|
128
|
+
>>> print(data_split)
|
|
129
|
+
TensorDict(
|
|
130
|
+
fields={
|
|
131
|
+
mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
132
|
+
next: TensorDict(
|
|
133
|
+
fields={
|
|
134
|
+
done: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
135
|
+
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
|
|
136
|
+
batch_size=torch.Size([2, 10]),
|
|
137
|
+
device=None,
|
|
138
|
+
is_shared=False),
|
|
139
|
+
obs: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
140
|
+
trajectory: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int32, is_shared=False)},
|
|
141
|
+
batch_size=torch.Size([2, 10]),
|
|
142
|
+
device=None,
|
|
143
|
+
is_shared=False)
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
mask_key = None
|
|
147
|
+
if trajectory_key is not None:
|
|
148
|
+
from torchrl.envs.utils import _replace_last
|
|
149
|
+
|
|
150
|
+
traj_ids_key = trajectory_key
|
|
151
|
+
mask_key = _replace_last(trajectory_key, "mask")
|
|
152
|
+
else:
|
|
153
|
+
if prefix is None and "collector" in rollout_tensordict.keys():
|
|
154
|
+
prefix = "collector"
|
|
155
|
+
if prefix is None:
|
|
156
|
+
traj_ids_key = "traj_ids"
|
|
157
|
+
mask_key = "mask"
|
|
158
|
+
else:
|
|
159
|
+
traj_ids_key = (prefix, "traj_ids")
|
|
160
|
+
mask_key = (prefix, "mask")
|
|
161
|
+
|
|
162
|
+
rollout_tensordict = rollout_tensordict.copy()
|
|
163
|
+
traj_ids = rollout_tensordict.get(traj_ids_key, None)
|
|
164
|
+
if traj_ids is None:
|
|
165
|
+
if done_key is None:
|
|
166
|
+
done_key = "done"
|
|
167
|
+
done_key = ("next", done_key)
|
|
168
|
+
done = rollout_tensordict.get(done_key)
|
|
169
|
+
idx = (slice(None),) * (rollout_tensordict.ndim - 1) + (slice(None, -1),)
|
|
170
|
+
done_sel = done[idx]
|
|
171
|
+
pads = [1, 0]
|
|
172
|
+
pads = [0, 0] * (done.ndim - rollout_tensordict.ndim) + pads
|
|
173
|
+
done_sel = torch.nn.functional.pad(done_sel, pads)
|
|
174
|
+
if done_sel.shape != done.shape:
|
|
175
|
+
raise RuntimeError(
|
|
176
|
+
f"done and done_sel have different shape {done.shape} - {done_sel.shape} "
|
|
177
|
+
)
|
|
178
|
+
traj_ids = done_sel.cumsum(rollout_tensordict.ndim - 1)
|
|
179
|
+
traj_ids = traj_ids.squeeze(-1)
|
|
180
|
+
if rollout_tensordict.ndim > 1:
|
|
181
|
+
for i in range(1, rollout_tensordict.shape[0]):
|
|
182
|
+
traj_ids[i] += traj_ids[i - 1].max() + 1
|
|
183
|
+
rollout_tensordict.set(traj_ids_key, traj_ids)
|
|
184
|
+
|
|
185
|
+
splits = traj_ids.reshape(-1)
|
|
186
|
+
splits = [(splits == i).sum().item() for i in splits.unique_consecutive()]
|
|
187
|
+
# if all splits are identical then we can skip this function
|
|
188
|
+
if len(set(splits)) == 1 and splits[0] == traj_ids.shape[-1]:
|
|
189
|
+
rollout_tensordict.set(
|
|
190
|
+
mask_key,
|
|
191
|
+
torch.ones(
|
|
192
|
+
rollout_tensordict.shape,
|
|
193
|
+
device=rollout_tensordict.device,
|
|
194
|
+
dtype=torch.bool,
|
|
195
|
+
),
|
|
196
|
+
)
|
|
197
|
+
if rollout_tensordict.ndimension() == 1:
|
|
198
|
+
rollout_tensordict = rollout_tensordict.unsqueeze(0)
|
|
199
|
+
return rollout_tensordict
|
|
200
|
+
|
|
201
|
+
out_splits = rollout_tensordict.reshape(-1)
|
|
202
|
+
|
|
203
|
+
if as_nested:
|
|
204
|
+
if hasattr(torch, "_nested_compute_contiguous_strides_offsets"):
|
|
205
|
+
|
|
206
|
+
def nest(x, splits=splits):
|
|
207
|
+
# Convert splits into shapes
|
|
208
|
+
shape = torch.tensor([[int(split), *x.shape[1:]] for split in splits])
|
|
209
|
+
return torch._nested_view_from_buffer(
|
|
210
|
+
x.reshape(-1),
|
|
211
|
+
shape,
|
|
212
|
+
*torch._nested_compute_contiguous_strides_offsets(shape),
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return out_splits._fast_apply(
|
|
216
|
+
nest,
|
|
217
|
+
batch_size=[len(splits), -1],
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
out_splits = out_splits.split(splits, 0)
|
|
221
|
+
|
|
222
|
+
layout = as_nested if as_nested is not bool else None
|
|
223
|
+
|
|
224
|
+
if torch.__version__ < "2.4":
|
|
225
|
+
# Layout must be True, there is no other layout available
|
|
226
|
+
if layout not in (True,):
|
|
227
|
+
raise RuntimeError(
|
|
228
|
+
f"layout={layout} is only available for torch>=v2.4"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def nest(*x):
|
|
232
|
+
return torch.nested.nested_tensor(list(x))
|
|
233
|
+
|
|
234
|
+
else:
|
|
235
|
+
|
|
236
|
+
def nest(*x):
|
|
237
|
+
return torch.nested.nested_tensor(list(x), layout=layout)
|
|
238
|
+
|
|
239
|
+
return out_splits[0]._fast_apply(
|
|
240
|
+
nest,
|
|
241
|
+
*out_splits[1:],
|
|
242
|
+
batch_size=[len(out_splits), *out_splits[0].batch_size[:-1], -1],
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
out_splits = out_splits.split(splits, 0)
|
|
246
|
+
|
|
247
|
+
for out_split in out_splits:
|
|
248
|
+
out_split.set(
|
|
249
|
+
mask_key,
|
|
250
|
+
torch.ones(
|
|
251
|
+
out_split.shape,
|
|
252
|
+
dtype=torch.bool,
|
|
253
|
+
device=out_split.device,
|
|
254
|
+
),
|
|
255
|
+
)
|
|
256
|
+
if len(out_splits) > 1:
|
|
257
|
+
MAX = max(*[out_split.shape[0] for out_split in out_splits])
|
|
258
|
+
else:
|
|
259
|
+
MAX = out_splits[0].shape[0]
|
|
260
|
+
td = torch.stack(
|
|
261
|
+
[pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0
|
|
262
|
+
)
|
|
263
|
+
return td
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@implement_for("torch", "2.5.0")
|
|
267
|
+
def _cast(
|
|
268
|
+
p: nn.Parameter | torch.Tensor,
|
|
269
|
+
param_maybe_buffer: nn.Parameter | torch.Tensor | None = None,
|
|
270
|
+
) -> nn.Parameter | torch.Tensor:
|
|
271
|
+
if param_maybe_buffer is None:
|
|
272
|
+
param_maybe_buffer = p
|
|
273
|
+
p = p.data
|
|
274
|
+
if isinstance(param_maybe_buffer, Parameter):
|
|
275
|
+
# Create parameter without gradients to avoid serialization issues
|
|
276
|
+
return Parameter(p, requires_grad=False)
|
|
277
|
+
if isinstance(param_maybe_buffer, Buffer):
|
|
278
|
+
return Buffer(p)
|
|
279
|
+
if p.requires_grad:
|
|
280
|
+
raise RuntimeError(f"Cannot cast tensor {p} with gradients")
|
|
281
|
+
return p
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _make_meta_policy(policy: nn.Module):
|
|
285
|
+
"""Create context manager that temporarily puts policy parameters on meta device.
|
|
286
|
+
|
|
287
|
+
This is used with weight sync schemes to send policy structure without weights.
|
|
288
|
+
The actual weights are distributed by the schemes.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
policy: Policy module to temporarily modify.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
A context manager that temporarily replaces policy parameters with meta device versions.
|
|
295
|
+
On exit, the original parameters are restored to the policy.
|
|
296
|
+
"""
|
|
297
|
+
param_and_buf = TensorDict.from_module(policy, as_module=True)
|
|
298
|
+
return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@implement_for("torch", None, "2.8")
|
|
302
|
+
def _make_meta_policy_cm(
|
|
303
|
+
policy: nn.Module, *, mp_start_method: str
|
|
304
|
+
) -> contextlib.AbstractContextManager:
|
|
305
|
+
"""Return the context manager used to make a policy 'stateless' for worker pickling.
|
|
306
|
+
|
|
307
|
+
On older PyTorch versions (<2.8), pickling meta-device storages when using the
|
|
308
|
+
``spawn`` start method may fail (e.g., triggering ``_share_filename_: only available on CPU``).
|
|
309
|
+
In that case, we avoid converting parameters/buffers to meta and simply return a no-op
|
|
310
|
+
context manager.
|
|
311
|
+
"""
|
|
312
|
+
if mp_start_method == "spawn":
|
|
313
|
+
return contextlib.nullcontext()
|
|
314
|
+
return _make_meta_policy(policy)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@implement_for("torch", "2.8")
|
|
318
|
+
def _make_meta_policy_cm( # noqa: F811
|
|
319
|
+
policy: nn.Module, *, mp_start_method: str
|
|
320
|
+
) -> contextlib.AbstractContextManager:
|
|
321
|
+
"""Return the context manager used to make a policy 'stateless' for worker pickling.
|
|
322
|
+
|
|
323
|
+
On PyTorch >= 2.8, meta-device policy structures can be pickled reliably under ``spawn``.
|
|
324
|
+
"""
|
|
325
|
+
return _make_meta_policy(policy)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@implement_for("torch", None, "2.5.0")
|
|
329
|
+
def _cast( # noqa
|
|
330
|
+
p: nn.Parameter | torch.Tensor,
|
|
331
|
+
param_maybe_buffer: nn.Parameter | torch.Tensor | None = None,
|
|
332
|
+
) -> nn.Parameter | torch.Tensor:
|
|
333
|
+
if param_maybe_buffer is None:
|
|
334
|
+
param_maybe_buffer = p
|
|
335
|
+
p = p.data
|
|
336
|
+
if isinstance(param_maybe_buffer, Parameter):
|
|
337
|
+
# Create parameter without gradients to avoid serialization issues
|
|
338
|
+
return Parameter(p, requires_grad=False)
|
|
339
|
+
if p.requires_grad:
|
|
340
|
+
raise RuntimeError(f"Cannot cast tensor {p} with gradients")
|
|
341
|
+
return p
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _map_to_cpu_if_needed(x):
|
|
345
|
+
"""Map tensors on exotic devices (MPS, NPU, etc.) to CPU.
|
|
346
|
+
|
|
347
|
+
CPU and CUDA tensors are kept as-is since they can be shared across processes.
|
|
348
|
+
Only exotic devices that don't support multiprocessing are mapped to CPU.
|
|
349
|
+
"""
|
|
350
|
+
if isinstance(x, torch.Tensor):
|
|
351
|
+
# CPU and CUDA can be shared across processes
|
|
352
|
+
if x.device.type in ("cpu", "cuda"):
|
|
353
|
+
return x
|
|
354
|
+
# Exotic devices (MPS, NPU, etc.) need to be mapped to CPU
|
|
355
|
+
return x.cpu()
|
|
356
|
+
return x
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _make_meta_params(param):
|
|
360
|
+
is_param = isinstance(param, Parameter)
|
|
361
|
+
|
|
362
|
+
pd = param.detach().to("meta")
|
|
363
|
+
|
|
364
|
+
if is_param:
|
|
365
|
+
pd = Parameter(pd, requires_grad=False)
|
|
366
|
+
return pd
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class _TrajectoryPool:
|
|
370
|
+
def __init__(self, ctx=None, lock: bool = False):
|
|
371
|
+
self.ctx = ctx
|
|
372
|
+
self._traj_id = torch.zeros((), device="cpu", dtype=torch.int)
|
|
373
|
+
# Only use shared memory when multiprocessing context is provided
|
|
374
|
+
# This avoids issues with shared memory when the mp subsystem is in a bad state
|
|
375
|
+
if ctx is not None:
|
|
376
|
+
self._traj_id = self._traj_id.share_memory_()
|
|
377
|
+
if ctx is None:
|
|
378
|
+
self.lock = contextlib.nullcontext() if not lock else mp.RLock()
|
|
379
|
+
else:
|
|
380
|
+
self.lock = contextlib.nullcontext() if not lock else ctx.RLock()
|
|
381
|
+
|
|
382
|
+
def get_traj_and_increment(self, n=1, device=None):
|
|
383
|
+
with self.lock:
|
|
384
|
+
v = self._traj_id.item()
|
|
385
|
+
out = torch.arange(v, v + n).to(device)
|
|
386
|
+
self._traj_id.copy_(1 + out[-1].item())
|
|
387
|
+
return out
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _map_weight(
|
|
391
|
+
weight,
|
|
392
|
+
policy_device,
|
|
393
|
+
):
|
|
394
|
+
|
|
395
|
+
is_param = isinstance(weight, Parameter)
|
|
396
|
+
is_buffer = isinstance(weight, Buffer)
|
|
397
|
+
weight = weight.data
|
|
398
|
+
if weight.device != policy_device:
|
|
399
|
+
weight = weight.to(policy_device)
|
|
400
|
+
elif weight.device.type in ("cpu",):
|
|
401
|
+
weight = weight.share_memory_()
|
|
402
|
+
if is_param:
|
|
403
|
+
weight = Parameter(weight, requires_grad=False)
|
|
404
|
+
elif is_buffer:
|
|
405
|
+
weight = Buffer(weight)
|
|
406
|
+
return weight
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def _make_policy_factory(
|
|
410
|
+
*, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None
|
|
411
|
+
):
|
|
412
|
+
has_policy_factory = policy_factory is not None and (
|
|
413
|
+
(isinstance(policy_factory, Sequence) and any(policy_factory))
|
|
414
|
+
or not isinstance(policy_factory, Sequence)
|
|
415
|
+
)
|
|
416
|
+
if policy is not None and has_policy_factory:
|
|
417
|
+
raise ValueError("policy cannot be used with policy_factory")
|
|
418
|
+
elif has_policy_factory:
|
|
419
|
+
if isinstance(policy_factory, Sequence):
|
|
420
|
+
return policy_factory
|
|
421
|
+
else:
|
|
422
|
+
policy = policy_factory()
|
|
423
|
+
|
|
424
|
+
if weight_sync_scheme is not None:
|
|
425
|
+
# Initialize the receiver on the worker side
|
|
426
|
+
weight_sync_scheme.init_on_receiver(
|
|
427
|
+
model=policy,
|
|
428
|
+
model_id="policy",
|
|
429
|
+
worker_idx=worker_idx,
|
|
430
|
+
)
|
|
431
|
+
# Synchronize initial weights
|
|
432
|
+
weight_sync_scheme.connect(worker_idx=worker_idx)
|
|
433
|
+
return policy
|