torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Gym-specific transforms."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from tensordict import TensorDictBase
|
|
14
|
+
from tensordict.utils import expand_as_right, NestedKey
|
|
15
|
+
from torchrl.data.tensor_specs import Unbounded
|
|
16
|
+
|
|
17
|
+
from torchrl.envs.transforms.transforms import FORWARD_NOT_IMPLEMENTED, Transform
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EndOfLifeTransform(Transform):
|
|
21
|
+
"""Registers the end-of-life signal from a Gym env with a `lives` method.
|
|
22
|
+
|
|
23
|
+
Proposed by DeepMind for the DQN and co. It helps value estimation.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
eol_key (NestedKey, optional): the key where the end-of-life signal should
|
|
27
|
+
be written. Defaults to ``"end-of-life"``.
|
|
28
|
+
done_key (NestedKey, optional): a "done" key in the parent env done_spec,
|
|
29
|
+
where the done value can be retrieved. This key must be unique and its
|
|
30
|
+
shape must match the shape of the end-of-life entry. Defaults to ``"done"``.
|
|
31
|
+
eol_attribute (str, optional): the location of the "lives" in the gym env.
|
|
32
|
+
Defaults to ``"unwrapped.ale.lives"``. Supported attribute types are
|
|
33
|
+
integer/array-like objects or callables that return these values.
|
|
34
|
+
|
|
35
|
+
.. note::
|
|
36
|
+
This transform should be used with gym envs that have a ``env.unwrapped.ale.lives``.
|
|
37
|
+
|
|
38
|
+
Examples:
|
|
39
|
+
>>> from torchrl.envs.libs.gym import GymEnv
|
|
40
|
+
>>> from torchrl.envs.transforms.transforms import TransformedEnv
|
|
41
|
+
>>> env = GymEnv("ALE/Breakout-v5")
|
|
42
|
+
>>> env.rollout(100)
|
|
43
|
+
TensorDict(
|
|
44
|
+
fields={
|
|
45
|
+
action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
46
|
+
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
47
|
+
next: TensorDict(
|
|
48
|
+
fields={
|
|
49
|
+
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
50
|
+
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
51
|
+
reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
52
|
+
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
53
|
+
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
54
|
+
batch_size=torch.Size([100]),
|
|
55
|
+
device=cpu,
|
|
56
|
+
is_shared=False),
|
|
57
|
+
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
58
|
+
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
59
|
+
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
60
|
+
batch_size=torch.Size([100]),
|
|
61
|
+
device=cpu,
|
|
62
|
+
is_shared=False)
|
|
63
|
+
>>> eol_transform = EndOfLifeTransform()
|
|
64
|
+
>>> env = TransformedEnv(env, eol_transform)
|
|
65
|
+
>>> env.rollout(100)
|
|
66
|
+
TensorDict(
|
|
67
|
+
fields={
|
|
68
|
+
action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
69
|
+
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
70
|
+
eol: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
71
|
+
lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
72
|
+
next: TensorDict(
|
|
73
|
+
fields={
|
|
74
|
+
done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
75
|
+
end-of-life: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
76
|
+
lives: Tensor(shape=torch.Size([100]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
77
|
+
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
78
|
+
reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
79
|
+
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
80
|
+
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
81
|
+
batch_size=torch.Size([100]),
|
|
82
|
+
device=cpu,
|
|
83
|
+
is_shared=False),
|
|
84
|
+
pixels: Tensor(shape=torch.Size([100, 210, 160, 3]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
85
|
+
terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
86
|
+
truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
87
|
+
batch_size=torch.Size([100]),
|
|
88
|
+
device=cpu,
|
|
89
|
+
is_shared=False)
|
|
90
|
+
|
|
91
|
+
The typical usage of this transform is to replace the "done" state by "end-of-life"
|
|
92
|
+
within the loss module. The end-of-life signal isn't registered within the ``done_spec``
|
|
93
|
+
because it should not instruct the env to reset.
|
|
94
|
+
|
|
95
|
+
Examples:
|
|
96
|
+
>>> from torchrl.objectives import DQNLoss
|
|
97
|
+
>>> module = torch.nn.Identity() # used as a placeholder
|
|
98
|
+
>>> loss = DQNLoss(module, action_space="categorical")
|
|
99
|
+
>>> loss.set_keys(done="end-of-life", terminated="end-of-life")
|
|
100
|
+
>>> # equivalently
|
|
101
|
+
>>> eol_transform.register_keys(loss)
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
NO_PARENT_ERR = "The {} transform is being executed without a parent env. This is currently not supported."
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
eol_key: NestedKey = "end-of-life",
|
|
109
|
+
lives_key: NestedKey = "lives",
|
|
110
|
+
done_key: NestedKey = "done",
|
|
111
|
+
eol_attribute="unwrapped.ale.lives",
|
|
112
|
+
):
|
|
113
|
+
super().__init__(in_keys=[done_key], out_keys=[eol_key, lives_key])
|
|
114
|
+
self.eol_key = eol_key
|
|
115
|
+
self.lives_key = lives_key
|
|
116
|
+
self.done_key = done_key
|
|
117
|
+
self.eol_attribute = eol_attribute.split(".")
|
|
118
|
+
|
|
119
|
+
def _get_lives(self):
|
|
120
|
+
from torchrl.envs.libs.gym import GymWrapper
|
|
121
|
+
|
|
122
|
+
base_env = self.parent.base_env
|
|
123
|
+
if not isinstance(base_env, GymWrapper):
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"The base_env is not a gym env. Compatibility of {type(self)} is not guaranteed with "
|
|
126
|
+
f"environment types that do not inherit from GymWrapper.",
|
|
127
|
+
category=UserWarning,
|
|
128
|
+
)
|
|
129
|
+
# getattr falls back on _env by default
|
|
130
|
+
lives = getattr(base_env, self.eol_attribute[0])
|
|
131
|
+
for att in self.eol_attribute[1:]:
|
|
132
|
+
if isinstance(lives, list):
|
|
133
|
+
# For SerialEnv (and who knows Parallel one day)
|
|
134
|
+
lives = [getattr(_lives, att) for _lives in lives]
|
|
135
|
+
else:
|
|
136
|
+
lives = getattr(lives, att)
|
|
137
|
+
if callable(lives):
|
|
138
|
+
lives = lives()
|
|
139
|
+
elif isinstance(lives, list) and all(callable(_lives) for _lives in lives):
|
|
140
|
+
lives = torch.as_tensor([_lives() for _lives in lives])
|
|
141
|
+
return lives
|
|
142
|
+
|
|
143
|
+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
|
|
144
|
+
return next_tensordict
|
|
145
|
+
|
|
146
|
+
def _step(self, tensordict, next_tensordict):
|
|
147
|
+
parent = self.parent
|
|
148
|
+
if parent is None:
|
|
149
|
+
raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
|
|
150
|
+
|
|
151
|
+
lives = self._get_lives()
|
|
152
|
+
end_of_life = torch.as_tensor(
|
|
153
|
+
tensordict.get(self.lives_key) > lives, device=self.parent.device
|
|
154
|
+
)
|
|
155
|
+
done = next_tensordict.get(self.done_key, None) # TODO: None soon to be removed
|
|
156
|
+
if done is None:
|
|
157
|
+
raise KeyError(
|
|
158
|
+
f"The done value pointed by {self.done_key} cannot be found in tensordict with keys {tensordict.keys(True, True)}. "
|
|
159
|
+
f"Make sure to pass the appropriate done_key to the {type(self)} transform."
|
|
160
|
+
)
|
|
161
|
+
end_of_life = expand_as_right(end_of_life, done) | done
|
|
162
|
+
next_tensordict.set(self.eol_key, end_of_life)
|
|
163
|
+
next_tensordict.set(self.lives_key, lives)
|
|
164
|
+
return next_tensordict
|
|
165
|
+
|
|
166
|
+
def _reset(self, tensordict, tensordict_reset):
|
|
167
|
+
parent = self.parent
|
|
168
|
+
if parent is None:
|
|
169
|
+
raise RuntimeError(self.NO_PARENT_ERR.format(type(self)))
|
|
170
|
+
lives = self._get_lives()
|
|
171
|
+
end_of_life = False
|
|
172
|
+
tensordict_reset.set(
|
|
173
|
+
self.eol_key,
|
|
174
|
+
torch.as_tensor(end_of_life).expand(
|
|
175
|
+
parent.full_done_spec[self.done_key].shape
|
|
176
|
+
),
|
|
177
|
+
)
|
|
178
|
+
tensordict_reset.set(self.lives_key, lives)
|
|
179
|
+
return tensordict_reset
|
|
180
|
+
|
|
181
|
+
def transform_observation_spec(self, observation_spec):
|
|
182
|
+
full_done_spec = self.parent.output_spec["full_done_spec"]
|
|
183
|
+
observation_spec[self.eol_key] = full_done_spec[self.done_key].clone()
|
|
184
|
+
observation_spec[self.lives_key] = Unbounded(
|
|
185
|
+
self.parent.batch_size,
|
|
186
|
+
device=self.parent.device,
|
|
187
|
+
dtype=torch.int64,
|
|
188
|
+
)
|
|
189
|
+
return observation_spec
|
|
190
|
+
|
|
191
|
+
def register_keys(
|
|
192
|
+
self, loss_or_advantage: torchrl.objectives.common.LossModule # noqa
|
|
193
|
+
):
|
|
194
|
+
"""Registers the end-of-life key at appropriate places within the loss.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
loss_or_advantage (torchrl.objectives.LossModule or torchrl.objectives.value.ValueEstimatorBase): a module to instruct what the end-of-life key is.
|
|
198
|
+
|
|
199
|
+
"""
|
|
200
|
+
loss_or_advantage.set_keys(done=self.eol_key, terminated=self.eol_key)
|
|
201
|
+
|
|
202
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
203
|
+
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from contextlib import nullcontext
|
|
9
|
+
from typing import overload, TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from tensordict import TensorDictBase
|
|
13
|
+
from tensordict.nn import TensorDictModuleBase
|
|
14
|
+
from torchrl._utils import logger as torchrl_logger
|
|
15
|
+
|
|
16
|
+
from torchrl.data.tensor_specs import TensorSpec
|
|
17
|
+
from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform
|
|
18
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from torchrl.weight_update import WeightSyncScheme
|
|
22
|
+
|
|
23
|
+
__all__ = ["ModuleTransform", "RayModuleTransform"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RayModuleTransform(RayTransform):
|
|
27
|
+
"""Ray-based ModuleTransform for distributed processing.
|
|
28
|
+
|
|
29
|
+
This transform creates a Ray actor that wraps a ModuleTransform,
|
|
30
|
+
allowing module execution in a separate Ray worker process.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
weight_sync_scheme: Optional weight synchronization scheme for updating
|
|
34
|
+
the module's weights from a parent collector. When provided, the scheme
|
|
35
|
+
is initialized on the receiver side (the Ray actor) and can receive
|
|
36
|
+
weight updates via torch.distributed.
|
|
37
|
+
**kwargs: Additional arguments passed to RayTransform and ModuleTransform.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> from torchrl.weight_update import RayModuleTransformScheme
|
|
41
|
+
>>> scheme = RayModuleTransformScheme()
|
|
42
|
+
>>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme)
|
|
43
|
+
>>> # The scheme can then be registered with a collector for weight updates
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, *, weight_sync_scheme=None, **kwargs):
|
|
47
|
+
self._weight_sync_scheme = weight_sync_scheme
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
|
|
50
|
+
# After actor is created, initialize the scheme on the receiver side
|
|
51
|
+
if weight_sync_scheme is not None:
|
|
52
|
+
# Store transform reference in the scheme for sender initialization
|
|
53
|
+
weight_sync_scheme._set_transform(self)
|
|
54
|
+
|
|
55
|
+
weight_sync_scheme.init_on_sender()
|
|
56
|
+
|
|
57
|
+
# Initialize receiver in the actor
|
|
58
|
+
torchrl_logger.debug(
|
|
59
|
+
"Setting up weight sync scheme on sender -- sender will do the remote call"
|
|
60
|
+
)
|
|
61
|
+
weight_sync_scheme.connect()
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def in_keys(self):
|
|
65
|
+
return self._ray.get(self._actor._getattr.remote("in_keys"))
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def out_keys(self):
|
|
69
|
+
return self._ray.get(self._actor._getattr.remote("out_keys"))
|
|
70
|
+
|
|
71
|
+
def _create_actor(self, **kwargs):
|
|
72
|
+
import ray
|
|
73
|
+
|
|
74
|
+
remote = self._ray.remote(ModuleTransform)
|
|
75
|
+
ray_kwargs = {}
|
|
76
|
+
num_gpus = self._num_gpus
|
|
77
|
+
if num_gpus is not None:
|
|
78
|
+
ray_kwargs["num_gpus"] = num_gpus
|
|
79
|
+
num_cpus = self._num_cpus
|
|
80
|
+
if num_cpus is not None:
|
|
81
|
+
ray_kwargs["num_cpus"] = num_cpus
|
|
82
|
+
actor_name = self._actor_name
|
|
83
|
+
if actor_name is not None:
|
|
84
|
+
ray_kwargs["name"] = actor_name
|
|
85
|
+
if ray_kwargs:
|
|
86
|
+
remote = remote.options(**ray_kwargs)
|
|
87
|
+
actor = remote.remote(**kwargs)
|
|
88
|
+
# wait till the actor is ready
|
|
89
|
+
ray.get(actor._ready.remote())
|
|
90
|
+
return actor
|
|
91
|
+
|
|
92
|
+
@overload
|
|
93
|
+
def update_weights(self, state_dict: dict[str, torch.Tensor]) -> None:
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def update_weights(self, params: TensorDictBase) -> None:
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
def update_weights(self, *args, **kwargs) -> None:
|
|
101
|
+
import ray
|
|
102
|
+
|
|
103
|
+
if self._update_weights_method == "tensordict":
|
|
104
|
+
try:
|
|
105
|
+
td = kwargs.get("params", args[0])
|
|
106
|
+
except IndexError:
|
|
107
|
+
raise ValueError("params must be provided")
|
|
108
|
+
return ray.get(self._actor._update_weights_tensordict.remote(params=td))
|
|
109
|
+
elif self._update_weights_method == "state_dict":
|
|
110
|
+
try:
|
|
111
|
+
state_dict = kwargs.get("state_dict", args[0])
|
|
112
|
+
except IndexError:
|
|
113
|
+
raise ValueError("state_dict must be provided")
|
|
114
|
+
return ray.get(
|
|
115
|
+
self._actor._update_weights_state_dict.remote(state_dict=state_dict)
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Invalid update_weights_method: {self._update_weights_method}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ModuleTransform(Transform, metaclass=_RayServiceMetaClass):
|
|
124
|
+
"""A transform that wraps a module.
|
|
125
|
+
|
|
126
|
+
Keyword Args:
|
|
127
|
+
module (TensorDictModuleBase): The module to wrap. Exclusive with `module_factory`. At least one of `module` or `module_factory` must be provided.
|
|
128
|
+
module_factory (Callable[[], TensorDictModuleBase]): The factory to create the module. Exclusive with `module`. At least one of `module` or `module_factory` must be provided.
|
|
129
|
+
no_grad (bool, optional): Whether to use gradient computation. Default is `False`.
|
|
130
|
+
inverse (bool, optional): Whether to use the inverse of the module. Default is `False`.
|
|
131
|
+
device (torch.device, optional): The device to use. Default is `None`.
|
|
132
|
+
use_ray_service (bool, optional): Whether to use Ray service. Default is `False`.
|
|
133
|
+
num_gpus (int, optional): The number of GPUs to use if using Ray. Default is `None`.
|
|
134
|
+
num_cpus (int, optional): The number of CPUs to use if using Ray. Default is `None`.
|
|
135
|
+
actor_name (str, optional): The name of the actor to use. Default is `None`. If an actor name is provided and
|
|
136
|
+
an actor with this name already exists, the existing actor will be used.
|
|
137
|
+
observation_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the observation
|
|
138
|
+
after it has been transformed by the module, or a function that modifies the existing spec.
|
|
139
|
+
Defaults to `None` (observation specs remain unchanged).
|
|
140
|
+
done_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the done
|
|
141
|
+
after it has been transformed by the module, or a function that modifies the existing spec.
|
|
142
|
+
Defaults to `None` (done specs remain unchanged).
|
|
143
|
+
reward_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the reward
|
|
144
|
+
after it has been transformed by the module, or a function that modifies the existing spec.
|
|
145
|
+
Defaults to `None` (reward specs remain unchanged).
|
|
146
|
+
state_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the state
|
|
147
|
+
after it has been transformed by the module, or a function that modifies the existing spec.
|
|
148
|
+
Defaults to `None` (state specs remain unchanged).
|
|
149
|
+
action_spec_transform (TensorSpec or Callable[[TensorSpec], TensorSpec]): either a new spec for the action
|
|
150
|
+
after it has been transformed by the module, or a function that modifies the existing spec.
|
|
151
|
+
Defaults to `None` (action specs remain unchanged).
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
_RayServiceClass = RayModuleTransform
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
*,
|
|
159
|
+
module: TensorDictModuleBase | None = None,
|
|
160
|
+
module_factory: Callable[[], TensorDictModuleBase] | None = None,
|
|
161
|
+
no_grad: bool = False,
|
|
162
|
+
inverse: bool = False,
|
|
163
|
+
device: torch.device | None = None,
|
|
164
|
+
use_ray_service: bool = False, # noqa
|
|
165
|
+
actor_name: str | None = None, # noqa
|
|
166
|
+
num_gpus: int | None = None,
|
|
167
|
+
num_cpus: int | None = None,
|
|
168
|
+
observation_spec_transform: TensorSpec
|
|
169
|
+
| Callable[[TensorSpec], TensorSpec]
|
|
170
|
+
| None = None,
|
|
171
|
+
action_spec_transform: TensorSpec
|
|
172
|
+
| Callable[[TensorSpec], TensorSpec]
|
|
173
|
+
| None = None,
|
|
174
|
+
reward_spec_transform: TensorSpec
|
|
175
|
+
| Callable[[TensorSpec], TensorSpec]
|
|
176
|
+
| None = None,
|
|
177
|
+
done_spec_transform: TensorSpec
|
|
178
|
+
| Callable[[TensorSpec], TensorSpec]
|
|
179
|
+
| None = None,
|
|
180
|
+
state_spec_transform: TensorSpec
|
|
181
|
+
| Callable[[TensorSpec], TensorSpec]
|
|
182
|
+
| None = None,
|
|
183
|
+
):
|
|
184
|
+
super().__init__()
|
|
185
|
+
if module is None and module_factory is None:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"At least one of `module` or `module_factory` must be provided."
|
|
188
|
+
)
|
|
189
|
+
if module is not None and module_factory is not None:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"Only one of `module` or `module_factory` must be provided."
|
|
192
|
+
)
|
|
193
|
+
self.module = module if module is not None else module_factory()
|
|
194
|
+
self.no_grad = no_grad
|
|
195
|
+
self.inverse = inverse
|
|
196
|
+
self.device = device
|
|
197
|
+
self.observation_spec_transform = observation_spec_transform
|
|
198
|
+
self.action_spec_transform = action_spec_transform
|
|
199
|
+
self.reward_spec_transform = reward_spec_transform
|
|
200
|
+
self.done_spec_transform = done_spec_transform
|
|
201
|
+
self.state_spec_transform = state_spec_transform
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def in_keys(self) -> list[str]:
|
|
205
|
+
return self._in_keys()
|
|
206
|
+
|
|
207
|
+
def _in_keys(self):
|
|
208
|
+
return self.module.in_keys if not self.inverse else []
|
|
209
|
+
|
|
210
|
+
@in_keys.setter
|
|
211
|
+
def in_keys(self, value: list[str] | None):
|
|
212
|
+
if value is not None:
|
|
213
|
+
raise RuntimeError(f"in_keys {value} cannot be set for ModuleTransform")
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def out_keys(self) -> list[str]:
|
|
217
|
+
return self._out_keys()
|
|
218
|
+
|
|
219
|
+
def _out_keys(self):
|
|
220
|
+
return self.module.out_keys if not self.inverse else []
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def in_keys_inv(self) -> list[str]:
|
|
224
|
+
return self._in_keys_inv()
|
|
225
|
+
|
|
226
|
+
def _in_keys_inv(self):
|
|
227
|
+
return self.module.out_keys if self.inverse else []
|
|
228
|
+
|
|
229
|
+
@in_keys_inv.setter
|
|
230
|
+
def in_keys_inv(self, value: list[str]):
|
|
231
|
+
if value is not None:
|
|
232
|
+
raise RuntimeError(f"in_keys_inv {value} cannot be set for ModuleTransform")
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def out_keys_inv(self) -> list[str]:
|
|
236
|
+
return self._out_keys_inv()
|
|
237
|
+
|
|
238
|
+
def _out_keys_inv(self):
|
|
239
|
+
return self.module.in_keys if self.inverse else []
|
|
240
|
+
|
|
241
|
+
@out_keys_inv.setter
|
|
242
|
+
def out_keys_inv(self, value: list[str] | None):
|
|
243
|
+
if value is not None:
|
|
244
|
+
raise RuntimeError(
|
|
245
|
+
f"out_keys_inv {value} cannot be set for ModuleTransform"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
@out_keys.setter
|
|
249
|
+
def out_keys(self, value: list[str] | None):
|
|
250
|
+
if value is not None:
|
|
251
|
+
raise RuntimeError(f"out_keys {value} cannot be set for ModuleTransform")
|
|
252
|
+
|
|
253
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
254
|
+
return self._call(tensordict)
|
|
255
|
+
|
|
256
|
+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
257
|
+
if self.inverse:
|
|
258
|
+
return tensordict
|
|
259
|
+
with torch.no_grad() if self.no_grad else nullcontext():
|
|
260
|
+
with (
|
|
261
|
+
tensordict.to(self.device)
|
|
262
|
+
if self.device is not None
|
|
263
|
+
else nullcontext(tensordict)
|
|
264
|
+
) as td:
|
|
265
|
+
return self.module(td)
|
|
266
|
+
|
|
267
|
+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
|
|
268
|
+
if not self.inverse:
|
|
269
|
+
return tensordict
|
|
270
|
+
with torch.no_grad() if self.no_grad else nullcontext():
|
|
271
|
+
with (
|
|
272
|
+
tensordict.to(self.device)
|
|
273
|
+
if self.device is not None
|
|
274
|
+
else nullcontext(tensordict)
|
|
275
|
+
) as td:
|
|
276
|
+
return self.module(td)
|
|
277
|
+
|
|
278
|
+
def _update_weights_tensordict(self, params: TensorDictBase) -> None:
|
|
279
|
+
params.to_module(self.module)
|
|
280
|
+
|
|
281
|
+
def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
|
|
282
|
+
self.module.load_state_dict(state_dict)
|
|
283
|
+
|
|
284
|
+
def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> None:
|
|
285
|
+
"""Initialize weight sync scheme on the receiver side (called in Ray actor).
|
|
286
|
+
|
|
287
|
+
This method is called by RayModuleTransform after the actor is created
|
|
288
|
+
to set up the receiver side of the weight synchronization scheme.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
scheme: The weight sync scheme instance (e.g., RayModuleTransformScheme).
|
|
292
|
+
model_id: Identifier for the model being synchronized.
|
|
293
|
+
"""
|
|
294
|
+
torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}")
|
|
295
|
+
scheme.init_on_receiver(model_id=model_id, context=self)
|
|
296
|
+
torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}")
|
|
297
|
+
scheme.connect()
|
|
298
|
+
self._weight_sync_scheme = scheme
|
|
299
|
+
|
|
300
|
+
def _receive_weights_scheme(self):
|
|
301
|
+
self._weight_sync_scheme.receive()
|
|
302
|
+
|
|
303
|
+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
|
|
304
|
+
if self.observation_spec_transform is not None:
|
|
305
|
+
if isinstance(self.observation_spec_transform, TensorSpec):
|
|
306
|
+
return self.observation_spec_transform
|
|
307
|
+
else:
|
|
308
|
+
return self.observation_spec_transform(observation_spec)
|
|
309
|
+
return observation_spec
|
|
310
|
+
|
|
311
|
+
def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec:
|
|
312
|
+
if self.action_spec_transform is not None:
|
|
313
|
+
if isinstance(self.action_spec_transform, TensorSpec):
|
|
314
|
+
return self.action_spec_transform
|
|
315
|
+
else:
|
|
316
|
+
return self.action_spec_transform(action_spec)
|
|
317
|
+
return action_spec
|
|
318
|
+
|
|
319
|
+
def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
|
|
320
|
+
if self.reward_spec_transform is not None:
|
|
321
|
+
if isinstance(self.reward_spec_transform, TensorSpec):
|
|
322
|
+
return self.reward_spec_transform
|
|
323
|
+
else:
|
|
324
|
+
return self.reward_spec_transform(reward_spec)
|
|
325
|
+
return reward_spec
|
|
326
|
+
|
|
327
|
+
def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec:
|
|
328
|
+
if self.done_spec_transform is not None:
|
|
329
|
+
if isinstance(self.done_spec_transform, TensorSpec):
|
|
330
|
+
return self.done_spec_transform
|
|
331
|
+
else:
|
|
332
|
+
return self.done_spec_transform(done_spec)
|
|
333
|
+
return done_spec
|
|
334
|
+
|
|
335
|
+
def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec:
|
|
336
|
+
if self.state_spec_transform is not None:
|
|
337
|
+
if isinstance(self.state_spec_transform, TensorSpec):
|
|
338
|
+
return self.state_spec_transform
|
|
339
|
+
else:
|
|
340
|
+
return self.state_spec_transform(state_spec)
|
|
341
|
+
return state_spec
|