torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def check_finite(tensor: torch.Tensor):
|
|
12
|
+
"""Raise an error if a tensor has non-finite elements."""
|
|
13
|
+
if not tensor.isfinite().all():
|
|
14
|
+
raise ValueError("Encountered a non-finite tensor.")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _init_first(fun):
|
|
18
|
+
def new_fun(self, *args, **kwargs):
|
|
19
|
+
if not self.initialized:
|
|
20
|
+
self._init()
|
|
21
|
+
return fun(self, *args, **kwargs)
|
|
22
|
+
|
|
23
|
+
return new_fun
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class _set_missing_tolerance:
|
|
27
|
+
"""Context manager to change the transform tolerance to missing values.
|
|
28
|
+
|
|
29
|
+
If a transform has a missing_tolerance of True, it will not raise an error if a key is missing during reset.
|
|
30
|
+
|
|
31
|
+
This is implemented via :meth:`~torchrl.envs.transforms.Transform.set_missing_tolerance`.
|
|
32
|
+
|
|
33
|
+
The way this is handled is that, if `_reset` calls the default `_call` method, it will not raise an error if an input key is missing.
|
|
34
|
+
|
|
35
|
+
For custom `_reset` methods, you should implement this yourself:
|
|
36
|
+
|
|
37
|
+
Exmples:
|
|
38
|
+
>>> def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
39
|
+
... with _set_missing_tolerance(self, True):
|
|
40
|
+
... tensordict_reset = self.foo(tensordict, tensordict_reset)
|
|
41
|
+
... return tensordict_reset
|
|
42
|
+
>>> def foo(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
|
43
|
+
... if self.input_keys[0] not in tensordict_reset and self.missing_tolerance:
|
|
44
|
+
... return tensordict_reset
|
|
45
|
+
... else:
|
|
46
|
+
... # your code here
|
|
47
|
+
|
|
48
|
+
Because `missing_tolerance` will be turned off during calls to `_step`, you can be sure that an appropriate KeyError will be raised
|
|
49
|
+
if the input key is missing at that time.
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, transform, mode):
|
|
54
|
+
self.transform = transform
|
|
55
|
+
self.mode = mode
|
|
56
|
+
|
|
57
|
+
def __enter__(self):
|
|
58
|
+
self.exit_mode = self.transform.missing_tolerance
|
|
59
|
+
if self.mode != self.exit_mode:
|
|
60
|
+
self.transform.set_missing_tolerance(self.mode)
|
|
61
|
+
|
|
62
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
63
|
+
if self.mode != self.exit_mode:
|
|
64
|
+
self.transform.set_missing_tolerance(self.exit_mode)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_reset(reset_key, tensordict):
|
|
68
|
+
_reset = tensordict.get(reset_key, None)
|
|
69
|
+
# reset key must be unraveled already
|
|
70
|
+
parent_td = (
|
|
71
|
+
tensordict.get(reset_key[:-1], None)
|
|
72
|
+
if isinstance(reset_key, tuple)
|
|
73
|
+
else tensordict
|
|
74
|
+
)
|
|
75
|
+
if parent_td is None:
|
|
76
|
+
# we do this just in case the nested td wasn't found
|
|
77
|
+
parent_td = tensordict
|
|
78
|
+
if _reset is None:
|
|
79
|
+
_reset = torch.ones(
|
|
80
|
+
(),
|
|
81
|
+
dtype=torch.bool,
|
|
82
|
+
device=parent_td.device,
|
|
83
|
+
).expand(parent_td.batch_size)
|
|
84
|
+
if _reset.ndim > parent_td.ndim:
|
|
85
|
+
_reset = _reset.flatten(parent_td.ndim, -1).any(-1)
|
|
86
|
+
return _reset
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _stateless_param(param):
|
|
90
|
+
is_param = isinstance(param, nn.Parameter)
|
|
91
|
+
param = param.data.to("meta")
|
|
92
|
+
if is_param:
|
|
93
|
+
return nn.Parameter(param, requires_grad=False)
|
|
94
|
+
return param
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import importlib
|
|
9
|
+
import os
|
|
10
|
+
import subprocess
|
|
11
|
+
from functools import partial
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from tensordict import TensorDictBase
|
|
15
|
+
from torch import nn
|
|
16
|
+
from torchrl._utils import logger as torchrl_logger
|
|
17
|
+
|
|
18
|
+
from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded
|
|
19
|
+
from torchrl.envs.transforms.transforms import (
|
|
20
|
+
CenterCrop,
|
|
21
|
+
Compose,
|
|
22
|
+
ObservationNorm,
|
|
23
|
+
Resize,
|
|
24
|
+
ToTensorImage,
|
|
25
|
+
Transform,
|
|
26
|
+
)
|
|
27
|
+
from torchrl.envs.transforms.utils import _set_missing_tolerance
|
|
28
|
+
|
|
29
|
+
_has_vc = importlib.util.find_spec("vc_models") is not None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class VC1Transform(Transform):
|
|
33
|
+
"""VC1 Transform class.
|
|
34
|
+
|
|
35
|
+
VC1 provides pre-trained ResNet weights aimed at facilitating visual
|
|
36
|
+
embedding for robotic tasks. The models are trained using Ego4d.
|
|
37
|
+
|
|
38
|
+
See the paper:
|
|
39
|
+
VC1: A Universal Visual Representation for Robot Manipulation (Suraj Nair,
|
|
40
|
+
Aravind Rajeswaran, Vikash Kumar, Chelsea Finn, Abhinav Gupta)
|
|
41
|
+
https://arxiv.org/abs/2203.12601
|
|
42
|
+
|
|
43
|
+
The VC1Transform is created in a lazy manner: the object will be initialized
|
|
44
|
+
only when an attribute (a spec or the forward method) will be queried.
|
|
45
|
+
The reason for this is that the :obj:`_init()` method requires some attributes of
|
|
46
|
+
the parent environment (if any) to be accessed: by making the class lazy we
|
|
47
|
+
can ensure that the following code snippet works as expected:
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
>>> transform = VC1Transform("default", in_keys=["pixels"])
|
|
51
|
+
>>> env.append_transform(transform)
|
|
52
|
+
>>> # the forward method will first call _init which will look at env.observation_spec
|
|
53
|
+
>>> env.reset()
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
in_keys (list of NestedKeys): list of input keys. If left empty, the
|
|
57
|
+
"pixels" key is assumed.
|
|
58
|
+
out_keys (list of NestedKeys, optional): list of output keys. If left empty,
|
|
59
|
+
"VC1_vec" is assumed.
|
|
60
|
+
model_name (str): One of ``"large"``, ``"base"`` or any other compatible
|
|
61
|
+
model name (see the `github repo <https://github.com/facebookresearch/eai-vc>`_ for more info). Defaults to ``"default"``
|
|
62
|
+
which provides a small, untrained model for testing.
|
|
63
|
+
del_keys (bool, optional): If ``True`` (default), the input key will be
|
|
64
|
+
discarded from the returned tensordict.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
inplace = False
|
|
68
|
+
IMPORT_ERROR = (
|
|
69
|
+
"Could not load vc_models. You can install it via "
|
|
70
|
+
"VC1Transform.install_vc_models()."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
|
|
74
|
+
if model_name == "default":
|
|
75
|
+
self.make_noload_model()
|
|
76
|
+
model_name = "vc1_vitb_noload"
|
|
77
|
+
self.model_name = model_name
|
|
78
|
+
self.del_keys = del_keys
|
|
79
|
+
|
|
80
|
+
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
|
81
|
+
self._init()
|
|
82
|
+
|
|
83
|
+
def _init(self):
|
|
84
|
+
try:
|
|
85
|
+
from vc_models.models.vit import model_utils
|
|
86
|
+
except ModuleNotFoundError as err:
|
|
87
|
+
raise ModuleNotFoundError(self.IMPORT_ERROR) from err
|
|
88
|
+
|
|
89
|
+
if self.model_name == "base":
|
|
90
|
+
model_name = model_utils.VC1_BASE_NAME
|
|
91
|
+
elif self.model_name == "large":
|
|
92
|
+
model_name = model_utils.VC1_LARGE_NAME
|
|
93
|
+
else:
|
|
94
|
+
model_name = self.model_name
|
|
95
|
+
|
|
96
|
+
model, embd_size, model_transforms, model_info = model_utils.load_model(
|
|
97
|
+
model_name
|
|
98
|
+
)
|
|
99
|
+
self.model = model
|
|
100
|
+
self.embd_size = embd_size
|
|
101
|
+
self.model_transforms = self._map_tv_to_torchrl(model_transforms)
|
|
102
|
+
|
|
103
|
+
def _map_tv_to_torchrl(
|
|
104
|
+
self,
|
|
105
|
+
model_transforms,
|
|
106
|
+
in_keys=None,
|
|
107
|
+
):
|
|
108
|
+
if in_keys is None:
|
|
109
|
+
in_keys = self.in_keys
|
|
110
|
+
from torchvision import transforms
|
|
111
|
+
|
|
112
|
+
if isinstance(model_transforms, transforms.Resize):
|
|
113
|
+
size = model_transforms.size
|
|
114
|
+
if isinstance(size, int):
|
|
115
|
+
size = (size, size)
|
|
116
|
+
return Resize(
|
|
117
|
+
*size,
|
|
118
|
+
in_keys=in_keys,
|
|
119
|
+
)
|
|
120
|
+
elif isinstance(model_transforms, transforms.CenterCrop):
|
|
121
|
+
size = model_transforms.size
|
|
122
|
+
if isinstance(size, int):
|
|
123
|
+
size = (size,)
|
|
124
|
+
return CenterCrop(
|
|
125
|
+
*size,
|
|
126
|
+
in_keys=in_keys,
|
|
127
|
+
)
|
|
128
|
+
elif isinstance(model_transforms, transforms.Normalize):
|
|
129
|
+
return ObservationNorm(
|
|
130
|
+
in_keys=in_keys,
|
|
131
|
+
loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1),
|
|
132
|
+
scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1),
|
|
133
|
+
standard_normal=True,
|
|
134
|
+
)
|
|
135
|
+
elif isinstance(model_transforms, transforms.ToTensor):
|
|
136
|
+
return ToTensorImage(
|
|
137
|
+
in_keys=in_keys,
|
|
138
|
+
)
|
|
139
|
+
elif isinstance(model_transforms, transforms.Compose):
|
|
140
|
+
transform_list = []
|
|
141
|
+
for t in model_transforms.transforms:
|
|
142
|
+
|
|
143
|
+
if isinstance(t, transforms.ToTensor):
|
|
144
|
+
transform_list.insert(0, t)
|
|
145
|
+
else:
|
|
146
|
+
transform_list.append(t)
|
|
147
|
+
if len(transform_list) == 0:
|
|
148
|
+
raise RuntimeError("Did not find any transform.")
|
|
149
|
+
for i, t in enumerate(transform_list):
|
|
150
|
+
if i == 0:
|
|
151
|
+
transform_list[i] = self._map_tv_to_torchrl(t)
|
|
152
|
+
else:
|
|
153
|
+
transform_list[i] = self._map_tv_to_torchrl(t)
|
|
154
|
+
return Compose(*transform_list)
|
|
155
|
+
else:
|
|
156
|
+
raise NotImplementedError(type(model_transforms))
|
|
157
|
+
|
|
158
|
+
def _call(self, next_tensordict):
|
|
159
|
+
if not self.del_keys:
|
|
160
|
+
in_keys = [
|
|
161
|
+
in_key
|
|
162
|
+
for in_key, out_key in zip(self.in_keys, self.out_keys)
|
|
163
|
+
if in_key != out_key
|
|
164
|
+
]
|
|
165
|
+
saved_td = next_tensordict.select(*in_keys)
|
|
166
|
+
with next_tensordict.view(-1) as tensordict_view:
|
|
167
|
+
super()._call(self.model_transforms(tensordict_view))
|
|
168
|
+
if self.del_keys:
|
|
169
|
+
next_tensordict.exclude(*self.in_keys, inplace=True)
|
|
170
|
+
else:
|
|
171
|
+
# reset in_keys
|
|
172
|
+
next_tensordict.update(saved_td)
|
|
173
|
+
return next_tensordict
|
|
174
|
+
|
|
175
|
+
forward = _call
|
|
176
|
+
|
|
177
|
+
def _reset(
|
|
178
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
179
|
+
) -> TensorDictBase:
|
|
180
|
+
# TODO: Check this makes sense
|
|
181
|
+
with _set_missing_tolerance(self, True):
|
|
182
|
+
tensordict_reset = self._call(tensordict_reset)
|
|
183
|
+
return tensordict_reset
|
|
184
|
+
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def _apply_transform(self, obs: torch.Tensor) -> None:
|
|
187
|
+
shape = None
|
|
188
|
+
if obs.ndimension() > 4:
|
|
189
|
+
shape = obs.shape[:-3]
|
|
190
|
+
obs = obs.flatten(0, -4)
|
|
191
|
+
out = self.model(obs)
|
|
192
|
+
if shape is not None:
|
|
193
|
+
out = out.view(*shape, *out.shape[1:])
|
|
194
|
+
return out
|
|
195
|
+
|
|
196
|
+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
|
|
197
|
+
if not isinstance(observation_spec, Composite):
|
|
198
|
+
raise ValueError("VC1Transform can only infer Composite")
|
|
199
|
+
|
|
200
|
+
keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
|
|
201
|
+
device = observation_spec[keys[0]].device
|
|
202
|
+
dim = observation_spec[keys[0]].shape[:-3]
|
|
203
|
+
|
|
204
|
+
observation_spec = observation_spec.clone()
|
|
205
|
+
if self.del_keys:
|
|
206
|
+
for in_key in keys:
|
|
207
|
+
del observation_spec[in_key]
|
|
208
|
+
|
|
209
|
+
for out_key in self.out_keys:
|
|
210
|
+
observation_spec[out_key] = Unbounded(
|
|
211
|
+
shape=torch.Size([*dim, self.embd_size]), device=device
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return observation_spec
|
|
215
|
+
|
|
216
|
+
def to(self, dest: DEVICE_TYPING | torch.dtype):
|
|
217
|
+
if isinstance(dest, torch.dtype):
|
|
218
|
+
self._dtype = dest
|
|
219
|
+
else:
|
|
220
|
+
self._device = dest
|
|
221
|
+
return super().to(dest)
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def device(self):
|
|
225
|
+
return self._device
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def dtype(self):
|
|
229
|
+
return self._dtype
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def install_vc_models(cls, auto_exit=False):
|
|
233
|
+
try:
|
|
234
|
+
from vc_models import models # noqa: F401
|
|
235
|
+
|
|
236
|
+
torchrl_logger.info("vc_models found, no need to install.")
|
|
237
|
+
except ModuleNotFoundError:
|
|
238
|
+
HOME = os.environ.get("HOME")
|
|
239
|
+
vcdir = HOME + "/.cache/torchrl/eai-vc"
|
|
240
|
+
parentdir = os.path.dirname(os.path.abspath(vcdir))
|
|
241
|
+
os.makedirs(parentdir, exist_ok=True)
|
|
242
|
+
try:
|
|
243
|
+
from git import Repo
|
|
244
|
+
except ModuleNotFoundError as err:
|
|
245
|
+
raise ModuleNotFoundError(
|
|
246
|
+
"Could not load git. Make sure that `git` has been installed "
|
|
247
|
+
"in your virtual environment."
|
|
248
|
+
) from err
|
|
249
|
+
Repo.clone_from("https://github.com/facebookresearch/eai-vc.git", vcdir)
|
|
250
|
+
os.chdir(vcdir + "/vc_models")
|
|
251
|
+
subprocess.call(["python", "setup.py", "develop"])
|
|
252
|
+
if not auto_exit:
|
|
253
|
+
input(
|
|
254
|
+
"VC1 has been successfully installed. Exit this python run and "
|
|
255
|
+
"relaunch it again. Press Enter to exit..."
|
|
256
|
+
)
|
|
257
|
+
exit()
|
|
258
|
+
|
|
259
|
+
@classmethod
|
|
260
|
+
def make_noload_model(cls):
|
|
261
|
+
"""Creates an naive model at a custom destination."""
|
|
262
|
+
import vc_models
|
|
263
|
+
|
|
264
|
+
models_filepath = os.path.dirname(os.path.abspath(vc_models.__file__))
|
|
265
|
+
cfg_path = os.path.join(
|
|
266
|
+
models_filepath, "conf", "model", "vc1_vitb_noload.yaml"
|
|
267
|
+
)
|
|
268
|
+
if os.path.exists(cfg_path):
|
|
269
|
+
return
|
|
270
|
+
config = """_target_: vc_models.models.load_model
|
|
271
|
+
model:
|
|
272
|
+
_target_: vc_models.models.vit.vit.load_mae_encoder
|
|
273
|
+
checkpoint_path:
|
|
274
|
+
model:
|
|
275
|
+
_target_: torchrl.envs.transforms.vc1._vit_base_patch16
|
|
276
|
+
img_size: 224
|
|
277
|
+
use_cls: True
|
|
278
|
+
drop_path_rate: 0.0
|
|
279
|
+
transform:
|
|
280
|
+
_target_: vc_models.transforms.vit_transforms
|
|
281
|
+
metadata:
|
|
282
|
+
algo: mae
|
|
283
|
+
model: vit_base_patch16
|
|
284
|
+
data:
|
|
285
|
+
- ego
|
|
286
|
+
- imagenet
|
|
287
|
+
- inav
|
|
288
|
+
comment: 182_epochs
|
|
289
|
+
"""
|
|
290
|
+
with open(cfg_path, "w") as file:
|
|
291
|
+
file.write(config)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _vit_base_patch16(**kwargs):
|
|
295
|
+
from vc_models.models.vit.vit import VisionTransformer
|
|
296
|
+
|
|
297
|
+
model = VisionTransformer(
|
|
298
|
+
patch_size=16,
|
|
299
|
+
embed_dim=16,
|
|
300
|
+
depth=4,
|
|
301
|
+
num_heads=4,
|
|
302
|
+
mlp_ratio=4,
|
|
303
|
+
qkv_bias=True,
|
|
304
|
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
305
|
+
**kwargs,
|
|
306
|
+
)
|
|
307
|
+
return model
|