torchrl 0.11.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Assertions and validation utilities for TorchRL tests."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"check_rollout_consistency_multikey_env",
|
|
15
|
+
"rand_reset",
|
|
16
|
+
"rollout_consistency_assertion",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def rollout_consistency_assertion(
|
|
21
|
+
rollout, *, done_key="done", observation_key="observation", done_strict=False
|
|
22
|
+
):
|
|
23
|
+
"""Test that observations in 'next' match observations in the next root tensordict.
|
|
24
|
+
|
|
25
|
+
Verifies consistency: when done is False the next observation should match,
|
|
26
|
+
and when done is True they should differ (indicating a reset occurred).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
rollout: The rollout tensordict to validate.
|
|
30
|
+
done_key: The key for the done signal.
|
|
31
|
+
observation_key: The key for observations.
|
|
32
|
+
done_strict: If True, raise an error if no done is detected.
|
|
33
|
+
"""
|
|
34
|
+
done = rollout[..., :-1]["next", done_key].squeeze(-1)
|
|
35
|
+
# data resulting from step, when it's not done
|
|
36
|
+
r_not_done = rollout[..., :-1]["next"][~done]
|
|
37
|
+
# data resulting from step, when it's not done, after step_mdp
|
|
38
|
+
r_not_done_tp1 = rollout[:, 1:][~done]
|
|
39
|
+
torch.testing.assert_close(
|
|
40
|
+
r_not_done[observation_key],
|
|
41
|
+
r_not_done_tp1[observation_key],
|
|
42
|
+
msg=f"Key {observation_key} did not match",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if done_strict and not done.any():
|
|
46
|
+
raise RuntimeError("No done detected, test could not complete.")
|
|
47
|
+
if done.any():
|
|
48
|
+
# data resulting from step, when it's done
|
|
49
|
+
r_done = rollout[..., :-1]["next"][done]
|
|
50
|
+
# data resulting from step, when it's done, after step_mdp and reset
|
|
51
|
+
r_done_tp1 = rollout[..., 1:][done]
|
|
52
|
+
# check that at least one obs after reset does not match the version before reset
|
|
53
|
+
assert not torch.isclose(
|
|
54
|
+
r_done[observation_key], r_done_tp1[observation_key]
|
|
55
|
+
).all()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def rand_reset(env):
|
|
59
|
+
"""Generate a tensordict with reset keys that mimic the done spec.
|
|
60
|
+
|
|
61
|
+
Values are drawn at random until at least one reset is present.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
env: The environment to generate reset keys for.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A TensorDict containing the reset signals.
|
|
68
|
+
"""
|
|
69
|
+
full_done_spec = env.full_done_spec
|
|
70
|
+
result = {}
|
|
71
|
+
for reset_key, list_of_done in zip(env.reset_keys, env.done_keys_groups):
|
|
72
|
+
val = full_done_spec[list_of_done[0]].rand()
|
|
73
|
+
while not val.any():
|
|
74
|
+
val = full_done_spec[list_of_done[0]].rand()
|
|
75
|
+
result[reset_key] = val
|
|
76
|
+
# create a data structure that keeps the batch size of the nested specs
|
|
77
|
+
result = (
|
|
78
|
+
full_done_spec.zero().update(result).exclude(*full_done_spec.keys(True, True))
|
|
79
|
+
)
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
|
|
84
|
+
"""Check rollout consistency for environments with multiple observation/action keys.
|
|
85
|
+
|
|
86
|
+
Validates that:
|
|
87
|
+
- Done and reset behavior is correct for root, nested_1, and nested_2
|
|
88
|
+
- Observations update correctly based on actions
|
|
89
|
+
- Rewards are computed correctly
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
td: The rollout tensordict to validate.
|
|
93
|
+
max_steps: The maximum steps before done in the environment.
|
|
94
|
+
"""
|
|
95
|
+
index_batch_size = (0,) * (len(td.batch_size) - 1)
|
|
96
|
+
|
|
97
|
+
# Check done and reset for root
|
|
98
|
+
observation_is_max = td["next", "observation"][..., 0, 0, 0] == max_steps + 1
|
|
99
|
+
next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
|
|
100
|
+
assert (td["next", "done"][observation_is_max]).all()
|
|
101
|
+
assert (~td["next", "done"][~observation_is_max]).all()
|
|
102
|
+
# Obs after done is 0
|
|
103
|
+
assert (td["observation"][index_batch_size][1:][next_is_done] == 0).all()
|
|
104
|
+
# Obs after not done is previous obs
|
|
105
|
+
assert (
|
|
106
|
+
td["observation"][index_batch_size][1:][~next_is_done]
|
|
107
|
+
== td["next", "observation"][index_batch_size][:-1][~next_is_done]
|
|
108
|
+
).all()
|
|
109
|
+
# Check observation and reward update with count action for root
|
|
110
|
+
action_is_count = td["action"].long().argmax(-1).to(torch.bool)
|
|
111
|
+
assert (
|
|
112
|
+
td["next", "observation"][action_is_count]
|
|
113
|
+
== td["observation"][action_is_count] + 1
|
|
114
|
+
).all()
|
|
115
|
+
assert (td["next", "reward"][action_is_count] == 1).all()
|
|
116
|
+
# Check observation and reward do not update with no-count action for root
|
|
117
|
+
assert (
|
|
118
|
+
td["next", "observation"][~action_is_count]
|
|
119
|
+
== td["observation"][~action_is_count]
|
|
120
|
+
).all()
|
|
121
|
+
assert (td["next", "reward"][~action_is_count] == 0).all()
|
|
122
|
+
|
|
123
|
+
# Check done and reset for nested_1
|
|
124
|
+
observation_is_max = td["next", "nested_1", "observation"][..., 0] == max_steps + 1
|
|
125
|
+
# done at the root always prevail
|
|
126
|
+
next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
|
|
127
|
+
assert (td["next", "nested_1", "done"][observation_is_max]).all()
|
|
128
|
+
assert (~td["next", "nested_1", "done"][~observation_is_max]).all()
|
|
129
|
+
# Obs after done is 0
|
|
130
|
+
assert (
|
|
131
|
+
td["nested_1", "observation"][index_batch_size][1:][next_is_done] == 0
|
|
132
|
+
).all()
|
|
133
|
+
# Obs after not done is previous obs
|
|
134
|
+
assert (
|
|
135
|
+
td["nested_1", "observation"][index_batch_size][1:][~next_is_done]
|
|
136
|
+
== td["next", "nested_1", "observation"][index_batch_size][:-1][~next_is_done]
|
|
137
|
+
).all()
|
|
138
|
+
# Check observation and reward update with count action for nested_1
|
|
139
|
+
action_is_count = td["nested_1"]["action"].to(torch.bool)
|
|
140
|
+
assert (
|
|
141
|
+
td["next", "nested_1", "observation"][action_is_count]
|
|
142
|
+
== td["nested_1", "observation"][action_is_count] + 1
|
|
143
|
+
).all()
|
|
144
|
+
assert (td["next", "nested_1", "gift"][action_is_count] == 1).all()
|
|
145
|
+
# Check observation and reward do not update with no-count action for nested_1
|
|
146
|
+
assert (
|
|
147
|
+
td["next", "nested_1", "observation"][~action_is_count]
|
|
148
|
+
== td["nested_1", "observation"][~action_is_count]
|
|
149
|
+
).all()
|
|
150
|
+
assert (td["next", "nested_1", "gift"][~action_is_count] == 0).all()
|
|
151
|
+
|
|
152
|
+
# Check done and reset for nested_2
|
|
153
|
+
observation_is_max = td["next", "nested_2", "observation"][..., 0] == max_steps + 1
|
|
154
|
+
# done at the root always prevail
|
|
155
|
+
next_is_done = td["next", "done"][index_batch_size][:-1].squeeze(-1)
|
|
156
|
+
assert (td["next", "nested_2", "done"][observation_is_max]).all()
|
|
157
|
+
assert (~td["next", "nested_2", "done"][~observation_is_max]).all()
|
|
158
|
+
# Obs after done is 0
|
|
159
|
+
assert (
|
|
160
|
+
td["nested_2", "observation"][index_batch_size][1:][next_is_done] == 0
|
|
161
|
+
).all()
|
|
162
|
+
# Obs after not done is previous obs
|
|
163
|
+
assert (
|
|
164
|
+
td["nested_2", "observation"][index_batch_size][1:][~next_is_done]
|
|
165
|
+
== td["next", "nested_2", "observation"][index_batch_size][:-1][~next_is_done]
|
|
166
|
+
).all()
|
|
167
|
+
# Check observation and reward update with count action for nested_2
|
|
168
|
+
action_is_count = td["nested_2"]["azione"].squeeze(-1).to(torch.bool)
|
|
169
|
+
assert (
|
|
170
|
+
td["next", "nested_2", "observation"][action_is_count]
|
|
171
|
+
== td["nested_2", "observation"][action_is_count] + 1
|
|
172
|
+
).all()
|
|
173
|
+
assert (td["next", "nested_2", "reward"][action_is_count] == 1).all()
|
|
174
|
+
# Check observation and reward do not update with no-count action for nested_2
|
|
175
|
+
assert (
|
|
176
|
+
td["next", "nested_2", "observation"][~action_is_count]
|
|
177
|
+
== td["nested_2", "observation"][~action_is_count]
|
|
178
|
+
).all()
|
|
179
|
+
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import psutil
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"assert_no_new_python_processes",
|
|
12
|
+
"is_python_process",
|
|
13
|
+
"snapshot_python_processes",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def is_python_process(comm: str | None, args: str | None) -> bool:
|
|
18
|
+
"""Check if a process is a python process."""
|
|
19
|
+
if comm is None:
|
|
20
|
+
comm = ""
|
|
21
|
+
comm = comm.lower()
|
|
22
|
+
if comm.startswith(("python", "pypy")):
|
|
23
|
+
return True
|
|
24
|
+
if not args:
|
|
25
|
+
return False
|
|
26
|
+
return "python" in args.lower()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def snapshot_python_processes(
|
|
30
|
+
root: psutil.Process | None = None,
|
|
31
|
+
) -> dict[tuple[int, float], dict[str, Any]]:
|
|
32
|
+
"""Snapshot python processes belonging to the given process tree.
|
|
33
|
+
|
|
34
|
+
Returns a dict keyed by (pid, start_time) -> info.
|
|
35
|
+
"""
|
|
36
|
+
if root is None:
|
|
37
|
+
root = psutil.Process(os.getpid())
|
|
38
|
+
|
|
39
|
+
uid = os.getuid()
|
|
40
|
+
|
|
41
|
+
# Snapshot descendant PIDs first, then query process info via process_iter.
|
|
42
|
+
# This avoids race conditions where a child exits between `children()` and
|
|
43
|
+
# attribute access on a stale Process handle (common with Ray helpers).
|
|
44
|
+
descendant_pids = {root.pid}
|
|
45
|
+
descendant_pids.update(p.pid for p in root.children(recursive=True))
|
|
46
|
+
|
|
47
|
+
out: dict[tuple[int, float], dict[str, Any]] = {}
|
|
48
|
+
for proc in psutil.process_iter(
|
|
49
|
+
attrs=["pid", "name", "cmdline", "create_time", "uids"], ad_value=None
|
|
50
|
+
):
|
|
51
|
+
info = proc.info
|
|
52
|
+
pid = info.get("pid")
|
|
53
|
+
if pid is None or pid not in descendant_pids:
|
|
54
|
+
continue
|
|
55
|
+
uids = info.get("uids")
|
|
56
|
+
if uids is None or uids.real != uid:
|
|
57
|
+
continue
|
|
58
|
+
|
|
59
|
+
name = info.get("name") or ""
|
|
60
|
+
cmdline = info.get("cmdline") or []
|
|
61
|
+
args = " ".join(cmdline) if isinstance(cmdline, (list, tuple)) else str(cmdline)
|
|
62
|
+
if not is_python_process(name, args):
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
start_time = float(info.get("create_time") or 0.0)
|
|
66
|
+
key = (int(pid), start_time)
|
|
67
|
+
out[key] = {
|
|
68
|
+
"pid": int(pid),
|
|
69
|
+
"start_time": start_time,
|
|
70
|
+
"comm": name,
|
|
71
|
+
"args": args,
|
|
72
|
+
}
|
|
73
|
+
return out
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def assert_no_new_python_processes(
|
|
77
|
+
*,
|
|
78
|
+
baseline: dict[tuple[int, float], dict[str, Any]],
|
|
79
|
+
baseline_time: float,
|
|
80
|
+
timeout: float = 20.0,
|
|
81
|
+
ignore_info_fn: Callable[[dict[str, Any]], bool] | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
"""Assert that no python process started after baseline_time remains alive.
|
|
84
|
+
|
|
85
|
+
The check is limited to the current process tree (pytest process + descendants).
|
|
86
|
+
"""
|
|
87
|
+
if ignore_info_fn is None:
|
|
88
|
+
|
|
89
|
+
def ignore_info_fn(_info: dict[str, Any]) -> bool:
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
deadline = time.time() + timeout
|
|
93
|
+
last_new: dict[tuple[int, float], dict[str, Any]] | None = None
|
|
94
|
+
while time.time() < deadline:
|
|
95
|
+
current = snapshot_python_processes()
|
|
96
|
+
new: dict[tuple[int, float], dict[str, Any]] = {}
|
|
97
|
+
for (pid, start_time), info in current.items():
|
|
98
|
+
if pid == os.getpid():
|
|
99
|
+
continue
|
|
100
|
+
if ignore_info_fn(info):
|
|
101
|
+
continue
|
|
102
|
+
# Guard against pid reuse: only consider processes started after the baseline.
|
|
103
|
+
if start_time and start_time < baseline_time - 1.0:
|
|
104
|
+
continue
|
|
105
|
+
if (pid, start_time) in baseline:
|
|
106
|
+
continue
|
|
107
|
+
new[(pid, start_time)] = info
|
|
108
|
+
if not new:
|
|
109
|
+
return
|
|
110
|
+
last_new = new
|
|
111
|
+
time.sleep(0.25)
|
|
112
|
+
|
|
113
|
+
if last_new is None:
|
|
114
|
+
return
|
|
115
|
+
details = "\n".join(
|
|
116
|
+
f"- pid={v['pid']} comm={v.get('comm')} args={v.get('args')}"
|
|
117
|
+
for v in last_new.values()
|
|
118
|
+
)
|
|
119
|
+
raise AssertionError(
|
|
120
|
+
"Leaked python processes detected after collector.shutdown().\n"
|
|
121
|
+
f"Processes still alive:\n{details}"
|
|
122
|
+
)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
"""Environment creation utilities for TorchRL tests."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from torchrl.envs import MultiThreadedEnv, ObservationNorm
|
|
13
|
+
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
|
|
14
|
+
from torchrl.envs.libs.envpool import _has_envpool
|
|
15
|
+
from torchrl.envs.libs.gym import GymEnv
|
|
16
|
+
from torchrl.envs.transforms import (
|
|
17
|
+
Compose,
|
|
18
|
+
RewardClipping,
|
|
19
|
+
ToTensorImage,
|
|
20
|
+
TransformedEnv,
|
|
21
|
+
)
|
|
22
|
+
from torchrl.testing.gym_helpers import HALFCHEETAH_VERSIONED, PONG_VERSIONED
|
|
23
|
+
from torchrl.testing.utils import mp_ctx
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"get_transform_out",
|
|
27
|
+
"make_envs",
|
|
28
|
+
"make_multithreaded_env",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def make_envs(
|
|
33
|
+
env_name,
|
|
34
|
+
frame_skip,
|
|
35
|
+
transformed_in,
|
|
36
|
+
transformed_out,
|
|
37
|
+
N,
|
|
38
|
+
device="cpu",
|
|
39
|
+
kwargs=None,
|
|
40
|
+
local_mp_ctx=mp_ctx,
|
|
41
|
+
):
|
|
42
|
+
"""Create parallel, serial, multithreaded, and single environment instances.
|
|
43
|
+
|
|
44
|
+
This helper creates environments suitable for testing batched environment behavior.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
env_name: The gym environment name.
|
|
48
|
+
frame_skip: Number of frames to skip.
|
|
49
|
+
transformed_in: Whether to apply transforms inside the base env.
|
|
50
|
+
transformed_out: Whether to apply transforms outside the batched env.
|
|
51
|
+
N: Number of environments in the batch.
|
|
52
|
+
device: Device for the environments.
|
|
53
|
+
kwargs: Additional keyword arguments for environment creation.
|
|
54
|
+
local_mp_ctx: Multiprocessing context ('fork' or 'spawn').
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple of (env_parallel, env_serial, env_multithread, env0).
|
|
58
|
+
"""
|
|
59
|
+
torch.manual_seed(0)
|
|
60
|
+
if not transformed_in:
|
|
61
|
+
|
|
62
|
+
def create_env_fn():
|
|
63
|
+
return GymEnv(env_name, frame_skip=frame_skip, device=device)
|
|
64
|
+
|
|
65
|
+
else:
|
|
66
|
+
if env_name == PONG_VERSIONED():
|
|
67
|
+
|
|
68
|
+
def create_env_fn():
|
|
69
|
+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
|
|
70
|
+
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
|
|
71
|
+
return TransformedEnv(
|
|
72
|
+
base_env,
|
|
73
|
+
Compose(*[ToTensorImage(in_keys=in_keys), RewardClipping(0, 0.1)]),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
else:
|
|
77
|
+
|
|
78
|
+
def create_env_fn():
|
|
79
|
+
|
|
80
|
+
base_env = GymEnv(env_name, frame_skip=frame_skip, device=device)
|
|
81
|
+
in_keys = list(base_env.observation_spec.keys(True, True))[:1]
|
|
82
|
+
|
|
83
|
+
return TransformedEnv(
|
|
84
|
+
base_env,
|
|
85
|
+
Compose(
|
|
86
|
+
ObservationNorm(in_keys=in_keys, loc=0.5, scale=1.1),
|
|
87
|
+
RewardClipping(0, 0.1),
|
|
88
|
+
),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
env0 = create_env_fn()
|
|
92
|
+
env_parallel = ParallelEnv(
|
|
93
|
+
N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx
|
|
94
|
+
)
|
|
95
|
+
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
|
|
96
|
+
|
|
97
|
+
for key in env0.observation_spec.keys(True, True):
|
|
98
|
+
obs_key = key
|
|
99
|
+
break
|
|
100
|
+
else:
|
|
101
|
+
obs_key = None
|
|
102
|
+
|
|
103
|
+
if transformed_out:
|
|
104
|
+
t_out = get_transform_out(env_name, transformed_in, obs_key=obs_key)
|
|
105
|
+
|
|
106
|
+
env0 = TransformedEnv(
|
|
107
|
+
env0,
|
|
108
|
+
t_out(),
|
|
109
|
+
)
|
|
110
|
+
env_parallel = TransformedEnv(
|
|
111
|
+
env_parallel,
|
|
112
|
+
t_out(),
|
|
113
|
+
)
|
|
114
|
+
env_serial = TransformedEnv(
|
|
115
|
+
env_serial,
|
|
116
|
+
t_out(),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
t_out = None
|
|
120
|
+
|
|
121
|
+
if _has_envpool:
|
|
122
|
+
env_multithread = make_multithreaded_env(
|
|
123
|
+
env_name,
|
|
124
|
+
frame_skip,
|
|
125
|
+
t_out,
|
|
126
|
+
N,
|
|
127
|
+
device="cpu",
|
|
128
|
+
kwargs=None,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
env_multithread = None
|
|
132
|
+
|
|
133
|
+
return env_parallel, env_serial, env_multithread, env0
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def make_multithreaded_env(
|
|
137
|
+
env_name,
|
|
138
|
+
frame_skip,
|
|
139
|
+
transformed_out,
|
|
140
|
+
N,
|
|
141
|
+
device="cpu",
|
|
142
|
+
kwargs=None,
|
|
143
|
+
):
|
|
144
|
+
"""Create a multithreaded environment using envpool.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
env_name: The gym environment name.
|
|
148
|
+
frame_skip: Number of frames to skip.
|
|
149
|
+
transformed_out: Transform factory to apply, or None.
|
|
150
|
+
N: Number of environments in the batch.
|
|
151
|
+
device: Device for the environment.
|
|
152
|
+
kwargs: Additional keyword arguments (unused, for API compatibility).
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
A MultiThreadedEnv instance, optionally wrapped with transforms.
|
|
156
|
+
"""
|
|
157
|
+
torch.manual_seed(0)
|
|
158
|
+
multithreaded_kwargs = (
|
|
159
|
+
{"frame_skip": frame_skip} if env_name == PONG_VERSIONED() else {}
|
|
160
|
+
)
|
|
161
|
+
env_multithread = MultiThreadedEnv(
|
|
162
|
+
N,
|
|
163
|
+
env_name,
|
|
164
|
+
create_env_kwargs=multithreaded_kwargs,
|
|
165
|
+
device=device,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if transformed_out:
|
|
169
|
+
for key in env_multithread.observation_spec.keys(True, True):
|
|
170
|
+
obs_key = key
|
|
171
|
+
break
|
|
172
|
+
else:
|
|
173
|
+
obs_key = None
|
|
174
|
+
env_multithread = TransformedEnv(
|
|
175
|
+
env_multithread,
|
|
176
|
+
get_transform_out(env_name, transformed_in=False, obs_key=obs_key)(),
|
|
177
|
+
)
|
|
178
|
+
return env_multithread
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_transform_out(env_name, transformed_in, obs_key=None):
|
|
182
|
+
"""Create a transform factory for output transforms based on environment type.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
env_name: The gym environment name.
|
|
186
|
+
transformed_in: Whether transforms were already applied inside.
|
|
187
|
+
obs_key: The observation key to transform.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
A callable that returns a Compose transform.
|
|
191
|
+
"""
|
|
192
|
+
if env_name == PONG_VERSIONED():
|
|
193
|
+
if obs_key is None:
|
|
194
|
+
obs_key = "pixels"
|
|
195
|
+
|
|
196
|
+
def t_out():
|
|
197
|
+
return (
|
|
198
|
+
Compose(*[ToTensorImage(in_keys=[obs_key]), RewardClipping(0, 0.1)])
|
|
199
|
+
if not transformed_in
|
|
200
|
+
else Compose(*[ObservationNorm(in_keys=[obs_key], loc=0, scale=1)])
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
elif env_name == HALFCHEETAH_VERSIONED:
|
|
204
|
+
if obs_key is None:
|
|
205
|
+
obs_key = ("observation", "velocity")
|
|
206
|
+
|
|
207
|
+
def t_out():
|
|
208
|
+
return Compose(
|
|
209
|
+
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
|
|
210
|
+
RewardClipping(0, 0.1),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
if obs_key is None:
|
|
215
|
+
obs_key = "observation"
|
|
216
|
+
|
|
217
|
+
def t_out():
|
|
218
|
+
return (
|
|
219
|
+
Compose(
|
|
220
|
+
ObservationNorm(in_keys=[obs_key], loc=0.5, scale=1.1),
|
|
221
|
+
RewardClipping(0, 0.1),
|
|
222
|
+
)
|
|
223
|
+
if not transformed_in
|
|
224
|
+
else Compose(ObservationNorm(in_keys=[obs_key], loc=1.0, scale=1.0))
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return t_out
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def make_isaac_env(env_name: str = "Isaac-Ant-v0"):
|
|
10
|
+
"""Helper function to create an IsaacLab env."""
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
torch.manual_seed(0)
|
|
14
|
+
import argparse
|
|
15
|
+
|
|
16
|
+
# This code block ensures that the Isaac app is started in headless mode
|
|
17
|
+
from isaaclab.app import AppLauncher
|
|
18
|
+
from torchrl import logger as torchrl_logger
|
|
19
|
+
|
|
20
|
+
parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
|
|
21
|
+
AppLauncher.add_app_launcher_args(parser)
|
|
22
|
+
args_cli, hydra_args = parser.parse_known_args(["--headless"])
|
|
23
|
+
AppLauncher(args_cli)
|
|
24
|
+
|
|
25
|
+
# Imports and env
|
|
26
|
+
import gymnasium as gym
|
|
27
|
+
import isaaclab_tasks # noqa: F401
|
|
28
|
+
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
|
|
29
|
+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
|
|
30
|
+
|
|
31
|
+
torchrl_logger.info("Making IsaacLab env...")
|
|
32
|
+
env = gym.make(env_name, cfg=AntEnvCfg())
|
|
33
|
+
torchrl_logger.info("Wrapping IsaacLab env...")
|
|
34
|
+
env = IsaacLabWrapper(env)
|
|
35
|
+
return env
|