torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .common import ModelBasedEnvBase
|
|
7
|
+
from .dreamer import DreamerDecoder, DreamerEnv
|
|
8
|
+
|
|
9
|
+
__all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"]
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import abc
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
from tensordict.nn import TensorDictModule
|
|
13
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
14
|
+
from torchrl.envs.common import EnvBase
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelBasedEnvBase(EnvBase):
|
|
18
|
+
"""Basic environment for Model Based RL sota-implementations.
|
|
19
|
+
|
|
20
|
+
Wrapper around the model of the MBRL algorithm.
|
|
21
|
+
It is meant to give an env framework to a world model (including but not limited to observations, reward, done state and safety constraints models).
|
|
22
|
+
and to behave as a classical environment.
|
|
23
|
+
|
|
24
|
+
This is a base class for other environments and it should not be used directly.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
>>> import torch
|
|
28
|
+
>>> from tensordict import TensorDict
|
|
29
|
+
>>> from torchrl.data import Composite, Unbounded
|
|
30
|
+
>>> class MyMBEnv(ModelBasedEnvBase):
|
|
31
|
+
... def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
|
|
32
|
+
... super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
|
|
33
|
+
... self.observation_spec = Composite(
|
|
34
|
+
... hidden_observation=Unbounded((4,))
|
|
35
|
+
... )
|
|
36
|
+
... self.state_spec = Composite(
|
|
37
|
+
... hidden_observation=Unbounded((4,)),
|
|
38
|
+
... )
|
|
39
|
+
... self.action_spec = Unbounded((1,))
|
|
40
|
+
... self.reward_spec = Unbounded((1,))
|
|
41
|
+
...
|
|
42
|
+
... def _reset(self, tensordict: TensorDict) -> TensorDict:
|
|
43
|
+
... tensordict = TensorDict(
|
|
44
|
+
... batch_size=self.batch_size,
|
|
45
|
+
... device=self.device,
|
|
46
|
+
... )
|
|
47
|
+
... tensordict = tensordict.update(self.state_spec.rand())
|
|
48
|
+
... tensordict = tensordict.update(self.observation_spec.rand())
|
|
49
|
+
... return tensordict
|
|
50
|
+
>>> # This environment is used as follows:
|
|
51
|
+
>>> import torch.nn as nn
|
|
52
|
+
>>> from torchrl.modules import MLP, WorldModelWrapper
|
|
53
|
+
>>> world_model = WorldModelWrapper(
|
|
54
|
+
... TensorDictModule(
|
|
55
|
+
... MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
|
|
56
|
+
... in_keys=["hidden_observation", "action"],
|
|
57
|
+
... out_keys=["hidden_observation"],
|
|
58
|
+
... ),
|
|
59
|
+
... TensorDictModule(
|
|
60
|
+
... nn.Linear(4, 1),
|
|
61
|
+
... in_keys=["hidden_observation"],
|
|
62
|
+
... out_keys=["reward"],
|
|
63
|
+
... ),
|
|
64
|
+
... )
|
|
65
|
+
>>> env = MyMBEnv(world_model)
|
|
66
|
+
>>> tensordict = env.rollout(max_steps=10)
|
|
67
|
+
>>> print(tensordict)
|
|
68
|
+
TensorDict(
|
|
69
|
+
fields={
|
|
70
|
+
action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
|
|
71
|
+
done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
|
|
72
|
+
hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
|
|
73
|
+
next: LazyStackedTensorDict(
|
|
74
|
+
fields={
|
|
75
|
+
hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
|
|
76
|
+
batch_size=torch.Size([10]),
|
|
77
|
+
device=cpu,
|
|
78
|
+
is_shared=False),
|
|
79
|
+
reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
|
|
80
|
+
batch_size=torch.Size([10]),
|
|
81
|
+
device=cpu,
|
|
82
|
+
is_shared=False)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
Properties:
|
|
86
|
+
observation_spec (Composite): sampling spec of the observations;
|
|
87
|
+
action_spec (TensorSpec): sampling spec of the actions;
|
|
88
|
+
reward_spec (TensorSpec): sampling spec of the rewards;
|
|
89
|
+
input_spec (Composite): sampling spec of the inputs;
|
|
90
|
+
batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes.
|
|
91
|
+
device (torch.device): device where the env input and output are expected to live
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
world_model (nn.Module): model that generates world states and its corresponding rewards;
|
|
95
|
+
params (List[torch.Tensor], optional): list of parameters of the world model;
|
|
96
|
+
buffers (List[torch.Tensor], optional): list of buffers of the world model;
|
|
97
|
+
device (torch.device, optional): device where the env input and output are expected to live
|
|
98
|
+
dtype (torch.dtype, optional): dtype of the env input and output
|
|
99
|
+
batch_size (torch.Size, optional): number of environments contained in the instance
|
|
100
|
+
run_type_check (bool, optional): whether to run type checks on the step of the env
|
|
101
|
+
|
|
102
|
+
Methods:
|
|
103
|
+
step (TensorDict -> TensorDict): step in the environment
|
|
104
|
+
reset (TensorDict, optional -> TensorDict): reset the environment
|
|
105
|
+
set_seed (int -> int): sets the seed of the environment
|
|
106
|
+
rand_step (TensorDict, optional -> TensorDict): random step given the action spec
|
|
107
|
+
rollout (Callable, ... -> TensorDict): executes a rollout in the environment with the given policy (or random
|
|
108
|
+
steps if no policy is provided)
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
world_model: TensorDictModule,
|
|
115
|
+
params: list[torch.Tensor] | None = None,
|
|
116
|
+
buffers: list[torch.Tensor] | None = None,
|
|
117
|
+
device: DEVICE_TYPING = "cpu",
|
|
118
|
+
batch_size: torch.Size | None = None,
|
|
119
|
+
run_type_checks: bool = False,
|
|
120
|
+
allow_done_after_reset: bool = False,
|
|
121
|
+
):
|
|
122
|
+
super().__init__(
|
|
123
|
+
device=device,
|
|
124
|
+
batch_size=batch_size,
|
|
125
|
+
run_type_checks=run_type_checks,
|
|
126
|
+
allow_done_after_reset=allow_done_after_reset,
|
|
127
|
+
)
|
|
128
|
+
self.world_model = world_model.to(self.device)
|
|
129
|
+
self.world_model_params = params
|
|
130
|
+
self.world_model_buffers = buffers
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def __new__(cls, *args, **kwargs):
|
|
134
|
+
return super().__new__(
|
|
135
|
+
cls, *args, _inplace_update=False, _batch_locked=False, **kwargs
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def set_specs_from_env(self, env: EnvBase):
|
|
139
|
+
"""Sets the specs of the environment from the specs of the given environment."""
|
|
140
|
+
device = self.device
|
|
141
|
+
output_spec = env.output_spec.clone()
|
|
142
|
+
input_spec = env.input_spec.clone()
|
|
143
|
+
if device is not None:
|
|
144
|
+
output_spec = output_spec.to(device)
|
|
145
|
+
input_spec = input_spec.to(device)
|
|
146
|
+
self.__dict__["_output_spec"] = output_spec
|
|
147
|
+
self.__dict__["_input_spec"] = input_spec
|
|
148
|
+
self.empty_cache()
|
|
149
|
+
|
|
150
|
+
def _step(
|
|
151
|
+
self,
|
|
152
|
+
tensordict: TensorDict,
|
|
153
|
+
) -> TensorDict:
|
|
154
|
+
# step method requires to be immutable
|
|
155
|
+
tensordict_out = tensordict.clone(recurse=False)
|
|
156
|
+
# Compute world state
|
|
157
|
+
if self.world_model_params is not None:
|
|
158
|
+
tensordict_out = self.world_model(
|
|
159
|
+
tensordict_out,
|
|
160
|
+
params=self.world_model_params,
|
|
161
|
+
buffers=self.world_model_buffers,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
tensordict_out = self.world_model(tensordict_out)
|
|
165
|
+
# done can be missing, it will be filled by `step`
|
|
166
|
+
# Convert to list for torch.compile compatibility (dynamo can't unpack _CompositeSpecKeysView)
|
|
167
|
+
keys_to_select = (
|
|
168
|
+
list(self.observation_spec.keys())
|
|
169
|
+
+ list(self.full_done_spec.keys())
|
|
170
|
+
+ list(self.full_reward_spec.keys())
|
|
171
|
+
)
|
|
172
|
+
tensordict_out = tensordict_out.select(*keys_to_select, strict=False)
|
|
173
|
+
return tensordict_out
|
|
174
|
+
|
|
175
|
+
@abc.abstractmethod
|
|
176
|
+
def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
|
|
177
|
+
raise NotImplementedError
|
|
178
|
+
|
|
179
|
+
def _set_seed(self, seed: int | None) -> None:
|
|
180
|
+
warnings.warn("Set seed isn't needed for model based environments")
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from tensordict import TensorDict
|
|
9
|
+
from tensordict.nn import TensorDictModule
|
|
10
|
+
from torchrl.data.tensor_specs import Composite
|
|
11
|
+
from torchrl.data.utils import DEVICE_TYPING
|
|
12
|
+
from torchrl.envs.common import EnvBase
|
|
13
|
+
from torchrl.envs.model_based import ModelBasedEnvBase
|
|
14
|
+
from torchrl.envs.transforms.transforms import Transform
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DreamerEnv(ModelBasedEnvBase):
|
|
18
|
+
"""Dreamer simulation environment.
|
|
19
|
+
|
|
20
|
+
This environment is used for imagination rollouts in Dreamer training.
|
|
21
|
+
It never terminates (done is always False) since imagination runs for a
|
|
22
|
+
fixed horizon. The done-checking methods are overridden to avoid CUDA
|
|
23
|
+
synchronization overhead from Python control flow on CUDA tensors.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
world_model: TensorDictModule,
|
|
29
|
+
prior_shape: tuple[int, ...],
|
|
30
|
+
belief_shape: tuple[int, ...],
|
|
31
|
+
obs_decoder: TensorDictModule = None,
|
|
32
|
+
device: DEVICE_TYPING = "cpu",
|
|
33
|
+
batch_size: torch.Size | None = None,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(
|
|
36
|
+
world_model,
|
|
37
|
+
device=device,
|
|
38
|
+
batch_size=batch_size,
|
|
39
|
+
# Skip done validation in reset() — imagination never terminates.
|
|
40
|
+
allow_done_after_reset=True,
|
|
41
|
+
)
|
|
42
|
+
self.obs_decoder = obs_decoder
|
|
43
|
+
self.prior_shape = prior_shape
|
|
44
|
+
self.belief_shape = belief_shape
|
|
45
|
+
|
|
46
|
+
def any_done(self, tensordict) -> bool:
|
|
47
|
+
"""Returns False — imagination rollouts never terminate.
|
|
48
|
+
|
|
49
|
+
Overridden to avoid CUDA sync from `done.any()` in parent class.
|
|
50
|
+
"""
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
def maybe_reset(self, tensordict):
|
|
54
|
+
"""No-op — imagination rollouts don't need partial resets.
|
|
55
|
+
|
|
56
|
+
Overridden to avoid CUDA sync from done checks in parent class.
|
|
57
|
+
"""
|
|
58
|
+
return tensordict
|
|
59
|
+
|
|
60
|
+
def set_specs_from_env(self, env: EnvBase):
|
|
61
|
+
"""Sets the specs of the environment from the specs of the given environment."""
|
|
62
|
+
super().set_specs_from_env(env)
|
|
63
|
+
self.action_spec = self.action_spec.to(self.device)
|
|
64
|
+
self.state_spec = Composite(
|
|
65
|
+
state=self.observation_spec["state"],
|
|
66
|
+
belief=self.observation_spec["belief"],
|
|
67
|
+
shape=env.batch_size,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _reset(self, tensordict=None, **kwargs) -> TensorDict:
|
|
71
|
+
batch_size = tensordict.batch_size if tensordict is not None else []
|
|
72
|
+
device = tensordict.device if tensordict is not None else self.device
|
|
73
|
+
if tensordict is None:
|
|
74
|
+
td = self.state_spec.rand(shape=batch_size)
|
|
75
|
+
# why don't we reuse actions taken at those steps?
|
|
76
|
+
td.set("action", self.action_spec.rand(shape=batch_size))
|
|
77
|
+
td[("next", "reward")] = self.reward_spec.rand(shape=batch_size)
|
|
78
|
+
td.update(self.observation_spec.rand(shape=batch_size))
|
|
79
|
+
if device is not None:
|
|
80
|
+
td = td.to(device, non_blocking=True)
|
|
81
|
+
if torch.cuda.is_available() and device.type == "cpu":
|
|
82
|
+
torch.cuda.synchronize()
|
|
83
|
+
elif torch.backends.mps.is_available():
|
|
84
|
+
torch.mps.synchronize()
|
|
85
|
+
else:
|
|
86
|
+
td = tensordict.clone()
|
|
87
|
+
return td
|
|
88
|
+
|
|
89
|
+
def decode_obs(self, tensordict: TensorDict, compute_latents=False) -> TensorDict:
|
|
90
|
+
if self.obs_decoder is None:
|
|
91
|
+
raise ValueError("No observation decoder provided")
|
|
92
|
+
if compute_latents:
|
|
93
|
+
tensordict = self.world_model(tensordict)
|
|
94
|
+
return self.obs_decoder(tensordict)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DreamerDecoder(Transform):
|
|
98
|
+
"""A transform to record the decoded observations in Dreamer.
|
|
99
|
+
|
|
100
|
+
Examples:
|
|
101
|
+
>>> model_based_env = DreamerEnv(...)
|
|
102
|
+
>>> model_based_env_eval = model_based_env.append_transform(DreamerDecoder())
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def _call(self, next_tensordict):
|
|
106
|
+
return self.parent.base_env.obs_decoder(next_tensordict)
|
|
107
|
+
|
|
108
|
+
def _reset(self, tensordict, tensordict_reset):
|
|
109
|
+
return self._call(tensordict_reset)
|
|
110
|
+
|
|
111
|
+
def transform_observation_spec(self, observation_spec):
|
|
112
|
+
return observation_spec
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from .gym_transforms import EndOfLifeTransform
|
|
7
|
+
from .module import ModuleTransform
|
|
8
|
+
from .r3m import R3MTransform
|
|
9
|
+
from .ray_service import RayTransform
|
|
10
|
+
from .rb_transforms import MultiStepTransform
|
|
11
|
+
from .transforms import (
|
|
12
|
+
ActionDiscretizer,
|
|
13
|
+
ActionMask,
|
|
14
|
+
AutoResetEnv,
|
|
15
|
+
AutoResetTransform,
|
|
16
|
+
BatchSizeTransform,
|
|
17
|
+
BinarizeReward,
|
|
18
|
+
BurnInTransform,
|
|
19
|
+
CatFrames,
|
|
20
|
+
CatTensors,
|
|
21
|
+
CenterCrop,
|
|
22
|
+
ClipTransform,
|
|
23
|
+
Compose,
|
|
24
|
+
ConditionalPolicySwitch,
|
|
25
|
+
ConditionalSkip,
|
|
26
|
+
Crop,
|
|
27
|
+
DeviceCastTransform,
|
|
28
|
+
DiscreteActionProjection,
|
|
29
|
+
DoubleToFloat,
|
|
30
|
+
DTypeCastTransform,
|
|
31
|
+
ExcludeTransform,
|
|
32
|
+
FiniteTensorDictCheck,
|
|
33
|
+
FlattenObservation,
|
|
34
|
+
FrameSkipTransform,
|
|
35
|
+
GrayScale,
|
|
36
|
+
gSDENoise,
|
|
37
|
+
Hash,
|
|
38
|
+
InitTracker,
|
|
39
|
+
LineariseRewards,
|
|
40
|
+
MultiAction,
|
|
41
|
+
NoopResetEnv,
|
|
42
|
+
ObservationNorm,
|
|
43
|
+
ObservationTransform,
|
|
44
|
+
PermuteTransform,
|
|
45
|
+
PinMemoryTransform,
|
|
46
|
+
RandomCropTensorDict,
|
|
47
|
+
RemoveEmptySpecs,
|
|
48
|
+
RenameTransform,
|
|
49
|
+
Resize,
|
|
50
|
+
Reward2GoTransform,
|
|
51
|
+
RewardClipping,
|
|
52
|
+
RewardScaling,
|
|
53
|
+
RewardSum,
|
|
54
|
+
SelectTransform,
|
|
55
|
+
SignTransform,
|
|
56
|
+
SqueezeTransform,
|
|
57
|
+
Stack,
|
|
58
|
+
StepCounter,
|
|
59
|
+
TargetReturn,
|
|
60
|
+
TensorDictPrimer,
|
|
61
|
+
TimeMaxPool,
|
|
62
|
+
Timer,
|
|
63
|
+
Tokenizer,
|
|
64
|
+
ToTensorImage,
|
|
65
|
+
TrajCounter,
|
|
66
|
+
Transform,
|
|
67
|
+
TransformedEnv,
|
|
68
|
+
UnaryTransform,
|
|
69
|
+
UnsqueezeTransform,
|
|
70
|
+
VecGymEnvTransform,
|
|
71
|
+
VecNorm,
|
|
72
|
+
)
|
|
73
|
+
from .vc1 import VC1Transform
|
|
74
|
+
from .vecnorm import VecNormV2
|
|
75
|
+
from .vip import VIPRewardTransform, VIPTransform
|
|
76
|
+
|
|
77
|
+
__all__ = [
|
|
78
|
+
"ActionDiscretizer",
|
|
79
|
+
"ActionMask",
|
|
80
|
+
"AutoResetEnv",
|
|
81
|
+
"AutoResetTransform",
|
|
82
|
+
"BatchSizeTransform",
|
|
83
|
+
"BinarizeReward",
|
|
84
|
+
"BurnInTransform",
|
|
85
|
+
"CatFrames",
|
|
86
|
+
"CatTensors",
|
|
87
|
+
"CenterCrop",
|
|
88
|
+
"ClipTransform",
|
|
89
|
+
"Compose",
|
|
90
|
+
"ConditionalPolicySwitch",
|
|
91
|
+
"ConditionalSkip",
|
|
92
|
+
"Crop",
|
|
93
|
+
"DTypeCastTransform",
|
|
94
|
+
"DeviceCastTransform",
|
|
95
|
+
"DiscreteActionProjection",
|
|
96
|
+
"DoubleToFloat",
|
|
97
|
+
"EndOfLifeTransform",
|
|
98
|
+
"ExcludeTransform",
|
|
99
|
+
"FiniteTensorDictCheck",
|
|
100
|
+
"FlattenObservation",
|
|
101
|
+
"FrameSkipTransform",
|
|
102
|
+
"GrayScale",
|
|
103
|
+
"Hash",
|
|
104
|
+
"InitTracker",
|
|
105
|
+
"LineariseRewards",
|
|
106
|
+
"ModuleTransform",
|
|
107
|
+
"MultiAction",
|
|
108
|
+
"MultiStepTransform",
|
|
109
|
+
"NoopResetEnv",
|
|
110
|
+
"ObservationNorm",
|
|
111
|
+
"ObservationTransform",
|
|
112
|
+
"PermuteTransform",
|
|
113
|
+
"PinMemoryTransform",
|
|
114
|
+
"R3MTransform",
|
|
115
|
+
"RandomCropTensorDict",
|
|
116
|
+
"RayTransform",
|
|
117
|
+
"RemoveEmptySpecs",
|
|
118
|
+
"RenameTransform",
|
|
119
|
+
"Resize",
|
|
120
|
+
"Reward2GoTransform",
|
|
121
|
+
"RewardClipping",
|
|
122
|
+
"RewardScaling",
|
|
123
|
+
"RewardSum",
|
|
124
|
+
"SelectTransform",
|
|
125
|
+
"SignTransform",
|
|
126
|
+
"SqueezeTransform",
|
|
127
|
+
"Stack",
|
|
128
|
+
"StepCounter",
|
|
129
|
+
"TargetReturn",
|
|
130
|
+
"TensorDictPrimer",
|
|
131
|
+
"TimeMaxPool",
|
|
132
|
+
"Timer",
|
|
133
|
+
"ToTensorImage",
|
|
134
|
+
"Tokenizer",
|
|
135
|
+
"TrajCounter",
|
|
136
|
+
"Transform",
|
|
137
|
+
"TransformedEnv",
|
|
138
|
+
"UnaryTransform",
|
|
139
|
+
"UnsqueezeTransform",
|
|
140
|
+
"VC1Transform",
|
|
141
|
+
"VIPRewardTransform",
|
|
142
|
+
"VIPTransform",
|
|
143
|
+
"VecGymEnvTransform",
|
|
144
|
+
"VecNorm",
|
|
145
|
+
"VecNormV2",
|
|
146
|
+
"gSDENoise",
|
|
147
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# copied from torchvision
|
|
11
|
+
def _get_image_num_channels(img: Tensor) -> int:
|
|
12
|
+
if img.ndim == 2:
|
|
13
|
+
return 1
|
|
14
|
+
elif img.ndim > 2:
|
|
15
|
+
return img.shape[-3]
|
|
16
|
+
|
|
17
|
+
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _assert_channels(img: Tensor, permitted: list[int]) -> None:
|
|
21
|
+
c = _get_image_num_channels(img)
|
|
22
|
+
if c not in permitted:
|
|
23
|
+
raise TypeError(
|
|
24
|
+
f"Input image tensor permitted channel values are {permitted}, but found "
|
|
25
|
+
f"{c} (full shape: {img.shape})"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
|
|
30
|
+
"""Turns an RGB image into grayscale."""
|
|
31
|
+
if img.ndim < 3:
|
|
32
|
+
raise TypeError(
|
|
33
|
+
"Input image tensor should have at least 3 dimensions, but found"
|
|
34
|
+
"{}".format(img.ndim)
|
|
35
|
+
)
|
|
36
|
+
_assert_channels(img, [3])
|
|
37
|
+
|
|
38
|
+
if num_output_channels not in (1, 3):
|
|
39
|
+
raise ValueError("num_output_channels should be either 1 or 3")
|
|
40
|
+
|
|
41
|
+
r, g, b = img.unbind(dim=-3)
|
|
42
|
+
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
|
|
43
|
+
l_img = l_img.unsqueeze(dim=-3)
|
|
44
|
+
|
|
45
|
+
if num_output_channels == 3:
|
|
46
|
+
return l_img.expand(img.shape)
|
|
47
|
+
|
|
48
|
+
return l_img
|