torchrl 0.11.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/.dylibs/libc++.1.0.dylib +0 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-darwin.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +395 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from tensordict import NestedKey, set_list_to_stack, TensorDict, TensorDictBase
|
|
12
|
+
from tensordict.tensorclass import NonTensorData, NonTensorStack
|
|
13
|
+
|
|
14
|
+
from torchrl.data.map.hash import SipHash
|
|
15
|
+
from torchrl.data.tensor_specs import (
|
|
16
|
+
Categorical as CategoricalSpec,
|
|
17
|
+
Composite,
|
|
18
|
+
NonTensor,
|
|
19
|
+
Unbounded,
|
|
20
|
+
)
|
|
21
|
+
from torchrl.envs import EnvBase
|
|
22
|
+
from torchrl.envs.utils import _StepMDP
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LLMHashingEnv(EnvBase):
|
|
26
|
+
"""A text generation environment that uses a hashing module to identify unique observations.
|
|
27
|
+
|
|
28
|
+
The primary goal of this environment is to identify token chains using a hashing function.
|
|
29
|
+
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
|
|
30
|
+
identifiers, or easily prune repeated token chains in a data structure.
|
|
31
|
+
|
|
32
|
+
.. The following figure gives an overview of this workflow:
|
|
33
|
+
.. .. figure:: /_static/img/rollout-llm.png
|
|
34
|
+
.. :alt: Data collection loop with our LLM environment.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
|
|
38
|
+
|
|
39
|
+
Keyword Args:
|
|
40
|
+
hashing_module (Callable[[torch.Tensor], torch.Tensor], optional):
|
|
41
|
+
A hashing function that takes a tensor as input and returns a hashed tensor.
|
|
42
|
+
Defaults to :class:`~torchrl.data.SipHash` if not provided.
|
|
43
|
+
observation_key (NestedKey, optional): The key for the observation in the TensorDict.
|
|
44
|
+
Defaults to "observation".
|
|
45
|
+
text_output (bool, optional): Whether to include the text output in the observation.
|
|
46
|
+
Defaults to `True`.
|
|
47
|
+
tokenizer (transformers.Tokenizer | None, optional):
|
|
48
|
+
A tokenizer function that converts text to tensors.
|
|
49
|
+
Only used when `text_output` is `True`.
|
|
50
|
+
Must implement the following methods: `decode` and `batch_decode`.
|
|
51
|
+
Defaults to ``None``.
|
|
52
|
+
text_key (NestedKey | None, optional): The key for the text output in the TensorDict.
|
|
53
|
+
Defaults to "text".
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
>>> from tensordict import TensorDict
|
|
57
|
+
>>> from torchrl.envs import LLMHashingEnv
|
|
58
|
+
>>> from transformers import GPT2Tokenizer
|
|
59
|
+
>>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
|
60
|
+
>>> x = tokenizer(["Check out TorchRL!"])["input_ids"]
|
|
61
|
+
>>> env = LLMHashingEnv(tokenizer=tokenizer)
|
|
62
|
+
>>> td = TensorDict(observation=x, batch_size=[1])
|
|
63
|
+
>>> td = env.reset(td)
|
|
64
|
+
>>> print(td)
|
|
65
|
+
TensorDict(
|
|
66
|
+
fields={
|
|
67
|
+
done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
68
|
+
hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
69
|
+
observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
70
|
+
terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
71
|
+
text: NonTensorStack(
|
|
72
|
+
['Check out TorchRL!'],
|
|
73
|
+
batch_size=torch.Size([1]),
|
|
74
|
+
device=None)},
|
|
75
|
+
batch_size=torch.Size([1]),
|
|
76
|
+
device=None,
|
|
77
|
+
is_shared=False)
|
|
78
|
+
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
vocab_size: int | None = None,
|
|
84
|
+
*,
|
|
85
|
+
hashing_module: Callable[[torch.Tensor], torch.Tensor] = None,
|
|
86
|
+
observation_key: NestedKey = "observation",
|
|
87
|
+
text_output: bool = True,
|
|
88
|
+
tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None,
|
|
89
|
+
text_key: NestedKey | None = "text",
|
|
90
|
+
):
|
|
91
|
+
super().__init__()
|
|
92
|
+
if vocab_size is None:
|
|
93
|
+
if tokenizer is None:
|
|
94
|
+
raise TypeError(
|
|
95
|
+
"You must provide a vocab_size integer if tokenizer is `None`."
|
|
96
|
+
)
|
|
97
|
+
vocab_size = tokenizer.vocab_size
|
|
98
|
+
self._batch_locked = False
|
|
99
|
+
if hashing_module is None:
|
|
100
|
+
hashing_module = SipHash()
|
|
101
|
+
|
|
102
|
+
self._hashing_module = hashing_module
|
|
103
|
+
self._tokenizer = tokenizer
|
|
104
|
+
self.observation_key = observation_key
|
|
105
|
+
observation_spec = {
|
|
106
|
+
observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)),
|
|
107
|
+
"hashing": Unbounded(shape=(1,), dtype=torch.int64),
|
|
108
|
+
}
|
|
109
|
+
self.text_output = text_output
|
|
110
|
+
if not text_output:
|
|
111
|
+
text_key = None
|
|
112
|
+
elif text_key is None:
|
|
113
|
+
text_key = "text"
|
|
114
|
+
if text_key is not None:
|
|
115
|
+
observation_spec[text_key] = NonTensor(shape=())
|
|
116
|
+
self.text_key = text_key
|
|
117
|
+
self.observation_spec = Composite(observation_spec)
|
|
118
|
+
self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,)))
|
|
119
|
+
_StepMDP(self)
|
|
120
|
+
|
|
121
|
+
@set_list_to_stack(True)
|
|
122
|
+
def make_tensordict(self, input: str | list[str]) -> TensorDict:
|
|
123
|
+
"""Converts a string or list of strings in a TensorDict with appropriate shape and device."""
|
|
124
|
+
list_len = len(input) if isinstance(input, list) else 0
|
|
125
|
+
tensordict = TensorDict(
|
|
126
|
+
{self.observation_key: self._tokenizer(input)}, device=self.device
|
|
127
|
+
)
|
|
128
|
+
if list_len:
|
|
129
|
+
tensordict.batch_size = [list_len]
|
|
130
|
+
return self.reset(tensordict)
|
|
131
|
+
|
|
132
|
+
def _reset(self, tensordict: TensorDictBase):
|
|
133
|
+
"""Initializes the environment with a given observation.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
tensordict (TensorDictBase): A TensorDict containing the initial observation.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
A TensorDict containing the initial observation, its hash, and other relevant information.
|
|
140
|
+
|
|
141
|
+
"""
|
|
142
|
+
out = tensordict.empty()
|
|
143
|
+
obs = tensordict.get(self.observation_key, None)
|
|
144
|
+
if obs is None:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
f"Resetting the {type(self).__name__} environment requires a prompt."
|
|
147
|
+
)
|
|
148
|
+
if self.text_output:
|
|
149
|
+
if obs.ndim > 1:
|
|
150
|
+
text = self._tokenizer.batch_decode(obs)
|
|
151
|
+
text = NonTensorStack.from_list(text)
|
|
152
|
+
else:
|
|
153
|
+
text = self._tokenizer.decode(obs)
|
|
154
|
+
text = NonTensorData(text)
|
|
155
|
+
out.set(self.text_key, text)
|
|
156
|
+
|
|
157
|
+
if obs.ndim > 1:
|
|
158
|
+
out.set("hashing", self._hashing_module(obs).unsqueeze(-1))
|
|
159
|
+
else:
|
|
160
|
+
out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1))
|
|
161
|
+
|
|
162
|
+
if not self.full_done_spec.is_empty():
|
|
163
|
+
out.update(self.full_done_spec.zero(tensordict.shape))
|
|
164
|
+
else:
|
|
165
|
+
out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool))
|
|
166
|
+
out.set(
|
|
167
|
+
"terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)
|
|
168
|
+
)
|
|
169
|
+
return out
|
|
170
|
+
|
|
171
|
+
def _step(self, tensordict):
|
|
172
|
+
"""Takes an action (i.e., the next token to generate) and returns the next observation and reward.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
tensordict: A TensorDict containing the current observation and action.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
A TensorDict containing the next observation, its hash, and other relevant information.
|
|
179
|
+
"""
|
|
180
|
+
out = tensordict.empty()
|
|
181
|
+
action = tensordict.get("action")
|
|
182
|
+
obs = torch.cat([tensordict.get(self.observation_key), action], -1)
|
|
183
|
+
kwargs = {self.observation_key: obs}
|
|
184
|
+
|
|
185
|
+
catval = torch.cat([tensordict.get("hashing"), action], -1)
|
|
186
|
+
if obs.ndim > 1:
|
|
187
|
+
new_hash = self._hashing_module(catval).unsqueeze(-1)
|
|
188
|
+
else:
|
|
189
|
+
new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1)
|
|
190
|
+
|
|
191
|
+
if self.text_output:
|
|
192
|
+
if obs.ndim > 1:
|
|
193
|
+
text = self._tokenizer.batch_decode(obs)
|
|
194
|
+
text = NonTensorStack.from_list(text)
|
|
195
|
+
else:
|
|
196
|
+
text = self._tokenizer.decode(obs)
|
|
197
|
+
text = NonTensorData(text)
|
|
198
|
+
kwargs[self.text_key] = text
|
|
199
|
+
kwargs.update(
|
|
200
|
+
{
|
|
201
|
+
"hashing": new_hash,
|
|
202
|
+
"done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool),
|
|
203
|
+
"terminated": torch.zeros(
|
|
204
|
+
(*tensordict.batch_size, 1), dtype=torch.bool
|
|
205
|
+
),
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
return out.update(kwargs)
|
|
209
|
+
|
|
210
|
+
def _set_seed(self, *args) -> None:
|
|
211
|
+
"""Sets the seed for the environment's randomness.
|
|
212
|
+
|
|
213
|
+
.. note:: This environment has no randomness, so this method does nothing.
|
|
214
|
+
"""
|
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import TensorDict, TensorDictBase
|
|
11
|
+
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
|
|
12
|
+
from torchrl.envs.common import EnvBase
|
|
13
|
+
from torchrl.envs.utils import make_composite_from_td
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PendulumEnv(EnvBase):
|
|
17
|
+
"""A stateless Pendulum environment.
|
|
18
|
+
|
|
19
|
+
See the Pendulum tutorial for more details: :ref:`tutorial <pendulum_tuto>`.
|
|
20
|
+
|
|
21
|
+
Specs:
|
|
22
|
+
>>> env = PendulumEnv()
|
|
23
|
+
>>> env.specs
|
|
24
|
+
Composite(
|
|
25
|
+
output_spec: Composite(
|
|
26
|
+
full_observation_spec: Composite(
|
|
27
|
+
th: BoundedContinuous(
|
|
28
|
+
shape=torch.Size([]),
|
|
29
|
+
space=ContinuousBox(
|
|
30
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
31
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
32
|
+
device=cpu,
|
|
33
|
+
dtype=torch.float32,
|
|
34
|
+
domain=continuous),
|
|
35
|
+
thdot: BoundedContinuous(
|
|
36
|
+
shape=torch.Size([]),
|
|
37
|
+
space=ContinuousBox(
|
|
38
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
39
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
40
|
+
device=cpu,
|
|
41
|
+
dtype=torch.float32,
|
|
42
|
+
domain=continuous),
|
|
43
|
+
params: Composite(
|
|
44
|
+
max_speed: UnboundedDiscrete(
|
|
45
|
+
shape=torch.Size([]),
|
|
46
|
+
space=ContinuousBox(
|
|
47
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
|
|
48
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
|
|
49
|
+
device=cpu,
|
|
50
|
+
dtype=torch.int64,
|
|
51
|
+
domain=discrete),
|
|
52
|
+
max_torque: UnboundedContinuous(
|
|
53
|
+
shape=torch.Size([]),
|
|
54
|
+
space=ContinuousBox(
|
|
55
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
56
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
57
|
+
device=cpu,
|
|
58
|
+
dtype=torch.float32,
|
|
59
|
+
domain=continuous),
|
|
60
|
+
dt: UnboundedContinuous(
|
|
61
|
+
shape=torch.Size([]),
|
|
62
|
+
space=ContinuousBox(
|
|
63
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
64
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
65
|
+
device=cpu,
|
|
66
|
+
dtype=torch.float32,
|
|
67
|
+
domain=continuous),
|
|
68
|
+
g: UnboundedContinuous(
|
|
69
|
+
shape=torch.Size([]),
|
|
70
|
+
space=ContinuousBox(
|
|
71
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
72
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
73
|
+
device=cpu,
|
|
74
|
+
dtype=torch.float32,
|
|
75
|
+
domain=continuous),
|
|
76
|
+
m: UnboundedContinuous(
|
|
77
|
+
shape=torch.Size([]),
|
|
78
|
+
space=ContinuousBox(
|
|
79
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
80
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
81
|
+
device=cpu,
|
|
82
|
+
dtype=torch.float32,
|
|
83
|
+
domain=continuous),
|
|
84
|
+
l: UnboundedContinuous(
|
|
85
|
+
shape=torch.Size([]),
|
|
86
|
+
space=ContinuousBox(
|
|
87
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
88
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
89
|
+
device=cpu,
|
|
90
|
+
dtype=torch.float32,
|
|
91
|
+
domain=continuous),
|
|
92
|
+
device=None,
|
|
93
|
+
shape=torch.Size([])),
|
|
94
|
+
device=None,
|
|
95
|
+
shape=torch.Size([])),
|
|
96
|
+
full_reward_spec: Composite(
|
|
97
|
+
reward: UnboundedContinuous(
|
|
98
|
+
shape=torch.Size([1]),
|
|
99
|
+
space=ContinuousBox(
|
|
100
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
101
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
102
|
+
device=cpu,
|
|
103
|
+
dtype=torch.float32,
|
|
104
|
+
domain=continuous),
|
|
105
|
+
device=None,
|
|
106
|
+
shape=torch.Size([])),
|
|
107
|
+
full_done_spec: Composite(
|
|
108
|
+
done: Categorical(
|
|
109
|
+
shape=torch.Size([1]),
|
|
110
|
+
space=CategoricalBox(n=2),
|
|
111
|
+
device=cpu,
|
|
112
|
+
dtype=torch.bool,
|
|
113
|
+
domain=discrete),
|
|
114
|
+
terminated: Categorical(
|
|
115
|
+
shape=torch.Size([1]),
|
|
116
|
+
space=CategoricalBox(n=2),
|
|
117
|
+
device=cpu,
|
|
118
|
+
dtype=torch.bool,
|
|
119
|
+
domain=discrete),
|
|
120
|
+
device=None,
|
|
121
|
+
shape=torch.Size([])),
|
|
122
|
+
device=None,
|
|
123
|
+
shape=torch.Size([])),
|
|
124
|
+
input_spec: Composite(
|
|
125
|
+
full_state_spec: Composite(
|
|
126
|
+
th: BoundedContinuous(
|
|
127
|
+
shape=torch.Size([]),
|
|
128
|
+
space=ContinuousBox(
|
|
129
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
130
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
131
|
+
device=cpu,
|
|
132
|
+
dtype=torch.float32,
|
|
133
|
+
domain=continuous),
|
|
134
|
+
thdot: BoundedContinuous(
|
|
135
|
+
shape=torch.Size([]),
|
|
136
|
+
space=ContinuousBox(
|
|
137
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
138
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
139
|
+
device=cpu,
|
|
140
|
+
dtype=torch.float32,
|
|
141
|
+
domain=continuous),
|
|
142
|
+
params: Composite(
|
|
143
|
+
max_speed: UnboundedDiscrete(
|
|
144
|
+
shape=torch.Size([]),
|
|
145
|
+
space=ContinuousBox(
|
|
146
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True),
|
|
147
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, contiguous=True)),
|
|
148
|
+
device=cpu,
|
|
149
|
+
dtype=torch.int64,
|
|
150
|
+
domain=discrete),
|
|
151
|
+
max_torque: UnboundedContinuous(
|
|
152
|
+
shape=torch.Size([]),
|
|
153
|
+
space=ContinuousBox(
|
|
154
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
155
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
156
|
+
device=cpu,
|
|
157
|
+
dtype=torch.float32,
|
|
158
|
+
domain=continuous),
|
|
159
|
+
dt: UnboundedContinuous(
|
|
160
|
+
shape=torch.Size([]),
|
|
161
|
+
space=ContinuousBox(
|
|
162
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
163
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
164
|
+
device=cpu,
|
|
165
|
+
dtype=torch.float32,
|
|
166
|
+
domain=continuous),
|
|
167
|
+
g: UnboundedContinuous(
|
|
168
|
+
shape=torch.Size([]),
|
|
169
|
+
space=ContinuousBox(
|
|
170
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
171
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
172
|
+
device=cpu,
|
|
173
|
+
dtype=torch.float32,
|
|
174
|
+
domain=continuous),
|
|
175
|
+
m: UnboundedContinuous(
|
|
176
|
+
shape=torch.Size([]),
|
|
177
|
+
space=ContinuousBox(
|
|
178
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
179
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
180
|
+
device=cpu,
|
|
181
|
+
dtype=torch.float32,
|
|
182
|
+
domain=continuous),
|
|
183
|
+
l: UnboundedContinuous(
|
|
184
|
+
shape=torch.Size([]),
|
|
185
|
+
space=ContinuousBox(
|
|
186
|
+
low=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
187
|
+
high=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
188
|
+
device=cpu,
|
|
189
|
+
dtype=torch.float32,
|
|
190
|
+
domain=continuous),
|
|
191
|
+
device=None,
|
|
192
|
+
shape=torch.Size([])),
|
|
193
|
+
device=None,
|
|
194
|
+
shape=torch.Size([])),
|
|
195
|
+
full_action_spec: Composite(
|
|
196
|
+
action: BoundedContinuous(
|
|
197
|
+
shape=torch.Size([1]),
|
|
198
|
+
space=ContinuousBox(
|
|
199
|
+
low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
|
|
200
|
+
high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
|
|
201
|
+
device=cpu,
|
|
202
|
+
dtype=torch.float32,
|
|
203
|
+
domain=continuous),
|
|
204
|
+
device=None,
|
|
205
|
+
shape=torch.Size([])),
|
|
206
|
+
device=None,
|
|
207
|
+
shape=torch.Size([])),
|
|
208
|
+
device=None,
|
|
209
|
+
shape=torch.Size([]))
|
|
210
|
+
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
DEFAULT_X = np.pi
|
|
214
|
+
DEFAULT_Y = 1.0
|
|
215
|
+
|
|
216
|
+
metadata = {
|
|
217
|
+
"render_modes": ["human", "rgb_array"],
|
|
218
|
+
"render_fps": 30,
|
|
219
|
+
}
|
|
220
|
+
batch_locked = False
|
|
221
|
+
rng = None
|
|
222
|
+
|
|
223
|
+
def __init__(self, td_params=None, seed=None, device=None):
|
|
224
|
+
if td_params is None:
|
|
225
|
+
td_params = self.gen_params(device=self.device)
|
|
226
|
+
|
|
227
|
+
super().__init__(device=device)
|
|
228
|
+
self._make_spec(td_params)
|
|
229
|
+
if seed is None:
|
|
230
|
+
seed = torch.empty((), dtype=torch.int64).random_(generator=self.rng).item()
|
|
231
|
+
self.set_seed(seed)
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def _step(cls, tensordict):
|
|
235
|
+
th, thdot = tensordict["th"], tensordict["thdot"] # th := theta
|
|
236
|
+
|
|
237
|
+
g_force = tensordict["params", "g"]
|
|
238
|
+
mass = tensordict["params", "m"]
|
|
239
|
+
length = tensordict["params", "l"]
|
|
240
|
+
dt = tensordict["params", "dt"]
|
|
241
|
+
u = tensordict["action"].squeeze(-1)
|
|
242
|
+
u = u.clamp(
|
|
243
|
+
-tensordict["params", "max_torque"], tensordict["params", "max_torque"]
|
|
244
|
+
)
|
|
245
|
+
costs = cls.angle_normalize(th) ** 2 + 0.1 * thdot**2 + 0.001 * (u**2)
|
|
246
|
+
|
|
247
|
+
new_thdot = (
|
|
248
|
+
thdot
|
|
249
|
+
+ (3 * g_force / (2 * length) * th.sin() + 3.0 / (mass * length**2) * u)
|
|
250
|
+
* dt
|
|
251
|
+
)
|
|
252
|
+
new_thdot = new_thdot.clamp(
|
|
253
|
+
-tensordict["params", "max_speed"], tensordict["params", "max_speed"]
|
|
254
|
+
)
|
|
255
|
+
new_th = th + new_thdot * dt
|
|
256
|
+
reward = -costs.view(*tensordict.shape, 1)
|
|
257
|
+
done = torch.zeros_like(reward, dtype=torch.bool)
|
|
258
|
+
out = TensorDict(
|
|
259
|
+
{
|
|
260
|
+
"th": new_th,
|
|
261
|
+
"thdot": new_thdot,
|
|
262
|
+
"params": tensordict["params"],
|
|
263
|
+
"reward": reward,
|
|
264
|
+
"done": done,
|
|
265
|
+
},
|
|
266
|
+
tensordict.shape,
|
|
267
|
+
)
|
|
268
|
+
return out
|
|
269
|
+
|
|
270
|
+
def _reset(self, tensordict):
|
|
271
|
+
batch_size = (
|
|
272
|
+
tensordict.batch_size if tensordict is not None else self.batch_size
|
|
273
|
+
)
|
|
274
|
+
if tensordict is None or "params" not in tensordict:
|
|
275
|
+
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
|
|
276
|
+
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
|
|
277
|
+
# parameters to get started.
|
|
278
|
+
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
|
|
279
|
+
elif "th" in tensordict and "thdot" in tensordict:
|
|
280
|
+
# we can hard-reset the env too
|
|
281
|
+
return tensordict
|
|
282
|
+
out = self._reset_random_data(
|
|
283
|
+
tensordict.shape, batch_size, tensordict["params"]
|
|
284
|
+
)
|
|
285
|
+
return out
|
|
286
|
+
|
|
287
|
+
def _reset_random_data(self, shape, batch_size, params):
|
|
288
|
+
|
|
289
|
+
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
|
|
290
|
+
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
|
|
291
|
+
low_th = -high_th
|
|
292
|
+
low_thdot = -high_thdot
|
|
293
|
+
|
|
294
|
+
# for non batch-locked environments, the input ``tensordict`` shape dictates the number
|
|
295
|
+
# of simulators run simultaneously. In other contexts, the initial
|
|
296
|
+
# random state's shape will depend upon the environment batch-size instead.
|
|
297
|
+
th = (
|
|
298
|
+
torch.rand(shape, generator=self.rng, device=self.device)
|
|
299
|
+
* (high_th - low_th)
|
|
300
|
+
+ low_th
|
|
301
|
+
)
|
|
302
|
+
thdot = (
|
|
303
|
+
torch.rand(shape, generator=self.rng, device=self.device)
|
|
304
|
+
* (high_thdot - low_thdot)
|
|
305
|
+
+ low_thdot
|
|
306
|
+
)
|
|
307
|
+
out = TensorDict(
|
|
308
|
+
{
|
|
309
|
+
"th": th,
|
|
310
|
+
"thdot": thdot,
|
|
311
|
+
"params": params,
|
|
312
|
+
},
|
|
313
|
+
batch_size=batch_size,
|
|
314
|
+
)
|
|
315
|
+
return out
|
|
316
|
+
|
|
317
|
+
def _make_spec(self, td_params):
|
|
318
|
+
# Under the hood, this will populate self.output_spec["observation"]
|
|
319
|
+
self.observation_spec = Composite(
|
|
320
|
+
th=Bounded(
|
|
321
|
+
low=-torch.pi,
|
|
322
|
+
high=torch.pi,
|
|
323
|
+
shape=(),
|
|
324
|
+
dtype=torch.float32,
|
|
325
|
+
),
|
|
326
|
+
thdot=Bounded(
|
|
327
|
+
low=-td_params["params", "max_speed"],
|
|
328
|
+
high=td_params["params", "max_speed"],
|
|
329
|
+
shape=(),
|
|
330
|
+
dtype=torch.float32,
|
|
331
|
+
),
|
|
332
|
+
# we need to add the ``params`` to the observation specs, as we want
|
|
333
|
+
# to pass it at each step during a rollout
|
|
334
|
+
params=make_composite_from_td(
|
|
335
|
+
td_params["params"], unsqueeze_null_shapes=False
|
|
336
|
+
),
|
|
337
|
+
shape=(),
|
|
338
|
+
)
|
|
339
|
+
# since the environment is stateless, we expect the previous output as input.
|
|
340
|
+
# For this, ``EnvBase`` expects some state_spec to be available
|
|
341
|
+
self.state_spec = self.observation_spec.clone()
|
|
342
|
+
# action-spec will be automatically wrapped in input_spec when
|
|
343
|
+
# `self.action_spec = spec` will be called supported
|
|
344
|
+
self.action_spec = Bounded(
|
|
345
|
+
low=-td_params["params", "max_torque"],
|
|
346
|
+
high=td_params["params", "max_torque"],
|
|
347
|
+
shape=(1,),
|
|
348
|
+
dtype=torch.float32,
|
|
349
|
+
)
|
|
350
|
+
self.reward_spec = Unbounded(shape=(*td_params.shape, 1))
|
|
351
|
+
|
|
352
|
+
def make_composite_from_td(td):
|
|
353
|
+
# custom function to convert a ``tensordict`` in a similar spec structure
|
|
354
|
+
# of unbounded values.
|
|
355
|
+
composite = Composite(
|
|
356
|
+
{
|
|
357
|
+
key: make_composite_from_td(tensor)
|
|
358
|
+
if isinstance(tensor, TensorDictBase)
|
|
359
|
+
else Unbounded(
|
|
360
|
+
dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
|
|
361
|
+
)
|
|
362
|
+
for key, tensor in td.items()
|
|
363
|
+
},
|
|
364
|
+
shape=td.shape,
|
|
365
|
+
)
|
|
366
|
+
return composite
|
|
367
|
+
|
|
368
|
+
def _set_seed(self, seed: int) -> None:
|
|
369
|
+
rng = torch.Generator(device=self.device)
|
|
370
|
+
rng.manual_seed(seed)
|
|
371
|
+
self.rng = rng
|
|
372
|
+
|
|
373
|
+
@staticmethod
|
|
374
|
+
def gen_params(g=10.0, batch_size=None, device=None) -> TensorDictBase:
|
|
375
|
+
"""Returns a ``tensordict`` containing the physical parameters such as gravitational force and torque or speed limits."""
|
|
376
|
+
if batch_size is None:
|
|
377
|
+
batch_size = []
|
|
378
|
+
td = TensorDict(
|
|
379
|
+
{
|
|
380
|
+
"params": TensorDict(
|
|
381
|
+
{
|
|
382
|
+
"max_speed": 8,
|
|
383
|
+
"max_torque": 2.0,
|
|
384
|
+
"dt": 0.05,
|
|
385
|
+
"g": g,
|
|
386
|
+
"m": 1.0,
|
|
387
|
+
"l": 1.0,
|
|
388
|
+
},
|
|
389
|
+
[],
|
|
390
|
+
)
|
|
391
|
+
},
|
|
392
|
+
[],
|
|
393
|
+
device=device,
|
|
394
|
+
)
|
|
395
|
+
if batch_size:
|
|
396
|
+
td = td.expand(batch_size).contiguous()
|
|
397
|
+
return td
|
|
398
|
+
|
|
399
|
+
@staticmethod
|
|
400
|
+
def angle_normalize(x):
|
|
401
|
+
return ((x + torch.pi) % (2 * torch.pi)) - torch.pi
|