torchrl 0.11.0__cp314-cp314-manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/benchmark_batched_envs.py +104 -0
- benchmarks/conftest.py +91 -0
- benchmarks/ecosystem/gym_env_throughput.py +321 -0
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +231 -0
- benchmarks/requirements.txt +7 -0
- benchmarks/storage/benchmark_sample_latency_over_rpc.py +193 -0
- benchmarks/test_collectors_benchmark.py +240 -0
- benchmarks/test_compressed_storage_benchmark.py +145 -0
- benchmarks/test_envs_benchmark.py +133 -0
- benchmarks/test_llm.py +101 -0
- benchmarks/test_non_tensor_env_benchmark.py +70 -0
- benchmarks/test_objectives_benchmarks.py +1199 -0
- benchmarks/test_replaybuffer_benchmark.py +254 -0
- sota-check/README.md +35 -0
- sota-implementations/README.md +142 -0
- sota-implementations/a2c/README.md +39 -0
- sota-implementations/a2c/a2c_atari.py +291 -0
- sota-implementations/a2c/a2c_mujoco.py +273 -0
- sota-implementations/a2c/utils_atari.py +240 -0
- sota-implementations/a2c/utils_mujoco.py +160 -0
- sota-implementations/bandits/README.md +7 -0
- sota-implementations/bandits/dqn.py +126 -0
- sota-implementations/cql/cql_offline.py +198 -0
- sota-implementations/cql/cql_online.py +249 -0
- sota-implementations/cql/discrete_cql_offline.py +180 -0
- sota-implementations/cql/discrete_cql_online.py +227 -0
- sota-implementations/cql/utils.py +471 -0
- sota-implementations/crossq/crossq.py +271 -0
- sota-implementations/crossq/utils.py +320 -0
- sota-implementations/ddpg/ddpg.py +231 -0
- sota-implementations/ddpg/utils.py +325 -0
- sota-implementations/decision_transformer/dt.py +163 -0
- sota-implementations/decision_transformer/lamb.py +167 -0
- sota-implementations/decision_transformer/online_dt.py +178 -0
- sota-implementations/decision_transformer/utils.py +562 -0
- sota-implementations/discrete_sac/discrete_sac.py +243 -0
- sota-implementations/discrete_sac/utils.py +324 -0
- sota-implementations/dqn/README.md +30 -0
- sota-implementations/dqn/dqn_atari.py +272 -0
- sota-implementations/dqn/dqn_cartpole.py +236 -0
- sota-implementations/dqn/utils_atari.py +132 -0
- sota-implementations/dqn/utils_cartpole.py +90 -0
- sota-implementations/dreamer/README.md +129 -0
- sota-implementations/dreamer/dreamer.py +586 -0
- sota-implementations/dreamer/dreamer_utils.py +1107 -0
- sota-implementations/expert-iteration/README.md +352 -0
- sota-implementations/expert-iteration/ei_utils.py +770 -0
- sota-implementations/expert-iteration/expert-iteration-async.py +512 -0
- sota-implementations/expert-iteration/expert-iteration-sync.py +508 -0
- sota-implementations/expert-iteration/requirements_gsm8k.txt +13 -0
- sota-implementations/expert-iteration/requirements_ifeval.txt +16 -0
- sota-implementations/gail/gail.py +327 -0
- sota-implementations/gail/gail_utils.py +68 -0
- sota-implementations/gail/ppo_utils.py +157 -0
- sota-implementations/grpo/README.md +273 -0
- sota-implementations/grpo/grpo-async.py +437 -0
- sota-implementations/grpo/grpo-sync.py +435 -0
- sota-implementations/grpo/grpo_utils.py +843 -0
- sota-implementations/grpo/requirements_gsm8k.txt +11 -0
- sota-implementations/grpo/requirements_ifeval.txt +16 -0
- sota-implementations/impala/README.md +33 -0
- sota-implementations/impala/impala_multi_node_ray.py +292 -0
- sota-implementations/impala/impala_multi_node_submitit.py +284 -0
- sota-implementations/impala/impala_single_node.py +261 -0
- sota-implementations/impala/utils.py +184 -0
- sota-implementations/iql/discrete_iql.py +230 -0
- sota-implementations/iql/iql_offline.py +164 -0
- sota-implementations/iql/iql_online.py +225 -0
- sota-implementations/iql/utils.py +437 -0
- sota-implementations/multiagent/README.md +74 -0
- sota-implementations/multiagent/iql.py +237 -0
- sota-implementations/multiagent/maddpg_iddpg.py +266 -0
- sota-implementations/multiagent/mappo_ippo.py +267 -0
- sota-implementations/multiagent/qmix_vdn.py +271 -0
- sota-implementations/multiagent/sac.py +337 -0
- sota-implementations/multiagent/utils/__init__.py +4 -0
- sota-implementations/multiagent/utils/logging.py +151 -0
- sota-implementations/multiagent/utils/utils.py +43 -0
- sota-implementations/ppo/README.md +29 -0
- sota-implementations/ppo/ppo_atari.py +305 -0
- sota-implementations/ppo/ppo_mujoco.py +293 -0
- sota-implementations/ppo/utils_atari.py +238 -0
- sota-implementations/ppo/utils_mujoco.py +152 -0
- sota-implementations/ppo_trainer/train.py +21 -0
- sota-implementations/redq/README.md +7 -0
- sota-implementations/redq/redq.py +199 -0
- sota-implementations/redq/utils.py +1060 -0
- sota-implementations/sac/sac-async.py +266 -0
- sota-implementations/sac/sac.py +239 -0
- sota-implementations/sac/utils.py +381 -0
- sota-implementations/sac_trainer/train.py +16 -0
- sota-implementations/td3/td3.py +254 -0
- sota-implementations/td3/utils.py +319 -0
- sota-implementations/td3_bc/td3_bc.py +177 -0
- sota-implementations/td3_bc/utils.py +251 -0
- torchrl/__init__.py +144 -0
- torchrl/_extension.py +74 -0
- torchrl/_torchrl.cpython-314-aarch64-linux-gnu.so +0 -0
- torchrl/_utils.py +1431 -0
- torchrl/collectors/__init__.py +48 -0
- torchrl/collectors/_base.py +1058 -0
- torchrl/collectors/_constants.py +88 -0
- torchrl/collectors/_multi_async.py +324 -0
- torchrl/collectors/_multi_base.py +1805 -0
- torchrl/collectors/_multi_sync.py +464 -0
- torchrl/collectors/_runner.py +581 -0
- torchrl/collectors/_single.py +2009 -0
- torchrl/collectors/_single_async.py +259 -0
- torchrl/collectors/collectors.py +62 -0
- torchrl/collectors/distributed/__init__.py +32 -0
- torchrl/collectors/distributed/default_configs.py +133 -0
- torchrl/collectors/distributed/generic.py +1306 -0
- torchrl/collectors/distributed/ray.py +1092 -0
- torchrl/collectors/distributed/rpc.py +1006 -0
- torchrl/collectors/distributed/sync.py +731 -0
- torchrl/collectors/distributed/utils.py +160 -0
- torchrl/collectors/llm/__init__.py +10 -0
- torchrl/collectors/llm/base.py +494 -0
- torchrl/collectors/llm/ray_collector.py +275 -0
- torchrl/collectors/llm/utils.py +36 -0
- torchrl/collectors/llm/weight_update/__init__.py +10 -0
- torchrl/collectors/llm/weight_update/vllm.py +348 -0
- torchrl/collectors/llm/weight_update/vllm_v2.py +311 -0
- torchrl/collectors/utils.py +433 -0
- torchrl/collectors/weight_update.py +591 -0
- torchrl/csrc/numpy_utils.h +38 -0
- torchrl/csrc/pybind.cpp +27 -0
- torchrl/csrc/segment_tree.h +458 -0
- torchrl/csrc/torch_utils.h +34 -0
- torchrl/csrc/utils.cpp +48 -0
- torchrl/csrc/utils.h +31 -0
- torchrl/data/__init__.py +187 -0
- torchrl/data/datasets/__init__.py +58 -0
- torchrl/data/datasets/atari_dqn.py +878 -0
- torchrl/data/datasets/common.py +281 -0
- torchrl/data/datasets/d4rl.py +489 -0
- torchrl/data/datasets/d4rl_infos.py +187 -0
- torchrl/data/datasets/gen_dgrl.py +375 -0
- torchrl/data/datasets/minari_data.py +643 -0
- torchrl/data/datasets/openml.py +177 -0
- torchrl/data/datasets/openx.py +798 -0
- torchrl/data/datasets/roboset.py +363 -0
- torchrl/data/datasets/utils.py +11 -0
- torchrl/data/datasets/vd4rl.py +432 -0
- torchrl/data/llm/__init__.py +34 -0
- torchrl/data/llm/dataset.py +491 -0
- torchrl/data/llm/history.py +1378 -0
- torchrl/data/llm/prompt.py +198 -0
- torchrl/data/llm/reward.py +225 -0
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/llm/utils.py +543 -0
- torchrl/data/map/__init__.py +21 -0
- torchrl/data/map/hash.py +185 -0
- torchrl/data/map/query.py +204 -0
- torchrl/data/map/tdstorage.py +363 -0
- torchrl/data/map/tree.py +1434 -0
- torchrl/data/map/utils.py +103 -0
- torchrl/data/postprocs/__init__.py +8 -0
- torchrl/data/postprocs/postprocs.py +391 -0
- torchrl/data/replay_buffers/__init__.py +99 -0
- torchrl/data/replay_buffers/checkpointers.py +622 -0
- torchrl/data/replay_buffers/ray_buffer.py +292 -0
- torchrl/data/replay_buffers/replay_buffers.py +2376 -0
- torchrl/data/replay_buffers/samplers.py +2578 -0
- torchrl/data/replay_buffers/scheduler.py +265 -0
- torchrl/data/replay_buffers/storages.py +2412 -0
- torchrl/data/replay_buffers/utils.py +1042 -0
- torchrl/data/replay_buffers/writers.py +781 -0
- torchrl/data/tensor_specs.py +7101 -0
- torchrl/data/utils.py +334 -0
- torchrl/envs/__init__.py +265 -0
- torchrl/envs/async_envs.py +1105 -0
- torchrl/envs/batched_envs.py +3093 -0
- torchrl/envs/common.py +4241 -0
- torchrl/envs/custom/__init__.py +11 -0
- torchrl/envs/custom/chess.py +617 -0
- torchrl/envs/custom/llm.py +214 -0
- torchrl/envs/custom/pendulum.py +401 -0
- torchrl/envs/custom/san_moves.txt +29274 -0
- torchrl/envs/custom/tictactoeenv.py +288 -0
- torchrl/envs/env_creator.py +263 -0
- torchrl/envs/gym_like.py +752 -0
- torchrl/envs/libs/__init__.py +68 -0
- torchrl/envs/libs/_gym_utils.py +326 -0
- torchrl/envs/libs/brax.py +846 -0
- torchrl/envs/libs/dm_control.py +544 -0
- torchrl/envs/libs/envpool.py +447 -0
- torchrl/envs/libs/gym.py +2239 -0
- torchrl/envs/libs/habitat.py +138 -0
- torchrl/envs/libs/isaac_lab.py +87 -0
- torchrl/envs/libs/isaacgym.py +203 -0
- torchrl/envs/libs/jax_utils.py +166 -0
- torchrl/envs/libs/jumanji.py +963 -0
- torchrl/envs/libs/meltingpot.py +599 -0
- torchrl/envs/libs/openml.py +153 -0
- torchrl/envs/libs/openspiel.py +652 -0
- torchrl/envs/libs/pettingzoo.py +1042 -0
- torchrl/envs/libs/procgen.py +351 -0
- torchrl/envs/libs/robohive.py +429 -0
- torchrl/envs/libs/smacv2.py +645 -0
- torchrl/envs/libs/unity_mlagents.py +891 -0
- torchrl/envs/libs/utils.py +147 -0
- torchrl/envs/libs/vmas.py +813 -0
- torchrl/envs/llm/__init__.py +63 -0
- torchrl/envs/llm/chat.py +730 -0
- torchrl/envs/llm/datasets/README.md +4 -0
- torchrl/envs/llm/datasets/__init__.py +17 -0
- torchrl/envs/llm/datasets/gsm8k.py +353 -0
- torchrl/envs/llm/datasets/ifeval.py +274 -0
- torchrl/envs/llm/envs.py +789 -0
- torchrl/envs/llm/libs/README.md +3 -0
- torchrl/envs/llm/libs/__init__.py +8 -0
- torchrl/envs/llm/libs/mlgym.py +869 -0
- torchrl/envs/llm/reward/__init__.py +10 -0
- torchrl/envs/llm/reward/gsm8k.py +324 -0
- torchrl/envs/llm/reward/ifeval/README.md +13 -0
- torchrl/envs/llm/reward/ifeval/__init__.py +10 -0
- torchrl/envs/llm/reward/ifeval/_instructions.py +1667 -0
- torchrl/envs/llm/reward/ifeval/_instructions_main.py +131 -0
- torchrl/envs/llm/reward/ifeval/_instructions_registry.py +100 -0
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1677 -0
- torchrl/envs/llm/reward/ifeval/_scorer.py +454 -0
- torchrl/envs/llm/transforms/__init__.py +55 -0
- torchrl/envs/llm/transforms/browser.py +292 -0
- torchrl/envs/llm/transforms/dataloading.py +859 -0
- torchrl/envs/llm/transforms/format.py +73 -0
- torchrl/envs/llm/transforms/kl.py +1544 -0
- torchrl/envs/llm/transforms/policy_version.py +189 -0
- torchrl/envs/llm/transforms/reason.py +323 -0
- torchrl/envs/llm/transforms/tokenizer.py +321 -0
- torchrl/envs/llm/transforms/tools.py +1955 -0
- torchrl/envs/model_based/__init__.py +9 -0
- torchrl/envs/model_based/common.py +180 -0
- torchrl/envs/model_based/dreamer.py +112 -0
- torchrl/envs/transforms/__init__.py +147 -0
- torchrl/envs/transforms/functional.py +48 -0
- torchrl/envs/transforms/gym_transforms.py +203 -0
- torchrl/envs/transforms/module.py +341 -0
- torchrl/envs/transforms/r3m.py +372 -0
- torchrl/envs/transforms/ray_service.py +663 -0
- torchrl/envs/transforms/rb_transforms.py +214 -0
- torchrl/envs/transforms/transforms.py +11835 -0
- torchrl/envs/transforms/utils.py +94 -0
- torchrl/envs/transforms/vc1.py +307 -0
- torchrl/envs/transforms/vecnorm.py +845 -0
- torchrl/envs/transforms/vip.py +407 -0
- torchrl/envs/utils.py +1718 -0
- torchrl/envs/vec_envs.py +11 -0
- torchrl/modules/__init__.py +206 -0
- torchrl/modules/distributions/__init__.py +73 -0
- torchrl/modules/distributions/continuous.py +830 -0
- torchrl/modules/distributions/discrete.py +908 -0
- torchrl/modules/distributions/truncated_normal.py +187 -0
- torchrl/modules/distributions/utils.py +233 -0
- torchrl/modules/llm/__init__.py +62 -0
- torchrl/modules/llm/backends/__init__.py +65 -0
- torchrl/modules/llm/backends/vllm/__init__.py +94 -0
- torchrl/modules/llm/backends/vllm/_models.py +46 -0
- torchrl/modules/llm/backends/vllm/base.py +72 -0
- torchrl/modules/llm/backends/vllm/vllm_async.py +2075 -0
- torchrl/modules/llm/backends/vllm/vllm_plugin.py +22 -0
- torchrl/modules/llm/backends/vllm/vllm_sync.py +446 -0
- torchrl/modules/llm/backends/vllm/vllm_utils.py +129 -0
- torchrl/modules/llm/policies/__init__.py +28 -0
- torchrl/modules/llm/policies/common.py +1809 -0
- torchrl/modules/llm/policies/transformers_wrapper.py +2756 -0
- torchrl/modules/llm/policies/vllm_wrapper.py +2241 -0
- torchrl/modules/llm/utils.py +23 -0
- torchrl/modules/mcts/__init__.py +21 -0
- torchrl/modules/mcts/scores.py +579 -0
- torchrl/modules/models/__init__.py +86 -0
- torchrl/modules/models/batchrenorm.py +119 -0
- torchrl/modules/models/decision_transformer.py +179 -0
- torchrl/modules/models/exploration.py +731 -0
- torchrl/modules/models/llm.py +156 -0
- torchrl/modules/models/model_based.py +596 -0
- torchrl/modules/models/models.py +1712 -0
- torchrl/modules/models/multiagent.py +1067 -0
- torchrl/modules/models/recipes/impala.py +185 -0
- torchrl/modules/models/utils.py +162 -0
- torchrl/modules/planners/__init__.py +10 -0
- torchrl/modules/planners/cem.py +228 -0
- torchrl/modules/planners/common.py +73 -0
- torchrl/modules/planners/mppi.py +265 -0
- torchrl/modules/tensordict_module/__init__.py +89 -0
- torchrl/modules/tensordict_module/actors.py +2457 -0
- torchrl/modules/tensordict_module/common.py +529 -0
- torchrl/modules/tensordict_module/exploration.py +814 -0
- torchrl/modules/tensordict_module/probabilistic.py +321 -0
- torchrl/modules/tensordict_module/rnn.py +1639 -0
- torchrl/modules/tensordict_module/sequence.py +132 -0
- torchrl/modules/tensordict_module/world_models.py +34 -0
- torchrl/modules/utils/__init__.py +38 -0
- torchrl/modules/utils/mappings.py +9 -0
- torchrl/modules/utils/utils.py +89 -0
- torchrl/objectives/__init__.py +78 -0
- torchrl/objectives/a2c.py +659 -0
- torchrl/objectives/common.py +753 -0
- torchrl/objectives/cql.py +1346 -0
- torchrl/objectives/crossq.py +710 -0
- torchrl/objectives/ddpg.py +453 -0
- torchrl/objectives/decision_transformer.py +371 -0
- torchrl/objectives/deprecated.py +516 -0
- torchrl/objectives/dqn.py +683 -0
- torchrl/objectives/dreamer.py +488 -0
- torchrl/objectives/functional.py +48 -0
- torchrl/objectives/gail.py +258 -0
- torchrl/objectives/iql.py +996 -0
- torchrl/objectives/llm/__init__.py +30 -0
- torchrl/objectives/llm/grpo.py +846 -0
- torchrl/objectives/llm/sft.py +482 -0
- torchrl/objectives/multiagent/__init__.py +8 -0
- torchrl/objectives/multiagent/qmixer.py +396 -0
- torchrl/objectives/ppo.py +1669 -0
- torchrl/objectives/redq.py +683 -0
- torchrl/objectives/reinforce.py +530 -0
- torchrl/objectives/sac.py +1580 -0
- torchrl/objectives/td3.py +570 -0
- torchrl/objectives/td3_bc.py +625 -0
- torchrl/objectives/utils.py +782 -0
- torchrl/objectives/value/__init__.py +28 -0
- torchrl/objectives/value/advantages.py +1956 -0
- torchrl/objectives/value/functional.py +1459 -0
- torchrl/objectives/value/utils.py +360 -0
- torchrl/record/__init__.py +17 -0
- torchrl/record/loggers/__init__.py +23 -0
- torchrl/record/loggers/common.py +48 -0
- torchrl/record/loggers/csv.py +226 -0
- torchrl/record/loggers/mlflow.py +142 -0
- torchrl/record/loggers/tensorboard.py +139 -0
- torchrl/record/loggers/trackio.py +163 -0
- torchrl/record/loggers/utils.py +78 -0
- torchrl/record/loggers/wandb.py +214 -0
- torchrl/record/recorder.py +554 -0
- torchrl/services/__init__.py +79 -0
- torchrl/services/base.py +109 -0
- torchrl/services/ray_service.py +453 -0
- torchrl/testing/__init__.py +107 -0
- torchrl/testing/assertions.py +179 -0
- torchrl/testing/dist_utils.py +122 -0
- torchrl/testing/env_creators.py +227 -0
- torchrl/testing/env_helper.py +35 -0
- torchrl/testing/gym_helpers.py +156 -0
- torchrl/testing/llm_mocks.py +119 -0
- torchrl/testing/mocking_classes.py +2720 -0
- torchrl/testing/modules.py +295 -0
- torchrl/testing/mp_helpers.py +15 -0
- torchrl/testing/ray_helpers.py +293 -0
- torchrl/testing/utils.py +190 -0
- torchrl/trainers/__init__.py +42 -0
- torchrl/trainers/algorithms/__init__.py +11 -0
- torchrl/trainers/algorithms/configs/__init__.py +705 -0
- torchrl/trainers/algorithms/configs/collectors.py +216 -0
- torchrl/trainers/algorithms/configs/common.py +41 -0
- torchrl/trainers/algorithms/configs/data.py +308 -0
- torchrl/trainers/algorithms/configs/envs.py +104 -0
- torchrl/trainers/algorithms/configs/envs_libs.py +361 -0
- torchrl/trainers/algorithms/configs/logging.py +80 -0
- torchrl/trainers/algorithms/configs/modules.py +570 -0
- torchrl/trainers/algorithms/configs/objectives.py +177 -0
- torchrl/trainers/algorithms/configs/trainers.py +340 -0
- torchrl/trainers/algorithms/configs/transforms.py +955 -0
- torchrl/trainers/algorithms/configs/utils.py +252 -0
- torchrl/trainers/algorithms/configs/weight_sync_schemes.py +191 -0
- torchrl/trainers/algorithms/configs/weight_update.py +159 -0
- torchrl/trainers/algorithms/ppo.py +373 -0
- torchrl/trainers/algorithms/sac.py +308 -0
- torchrl/trainers/helpers/__init__.py +40 -0
- torchrl/trainers/helpers/collectors.py +416 -0
- torchrl/trainers/helpers/envs.py +573 -0
- torchrl/trainers/helpers/logger.py +33 -0
- torchrl/trainers/helpers/losses.py +132 -0
- torchrl/trainers/helpers/models.py +658 -0
- torchrl/trainers/helpers/replay_buffer.py +59 -0
- torchrl/trainers/helpers/trainers.py +301 -0
- torchrl/trainers/trainers.py +2052 -0
- torchrl/weight_update/__init__.py +33 -0
- torchrl/weight_update/_distributed.py +749 -0
- torchrl/weight_update/_mp.py +624 -0
- torchrl/weight_update/_noupdate.py +102 -0
- torchrl/weight_update/_ray.py +1032 -0
- torchrl/weight_update/_rpc.py +284 -0
- torchrl/weight_update/_shared.py +891 -0
- torchrl/weight_update/llm/__init__.py +32 -0
- torchrl/weight_update/llm/vllm_double_buffer.py +370 -0
- torchrl/weight_update/llm/vllm_nccl.py +710 -0
- torchrl/weight_update/utils.py +73 -0
- torchrl/weight_update/weight_sync_schemes.py +1244 -0
- torchrl-0.11.0.dist-info/METADATA +1308 -0
- torchrl-0.11.0.dist-info/RECORD +394 -0
- torchrl-0.11.0.dist-info/WHEEL +5 -0
- torchrl-0.11.0.dist-info/entry_points.txt +2 -0
- torchrl-0.11.0.dist-info/licenses/LICENSE +21 -0
- torchrl-0.11.0.dist-info/top_level.txt +7 -0
|
@@ -0,0 +1,798 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
#
|
|
3
|
+
# This source code is licensed under the MIT license found in the
|
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import importlib.util
|
|
8
|
+
import io
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
import shutil
|
|
12
|
+
import tempfile
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from tensordict import make_tensordict, NonTensorData, pad, TensorDict
|
|
19
|
+
from tensordict.utils import _is_non_tensor
|
|
20
|
+
|
|
21
|
+
from torchrl.data.datasets.common import BaseDatasetExperienceReplay
|
|
22
|
+
from torchrl.data.datasets.utils import _get_root_dir
|
|
23
|
+
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
24
|
+
from torchrl.data.replay_buffers.samplers import (
|
|
25
|
+
Sampler,
|
|
26
|
+
SliceSampler,
|
|
27
|
+
SliceSamplerWithoutReplacement,
|
|
28
|
+
)
|
|
29
|
+
from torchrl.data.replay_buffers.storages import _collate_id, Storage, TensorStorage
|
|
30
|
+
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
31
|
+
|
|
32
|
+
_has_datasets = importlib.util.find_spec("datasets", None) is not None
|
|
33
|
+
_has_tv = importlib.util.find_spec("torchvision", None) is not None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenXExperienceReplay(BaseDatasetExperienceReplay):
|
|
37
|
+
"""Open X-Embodiment datasets experience replay.
|
|
38
|
+
|
|
39
|
+
The Open X-Embodiment Dataset contains 1M+ real robot trajectories
|
|
40
|
+
spanning 22 robot embodiments, collected through a collaboration between
|
|
41
|
+
21 institutions, demonstrating 527 skills (160266 tasks).
|
|
42
|
+
|
|
43
|
+
Website: https://robotics-transformer-x.github.io/
|
|
44
|
+
|
|
45
|
+
GitHub: https://github.com/google-deepmind/open_x_embodiment
|
|
46
|
+
|
|
47
|
+
Paper: https://arxiv.org/abs/2310.08864
|
|
48
|
+
|
|
49
|
+
The data format follows the :ref:`TED convention <TED-format>`.
|
|
50
|
+
|
|
51
|
+
.. note::
|
|
52
|
+
Non-tensor data will be written in the tensordict data using the
|
|
53
|
+
:class:`~tensordict.tensorclass.NonTensorData` primitive.
|
|
54
|
+
For instance, the `language_instruction` field in the data will
|
|
55
|
+
be stored in `data.get_non_tensor("language_instruction")` (or equivalently
|
|
56
|
+
`data.get("language_instruction").data`). See the documentation of this
|
|
57
|
+
class for more information on how to interact with non-tensor data
|
|
58
|
+
stored in a :class:`~tensordict.TensorDict`.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dataset_id (str): The dataset to be downloaded.
|
|
62
|
+
Must be part of ``OpenXExperienceReplay.available_datasets``.
|
|
63
|
+
batch_size (int): Batch-size used during sampling.
|
|
64
|
+
Can be overridden by `data.sample(batch_size)` if necessary.
|
|
65
|
+
See ``num_slices`` and ``slice_len`` keyword arguments for a refined
|
|
66
|
+
sampling strategy.
|
|
67
|
+
If the ``batch_size`` is ``None`` (default), iterating over the
|
|
68
|
+
dataset will deliver trajectories one at a time *whereas* calling
|
|
69
|
+
:meth:`sample` will *still* require a batch-size to be provided.
|
|
70
|
+
|
|
71
|
+
Keyword Args:
|
|
72
|
+
shuffle (bool, optional): if ``True``, trajectories are delivered in a
|
|
73
|
+
random order when the dataset is iterated over.
|
|
74
|
+
If ``False``, the dataset is iterated over in the pre-defined order.
|
|
75
|
+
|
|
76
|
+
.. warning::
|
|
77
|
+
shuffle=False will also impact the sampling. We advice users to
|
|
78
|
+
create a copy of the dataset where the ``shuffle`` attribute of the
|
|
79
|
+
sampler is set to ``False`` if they wish to enjoy the two different
|
|
80
|
+
behaviors (shuffled and not) within the same code base.
|
|
81
|
+
|
|
82
|
+
num_slices (int, optional): the number of slices in a batch. This
|
|
83
|
+
corresponds to the number of trajectories present in a batch.
|
|
84
|
+
Once collected, the batch is presented as a concatenation of
|
|
85
|
+
sub-trajectories that can be recovered through `batch.reshape(num_slices, -1)`.
|
|
86
|
+
The `batch_size` must be divisible by `num_slices` if provided.
|
|
87
|
+
This argument is exclusive with ``slice_len``.
|
|
88
|
+
If the ``num_slices`` argument equates the ``batch_size``, each sample
|
|
89
|
+
will belong to a different trajectory.
|
|
90
|
+
If neither ``slice_len`` nor ``num_slice`` are provided:
|
|
91
|
+
whenever a trajectory has a length shorter than the
|
|
92
|
+
batch-size, a contiguous slice of it of length `batch_size` will be
|
|
93
|
+
sampled. If the trajectory length is insufficient, an exception will
|
|
94
|
+
be raised unless `pad` is not `None`.
|
|
95
|
+
slice_len (int, optional): the length of slices in a batch. This
|
|
96
|
+
corresponds to the length of trajectories present in a batch.
|
|
97
|
+
Once collected, the batch is presented as a concatenation of
|
|
98
|
+
sub-trajectories that can be recovered through `batch.reshape(-1, slice_len)`.
|
|
99
|
+
The `batch_size` must be divisible by `slice_len` if provided.
|
|
100
|
+
This argument is exclusive with ``num_slice``.
|
|
101
|
+
If the ``slice_len`` argument equates ``1``, each sample
|
|
102
|
+
will belong to a different trajectory.
|
|
103
|
+
If neither ``slice_len`` nor ``num_slice`` are provided:
|
|
104
|
+
whenever a trajectory has a length shorter than the
|
|
105
|
+
batch-size, a contiguous slice of it of length `batch_size` will be
|
|
106
|
+
sampled. If the trajectory length is insufficient, an exception will
|
|
107
|
+
be raised unless `pad` is not `None`.
|
|
108
|
+
|
|
109
|
+
.. note::
|
|
110
|
+
The ``slice_len`` (but not ``num_slices``) can be used when
|
|
111
|
+
iterating over a dataset without passing a batch-size in the,
|
|
112
|
+
constructor. In these cases, a random sub-sequence of the
|
|
113
|
+
trajectory will be chosen.
|
|
114
|
+
|
|
115
|
+
replacement (bool, optional): if ``False``, sampling will be done
|
|
116
|
+
without replacement. Defaults to ``True`` for downloaded datasets,
|
|
117
|
+
``False`` for streamed datasets.
|
|
118
|
+
pad (bool, :obj:`float` or None): if ``True``, trajectories of insufficient length
|
|
119
|
+
given the `slice_len` or `num_slices` arguments will be padded with
|
|
120
|
+
0s. If another value is provided, it will be used for padding. If
|
|
121
|
+
``False`` or ``None`` (default) any encounter with a trajectory of
|
|
122
|
+
insufficient length will raise an exception.
|
|
123
|
+
root (Path or str, optional): The OpenX dataset root directory.
|
|
124
|
+
The actual dataset memory-mapped files will be saved under
|
|
125
|
+
`<root>/<dataset_id>`. If none is provided, it defaults to
|
|
126
|
+
`~/.cache/torchrl/atari`.openx`.
|
|
127
|
+
streaming (bool, optional): if ``True``, the data won't be downloaded but
|
|
128
|
+
read from a stream instead.
|
|
129
|
+
|
|
130
|
+
.. note:: The formatting of the data **will change** when `download=True`
|
|
131
|
+
compared to `streaming=True`. If the data is downloaded and
|
|
132
|
+
the sampler is left untouched (ie, `num_slices=None`, `slice_len=None`
|
|
133
|
+
and `sampler=None`, transitions will be sampled randomly from
|
|
134
|
+
the dataset. This isn't possible at a reasonable cost with
|
|
135
|
+
`streaming=True`: in this case, trajectories will be sampled
|
|
136
|
+
one at a time and delivered as such (with cropping to comply with
|
|
137
|
+
the batch-size etc). The behavior of the two modalities is
|
|
138
|
+
much more similar when `num_slices` and `slice_len` are specified,
|
|
139
|
+
as in these cases, views of sub-episodes will be returned in both
|
|
140
|
+
cases.
|
|
141
|
+
|
|
142
|
+
download (bool or str, optional): Whether the dataset should be downloaded if
|
|
143
|
+
not found. Defaults to ``True``. Download can also be passed as "force",
|
|
144
|
+
in which case the downloaded data will be overwritten.
|
|
145
|
+
sampler (Sampler, optional): the sampler to be used. If none is provided
|
|
146
|
+
a default RandomSampler() will be used.
|
|
147
|
+
writer (Writer, optional): the writer to be used. If none is provided
|
|
148
|
+
a default :class:`~torchrl.data.replay_buffers.writers.ImmutableDatasetWriter` will be used.
|
|
149
|
+
collate_fn (callable, optional): merges a list of samples to form a
|
|
150
|
+
mini-batch of Tensor(s)/outputs. Used when using batched
|
|
151
|
+
loading from a map-style dataset.
|
|
152
|
+
pin_memory (bool): whether pin_memory() should be called on the rb
|
|
153
|
+
samples.
|
|
154
|
+
prefetch (int, optional): number of next batches to be prefetched
|
|
155
|
+
using multithreading.
|
|
156
|
+
transform (Transform, optional): Transform to be executed when sample() is called.
|
|
157
|
+
To chain transforms use the :class:`~torchrl.envs.transforms.transforms.Compose` class.
|
|
158
|
+
split_trajs (bool, optional): if ``True``, the trajectories will be split
|
|
159
|
+
along the first dimension and padded to have a matching shape.
|
|
160
|
+
To split the trajectories, the ``"done"`` signal will be used, which
|
|
161
|
+
is recovered via ``done = truncated | terminated``. In other words,
|
|
162
|
+
it is assumed that any ``truncated`` or ``terminated`` signal is
|
|
163
|
+
equivalent to the end of a trajectory.
|
|
164
|
+
Defaults to ``False``.
|
|
165
|
+
strict_length (bool, optional): if ``False``, trajectories of length
|
|
166
|
+
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
|
167
|
+
allowed to appear in the batch.
|
|
168
|
+
Be mindful that this can result in effective `batch_size` shorter
|
|
169
|
+
than the one asked for! Trajectories can be split using
|
|
170
|
+
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
|
171
|
+
|
|
172
|
+
Examples:
|
|
173
|
+
>>> from torchrl.data.datasets import OpenXExperienceReplay
|
|
174
|
+
>>> import tempfile
|
|
175
|
+
>>> # Download the data, and sample 128 elements in each batch out of two trajectories
|
|
176
|
+
>>> num_slices = 2
|
|
177
|
+
>>> with tempfile.TemporaryDirectory() as root:
|
|
178
|
+
... dataset = OpenXExperienceReplay("cmu_stretch", batch_size=128,
|
|
179
|
+
... num_slices=num_slices, download=True, streaming=False,
|
|
180
|
+
... root=root,
|
|
181
|
+
... )
|
|
182
|
+
... for batch in dataset:
|
|
183
|
+
... print(batch.reshape(num_slices, -1))
|
|
184
|
+
... break
|
|
185
|
+
TensorDict(
|
|
186
|
+
fields={
|
|
187
|
+
action: Tensor(shape=torch.Size([2, 64, 8]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
188
|
+
discount: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
189
|
+
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
190
|
+
episode: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int32, is_shared=False),
|
|
191
|
+
index: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.int64, is_shared=False),
|
|
192
|
+
is_init: Tensor(shape=torch.Size([2, 64]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
193
|
+
language_embedding: Tensor(shape=torch.Size([2, 64, 512]), device=cpu, dtype=torch.float64, is_shared=False),
|
|
194
|
+
language_instruction: NonTensorData(
|
|
195
|
+
data='lift open green garbage can lid',
|
|
196
|
+
batch_size=torch.Size([2, 64]),
|
|
197
|
+
device=cpu,
|
|
198
|
+
is_shared=False),
|
|
199
|
+
next: TensorDict(
|
|
200
|
+
fields={
|
|
201
|
+
done: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
202
|
+
observation: TensorDict(
|
|
203
|
+
fields={
|
|
204
|
+
image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
205
|
+
state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
206
|
+
batch_size=torch.Size([2, 64]),
|
|
207
|
+
device=cpu,
|
|
208
|
+
is_shared=False),
|
|
209
|
+
reward: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
210
|
+
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
211
|
+
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
212
|
+
batch_size=torch.Size([2, 64]),
|
|
213
|
+
device=cpu,
|
|
214
|
+
is_shared=False),
|
|
215
|
+
observation: TensorDict(
|
|
216
|
+
fields={
|
|
217
|
+
image: Tensor(shape=torch.Size([2, 64, 3, 128, 128]), device=cpu, dtype=torch.uint8, is_shared=False),
|
|
218
|
+
state: Tensor(shape=torch.Size([2, 64, 4]), device=cpu, dtype=torch.float64, is_shared=False)},
|
|
219
|
+
batch_size=torch.Size([2, 64]),
|
|
220
|
+
device=cpu,
|
|
221
|
+
is_shared=False),
|
|
222
|
+
terminated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False),
|
|
223
|
+
truncated: Tensor(shape=torch.Size([2, 64, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
|
|
224
|
+
batch_size=torch.Size([2, 64]),
|
|
225
|
+
device=cpu,
|
|
226
|
+
is_shared=False)
|
|
227
|
+
>>> # Read data from a stream. Deliver entire trajectories when iterating
|
|
228
|
+
>>> dataset = OpenXExperienceReplay("cmu_stretch",
|
|
229
|
+
... num_slices=num_slices, download=False, streaming=True)
|
|
230
|
+
>>> for data in dataset: # data does not have a consistent shape
|
|
231
|
+
... break
|
|
232
|
+
>>> # Define batch-size dynamically
|
|
233
|
+
>>> data = dataset.sample(128) # delivers 2 sub-trajectories of length 64
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
available_datasets = [
|
|
238
|
+
"fractal20220817_data",
|
|
239
|
+
"kuka",
|
|
240
|
+
"bridge",
|
|
241
|
+
"taco_play",
|
|
242
|
+
"jaco_play",
|
|
243
|
+
"berkeley_cable_routing",
|
|
244
|
+
"roboturk",
|
|
245
|
+
"nyu_door_opening_surprising_effectiveness",
|
|
246
|
+
"viola",
|
|
247
|
+
"berkeley_autolab_ur5",
|
|
248
|
+
"toto",
|
|
249
|
+
"language_table",
|
|
250
|
+
"columbia_cairlab_pusht_real",
|
|
251
|
+
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds",
|
|
252
|
+
"nyu_rot_dataset_converted_externally_to_rlds",
|
|
253
|
+
"stanford_hydra_dataset_converted_externally_to_rlds",
|
|
254
|
+
"austin_buds_dataset_converted_externally_to_rlds",
|
|
255
|
+
"nyu_franka_play_dataset_converted_externally_to_rlds",
|
|
256
|
+
"maniskill_dataset_converted_externally_to_rlds",
|
|
257
|
+
"furniture_bench_dataset_converted_externally_to_rlds",
|
|
258
|
+
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
|
|
259
|
+
"ucsd_kitchen_dataset_converted_externally_to_rlds",
|
|
260
|
+
"ucsd_pick_and_place_dataset_converted_externally_to_rlds",
|
|
261
|
+
"austin_sailor_dataset_converted_externally_to_rlds",
|
|
262
|
+
"austin_sirius_dataset_converted_externally_to_rlds",
|
|
263
|
+
"bc_z",
|
|
264
|
+
"usc_cloth_sim_converted_externally_to_rlds",
|
|
265
|
+
"utokyo_pr2_opening_fridge_converted_externally_to_rlds",
|
|
266
|
+
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds",
|
|
267
|
+
"utokyo_saytap_converted_externally_to_rlds",
|
|
268
|
+
"utokyo_xarm_pick_and_place_converted_externally_to_rlds",
|
|
269
|
+
"utokyo_xarm_bimanual_converted_externally_to_rlds",
|
|
270
|
+
"robo_net",
|
|
271
|
+
"berkeley_mvp_converted_externally_to_rlds",
|
|
272
|
+
"berkeley_rpt_converted_externally_to_rlds",
|
|
273
|
+
"kaist_nonprehensile_converted_externally_to_rlds",
|
|
274
|
+
"stanford_mask_vit_converted_externally_to_rlds",
|
|
275
|
+
"tokyo_u_lsmo_converted_externally_to_rlds",
|
|
276
|
+
"dlr_sara_pour_converted_externally_to_rlds",
|
|
277
|
+
"dlr_sara_grid_clamp_converted_externally_to_rlds",
|
|
278
|
+
"dlr_edan_shared_control_converted_externally_to_rlds",
|
|
279
|
+
"asu_table_top_converted_externally_to_rlds",
|
|
280
|
+
"stanford_robocook_converted_externally_to_rlds",
|
|
281
|
+
"eth_agent_affordances",
|
|
282
|
+
"imperialcollege_sawyer_wrist_cam",
|
|
283
|
+
"iamlab_cmu_pickup_insert_converted_externally_to_rlds",
|
|
284
|
+
"uiuc_d3field",
|
|
285
|
+
"utaustin_mutex",
|
|
286
|
+
"berkeley_fanuc_manipulation",
|
|
287
|
+
"cmu_playing_with_food",
|
|
288
|
+
"cmu_play_fusion",
|
|
289
|
+
"cmu_stretch",
|
|
290
|
+
"berkeley_gnm_recon",
|
|
291
|
+
"berkeley_gnm_cory_hall",
|
|
292
|
+
"berkeley_gnm_sac_son",
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
# some very high number that should be above all trajecory lengths in the dataset
|
|
296
|
+
_MAX_TRAJ_LEN = 1_000_000
|
|
297
|
+
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
dataset_id,
|
|
301
|
+
batch_size: int | None = None,
|
|
302
|
+
*,
|
|
303
|
+
shuffle: bool = True,
|
|
304
|
+
num_slices: int | None = None,
|
|
305
|
+
slice_len: int | None = None,
|
|
306
|
+
pad: float | bool | None = None,
|
|
307
|
+
replacement: bool | None = None,
|
|
308
|
+
streaming: bool | None = None,
|
|
309
|
+
root: str | Path | None = None,
|
|
310
|
+
download: bool | None = None,
|
|
311
|
+
sampler: Sampler | None = None,
|
|
312
|
+
writer: Writer | None = None,
|
|
313
|
+
collate_fn: Callable | None = None,
|
|
314
|
+
pin_memory: bool = False,
|
|
315
|
+
prefetch: int | None = None,
|
|
316
|
+
transform: torchrl.envs.Transform | None = None, # noqa-F821
|
|
317
|
+
split_trajs: bool = False,
|
|
318
|
+
strict_length: bool = True,
|
|
319
|
+
):
|
|
320
|
+
if download is None and streaming is None:
|
|
321
|
+
download = False
|
|
322
|
+
streaming = True
|
|
323
|
+
elif download is None:
|
|
324
|
+
download = not streaming
|
|
325
|
+
elif streaming is None:
|
|
326
|
+
streaming = not download
|
|
327
|
+
self.download = download
|
|
328
|
+
self.streaming = streaming
|
|
329
|
+
self.dataset_id = dataset_id
|
|
330
|
+
self.split_trajs = split_trajs
|
|
331
|
+
self.shuffle = shuffle
|
|
332
|
+
self.num_slices = num_slices
|
|
333
|
+
self.slice_len = slice_len
|
|
334
|
+
self.pad = pad
|
|
335
|
+
self.strict_length = strict_length
|
|
336
|
+
if (self.num_slices is not None) and (self.slice_len is not None):
|
|
337
|
+
raise ValueError("num_slices or slice_len can be not None, but not both.")
|
|
338
|
+
if split_trajs:
|
|
339
|
+
raise NotImplementedError
|
|
340
|
+
if not streaming:
|
|
341
|
+
if replacement is None:
|
|
342
|
+
replacement = True
|
|
343
|
+
if pad is not None:
|
|
344
|
+
raise RuntimeError(
|
|
345
|
+
"the `pad` argument is to be used only with streaming datasets."
|
|
346
|
+
)
|
|
347
|
+
if root is None:
|
|
348
|
+
root = _get_root_dir("openx")
|
|
349
|
+
os.makedirs(root, exist_ok=True)
|
|
350
|
+
self.root = Path(root)
|
|
351
|
+
if self.download == "force" or (
|
|
352
|
+
self.download and not self._is_downloaded()
|
|
353
|
+
):
|
|
354
|
+
if download == "force" and os.path.exists(self.data_path_root):
|
|
355
|
+
shutil.rmtree(self.data_path_root)
|
|
356
|
+
|
|
357
|
+
storage = self._download_and_preproc()
|
|
358
|
+
else:
|
|
359
|
+
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
|
360
|
+
if num_slices is not None or slice_len is not None:
|
|
361
|
+
if sampler is not None:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if replacement:
|
|
367
|
+
if not self.shuffle:
|
|
368
|
+
raise RuntimeError(
|
|
369
|
+
"shuffle=False can only be used when replacement=False."
|
|
370
|
+
)
|
|
371
|
+
sampler = SliceSampler(
|
|
372
|
+
num_slices=num_slices,
|
|
373
|
+
slice_len=slice_len,
|
|
374
|
+
strict_length=strict_length,
|
|
375
|
+
)
|
|
376
|
+
else:
|
|
377
|
+
sampler = SliceSamplerWithoutReplacement(
|
|
378
|
+
num_slices=num_slices,
|
|
379
|
+
slice_len=slice_len,
|
|
380
|
+
strict_length=strict_length,
|
|
381
|
+
shuffle=self.shuffle,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
else:
|
|
385
|
+
if replacement is True:
|
|
386
|
+
# replacement can be False or None
|
|
387
|
+
raise RuntimeError(
|
|
388
|
+
"replacement=True is not available with streamed datasets."
|
|
389
|
+
)
|
|
390
|
+
self.root = None
|
|
391
|
+
if download:
|
|
392
|
+
raise ValueError(
|
|
393
|
+
"download and streaming cannot be set to ``True`` concomitantly."
|
|
394
|
+
)
|
|
395
|
+
storage = _StreamingStorage(
|
|
396
|
+
dataset_id=dataset_id,
|
|
397
|
+
shuffle=self.shuffle,
|
|
398
|
+
num_slices=self.num_slices,
|
|
399
|
+
slice_len=self.slice_len,
|
|
400
|
+
pad=self.pad,
|
|
401
|
+
)
|
|
402
|
+
if sampler is None:
|
|
403
|
+
sampler = _StreamingSampler()
|
|
404
|
+
if writer is None:
|
|
405
|
+
writer = ImmutableDatasetWriter()
|
|
406
|
+
if collate_fn is None:
|
|
407
|
+
collate_fn = _collate_id
|
|
408
|
+
super().__init__(
|
|
409
|
+
storage=storage,
|
|
410
|
+
sampler=sampler,
|
|
411
|
+
writer=writer,
|
|
412
|
+
collate_fn=collate_fn,
|
|
413
|
+
pin_memory=pin_memory,
|
|
414
|
+
prefetch=prefetch,
|
|
415
|
+
batch_size=batch_size,
|
|
416
|
+
transform=transform,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def __iter__(self):
|
|
420
|
+
if self._batch_size is None:
|
|
421
|
+
# we can still iterate over the dataset
|
|
422
|
+
if isinstance(self._storage, _StreamingStorage):
|
|
423
|
+
yield from self._storage
|
|
424
|
+
elif self.slice_len is not None and self.num_slices is None:
|
|
425
|
+
try:
|
|
426
|
+
# truncate the trajs with slice_len
|
|
427
|
+
self._batch_size = self.slice_len
|
|
428
|
+
self.num_slices = 1
|
|
429
|
+
self.slice_len = None
|
|
430
|
+
yield from self
|
|
431
|
+
finally:
|
|
432
|
+
self.slice_len = self._batch_size
|
|
433
|
+
self._batch_size = None
|
|
434
|
+
self.num_slices = None
|
|
435
|
+
else:
|
|
436
|
+
# if we don't have a batch size but we know how many trajectories
|
|
437
|
+
# we want in each batch, we can build that on the fly.
|
|
438
|
+
# The only time we can do this is if num_slices is given but not
|
|
439
|
+
# slice_len.
|
|
440
|
+
num_slices = self.num_slices
|
|
441
|
+
if not num_slices:
|
|
442
|
+
num_slices = 1
|
|
443
|
+
sampler = SliceSamplerWithoutReplacement(
|
|
444
|
+
num_slices=num_slices,
|
|
445
|
+
strict_length=False,
|
|
446
|
+
shuffle=self.shuffle,
|
|
447
|
+
)
|
|
448
|
+
batch_size = self._MAX_TRAJ_LEN
|
|
449
|
+
yield from TensorDictReplayBuffer(
|
|
450
|
+
storage=self._storage,
|
|
451
|
+
sampler=sampler,
|
|
452
|
+
batch_size=batch_size,
|
|
453
|
+
transform=self._transform,
|
|
454
|
+
)
|
|
455
|
+
else:
|
|
456
|
+
yield from super().__iter__()
|
|
457
|
+
|
|
458
|
+
@property
|
|
459
|
+
def data_path(self):
|
|
460
|
+
if self.streaming:
|
|
461
|
+
return None
|
|
462
|
+
if self.split_trajs:
|
|
463
|
+
return Path(self.root) / (self.dataset_id + "_split")
|
|
464
|
+
return self.data_path_root
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def data_path_root(self):
|
|
468
|
+
if self.streaming:
|
|
469
|
+
return None
|
|
470
|
+
return self.root / self.dataset_id
|
|
471
|
+
|
|
472
|
+
def _is_downloaded(self):
|
|
473
|
+
return os.path.exists(self.data_path_root)
|
|
474
|
+
|
|
475
|
+
def _download_and_preproc(self):
|
|
476
|
+
if not _has_datasets:
|
|
477
|
+
raise ImportError(
|
|
478
|
+
f"the `datasets` library is required for the dataset {self.dataset_id}."
|
|
479
|
+
)
|
|
480
|
+
import datasets
|
|
481
|
+
|
|
482
|
+
with tempfile.TemporaryDirectory() as cache_dir:
|
|
483
|
+
dataset = datasets.load_dataset(
|
|
484
|
+
"jxu124/OpenX-Embodiment",
|
|
485
|
+
self.dataset_id,
|
|
486
|
+
streaming=False,
|
|
487
|
+
split="train",
|
|
488
|
+
cache_dir=cache_dir,
|
|
489
|
+
trust_remote_code=True,
|
|
490
|
+
)
|
|
491
|
+
# iterate over the dataset a first time to count elements
|
|
492
|
+
total_frames = 0
|
|
493
|
+
|
|
494
|
+
try:
|
|
495
|
+
import tqdm
|
|
496
|
+
|
|
497
|
+
_has_tqdm = True
|
|
498
|
+
pbar = tqdm.tqdm(dataset, desc="counting")
|
|
499
|
+
except ImportError:
|
|
500
|
+
_has_tqdm = False
|
|
501
|
+
pbar = dataset
|
|
502
|
+
|
|
503
|
+
for data in pbar:
|
|
504
|
+
if total_frames == 0:
|
|
505
|
+
for step in data["data.pickle"]["steps"]:
|
|
506
|
+
td = _make_tensordict_image_conv(step).zero_()
|
|
507
|
+
# format td: requires td to have a non-null batch_size
|
|
508
|
+
td = td.expand(2, *td.shape)
|
|
509
|
+
_format_data(td, 0)
|
|
510
|
+
td = td[0]
|
|
511
|
+
total_frames += len(data["data.pickle"]["steps"])
|
|
512
|
+
td_data = td.expand(total_frames)
|
|
513
|
+
|
|
514
|
+
def expand_non_tensor(x):
|
|
515
|
+
if isinstance(x, NonTensorData):
|
|
516
|
+
return x.maybe_to_stack()
|
|
517
|
+
return x
|
|
518
|
+
|
|
519
|
+
td_data = td_data._apply_nest(
|
|
520
|
+
expand_non_tensor,
|
|
521
|
+
is_leaf=lambda x: issubclass(x, torch.Tensor) or _is_non_tensor(x),
|
|
522
|
+
)
|
|
523
|
+
td_data = td_data.memmap_like(self.root / self.dataset_id)
|
|
524
|
+
if _has_tqdm:
|
|
525
|
+
pbar = tqdm.tqdm(dataset, desc="preproc", total=total_frames)
|
|
526
|
+
else:
|
|
527
|
+
pbar = dataset
|
|
528
|
+
idx0 = 0
|
|
529
|
+
idx1 = 0
|
|
530
|
+
episode = 0
|
|
531
|
+
for data in pbar:
|
|
532
|
+
current_ep = torch.stack(
|
|
533
|
+
[
|
|
534
|
+
_make_tensordict_image_conv(step)
|
|
535
|
+
for step in data["data.pickle"]["steps"]
|
|
536
|
+
]
|
|
537
|
+
).contiguous()
|
|
538
|
+
_format_data(current_ep, episode)
|
|
539
|
+
episode += 1
|
|
540
|
+
idx1 += len(current_ep)
|
|
541
|
+
td_data[idx0:idx1] = current_ep
|
|
542
|
+
idx0 = idx1
|
|
543
|
+
if _has_tqdm:
|
|
544
|
+
pbar.update(current_ep.shape[0])
|
|
545
|
+
return TensorStorage(td_data.lock_())
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class _StreamingStorage(Storage):
|
|
549
|
+
SLICE_MISMATCH = "The batch_size {} must be divisible by num_slices {} or slice_len {} if provided."
|
|
550
|
+
|
|
551
|
+
def __init__(
|
|
552
|
+
self,
|
|
553
|
+
dataset_id: str,
|
|
554
|
+
repo: str = "jxu124/OpenX-Embodiment",
|
|
555
|
+
split="train",
|
|
556
|
+
base_path="data.pickle",
|
|
557
|
+
shuffle: bool = True,
|
|
558
|
+
truncate: bool = True,
|
|
559
|
+
num_slices=None,
|
|
560
|
+
slice_len=None,
|
|
561
|
+
pad=None,
|
|
562
|
+
):
|
|
563
|
+
self.shuffle = shuffle
|
|
564
|
+
self.dataset_id = dataset_id
|
|
565
|
+
self.repo = repo
|
|
566
|
+
self.split = split
|
|
567
|
+
self._init()
|
|
568
|
+
self.base_path = base_path
|
|
569
|
+
self.truncate = truncate
|
|
570
|
+
self.num_slices = num_slices
|
|
571
|
+
self.slice_len = slice_len
|
|
572
|
+
self.pad = pad
|
|
573
|
+
|
|
574
|
+
def _init(self):
|
|
575
|
+
if not _has_datasets:
|
|
576
|
+
raise ImportError(
|
|
577
|
+
f"the `datasets` library is required for the dataset {self.dataset_id}."
|
|
578
|
+
)
|
|
579
|
+
import datasets
|
|
580
|
+
|
|
581
|
+
try:
|
|
582
|
+
dataset = datasets.load_dataset(
|
|
583
|
+
self.repo, self.dataset_id, streaming=True, split=self.split
|
|
584
|
+
)
|
|
585
|
+
except Exception as e:
|
|
586
|
+
if "Dataset scripts are no longer supported" in str(e):
|
|
587
|
+
raise RuntimeError(
|
|
588
|
+
f"Failed to load dataset {self.dataset_id}. Your version of `datasets` is too new - please downgrade to <4.0.0."
|
|
589
|
+
) from e
|
|
590
|
+
raise e
|
|
591
|
+
|
|
592
|
+
if self.shuffle:
|
|
593
|
+
dataset = dataset.shuffle()
|
|
594
|
+
self.dataset = dataset
|
|
595
|
+
self.dataset_iter = iter(dataset)
|
|
596
|
+
|
|
597
|
+
def __iter__(self):
|
|
598
|
+
episode = 0
|
|
599
|
+
for data in self.dataset:
|
|
600
|
+
if self.base_path:
|
|
601
|
+
data = data[self.base_path]
|
|
602
|
+
data = torch.stack(
|
|
603
|
+
[_make_tensordict_image_conv(step) for step in data["steps"]]
|
|
604
|
+
).contiguous()
|
|
605
|
+
_format_data(data, episode)
|
|
606
|
+
if self.slice_len is not None:
|
|
607
|
+
yield _slice_data(data, slice_len=self.slice_len, pad_value=self.pad)
|
|
608
|
+
else:
|
|
609
|
+
yield data
|
|
610
|
+
|
|
611
|
+
def get(self, index: range | torch.Tensor) -> Any:
|
|
612
|
+
if not isinstance(index, range):
|
|
613
|
+
if (index[1:] != index[:-1] + 1).any():
|
|
614
|
+
# we use a range to indicate how much data we want
|
|
615
|
+
raise RuntimeError("iterable datasets do not support indexing.")
|
|
616
|
+
index = range(index.shape[0])
|
|
617
|
+
total = 0
|
|
618
|
+
data_list = []
|
|
619
|
+
episode = 0
|
|
620
|
+
batch_size = index.stop
|
|
621
|
+
if self.num_slices is not None:
|
|
622
|
+
if batch_size % self.num_slices != 0:
|
|
623
|
+
raise ValueError(
|
|
624
|
+
self.SLICE_MISMATCH.format(
|
|
625
|
+
batch_size, self.num_slices, self.slice_len
|
|
626
|
+
)
|
|
627
|
+
)
|
|
628
|
+
num_slices = self.num_slices
|
|
629
|
+
slice_len = batch_size // num_slices
|
|
630
|
+
else:
|
|
631
|
+
if batch_size % self.slice_len != 0:
|
|
632
|
+
raise ValueError(
|
|
633
|
+
self.SLICE_MISMATCH.format(
|
|
634
|
+
batch_size, self.num_slices, self.slice_len
|
|
635
|
+
)
|
|
636
|
+
)
|
|
637
|
+
slice_len = self.slice_len
|
|
638
|
+
# num_slices = batch_size // slice_len
|
|
639
|
+
|
|
640
|
+
while total < batch_size:
|
|
641
|
+
try:
|
|
642
|
+
data = next(self.dataset_iter)
|
|
643
|
+
except StopIteration:
|
|
644
|
+
self.dataset_iter = iter(self.dataset)
|
|
645
|
+
data = next(self.dataset_iter)
|
|
646
|
+
|
|
647
|
+
if self.base_path:
|
|
648
|
+
data = data[self.base_path]
|
|
649
|
+
data = torch.stack(
|
|
650
|
+
[_make_tensordict_image_conv(step) for step in data["steps"]]
|
|
651
|
+
).contiguous()
|
|
652
|
+
_format_data(data, episode)
|
|
653
|
+
data = _slice_data(data, slice_len=slice_len, pad_value=self.pad)
|
|
654
|
+
data_list.append(data)
|
|
655
|
+
total += data.numel()
|
|
656
|
+
episode += 1
|
|
657
|
+
data = torch.cat(data_list)
|
|
658
|
+
if self.truncate:
|
|
659
|
+
return data[: index.stop]
|
|
660
|
+
return data
|
|
661
|
+
|
|
662
|
+
def dumps(self, path):
|
|
663
|
+
path = Path(path)
|
|
664
|
+
state_dict = self.state_dict()
|
|
665
|
+
json.dump(state_dict, path / "state_dict.json")
|
|
666
|
+
|
|
667
|
+
def state_dict(self) -> dict[str, Any]:
|
|
668
|
+
return {
|
|
669
|
+
"repo": self.repo,
|
|
670
|
+
"split": self.split,
|
|
671
|
+
"dataset_id": self.dataset_id,
|
|
672
|
+
"shuffle": self.shuffle,
|
|
673
|
+
"base_path": self.base_path,
|
|
674
|
+
"truncated": self.truncate,
|
|
675
|
+
"num_slices": self.num_slices,
|
|
676
|
+
"slice_len": self.slice_len,
|
|
677
|
+
"pad": self.pad,
|
|
678
|
+
}
|
|
679
|
+
|
|
680
|
+
def loads(self, path):
|
|
681
|
+
path = Path(path)
|
|
682
|
+
state_dict = json.load(path / "state_dict.json")
|
|
683
|
+
self.load_state_dict(state_dict)
|
|
684
|
+
|
|
685
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
686
|
+
for key, val in state_dict.items():
|
|
687
|
+
setattr(self, key, val)
|
|
688
|
+
self._init()
|
|
689
|
+
|
|
690
|
+
def __len__(self):
|
|
691
|
+
raise RuntimeError(
|
|
692
|
+
f"{type(self)} does not have a length. Use a downloaded dataset to "
|
|
693
|
+
f"access this property."
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _slice_data(data: TensorDict, slice_len, pad_value):
|
|
698
|
+
if data.shape[-1] < slice_len:
|
|
699
|
+
if pad_value is None:
|
|
700
|
+
raise RuntimeError(
|
|
701
|
+
f"The trajectory length ({data.shape[-1]}) is shorter than the slice length ({slice_len}). "
|
|
702
|
+
f"Decrease the slice length or provide a padding value."
|
|
703
|
+
)
|
|
704
|
+
if pad_value is True:
|
|
705
|
+
pad_value = 0
|
|
706
|
+
return pad(data, [0, slice_len - data.shape[-1]], value=pad_value)
|
|
707
|
+
|
|
708
|
+
if data.ndim == 1:
|
|
709
|
+
random_range = (
|
|
710
|
+
((data.shape[-1] - slice_len) * torch.rand(())).floor().int().item()
|
|
711
|
+
)
|
|
712
|
+
random_range = slice(random_range, random_range + slice_len)
|
|
713
|
+
else:
|
|
714
|
+
raise NotImplementedError(data)
|
|
715
|
+
data = data[..., random_range]
|
|
716
|
+
truncated = data.get(("next", "truncated"))
|
|
717
|
+
truncated = torch.index_fill(
|
|
718
|
+
truncated,
|
|
719
|
+
dim=data.ndim - 1,
|
|
720
|
+
value=True,
|
|
721
|
+
index=torch.as_tensor(-1, device=truncated.device),
|
|
722
|
+
)
|
|
723
|
+
done = data.get(("next", "done"))
|
|
724
|
+
data.set(("next", "truncated"), truncated)
|
|
725
|
+
data.set(("next", "done"), truncated | done)
|
|
726
|
+
return data
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class _StreamingSampler(Sampler):
|
|
730
|
+
def __init__(self):
|
|
731
|
+
...
|
|
732
|
+
|
|
733
|
+
def sample(self, storage: Storage, batch_size: int) -> tuple[Any, dict]:
|
|
734
|
+
return range(batch_size), {}
|
|
735
|
+
|
|
736
|
+
def _empty(self):
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
def dumps(self, path):
|
|
740
|
+
...
|
|
741
|
+
|
|
742
|
+
def loads(self, path):
|
|
743
|
+
...
|
|
744
|
+
|
|
745
|
+
def state_dict(self) -> dict[str, Any]:
|
|
746
|
+
return {}
|
|
747
|
+
|
|
748
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
749
|
+
...
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
OPENX_KEY_MAP = {
|
|
753
|
+
"is_first": "is_init",
|
|
754
|
+
"is_last": ("next", "done"),
|
|
755
|
+
"is_terminal": ("next", "terminated"),
|
|
756
|
+
"reward": ("next", "reward"),
|
|
757
|
+
}
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
def _format_data(data: TensorDict, episode: int):
|
|
761
|
+
observation_ = data.get("observation")
|
|
762
|
+
observation_pad = pad(observation_[1:], [0, 1])
|
|
763
|
+
data.set(("next", "observation"), observation_pad)
|
|
764
|
+
for key, newkey in OPENX_KEY_MAP.items():
|
|
765
|
+
data.rename_key_(key, newkey)
|
|
766
|
+
data.set(
|
|
767
|
+
("next", "truncated"),
|
|
768
|
+
data.get(("next", "done")) & ~data.get(("next", "terminated")),
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
for key in ("done", "terminated", "truncated", "reward"):
|
|
772
|
+
data.set(("next", key), data.get(("next", key)).unsqueeze(-1))
|
|
773
|
+
if key != "reward":
|
|
774
|
+
data.set(key, torch.zeros_like(data.get(("next", key))))
|
|
775
|
+
|
|
776
|
+
data.set(
|
|
777
|
+
"episode", torch.full(data.shape, episode, device=data.device, dtype=torch.int)
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def _make_tensordict_image_conv(data):
|
|
782
|
+
# in some datasets, the images are not well converted.
|
|
783
|
+
# before building the tensordict, we load the PIL image and convert it to a tensor
|
|
784
|
+
try:
|
|
785
|
+
img_bytes = data["observation"]["image"]["bytes"]
|
|
786
|
+
if not _has_tv:
|
|
787
|
+
raise ImportError(
|
|
788
|
+
"the `torchvision` library is required to read the image observation."
|
|
789
|
+
)
|
|
790
|
+
import torchvision.transforms.v2.functional
|
|
791
|
+
from PIL import Image
|
|
792
|
+
|
|
793
|
+
img = Image.open(io.BytesIO(img_bytes))
|
|
794
|
+
tensor = torchvision.transforms.v2.functional.pil_to_tensor(img)
|
|
795
|
+
data["observation"]["image"] = tensor
|
|
796
|
+
except KeyError:
|
|
797
|
+
pass
|
|
798
|
+
return make_tensordict(data, auto_batch_size=True)
|