torchrl 0.11.0__cp314-cp314t-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cp314t-win_amd64.pyd +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/LICENSE +21 -0
- torchrl-0.11.0.dist-info/METADATA +1307 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,643 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import json
|
|
9
|
+
import os.path
|
|
10
|
+
import shutil
|
|
11
|
+
import tempfile
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from contextlib import nullcontext
|
|
15
|
+
from dataclasses import asdict
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from tensordict import (
|
|
20
|
+
is_non_tensor,
|
|
21
|
+
is_tensor_collection,
|
|
22
|
+
NonTensorData,
|
|
23
|
+
NonTensorStack,
|
|
24
|
+
PersistentTensorDict,
|
|
25
|
+
set_list_to_stack,
|
|
26
|
+
TensorDict,
|
|
27
|
+
TensorDictBase,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger
|
|
31
|
+
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
|
|
32
|
+
from torchrl.data.datasets.utils import _get_root_dir
|
|
33
|
+
from torchrl.data.replay_buffers.samplers import Sampler
|
|
34
|
+
from torchrl.data.replay_buffers.storages import TensorStorage
|
|
35
|
+
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
36
|
+
from torchrl.data.tensor_specs import Bounded, Categorical, Composite, Unbounded
|
|
37
|
+
from torchrl.envs.utils import _classproperty
|
|
38
|
+
|
|
39
|
+
_has_tqdm = importlib.util.find_spec("tqdm", None) is not None
|
|
40
|
+
_has_minari = importlib.util.find_spec("minari", None) is not None
|
|
41
|
+
|
|
42
|
+
_NAME_MATCH = KeyDependentDefaultDict(lambda key: key)
|
|
43
|
+
_NAME_MATCH["observations"] = "observation"
|
|
44
|
+
_NAME_MATCH["rewards"] = "reward"
|
|
45
|
+
_NAME_MATCH["truncations"] = "truncated"
|
|
46
|
+
_NAME_MATCH["terminations"] = "terminated"
|
|
47
|
+
_NAME_MATCH["actions"] = "action"
|
|
48
|
+
_NAME_MATCH["infos"] = "info"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_DTYPE_DIR = {
|
|
52
|
+
"float16": torch.float16,
|
|
53
|
+
"float32": torch.float32,
|
|
54
|
+
"float64": torch.float64,
|
|
55
|
+
"int64": torch.int64,
|
|
56
|
+
"int32": torch.int32,
|
|
57
|
+
"uint8": torch.uint8,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MinariExperienceReplay(BaseDatasetExperienceReplay):
|
|
62
|
+
"""Minari Experience replay dataset.
|
|
63
|
+
|
|
64
|
+
Learn more about Minari on their website: https://minari.farama.org/
|
|
65
|
+
|
|
66
|
+
The data format follows the :ref:`TED convention <TED-format>`.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
dataset_id (str): The dataset to be downloaded. Must be part of MinariExperienceReplay.available_datasets
|
|
70
|
+
batch_size (int): Batch-size used during sampling. Can be overridden by `data.sample(batch_size)` if
|
|
71
|
+
necessary.
|
|
72
|
+
|
|
73
|
+
Keyword Args:
|
|
74
|
+
root (Path or str, optional): The Minari dataset root directory.
|
|
75
|
+
The actual dataset memory-mapped files will be saved under
|
|
76
|
+
`<root>/<dataset_id>`. If none is provided, it defaults to
|
|
77
|
+
`~/.cache/torchrl/atari`.minari`.
|
|
78
|
+
download (bool or str, optional): Whether the dataset should be downloaded if
|
|
79
|
+
not found. Defaults to ``True``. Download can also be passed as ``"force"``,
|
|
80
|
+
in which case the downloaded data will be overwritten.
|
|
81
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided
|
|
82
|
+
a default RandomSampler() will be used.
|
|
83
|
+
writer (Writer, optional): the writer to be used. If none is provided
|
|
84
|
+
a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
|
|
85
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
86
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
87
|
+
loading from a map-style dataset.
|
|
88
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
89
|
+
samples.
|
|
90
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
91
|
+
using multithreading.
|
|
92
|
+
transform (Transform, optional): Transform to be executed when sample() is called.
|
|
93
|
+
To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
|
|
94
|
+
split_trajs (bool, optional): if ``True``, the trajectories will be split
|
|
95
|
+
along the first dimension and padded to have a matching shape.
|
|
96
|
+
To split the trajectories, the ``"done"`` signal will be used, which
|
|
97
|
+
is recovered via ``done = truncated | terminated``. In other words,
|
|
98
|
+
it is assumed that any ``truncated`` or ``terminated`` signal is
|
|
99
|
+
equivalent to the end of a trajectory.
|
|
100
|
+
Defaults to ``False``.
|
|
101
|
+
load_from_local_minari (bool, optional): if ``True``, the dataset will be loaded directly
|
|
102
|
+
from the local Minari cache (typically located at ``~/.minari/datasets``),
|
|
103
|
+
bypassing any remote download. This is useful when working with custom
|
|
104
|
+
Minari datasets previously generated and stored locally, or when network
|
|
105
|
+
access should be avoided. If the dataset is not found in the expected
|
|
106
|
+
cache directory, a ``FileNotFoundError`` will be raised.
|
|
107
|
+
Defaults to ``False``.
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
Attributes:
|
|
111
|
+
available_datasets: a list of accepted entries to be downloaded.
|
|
112
|
+
|
|
113
|
+
.. note::
|
|
114
|
+
Text data is currenrtly discarded from the wrapped dataset, as there is not
|
|
115
|
+
PyTorch native way of representing text data.
|
|
116
|
+
If this feature is required, please post an issue on TorchRL's GitHub
|
|
117
|
+
repository.
|
|
118
|
+
|
|
119
|
+
Examples:
|
|
120
|
+
>>> from torchrl.data.datasets.minari_data import MinariExperienceReplay
|
|
121
|
+
>>> data = MinariExperienceReplay("door-human-v1", batch_size=32, download="force")
|
|
122
|
+
>>> for sample in data:
|
|
123
|
+
... torchrl_logger.info(sample)
|
|
124
|
+
... break
|
|
125
|
+
TensorDict(
|
|
126
|
+
fields={
|
|
127
|
+
action: Tensor(shape=torch.Size([32, 28]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
128
|
+
index: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
129
|
+
info: TensorDict(
|
|
130
|
+
fields={
|
|
131
|
+
success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
132
|
+
batch_size=torch.Size([32]),
|
|
133
|
+
device=cpu,
|
|
134
|
+
is_shared=False),
|
|
135
|
+
next: TensorDict(
|
|
136
|
+
fields={
|
|
137
|
+
done: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
138
|
+
info: TensorDict(
|
|
139
|
+
fields={
|
|
140
|
+
success: Tensor(shape=torch.Size([32]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
141
|
+
batch_size=torch.Size([32]),
|
|
142
|
+
device=cpu,
|
|
143
|
+
is_shared=False),
|
|
144
|
+
observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
145
|
+
reward: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
146
|
+
state: TensorDict(
|
|
147
|
+
fields={
|
|
148
|
+
door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
149
|
+
qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
150
|
+
qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
151
|
+
batch_size=torch.Size([32]),
|
|
152
|
+
device=cpu,
|
|
153
|
+
is_shared=False),
|
|
154
|
+
terminated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
155
|
+
truncated: Tensor(shape=torch.Size([32, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
156
|
+
batch_size=torch.Size([32]),
|
|
157
|
+
device=cpu,
|
|
158
|
+
is_shared=False),
|
|
159
|
+
observation: Tensor(shape=torch.Size([32, 39]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
160
|
+
state: TensorDict(
|
|
161
|
+
fields={
|
|
162
|
+
door_body_pos: Tensor(shape=torch.Size([32, 3]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
163
|
+
qpos: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
164
|
+
qvel: Tensor(shape=torch.Size([32, 30]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
165
|
+
batch_size=torch.Size([32]),
|
|
166
|
+
device=cpu,
|
|
167
|
+
is_shared=False)},
|
|
168
|
+
batch_size=torch.Size([32]),
|
|
169
|
+
device=cpu,
|
|
170
|
+
is_shared=False)
|
|
171
|
+
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
dataset_id,
|
|
177
|
+
batch_size: int,
|
|
178
|
+
*,
|
|
179
|
+
root: str | Path | None = None,
|
|
180
|
+
download: bool = True,
|
|
181
|
+
sampler: Sampler | None = None,
|
|
182
|
+
writer: Writer | None = None,
|
|
183
|
+
collate_fn: Callable | None = None,
|
|
184
|
+
pin_memory: bool = False,
|
|
185
|
+
prefetch: int | None = None,
|
|
186
|
+
transform: torchrl.envs.Transform | None = None, # noqa-F821
|
|
187
|
+
split_trajs: bool = False,
|
|
188
|
+
load_from_local_minari: bool = False,
|
|
189
|
+
):
|
|
190
|
+
self.dataset_id = dataset_id
|
|
191
|
+
if root is None:
|
|
192
|
+
root = _get_root_dir("minari")
|
|
193
|
+
os.makedirs(root, exist_ok=True)
|
|
194
|
+
self.root = root
|
|
195
|
+
self.split_trajs = split_trajs
|
|
196
|
+
self.download = download
|
|
197
|
+
self.load_from_local_minari = load_from_local_minari
|
|
198
|
+
|
|
199
|
+
if (
|
|
200
|
+
self.download == "force"
|
|
201
|
+
or (self.download and not self._is_downloaded())
|
|
202
|
+
or self.load_from_local_minari
|
|
203
|
+
):
|
|
204
|
+
if self.download == "force":
|
|
205
|
+
try:
|
|
206
|
+
if os.path.exists(self.data_path_root):
|
|
207
|
+
shutil.rmtree(self.data_path_root)
|
|
208
|
+
if self.data_path != self.data_path_root:
|
|
209
|
+
shutil.rmtree(self.data_path)
|
|
210
|
+
except FileNotFoundError:
|
|
211
|
+
pass
|
|
212
|
+
storage = self._download_and_preproc()
|
|
213
|
+
elif self.split_trajs and not os.path.exists(self.data_path):
|
|
214
|
+
storage = self._make_split()
|
|
215
|
+
else:
|
|
216
|
+
storage = self._load()
|
|
217
|
+
storage = TensorStorage(storage)
|
|
218
|
+
|
|
219
|
+
if writer is None:
|
|
220
|
+
writer = ImmutableDatasetWriter()
|
|
221
|
+
|
|
222
|
+
super().__init__(
|
|
223
|
+
storage=storage,
|
|
224
|
+
sampler=sampler,
|
|
225
|
+
writer=writer,
|
|
226
|
+
collate_fn=collate_fn,
|
|
227
|
+
pin_memory=pin_memory,
|
|
228
|
+
prefetch=prefetch,
|
|
229
|
+
batch_size=batch_size,
|
|
230
|
+
transform=transform,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@_classproperty
|
|
234
|
+
def available_datasets(self):
|
|
235
|
+
if not _has_minari:
|
|
236
|
+
raise ImportError("minari library not found.")
|
|
237
|
+
import minari
|
|
238
|
+
|
|
239
|
+
return minari.list_remote_datasets().keys()
|
|
240
|
+
|
|
241
|
+
def _is_downloaded(self):
|
|
242
|
+
return os.path.exists(self.data_path_root)
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def data_path(self) -> Path:
|
|
246
|
+
if self.split_trajs:
|
|
247
|
+
return Path(self.root) / (self.dataset_id + "_split")
|
|
248
|
+
return self.data_path_root
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def data_path_root(self) -> Path:
|
|
252
|
+
return Path(self.root) / self.dataset_id
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def metadata_path(self) -> Path:
|
|
256
|
+
return Path(self.root) / self.dataset_id / "env_metadata.json"
|
|
257
|
+
|
|
258
|
+
def _download_and_preproc(self):
|
|
259
|
+
if not _has_minari:
|
|
260
|
+
raise ImportError("minari library not found.")
|
|
261
|
+
import minari
|
|
262
|
+
|
|
263
|
+
if _has_tqdm:
|
|
264
|
+
from tqdm import tqdm
|
|
265
|
+
|
|
266
|
+
prev_minari_datasets_path_save = prev_minari_datasets_path = os.environ.get(
|
|
267
|
+
"MINARI_DATASETS_PATH"
|
|
268
|
+
)
|
|
269
|
+
try:
|
|
270
|
+
if prev_minari_datasets_path is None:
|
|
271
|
+
prev_minari_datasets_path = os.path.expanduser("~/.minari/datasets")
|
|
272
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
273
|
+
|
|
274
|
+
total_steps = 0
|
|
275
|
+
td_data = TensorDict()
|
|
276
|
+
|
|
277
|
+
if self.load_from_local_minari:
|
|
278
|
+
# Load minari dataset from user's local Minari cache
|
|
279
|
+
|
|
280
|
+
parent_dir = (
|
|
281
|
+
Path(prev_minari_datasets_path) / self.dataset_id / "data"
|
|
282
|
+
)
|
|
283
|
+
h5_path = parent_dir / "main_data.hdf5"
|
|
284
|
+
|
|
285
|
+
if not h5_path.exists():
|
|
286
|
+
raise FileNotFoundError(
|
|
287
|
+
f"{h5_path} does not exist in local Minari cache!"
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
torchrl_logger.info(
|
|
291
|
+
f"loading dataset from local Minari cache at {h5_path}"
|
|
292
|
+
)
|
|
293
|
+
h5_data = PersistentTensorDict.from_h5(h5_path)
|
|
294
|
+
h5_data = h5_data.to_tensordict()
|
|
295
|
+
|
|
296
|
+
else:
|
|
297
|
+
# temporarily change the minari cache path
|
|
298
|
+
prev_minari_datasets_path_save2 = os.environ.get(
|
|
299
|
+
"MINARI_DATASETS_PATH"
|
|
300
|
+
)
|
|
301
|
+
os.environ["MINARI_DATASETS_PATH"] = tmpdir
|
|
302
|
+
try:
|
|
303
|
+
minari.download_dataset(dataset_id=self.dataset_id)
|
|
304
|
+
finally:
|
|
305
|
+
if prev_minari_datasets_path_save2 is not None:
|
|
306
|
+
os.environ[
|
|
307
|
+
"MINARI_DATASETS_PATH"
|
|
308
|
+
] = prev_minari_datasets_path_save2
|
|
309
|
+
|
|
310
|
+
parent_dir = Path(tmpdir) / self.dataset_id / "data"
|
|
311
|
+
|
|
312
|
+
torchrl_logger.info(
|
|
313
|
+
"first read through data to create data structure..."
|
|
314
|
+
)
|
|
315
|
+
h5_data = PersistentTensorDict.from_h5(
|
|
316
|
+
parent_dir / "main_data.hdf5"
|
|
317
|
+
)
|
|
318
|
+
h5_data = h5_data.to_tensordict()
|
|
319
|
+
|
|
320
|
+
# populate the tensordict
|
|
321
|
+
episode_dict = {}
|
|
322
|
+
dataset_has_nontensor = False
|
|
323
|
+
for i, (episode_key, episode) in enumerate(h5_data.items()):
|
|
324
|
+
episode_num = int(episode_key[len("episode_") :])
|
|
325
|
+
episode_len = episode["actions"].shape[0]
|
|
326
|
+
episode_dict[episode_num] = (episode_key, episode_len)
|
|
327
|
+
# Get the total number of steps for the dataset
|
|
328
|
+
total_steps += episode_len
|
|
329
|
+
if i == 0:
|
|
330
|
+
td_data.set("episode", 0)
|
|
331
|
+
seen = set()
|
|
332
|
+
for key, val in episode.items():
|
|
333
|
+
match = _NAME_MATCH[key]
|
|
334
|
+
if match in seen:
|
|
335
|
+
continue
|
|
336
|
+
seen.add(match)
|
|
337
|
+
if key in ("observations", "state", "infos"):
|
|
338
|
+
val = episode[key]
|
|
339
|
+
if is_tensor_collection(val) and any(
|
|
340
|
+
isinstance(
|
|
341
|
+
val.get(k), (NonTensorData, NonTensorStack)
|
|
342
|
+
)
|
|
343
|
+
for k in val.keys()
|
|
344
|
+
):
|
|
345
|
+
non_tensor_probe = val.clone()
|
|
346
|
+
_extract_nontensor_fields(
|
|
347
|
+
non_tensor_probe, recursive=True
|
|
348
|
+
)
|
|
349
|
+
dataset_has_nontensor = True
|
|
350
|
+
if (
|
|
351
|
+
not val.shape
|
|
352
|
+
): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1:
|
|
353
|
+
if val.is_empty():
|
|
354
|
+
continue
|
|
355
|
+
if is_non_tensor(val):
|
|
356
|
+
continue
|
|
357
|
+
val = _patch_info(val)
|
|
358
|
+
td_data.set(("next", match), torch.zeros_like(val[0]))
|
|
359
|
+
td_data.set(match, torch.zeros_like(val[0]))
|
|
360
|
+
elif key not in ("terminations", "truncations", "rewards"):
|
|
361
|
+
td_data.set(match, torch.zeros_like(val[0]))
|
|
362
|
+
else:
|
|
363
|
+
td_data.set(
|
|
364
|
+
("next", match),
|
|
365
|
+
torch.zeros_like(val[0].unsqueeze(-1)),
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# give it the proper size
|
|
369
|
+
td_data["next", "done"] = (
|
|
370
|
+
td_data["next", "truncated"] | td_data["next", "terminated"]
|
|
371
|
+
)
|
|
372
|
+
if "terminated" in td_data.keys():
|
|
373
|
+
td_data["done"] = td_data["truncated"] | td_data["terminated"]
|
|
374
|
+
td_data = td_data.expand(total_steps).contiguous()
|
|
375
|
+
# save to designated location
|
|
376
|
+
torchrl_logger.info(
|
|
377
|
+
f"creating tensordict data in {self.data_path_root}: "
|
|
378
|
+
)
|
|
379
|
+
if dataset_has_nontensor:
|
|
380
|
+
_preallocate_nontensor_fields(
|
|
381
|
+
td_data, episode, total_steps, name_map=_NAME_MATCH
|
|
382
|
+
)
|
|
383
|
+
torchrl_logger.info(f"tensordict structure: {td_data}")
|
|
384
|
+
|
|
385
|
+
torchrl_logger.info(
|
|
386
|
+
f"Reading data from {max(*episode_dict) + 1} episodes"
|
|
387
|
+
)
|
|
388
|
+
index = 0
|
|
389
|
+
with tqdm(total=total_steps) if _has_tqdm else nullcontext() as pbar:
|
|
390
|
+
# iterate over episodes and populate the tensordict
|
|
391
|
+
for episode_num in sorted(episode_dict):
|
|
392
|
+
episode_key, steps = episode_dict[episode_num]
|
|
393
|
+
episode = _patch_nontensor_data_to_stack(
|
|
394
|
+
h5_data.get(episode_key)
|
|
395
|
+
)
|
|
396
|
+
idx = slice(index, (index + steps))
|
|
397
|
+
data_view = td_data[idx]
|
|
398
|
+
data_view.fill_("episode", episode_num)
|
|
399
|
+
for key, val in episode.items():
|
|
400
|
+
match = _NAME_MATCH[key]
|
|
401
|
+
if key in (
|
|
402
|
+
"observations",
|
|
403
|
+
"state",
|
|
404
|
+
"infos",
|
|
405
|
+
):
|
|
406
|
+
if not val.shape or steps != val.shape[0] - 1:
|
|
407
|
+
if val.is_empty():
|
|
408
|
+
continue
|
|
409
|
+
if is_non_tensor(val):
|
|
410
|
+
continue
|
|
411
|
+
val = _patch_info(val)
|
|
412
|
+
if steps != val.shape[0] - 1:
|
|
413
|
+
raise RuntimeError(
|
|
414
|
+
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}."
|
|
415
|
+
)
|
|
416
|
+
val_next = val[1:].clone()
|
|
417
|
+
val_copy = val[:-1].clone()
|
|
418
|
+
|
|
419
|
+
data_view["next", match].copy_(val_next)
|
|
420
|
+
data_view[match].copy_(val_copy)
|
|
421
|
+
|
|
422
|
+
if is_tensor_collection(val_next):
|
|
423
|
+
non_tensors_next = _extract_nontensor_fields(
|
|
424
|
+
val_next
|
|
425
|
+
)
|
|
426
|
+
non_tensors_now = _extract_nontensor_fields(
|
|
427
|
+
val_copy
|
|
428
|
+
)
|
|
429
|
+
data_view["next", match].update_(non_tensors_next)
|
|
430
|
+
data_view[match].update_(non_tensors_now)
|
|
431
|
+
|
|
432
|
+
elif key not in ("terminations", "truncations", "rewards"):
|
|
433
|
+
if steps is None:
|
|
434
|
+
steps = val.shape[0]
|
|
435
|
+
else:
|
|
436
|
+
if steps != val.shape[0]:
|
|
437
|
+
raise RuntimeError(
|
|
438
|
+
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
|
|
439
|
+
)
|
|
440
|
+
data_view[match].copy_(val)
|
|
441
|
+
else:
|
|
442
|
+
if steps is None:
|
|
443
|
+
steps = val.shape[0]
|
|
444
|
+
else:
|
|
445
|
+
if steps != val.shape[0]:
|
|
446
|
+
raise RuntimeError(
|
|
447
|
+
f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0]}."
|
|
448
|
+
)
|
|
449
|
+
data_view[("next", match)].copy_(val.unsqueeze(-1))
|
|
450
|
+
data_view["next", "done"].copy_(
|
|
451
|
+
data_view["next", "terminated"]
|
|
452
|
+
| data_view["next", "truncated"]
|
|
453
|
+
)
|
|
454
|
+
if "done" in data_view.keys():
|
|
455
|
+
data_view["done"].copy_(
|
|
456
|
+
data_view["terminated"] | data_view["truncated"]
|
|
457
|
+
)
|
|
458
|
+
if pbar is not None:
|
|
459
|
+
pbar.update(steps)
|
|
460
|
+
pbar.set_description(
|
|
461
|
+
f"index={index} - episode num {episode_num}"
|
|
462
|
+
)
|
|
463
|
+
index += steps
|
|
464
|
+
|
|
465
|
+
td_data = td_data.memmap_like(self.data_path_root)
|
|
466
|
+
# Add a "done" entry
|
|
467
|
+
if self.split_trajs:
|
|
468
|
+
with td_data.unlock_():
|
|
469
|
+
from torchrl.collectors.utils import split_trajectories
|
|
470
|
+
|
|
471
|
+
td_data = split_trajectories(td_data).memmap_(self.data_path)
|
|
472
|
+
with open(self.metadata_path, "w") as metadata_file:
|
|
473
|
+
dataset = minari.load_dataset(self.dataset_id)
|
|
474
|
+
self.metadata = asdict(dataset.spec)
|
|
475
|
+
self.metadata["observation_space"] = _spec_to_dict(
|
|
476
|
+
self.metadata["observation_space"]
|
|
477
|
+
)
|
|
478
|
+
self.metadata["action_space"] = _spec_to_dict(
|
|
479
|
+
self.metadata["action_space"]
|
|
480
|
+
)
|
|
481
|
+
json.dump(self.metadata, metadata_file)
|
|
482
|
+
self._load_and_proc_metadata()
|
|
483
|
+
return td_data
|
|
484
|
+
finally:
|
|
485
|
+
if prev_minari_datasets_path_save is not None:
|
|
486
|
+
os.environ["MINARI_DATASETS_PATH"] = prev_minari_datasets_path_save
|
|
487
|
+
|
|
488
|
+
def _make_split(self):
|
|
489
|
+
from torchrl.collectors.utils import split_trajectories
|
|
490
|
+
|
|
491
|
+
self._load_and_proc_metadata()
|
|
492
|
+
td_data = TensorDict.load_memmap(self.data_path_root)
|
|
493
|
+
td_data = split_trajectories(td_data).memmap_(self.data_path)
|
|
494
|
+
return td_data
|
|
495
|
+
|
|
496
|
+
def _load(self):
|
|
497
|
+
self._load_and_proc_metadata()
|
|
498
|
+
return TensorDict.load_memmap(self.data_path)
|
|
499
|
+
|
|
500
|
+
def _load_and_proc_metadata(self):
|
|
501
|
+
with open(self.metadata_path) as file:
|
|
502
|
+
self.metadata = json.load(file)
|
|
503
|
+
self.metadata["observation_space"] = _proc_spec(
|
|
504
|
+
self.metadata["observation_space"]
|
|
505
|
+
)
|
|
506
|
+
self.metadata["action_space"] = _proc_spec(self.metadata["action_space"])
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _proc_spec(spec):
|
|
510
|
+
if spec is None:
|
|
511
|
+
return
|
|
512
|
+
if spec["type"] == "Dict":
|
|
513
|
+
return Composite(
|
|
514
|
+
{key: _proc_spec(subspec) for key, subspec in spec["subspaces"].items()}
|
|
515
|
+
)
|
|
516
|
+
elif spec["type"] == "Box":
|
|
517
|
+
if all(item == -float("inf") for item in spec["low"]) and all(
|
|
518
|
+
item == float("inf") for item in spec["high"]
|
|
519
|
+
):
|
|
520
|
+
return Unbounded(spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]])
|
|
521
|
+
return Bounded(
|
|
522
|
+
shape=spec["shape"],
|
|
523
|
+
low=torch.as_tensor(spec["low"]),
|
|
524
|
+
high=torch.as_tensor(spec["high"]),
|
|
525
|
+
dtype=_DTYPE_DIR[spec["dtype"]],
|
|
526
|
+
)
|
|
527
|
+
elif spec["type"] == "Discrete":
|
|
528
|
+
return Categorical(
|
|
529
|
+
spec["n"], shape=spec["shape"], dtype=_DTYPE_DIR[spec["dtype"]]
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
raise NotImplementedError(f"{type(spec)}")
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def _spec_to_dict(spec):
|
|
536
|
+
from torchrl.envs.libs.gym import gym_backend
|
|
537
|
+
|
|
538
|
+
if isinstance(spec, gym_backend("spaces").Dict):
|
|
539
|
+
return {
|
|
540
|
+
"type": "Dict",
|
|
541
|
+
"subspaces": {key: _spec_to_dict(val) for key, val in spec.items()},
|
|
542
|
+
}
|
|
543
|
+
if isinstance(spec, gym_backend("spaces").Box):
|
|
544
|
+
return {
|
|
545
|
+
"type": "Box",
|
|
546
|
+
"low": spec.low.tolist(),
|
|
547
|
+
"high": spec.high.tolist(),
|
|
548
|
+
"dtype": str(spec.dtype),
|
|
549
|
+
"shape": tuple(spec.shape),
|
|
550
|
+
}
|
|
551
|
+
if isinstance(spec, gym_backend("spaces").Discrete):
|
|
552
|
+
return {
|
|
553
|
+
"type": "Discrete",
|
|
554
|
+
"dtype": str(spec.dtype),
|
|
555
|
+
"n": int(spec.n),
|
|
556
|
+
"shape": tuple(spec.shape),
|
|
557
|
+
}
|
|
558
|
+
if isinstance(spec, gym_backend("spaces").Text):
|
|
559
|
+
return
|
|
560
|
+
raise NotImplementedError(f"{type(spec)}, {str(spec)}")
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def _patch_info(info_td):
|
|
564
|
+
# Some info dicts have tensors with one less element than others
|
|
565
|
+
# We explicitly assume that the missing item is in the first position because
|
|
566
|
+
# it wasn't given at reset time.
|
|
567
|
+
# An alternative explanation could be that the last element is missing because
|
|
568
|
+
# deemed useless for training...
|
|
569
|
+
unique_shapes = defaultdict(list)
|
|
570
|
+
for subkey, subval in info_td.items():
|
|
571
|
+
unique_shapes[subval.shape[0]].append(subkey)
|
|
572
|
+
if len(unique_shapes) == 1:
|
|
573
|
+
unique_shapes[subval.shape[0] + 1] = []
|
|
574
|
+
if len(unique_shapes) != 2:
|
|
575
|
+
raise RuntimeError(
|
|
576
|
+
f"Unique shapes in a sub-tensordict can only be of length 2, got shapes {unique_shapes}."
|
|
577
|
+
)
|
|
578
|
+
val_td = info_td.to_tensordict()
|
|
579
|
+
min_shape = min(*unique_shapes) # can only be found at root
|
|
580
|
+
max_shape = min_shape + 1
|
|
581
|
+
val_td_sel = val_td.select(*unique_shapes[min_shape])
|
|
582
|
+
val_td_sel = val_td_sel.apply(
|
|
583
|
+
lambda x: torch.cat([torch.zeros_like(x[:1]), x], 0), batch_size=[min_shape + 1]
|
|
584
|
+
)
|
|
585
|
+
source = val_td.select(*unique_shapes[max_shape])
|
|
586
|
+
# make sure source has no batch size
|
|
587
|
+
source.batch_size = ()
|
|
588
|
+
if not source.is_empty():
|
|
589
|
+
val_td_sel.update(source, update_batch_size=True)
|
|
590
|
+
return val_td_sel
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _patch_nontensor_data_to_stack(tensordict: TensorDictBase):
|
|
594
|
+
"""Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack."""
|
|
595
|
+
for key, val in tensordict.items():
|
|
596
|
+
if isinstance(val, TensorDictBase):
|
|
597
|
+
_patch_nontensor_data_to_stack(val) # in-place recursive
|
|
598
|
+
elif isinstance(val, NonTensorData):
|
|
599
|
+
data_list = list(val.data)
|
|
600
|
+
with set_list_to_stack(True):
|
|
601
|
+
tensordict[key] = data_list
|
|
602
|
+
return tensordict
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def _extract_nontensor_fields(
|
|
606
|
+
tensordict: TensorDictBase, recursive: bool = False
|
|
607
|
+
) -> TensorDict:
|
|
608
|
+
"""Deletes the NonTensor fields from tensordict and returns the deleted tensordict."""
|
|
609
|
+
extracted = {}
|
|
610
|
+
for key in list(tensordict.keys()):
|
|
611
|
+
val = tensordict.get(key)
|
|
612
|
+
if is_non_tensor(val):
|
|
613
|
+
extracted[key] = val
|
|
614
|
+
del tensordict[key]
|
|
615
|
+
elif recursive and is_tensor_collection(val):
|
|
616
|
+
nested = _extract_nontensor_fields(val, recursive=True)
|
|
617
|
+
if len(nested) > 0:
|
|
618
|
+
extracted[key] = nested
|
|
619
|
+
return TensorDict(extracted, batch_size=tensordict.batch_size)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
def _preallocate_nontensor_fields(
|
|
623
|
+
td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict
|
|
624
|
+
):
|
|
625
|
+
"""Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping."""
|
|
626
|
+
with set_list_to_stack(True):
|
|
627
|
+
|
|
628
|
+
def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()):
|
|
629
|
+
for key, val in src_td.items():
|
|
630
|
+
mapped_key = name_map.get(key, key)
|
|
631
|
+
full_dst_key = prefix + (mapped_key,)
|
|
632
|
+
|
|
633
|
+
if is_non_tensor(val):
|
|
634
|
+
dummy_stack = NonTensorStack(
|
|
635
|
+
*[total_steps for _ in range(total_steps)]
|
|
636
|
+
)
|
|
637
|
+
dst_td.set(full_dst_key, dummy_stack)
|
|
638
|
+
dst_td.set(("next",) + full_dst_key, dummy_stack)
|
|
639
|
+
|
|
640
|
+
elif is_tensor_collection(val):
|
|
641
|
+
_recurse(val, dst_td, full_dst_key)
|
|
642
|
+
|
|
643
|
+
_recurse(example, td_data)
|